├── .gitattributes ├── .github ├── CODEOWNERS ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── code-quality.yml │ ├── sast-workflows.yml │ └── test-release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── noxfile.py ├── poetry.lock ├── pyproject.toml ├── src └── transformer_embeddings │ ├── __init__.py │ ├── helpers.py │ ├── model.py │ └── poolers.py └── tests ├── __init__.py ├── test_model.py └── test_poolers.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This file is managed by terraform in the github-infra repo. 2 | # Changes to this CODEOWNERS file should be done in github-infra. 3 | 4 | * @HeadspaceMeditation/machine-learning @HeadspaceMeditation/AutoApprove 5 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | ## Changes 4 | 5 | - 6 | 7 | ## Tests 8 | 9 | - [ ] Local. 10 | - [ ] CI. 11 | -------------------------------------------------------------------------------- /.github/workflows/code-quality.yml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | name: Code Quality 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | - uses: pre-commit/action@v2.0.0 18 | -------------------------------------------------------------------------------- /.github/workflows/sast-workflows.yml: -------------------------------------------------------------------------------- 1 | name: 'SAST Scans' 2 | 3 | on: 4 | pull_request: 5 | merge_group: 6 | 7 | jobs: 8 | secret-scanning-review: 9 | if: ${{ github.actor != 'dependabot[bot]' }} 10 | runs-on: ubuntu-latest 11 | env: 12 | SAST_GH_APP_ID: ${{ secrets.SAST_GH_APP_ID }} 13 | SAST_GH_APP_PRIVATE_KEY: ${{ secrets.SAST_GH_APP_PRIVATE_KEY }} 14 | steps: 15 | - name: Generate GitHub App Token 16 | id: app-token 17 | uses: actions/create-github-app-token@v2 18 | with: 19 | app-id: ${{ env.SAST_GH_APP_ID }} 20 | private-key: ${{ env.SAST_GH_APP_PRIVATE_KEY }} 21 | - name: Secret Scanning Review Action 22 | uses: advanced-security/secret-scanning-review-action@v2.1.0 23 | with: 24 | token: ${{ steps.app-token.outputs.token }} 25 | fail-on-alert: true 26 | fail-on-alert-exclude-closed: true 27 | 28 | dependency-review: 29 | runs-on: ubuntu-latest 30 | steps: 31 | - name: 'Checkout Repository' 32 | uses: actions/checkout@v4 33 | - name: 'Dependency Review' 34 | uses: actions/dependency-review-action@v4 35 | with: 36 | comment-summary-in-pr: true 37 | fail-on-severity: high 38 | -------------------------------------------------------------------------------- /.github/workflows/test-release.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | test: 10 | name: ${{ matrix.python }} / ${{ matrix.os }} 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: true 14 | matrix: 15 | os: [ubuntu-latest, macos-latest] 16 | python: ["3.8", "3.9", "3.10"] 17 | 18 | steps: 19 | - name: Checkout the repository 20 | uses: actions/checkout@v3 21 | with: 22 | fetch-depth: 0 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: ${{ matrix.python }} 28 | 29 | - name: Install Poetry 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install "poetry==1.3.1" 33 | 34 | - name: Install test dependencies with poetry 35 | env: 36 | # https://github.com/python-poetry/poetry/issues/5250#issuecomment-1067193647 37 | PYTHON_KEYRING_BACKEND: "keyring.backends.fail.Keyring" 38 | run: | 39 | poetry install --no-interaction --no-ansi --only test 40 | 41 | - name: Run tests with nox 42 | run: | 43 | poetry run nox --python ${{ matrix.python }} 44 | 45 | release: 46 | runs-on: ubuntu-latest 47 | needs: test 48 | steps: 49 | - name: Check out the repository 50 | uses: actions/checkout@v3 51 | with: 52 | fetch-depth: 0 53 | 54 | - name: Set up Python 55 | uses: actions/setup-python@v4 56 | with: 57 | python-version: 3.8 58 | 59 | - name: Setup git 60 | run: | 61 | git config user.name release-transformer-embeddings 62 | git config user.email transformer-embeddings@headspace.com 63 | 64 | - name: Install Poetry, dependencies, release 65 | if: ${{ github.ref == 'refs/heads/main' }} 66 | env: 67 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 68 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 69 | run: | 70 | python -m pip install --upgrade pip 71 | pip install poetry 72 | poetry install --no-interaction --no-ansi --only release 73 | poetry run semantic-release publish 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .mypy_cache/ 2 | /.coverage 3 | /.coverage.* 4 | /.nox/ 5 | /.python-version 6 | /.pytype/ 7 | /dist/ 8 | /docs/_build/ 9 | /src/*.egg-info/ 10 | __pycache__/ 11 | 12 | # Any models 13 | *.tar.gz 14 | model/ 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-ast 9 | - id: check-merge-conflict 10 | - id: check-toml 11 | - id: check-yaml 12 | - id: end-of-file-fixer 13 | exclude: ".min.js" 14 | - id: requirements-txt-fixer 15 | - id: trailing-whitespace 16 | - repo: https://github.com/psf/black 17 | rev: 23.1.0 18 | hooks: 19 | - id: black-jupyter 20 | - repo: https://github.com/pycqa/isort 21 | rev: 5.12.0 22 | hooks: 23 | - id: isort 24 | - repo: https://github.com/myint/docformatter/ 25 | rev: v1.5.1 26 | hooks: 27 | - id: docformatter 28 | args: 29 | [ 30 | "--in-place", 31 | "--wrap-summaries=88", 32 | "--wrap-descriptions=88", 33 | "--pre-summary-newline", 34 | ] 35 | - repo: https://github.com/pre-commit/mirrors-prettier 36 | rev: v2.7.1 37 | hooks: 38 | - id: prettier 39 | exclude: ".min.js" 40 | - repo: https://github.com/pre-commit/pygrep-hooks 41 | rev: v1.10.0 42 | hooks: 43 | - id: python-no-log-warn 44 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | CODEOWNERS 2 | CHANGELOG.md 3 | PULL_REQUEST_TEMPLATE.md 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 4 | 5 | ## v4.0.14 (2024-04-15) 6 | 7 | 8 | ## v4.0.13 (2024-02-22) 9 | 10 | 11 | ## v4.0.12 (2024-02-20) 12 | 13 | 14 | ## v4.0.11 (2024-02-07) 15 | 16 | 17 | ## v4.0.10 (2024-01-30) 18 | 19 | 20 | ## v4.0.9 (2024-01-11) 21 | 22 | 23 | ## v4.0.8 (2024-01-11) 24 | ### Fix 25 | * Updating transformers version to resolve critical dependabot security alerts ([#39](https://github.com/HeadspaceMeditation/transformer-embeddings/issues/39)) ([`4a952ac`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/4a952acf64bb91e03fd4ee73b0fc5cfa2c032e1d)) 26 | 27 | ## v4.0.7 (2024-01-11) 28 | 29 | 30 | ## v4.0.6 (2023-10-03) 31 | ### Fix 32 | * Fix ModelOutput subclass by adding dataclass decorator ([#32](https://github.com/HeadspaceMeditation/transformer-embeddings/issues/32)) ([`5cfd22c`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/5cfd22ce50196ec8a4e9436f40dc6a20960fe4d1)) 33 | 34 | ## v4.0.5 (2023-10-03) 35 | 36 | 37 | ## v4.0.4 (2023-08-24) 38 | 39 | 40 | ## v4.0.3 (2023-06-12) 41 | 42 | 43 | ## v4.0.2 (2023-06-02) 44 | 45 | 46 | ## v4.0.1 (2023-05-23) 47 | 48 | 49 | ## v4.0.0 (2023-03-28) 50 | ### Breaking 51 | * Dropping py37 support ([`9c6c81e`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/9c6c81e5258ec4a72d97da60223ddd1694ea02fe)) 52 | 53 | ### Documentation 54 | * Updating docs to communicate breaking python support changes ([#18](https://github.com/HeadspaceMeditation/transformer-embeddings/issues/18)) ([`9c6c81e`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/9c6c81e5258ec4a72d97da60223ddd1694ea02fe)) 55 | 56 | ## v3.1.1 (2023-03-27) 57 | 58 | 59 | ## v3.1.0 (2023-02-15) 60 | ### Feature 61 | * Separate S3 dependencies to an optional ([#16](https://github.com/HeadspaceMeditation/transformer-embeddings/issues/16)) ([`036cdc8`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/036cdc887ed091460921ed3edb314b71455df221)) 62 | 63 | ## v3.0.6 (2023-02-15) 64 | 65 | 66 | ## v3.0.5 (2023-02-10) 67 | 68 | 69 | ## v3.0.4 (2023-01-04) 70 | 71 | 72 | ## v3.0.3 (2023-01-03) 73 | 74 | 75 | ## v3.0.2 (2022-12-14) 76 | 77 | 78 | ## v3.0.1 (2022-12-14) 79 | 80 | 81 | ## v3.0.0 (2022-10-25) 82 | ### Fix 83 | * Update commit version in config ([#6](https://github.com/HeadspaceMeditation/transformer-embeddings/issues/6)) ([`aa08d46`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/aa08d46e7e1dd5bd65fd05d5cf88a5b9febaa5c3)) 84 | 85 | ### Breaking 86 | * Public release ([`aa08d46`](https://github.com/HeadspaceMeditation/transformer-embeddings/commit/aa08d46e7e1dd5bd65fd05d5cf88a5b9febaa5c3)) 87 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | - Demonstrating empathy and kindness toward other people 21 | - Being respectful of differing opinions, viewpoints, and experiences 22 | - Giving and gracefully accepting constructive feedback 23 | - Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | - Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | - The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | - Trolling, insulting or derogatory comments, and personal or political attacks 33 | - Public or private harassment 34 | - Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | - Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [transformer-embeddings@headspace.com](mailto:transformer-embeddings@headspace.com). 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][mozilla coc]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][faq]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [mozilla coc]: https://github.com/mozilla/diversity 131 | [faq]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributor Guide 2 | 3 | Thank you for your interest in improving this project. 4 | This project is open-source under the [Apache 2.0 license] and 5 | welcomes contributions in the form of bug reports, feature requests, and pull requests. 6 | 7 | Here is a list of important resources for contributors: 8 | 9 | - [Source Code] 10 | - Documentation 11 | - [Issue Tracker] 12 | - [Code of Conduct] 13 | 14 | [apache 2.0 license]: https://opensource.org/licenses/Apache-2.0 15 | [source code]: https://github.com/HeadspaceMeditation/transformer-embeddings 16 | [issue tracker]: https://github.com/HeadspaceMeditation/transformer-embeddings/issues 17 | 18 | ## How to report a bug 19 | 20 | Report bugs on the [Issue Tracker]. 21 | 22 | When filing an issue, make sure to answer these questions: 23 | 24 | - Which operating system and Python version are you using? 25 | - Which version of this project are you using? 26 | - What did you do? 27 | - What did you expect to see? 28 | - What did you see instead? 29 | 30 | The best way to get your bug fixed is to provide a test case, 31 | and/or steps to reproduce the issue. 32 | 33 | ## How to request a feature 34 | 35 | Request features on the [Issue Tracker]. 36 | 37 | ## How to set up your development environment 38 | 39 | You need Python 3.7+ and [Poetry]. 40 | 41 | Install the package with development and test requirements: 42 | 43 | ```console 44 | $ poetry install 45 | ``` 46 | 47 | You can now run an interactive Python session, 48 | or the command-line interface: 49 | 50 | ```console 51 | $ poetry run python 52 | ``` 53 | 54 | [poetry]: https://python-poetry.org/ 55 | [nox]: https://nox.thea.codes/ 56 | [nox-poetry]: https://nox-poetry.readthedocs.io/ 57 | 58 | ## How to test the project 59 | 60 | Run the full test suite: 61 | 62 | ```console 63 | $ nox 64 | ``` 65 | 66 | List the available Nox sessions: 67 | 68 | ```console 69 | $ nox --list-sessions 70 | ``` 71 | 72 | You can also run a specific Nox session. 73 | For example, invoke the unit test suite like this: 74 | 75 | ```console 76 | $ nox --session=tests 77 | ``` 78 | 79 | Unit tests are located in the _tests_ directory, 80 | and are written using the [pytest] testing framework. 81 | 82 | [pytest]: https://pytest.readthedocs.io/ 83 | 84 | ## How to submit changes 85 | 86 | Open a [pull request] to submit changes to this project. 87 | 88 | Your pull request needs to meet the following guidelines for acceptance: 89 | 90 | - The Nox test suite must pass without errors and warnings. 91 | - Include unit tests. 92 | - If your changes add functionality, update the documentation accordingly. 93 | 94 | Feel free to submit early, though—we can always iterate on this. 95 | 96 | To run linting and code formatting checks before committing your change, you can install pre-commit as a Git hook by running the following command: 97 | 98 | ```console 99 | $ poetry run pre-commit install 100 | ``` 101 | 102 | It is recommended to open an issue before starting work on anything. This will allow a chance to talk it over with the owners and validate your approach. 103 | 104 | [pull request]: https://github.com/HeadspaceMeditation/transformer-embeddings/pulls 105 | 106 | 107 | 108 | [code of conduct]: CODE_OF_CONDUCT.md 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Headspace Health 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | ------------------------------------------------------------------------ 16 | 17 | Apache License 18 | Version 2.0, January 2004 19 | http://www.apache.org/licenses/ 20 | 21 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 22 | 23 | 1. Definitions. 24 | 25 | "License" shall mean the terms and conditions for use, reproduction, 26 | and distribution as defined by Sections 1 through 9 of this document. 27 | 28 | "Licensor" shall mean the copyright owner or entity authorized by 29 | the copyright owner that is granting the License. 30 | 31 | "Legal Entity" shall mean the union of the acting entity and all 32 | other entities that control, are controlled by, or are under common 33 | control with that entity. For the purposes of this definition, 34 | "control" means (i) the power, direct or indirect, to cause the 35 | direction or management of such entity, whether by contract or 36 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 37 | outstanding shares, or (iii) beneficial ownership of such entity. 38 | 39 | "You" (or "Your") shall mean an individual or Legal Entity 40 | exercising permissions granted by this License. 41 | 42 | "Source" form shall mean the preferred form for making modifications, 43 | including but not limited to software source code, documentation 44 | source, and configuration files. 45 | 46 | "Object" form shall mean any form resulting from mechanical 47 | transformation or translation of a Source form, including but 48 | not limited to compiled object code, generated documentation, 49 | and conversions to other media types. 50 | 51 | "Work" shall mean the work of authorship, whether in Source or 52 | Object form, made available under the License, as indicated by a 53 | copyright notice that is included in or attached to the work 54 | (an example is provided in the Appendix below). 55 | 56 | "Derivative Works" shall mean any work, whether in Source or Object 57 | form, that is based on (or derived from) the Work and for which the 58 | editorial revisions, annotations, elaborations, or other modifications 59 | represent, as a whole, an original work of authorship. For the purposes 60 | of this License, Derivative Works shall not include works that remain 61 | separable from, or merely link (or bind by name) to the interfaces of, 62 | the Work and Derivative Works thereof. 63 | 64 | "Contribution" shall mean any work of authorship, including 65 | the original version of the Work and any modifications or additions 66 | to that Work or Derivative Works thereof, that is intentionally 67 | submitted to Licensor for inclusion in the Work by the copyright owner 68 | or by an individual or Legal Entity authorized to submit on behalf of 69 | the copyright owner. For the purposes of this definition, "submitted" 70 | means any form of electronic, verbal, or written communication sent 71 | to the Licensor or its representatives, including but not limited to 72 | communication on electronic mailing lists, source code control systems, 73 | and issue tracking systems that are managed by, or on behalf of, the 74 | Licensor for the purpose of discussing and improving the Work, but 75 | excluding communication that is conspicuously marked or otherwise 76 | designated in writing by the copyright owner as "Not a Contribution." 77 | 78 | "Contributor" shall mean Licensor and any individual or Legal Entity 79 | on behalf of whom a Contribution has been received by Licensor and 80 | subsequently incorporated within the Work. 81 | 82 | 2. Grant of Copyright License. Subject to the terms and conditions of 83 | this License, each Contributor hereby grants to You a perpetual, 84 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 85 | copyright license to reproduce, prepare Derivative Works of, 86 | publicly display, publicly perform, sublicense, and distribute the 87 | Work and such Derivative Works in Source or Object form. 88 | 89 | 3. Grant of Patent License. Subject to the terms and conditions of 90 | this License, each Contributor hereby grants to You a perpetual, 91 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 92 | (except as stated in this section) patent license to make, have made, 93 | use, offer to sell, sell, import, and otherwise transfer the Work, 94 | where such license applies only to those patent claims licensable 95 | by such Contributor that are necessarily infringed by their 96 | Contribution(s) alone or by combination of their Contribution(s) 97 | with the Work to which such Contribution(s) was submitted. If You 98 | institute patent litigation against any entity (including a 99 | cross-claim or counterclaim in a lawsuit) alleging that the Work 100 | or a Contribution incorporated within the Work constitutes direct 101 | or contributory patent infringement, then any patent licenses 102 | granted to You under this License for that Work shall terminate 103 | as of the date such litigation is filed. 104 | 105 | 4. Redistribution. You may reproduce and distribute copies of the 106 | Work or Derivative Works thereof in any medium, with or without 107 | modifications, and in Source or Object form, provided that You 108 | meet the following conditions: 109 | 110 | (a) You must give any other recipients of the Work or 111 | Derivative Works a copy of this License; and 112 | 113 | (b) You must cause any modified files to carry prominent notices 114 | stating that You changed the files; and 115 | 116 | (c) You must retain, in the Source form of any Derivative Works 117 | that You distribute, all copyright, patent, trademark, and 118 | attribution notices from the Source form of the Work, 119 | excluding those notices that do not pertain to any part of 120 | the Derivative Works; and 121 | 122 | (d) If the Work includes a "NOTICE" text file as part of its 123 | distribution, then any Derivative Works that You distribute must 124 | include a readable copy of the attribution notices contained 125 | within such NOTICE file, excluding those notices that do not 126 | pertain to any part of the Derivative Works, in at least one 127 | of the following places: within a NOTICE text file distributed 128 | as part of the Derivative Works; within the Source form or 129 | documentation, if provided along with the Derivative Works; or, 130 | within a display generated by the Derivative Works, if and 131 | wherever such third-party notices normally appear. The contents 132 | of the NOTICE file are for informational purposes only and 133 | do not modify the License. You may add Your own attribution 134 | notices within Derivative Works that You distribute, alongside 135 | or as an addendum to the NOTICE text from the Work, provided 136 | that such additional attribution notices cannot be construed 137 | as modifying the License. 138 | 139 | You may add Your own copyright statement to Your modifications and 140 | may provide additional or different license terms and conditions 141 | for use, reproduction, or distribution of Your modifications, or 142 | for any such Derivative Works as a whole, provided Your use, 143 | reproduction, and distribution of the Work otherwise complies with 144 | the conditions stated in this License. 145 | 146 | 5. Submission of Contributions. Unless You explicitly state otherwise, 147 | any Contribution intentionally submitted for inclusion in the Work 148 | by You to the Licensor shall be under the terms and conditions of 149 | this License, without any additional terms or conditions. 150 | Notwithstanding the above, nothing herein shall supersede or modify 151 | the terms of any separate license agreement you may have executed 152 | with Licensor regarding such Contributions. 153 | 154 | 6. Trademarks. This License does not grant permission to use the trade 155 | names, trademarks, service marks, or product names of the Licensor, 156 | except as required for reasonable and customary use in describing the 157 | origin of the Work and reproducing the content of the NOTICE file. 158 | 159 | 7. Disclaimer of Warranty. Unless required by applicable law or 160 | agreed to in writing, Licensor provides the Work (and each 161 | Contributor provides its Contributions) on an "AS IS" BASIS, 162 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 163 | implied, including, without limitation, any warranties or conditions 164 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 165 | PARTICULAR PURPOSE. You are solely responsible for determining the 166 | appropriateness of using or redistributing the Work and assume any 167 | risks associated with Your exercise of permissions under this License. 168 | 169 | 8. Limitation of Liability. In no event and under no legal theory, 170 | whether in tort (including negligence), contract, or otherwise, 171 | unless required by applicable law (such as deliberate and grossly 172 | negligent acts) or agreed to in writing, shall any Contributor be 173 | liable to You for damages, including any direct, indirect, special, 174 | incidental, or consequential damages of any character arising as a 175 | result of this License or out of the use or inability to use the 176 | Work (including but not limited to damages for loss of goodwill, 177 | work stoppage, computer failure or malfunction, or any and all 178 | other commercial damages or losses), even if such Contributor 179 | has been advised of the possibility of such damages. 180 | 181 | 9. Accepting Warranty or Additional Liability. While redistributing 182 | the Work or Derivative Works thereof, You may choose to offer, 183 | and charge a fee for, acceptance of support, warranty, indemnity, 184 | or other liability obligations and/or rights consistent with this 185 | License. However, in accepting such obligations, You may act only 186 | on Your own behalf and on Your sole responsibility, not on behalf 187 | of any other Contributor, and only if You agree to indemnify, 188 | defend, and hold each Contributor harmless for any liability 189 | incurred by, or claims asserted against, such Contributor by reason 190 | of your accepting any such warranty or additional liability. 191 | 192 | END OF TERMS AND CONDITIONS 193 | 194 | APPENDIX: How to apply the Apache License to your work. 195 | 196 | To apply the Apache License to your work, attach the following 197 | boilerplate notice, with the fields enclosed by brackets "[]" 198 | replaced with your own identifying information. (Don't include 199 | the brackets!) The text should be enclosed in the appropriate 200 | comment syntax for the file format. We also recommend that a 201 | file or class name and description of purpose be included on the 202 | same "printed page" as the copyright notice for easier 203 | identification within third-party archives. 204 | 205 | Copyright [yyyy] [name of copyright owner] 206 | 207 | Licensed under the Apache License, Version 2.0 (the "License"); 208 | you may not use this file except in compliance with the License. 209 | You may obtain a copy of the License at 210 | 211 | http://www.apache.org/licenses/LICENSE-2.0 212 | 213 | Unless required by applicable law or agreed to in writing, software 214 | distributed under the License is distributed on an "AS IS" BASIS, 215 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 216 | See the License for the specific language governing permissions and 217 | limitations under the License. 218 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Embeddings 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/transformer-embeddings.svg)][pypi_] 4 | [![Status](https://img.shields.io/pypi/status/transformer-embeddings.svg)][status] 5 | [![Python Version](https://img.shields.io/pypi/pyversions/transformer-embeddings)][python version] 6 | [![License](https://img.shields.io/pypi/l/transformer-embeddings)][license] 7 | 8 | [![Tests](https://github.com/HeadspaceMeditation/transformer-embeddings/workflows/Tests/badge.svg?branch=main)][tests] 9 | 10 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)][pre-commit] 11 | [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)][black] 12 | 13 | [pypi_]: https://pypi.org/project/transformer-embeddings/ 14 | [status]: https://pypi.org/project/transformer-embeddings/ 15 | [python version]: https://pypi.org/project/transformer-embeddings 16 | [read the docs]: https://transformer-embeddings.readthedocs.io/ 17 | [tests]: https://github.com/HeadspaceMeditation/transformer-embeddings/actions?workflow=Tests 18 | [codecov]: https://app.codecov.io/gh/HeadspaceMeditation/transformer-embeddings 19 | [pre-commit]: https://github.com/pre-commit/pre-commit 20 | [black]: https://github.com/psf/black 21 | 22 | This library simplifies and streamlines the usage of encoder transformer models supported by [HuggingFace's `transformers` library](https://github.com/huggingface/transformers/) ([model hub](https://huggingface.co/models) or local) to generate embeddings for string inputs, similar to the way `sentence-transformers` does. 23 | 24 | Please note that starting with v4, we have dropped support for Python 3.7. If you need to use this library with Python 3.7, the latest compatible release is [`version 3.1.0`](https://pypi.org/project/transformer-embeddings/3.1.0/). 25 | 26 | ## Why use this over HuggingFace's `transformers` or `sentence-transformers`? 27 | 28 | Under the hood, we take care of: 29 | 30 | 1. Can be used with any model on the HF model hub, with sensible defaults for inference. 31 | 2. Setting the PyTorch model to `eval` mode. 32 | 3. Using `no_grad()` when doing the forward pass. 33 | 4. Batching, and returning back output in the format produced by HF transformers. 34 | 5. Padding / truncating to model defaults. 35 | 6. Moving to and from GPUs if available. 36 | 37 | ## Installation 38 | 39 | You can install _Transformer Embeddings_ via [pip] from [PyPI]: 40 | 41 | ```console 42 | $ pip install transformer-embeddings 43 | ``` 44 | 45 | ## Usage 46 | 47 | ```python 48 | from transformer_embeddings import TransformerEmbeddings 49 | 50 | transformer = TransformerEmbeddings("model_name") 51 | ``` 52 | 53 | If you have a previously instantiated `model` and / or `tokenizer`, you can pass that in. 54 | 55 | ```python 56 | transformer = TransformerEmbeddings(model=model, tokenizer=tokenizer) 57 | ``` 58 | 59 | ```python 60 | transformer = TransformerEmbeddings(model_name="model_name", model=model) 61 | ``` 62 | 63 | or 64 | 65 | ```python 66 | transformer = TransformerEmbeddings(model_name="model_name", tokenizer=tokenizer) 67 | ``` 68 | 69 | **Note:** The `model_name` should be included if only 1 of model or tokenizer are passed in. 70 | 71 | ### Embeddings 72 | 73 | To get output embeddings: 74 | 75 | ```python 76 | embeddings = transformer.encode(["Lorem ipsum dolor sit amet", 77 | "consectetur adipiscing elit", 78 | "sed do eiusmod tempor incididunt", 79 | "ut labore et dolore magna aliqua."]) 80 | embeddings.output 81 | ``` 82 | 83 | ### Pooled Output 84 | 85 | To get pooled outputs: 86 | 87 | ```python 88 | from transformer_embeddings import TransformerEmbeddings, mean_pooling 89 | 90 | transformer = TransformerEmbeddings("model_name", return_output=False, pooling_fn=mean_pooling) 91 | 92 | embeddings = transformer.encode(["Lorem ipsum dolor sit amet", 93 | "consectetur adipiscing elit", 94 | "sed do eiusmod tempor incididunt", 95 | "ut labore et dolore magna aliqua."]) 96 | 97 | embeddings.pooled 98 | ``` 99 | 100 | ### Exporting the Model 101 | 102 | Once you are done testing and training the model, it can be exported into a single tarball: 103 | 104 | ```python 105 | from transformer_embeddings import TransformerEmbeddings 106 | 107 | transformer = TransformerEmbeddings("model_name") 108 | transformer.export(additional_files=["/path/to/other/files/to/include/in/tarball.pickle"]) 109 | ``` 110 | 111 | This tarball can also be uploaded to S3, but requires installing the S3 extras (`pip install transformer-embeddings[s3]`). And then using: 112 | 113 | ```python 114 | from transformer_embeddings import TransformerEmbeddings 115 | 116 | transformer = TransformerEmbeddings("model_name") 117 | transformer.export( 118 | additional_files=["/path/to/other/files/to/include/in/tarball.pickle"], 119 | s3_path="s3://bucket/models/model-name/date-version/", 120 | ) 121 | ``` 122 | 123 | ## Contributing 124 | 125 | Contributions are very welcome. To learn more, see the [Contributor Guide]. 126 | 127 | ## License 128 | 129 | Distributed under the terms of the [Apache 2.0 license][license], _Transformer Embeddings_ is free and open source software. 130 | 131 | ## Issues 132 | 133 | If you encounter any problems, please [file an issue] along with a detailed description. 134 | 135 | ## Credits 136 | 137 | This project was partly generated from [@cjolowicz]'s [Hypermodern Python Cookiecutter] template. 138 | 139 | [@cjolowicz]: https://github.com/cjolowicz 140 | [pypi]: https://pypi.org/ 141 | [hypermodern python cookiecutter]: https://github.com/cjolowicz/cookiecutter-hypermodern-python 142 | [file an issue]: https://github.com/HeadspaceMeditation/transformer-embeddings/issues 143 | [pip]: https://pip.pypa.io/ 144 | 145 | 146 | 147 | [license]: https://github.com/HeadspaceMeditation/transformer-embeddings/blob/main/LICENSE 148 | [contributor guide]: https://github.com/HeadspaceMeditation/transformer-embeddings/blob/main/CONTRIBUTING.md 149 | [command-line reference]: https://transformer-embeddings.readthedocs.io/en/latest/usage.html 150 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | from nox_poetry import Session, session 2 | 3 | 4 | @session(python=["3.8", "3.9", "3.10"]) 5 | def tests(session: Session): 6 | """Run the tests and generate reports.""" 7 | session.run_always("poetry", "install", external=True) 8 | session.run( 9 | "pytest", 10 | "--log-cli-level=20", 11 | ) 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "transformer-embeddings" 3 | version = "4.0.14" 4 | description = "Transformer Embeddings" 5 | authors = ["Headspace Health "] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | homepage = "https://github.com/HeadspaceMeditation/transformer-embeddings" 9 | repository = "https://github.com/HeadspaceMeditation/transformer-embeddings" 10 | documentation = "https://github.com/HeadspaceMeditation/transformer-embeddings" 11 | classifiers = [ 12 | "Development Status :: 5 - Production/Stable", 13 | ] 14 | 15 | [tool.poetry.urls] 16 | Changelog = "https://github.com/HeadspaceMeditation/transformer-embeddings/releases" 17 | 18 | [tool.poetry.dependencies] 19 | python = "^3.8,<3.11" 20 | transformers = "^4.36.0" 21 | torch = "^1.9.1" 22 | s3fs = { version = "^2023.1.0", optional = true } 23 | 24 | [tool.poetry.group.test.dependencies] 25 | pytest = "^7.1.3" 26 | pytest-repeat = "^0.9.1" 27 | nox = "^2022.8.7" 28 | nox-poetry = "^1.0.1" 29 | 30 | [tool.poetry.group.dev.dependencies] 31 | pre-commit = "^2.20.0" 32 | 33 | [tool.poetry.group.release.dependencies] 34 | python-semantic-release = "^7.32.1" 35 | 36 | [tool.poetry.extras] 37 | s3 = ["s3fs"] 38 | 39 | [tool.isort] 40 | profile = "black" 41 | lines_after_imports = 2 42 | 43 | [build-system] 44 | requires = ["poetry-core>=1.0.0"] 45 | build-backend = "poetry.core.masonry.api" 46 | 47 | [tool.semantic_release] 48 | version_source = "tag" # Resolution for https://github.com/relekang/python-semantic-release/issues/460#issuecomment-1192261285 49 | commit_version_number = true 50 | version_variable = "src/transformer_embeddings/__init__.py:__version__" 51 | version_toml = "pyproject.toml:tool.poetry.version" 52 | branch = "main" 53 | patch_without_tag = true # Create a patch release on every commit merged to `main`. 54 | upload_to_pypi = true 55 | upload_to_release = true 56 | build_command = "pip install poetry && poetry build" 57 | tag_format = "v{version}" 58 | commit_subject = "chore(release): v{version}" 59 | changelog_sections = "feature,fix,breaking,documentation,performance,refactor,test" 60 | -------------------------------------------------------------------------------- /src/transformer_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """Transformer Embeddings.""" 2 | from transformer_embeddings.model import TransformerEmbeddings 3 | from transformer_embeddings.poolers import get_pooler_output, mean_pooling 4 | 5 | 6 | __version__ = "4.0.14" 7 | -------------------------------------------------------------------------------- /src/transformer_embeddings/helpers.py: -------------------------------------------------------------------------------- 1 | from os.path import commonpath 2 | from pathlib import Path 3 | from tarfile import open as tarfile_open 4 | from typing import List, Optional, Union 5 | 6 | 7 | def compress_files( 8 | filenames: List[Union[str, Path]], 9 | compressed_file: Union[str, Path], 10 | arcname: Optional[Union[str, Path]] = None, 11 | ) -> bool: 12 | """ 13 | Given a list of files or directories, compress them into a tarball at the 14 | compressed_file location. 15 | 16 | Parameters 17 | ---------- 18 | filenames : List[Union[str, Path]] 19 | List of files or directories 20 | compressed_file : Union[str, Path] 21 | Destination compressed file. 22 | arcname : Union[str, Path] 23 | Base directory for the compressed file. Default: Maximum common path for all 24 | the passed in filenames. 25 | 26 | Returns 27 | ------- 28 | bool 29 | The status of the compress operation. Checks if the output file was created. 30 | 31 | Raises 32 | ------ 33 | NameError 34 | Raised when the expected name of the compressed file doesn't end in tar or tar.gz. 35 | """ 36 | # If the file to be written is provided as an str object, convert it to a Path object. 37 | if isinstance(compressed_file, str): 38 | compressed_file = Path(compressed_file) 39 | 40 | # Ensure the file to be written ends with tar or tar.gz. 41 | if compressed_file.name.endswith(".tar.gz") or compressed_file.name.endswith( 42 | ".tar" 43 | ): 44 | tar_file = tarfile_open(compressed_file, "w:gz") 45 | else: 46 | raise NameError( 47 | f"Name of compressed file ({compressed_file}) does not end in tar or tar.gz." 48 | ) 49 | 50 | base_path = Path(commonpath(filenames)) 51 | 52 | for filename in filenames: 53 | if isinstance(filename, str): 54 | filename = Path(filename) 55 | 56 | if arcname is None: 57 | arcname = filename.relative_to(base_path) 58 | # If common path is "", the arcname becomes a directory `.` in which all 59 | # files are then stored. This line avoids that scenario. 60 | arcname = "" if arcname.as_posix() == "." else arcname 61 | tar_file.add(filename, arcname=arcname) 62 | tar_file.close() 63 | return compressed_file.exists() 64 | -------------------------------------------------------------------------------- /src/transformer_embeddings/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from logging import getLogger 3 | from pathlib import Path 4 | from shutil import copy2 5 | from tempfile import TemporaryDirectory 6 | from typing import Callable, List, Optional, Tuple, Union 7 | 8 | from torch import Tensor, cat, cuda, device 9 | from torch.autograd.grad_mode import no_grad 10 | from torch.nn.functional import pad 11 | from tqdm.auto import trange 12 | from transformers import ( 13 | AutoModel, 14 | AutoTokenizer, 15 | BatchEncoding, 16 | PreTrainedModel, 17 | PreTrainedTokenizer, 18 | PreTrainedTokenizerFast, 19 | ) 20 | from transformers.file_utils import ModelOutput 21 | 22 | from transformer_embeddings.helpers import compress_files 23 | 24 | 25 | logger = getLogger(__name__) 26 | 27 | MODEL_TARBALL = "model.tar.gz" 28 | 29 | DEVICE_CUDA = device("cuda") 30 | DEVICE_CPU = device("cpu") 31 | DEVICE = DEVICE_CUDA if cuda.is_available() else DEVICE_CPU 32 | 33 | TransformerInputOutput = Union[BatchEncoding, ModelOutput] 34 | 35 | 36 | @dataclass 37 | class TransformerEmbeddingsOutput(ModelOutput): 38 | output: Optional[ModelOutput] = None 39 | input: Optional[BatchEncoding] = None 40 | pooled: Optional[Tensor] = None 41 | 42 | 43 | class TransformerEmbeddings: 44 | """Thin wrapper on top of the HuggingFace's transformers library to simplify 45 | generating embeddings.""" 46 | 47 | def __init__( 48 | self, 49 | model_name: Optional[Union[str, Path]] = None, 50 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, 51 | model: PreTrainedModel = None, 52 | batch_size: int = 16, 53 | tensors_to_cpu: bool = True, 54 | return_output: bool = True, 55 | return_input: bool = False, 56 | pooling_fn: Callable[[ModelOutput, BatchEncoding], Tensor] = None, 57 | ): 58 | """ 59 | Create a TransformerEmbeddings object. 60 | 61 | Parameters 62 | ---------- 63 | model_name : Union[str, Path], optional 64 | Name of the model. Anything supported by HF's `from_pretrained()` method 65 | (HF model hub model name, local path). Default: None. 66 | tokenizer : Union[PreTrainedTokenizer, PreTrainedTokenizerFast], optional 67 | Tokenizer object. Default: None. 68 | model : PreTrainedModel, optional 69 | Model object. Default: None. 70 | batch_size : int, optional 71 | Batch size for the foward pass. Default: 16. 72 | tensors_to_cpu : bool, optional 73 | Move the output back to CPU if this is operating on GPU? Default: True. 74 | return_output : bool, optional 75 | Should all the outputs from the model's forward pass be returned? Default: True. 76 | return_input : bool, optional 77 | Should the tokenized inputs be returned? Default: False. 78 | pooling_fn : Callable[[ModelOutput, BatchEncoding], Tensor], optional 79 | Function to apply to pool the output produced into a tensor. If provided, 80 | self.return_pooled is set to True and returned. 81 | """ 82 | self.load_model(model_name=model_name, tokenizer=tokenizer, model=model) 83 | 84 | self.batch_size = batch_size 85 | self.tensors_to_cpu = tensors_to_cpu 86 | self.pooling_fn = pooling_fn 87 | 88 | self.return_output = return_output 89 | self.return_input = return_input 90 | self.return_pooled = self.pooling_fn is not None 91 | 92 | logger.info("TransformerEmbedddings model initialized.") 93 | 94 | def load_model( 95 | self, 96 | model_name: Optional[str], 97 | tokenizer: PreTrainedTokenizer, 98 | model: PreTrainedModel, 99 | ) -> None: 100 | """ 101 | Load the model. 102 | 103 | Raises 104 | ------ 105 | ValueError 106 | If either model or tokenizer are provided without the other, and model_name is also not provided. 107 | """ 108 | 109 | if model_name and tokenizer is None: 110 | logger.info(f"Loading tokenizer from {model_name}") 111 | tokenizer = AutoTokenizer.from_pretrained(model_name) 112 | 113 | if model_name and model is None: 114 | logger.info(f"Loading model from {model_name}") 115 | model = AutoModel.from_pretrained(model_name) 116 | 117 | if tokenizer is None: 118 | raise ValueError("Tokenizer was not passed or created.") 119 | else: 120 | self.tokenizer = tokenizer 121 | 122 | if model is None: 123 | raise ValueError("Model was not passed or created.") 124 | else: 125 | self.model = model.to(DEVICE).eval() 126 | logger.info(f"Model and tokenizer loaded, on device {DEVICE}, set to eval().") 127 | if ( 128 | model.config 129 | and model.config.max_position_embeddings 130 | and tokenizer.model_max_length 131 | and model.config.max_position_embeddings != tokenizer.model_max_length 132 | ): 133 | logger.warning( 134 | f"Model's maximum position embeddings ({model.config.max_position_embeddings}) do not match tokenizer's maximum length ({tokenizer.model_max_length})." 135 | ) 136 | 137 | def tokenize(self, input_strings: List[str]) -> BatchEncoding: 138 | """ 139 | Tokenize the input. 140 | 141 | Parameters 142 | ---------- 143 | input_strings : List[str] 144 | Input strings for the batch. 145 | 146 | Returns 147 | ------- 148 | BatchEncoding 149 | Tokenizer output. 150 | """ 151 | return self.tokenizer( 152 | input_strings, 153 | return_tensors="pt", 154 | padding=True, 155 | truncation=True, 156 | ) 157 | 158 | def _pad_tensors(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: 159 | """Pad tensors to equal length.""" 160 | max_length = max(x.shape[1], y.shape[1]) 161 | # padding passed to torch.nn.functional.pad are backwards from the last axis 162 | # and is specified twice for each dimension (front and back). 163 | # input tensors are 2D, output tensors are 3D. 164 | # input and output tensors both have to be padded at dimension 1. 165 | # padding for input tensors becomes: (0, max_length - x.shape[1]) 166 | # padding for output tensors becomes: (0, 0, 0, max_length - x.shape[1]) 167 | # See: https://pytorch.org/docs/master/generated/torch.nn.functional.pad.html 168 | # We don't pad along the dimension we stack (dim=0; batch dimension), thus 169 | # the -2. For the final dimension, we only pad at the back, hence the -1. 170 | prefix_length = 2 * len(x.shape) - 2 - 1 171 | pad_prefix = (0,) * prefix_length 172 | x = pad(x, pad=pad_prefix + (max_length - x.shape[1],)) 173 | y = pad(y, pad=pad_prefix + (max_length - y.shape[1],)) 174 | return x, y 175 | 176 | def _change_device(self, data: TransformerInputOutput) -> TransformerInputOutput: 177 | """Move tensors to CPU if specified.""" 178 | if self.tensors_to_cpu: 179 | for key, value in data.items(): 180 | data[key] = value.to(DEVICE_CPU) 181 | return data 182 | 183 | def _stack_batch( 184 | self, 185 | data: Optional[TransformerInputOutput], 186 | batch_data: TransformerInputOutput, 187 | ) -> TransformerInputOutput: 188 | """Stack batch data with data from previous batches.""" 189 | # Stack. 190 | if data is None: 191 | data = batch_data 192 | else: 193 | for key, value in batch_data.items(): 194 | previous = data[key] 195 | if previous.shape != value.shape: 196 | previous, value = self._pad_tensors(previous, value) 197 | data[key] = cat((previous, value), dim=0) 198 | return data 199 | 200 | def _stack_pooled(self, pooled: Optional[Tensor], batch_pooled: Tensor) -> Tensor: 201 | # Pooling always occurs on tensors on the same device, so we do not need to 202 | # move them. The pooled tensor would already be on CPU if tensors_to_cpu is 203 | # True, else they'd be on GPU. Either way, they'd be on the device we want it on. 204 | return batch_pooled if pooled is None else cat((pooled, batch_pooled), dim=0) 205 | 206 | def encode(self, input_strings: List[str]) -> TransformerEmbeddingsOutput: 207 | """ 208 | Generate embeddings for the given input. 209 | 210 | Parameters 211 | ---------- 212 | input_strings : List[str] 213 | String input for which embeddings should be generated. 214 | 215 | Returns 216 | ------- 217 | TransformerEmbeddingsOutput 218 | Model output, inputs and / or pooled output. 219 | """ 220 | logger.info(f"Generating embeddings for {len(input_strings)} input strings.") 221 | output = None 222 | input = None 223 | pooled = None 224 | for i in trange(0, len(input_strings), self.batch_size): 225 | batch_tokenized_input = self.tokenize( 226 | input_strings[i : i + self.batch_size] 227 | ) 228 | 229 | with no_grad(): 230 | batch_outputs = self.model(**batch_tokenized_input.to(DEVICE)) 231 | 232 | # Move all tensors to CPU. 233 | batch_tokenized_input = self._change_device(batch_tokenized_input) 234 | batch_outputs = self._change_device(batch_outputs) 235 | 236 | # Stack tensors from batch to output for all inputs. 237 | if self.return_output: 238 | output = self._stack_batch(output, batch_outputs) 239 | if self.return_input: 240 | input = self._stack_batch(input, batch_tokenized_input) 241 | if self.return_pooled: 242 | batch_pooled = self.pooling_fn(batch_outputs, batch_tokenized_input) 243 | pooled = self._stack_pooled(pooled, batch_pooled) 244 | return TransformerEmbeddingsOutput(output=output, input=input, pooled=pooled) 245 | 246 | def export( 247 | self, 248 | output_dir: Optional[Union[str, Path]] = None, 249 | additional_files: Optional[List[Union[str, Path]]] = None, 250 | s3_path: Optional[str] = None, 251 | ) -> Path: 252 | """ 253 | Export the model and tokenizer to a directory and compress it into a tarball. If 254 | an S3 path is provided, also upload the tarball to S3. 255 | 256 | Parameters 257 | ---------- 258 | output_dir : Optional[Union[str, Path]], optional 259 | Output directory. Default: A temporary directory is created that is cleaned up when the function returns. 260 | additional_files : Optional[List[Path]], optional 261 | Additional files to include in the exported tarball. Default: None. 262 | s3_path : Optional[str], optional 263 | S3 path at which to upload the file. Default: None, which means that the file is not uploaded to S3. 264 | 265 | Returns 266 | ------- 267 | Path 268 | Tarball path. If it was a temporary directory, it will be empty. 269 | """ 270 | # Set directory. 271 | temporary_directory = TemporaryDirectory() 272 | if output_dir is None: 273 | output_dir = temporary_directory.name 274 | output_dir = Path(output_dir) if isinstance(output_dir, str) else output_dir 275 | if not output_dir.exists() or output_dir.is_dir(): 276 | output_dir.mkdir(parents=True, exist_ok=True) 277 | 278 | # Copy additional files into output directory before compressing. 279 | for file in additional_files or []: 280 | copy2(file, output_dir) 281 | 282 | # Save model, tokenizer. 283 | self.model.save_pretrained(output_dir) 284 | self.tokenizer.save_pretrained(output_dir) 285 | 286 | # Create a tarball. 287 | logger.debug(f"Folder being added to the tarball is {output_dir}.") 288 | compressed_file = output_dir.joinpath(MODEL_TARBALL) 289 | compress_files([output_dir], compressed_file) 290 | logger.info(f"Tarball {compressed_file} created.") 291 | 292 | if s3_path: 293 | try: 294 | from s3fs import S3FileSystem 295 | except ImportError: 296 | raise ImportError( 297 | "Please install the s3 extras of the package to upload to S3." 298 | ) 299 | 300 | s3_fs = S3FileSystem() 301 | logger.info(f"Tarball {compressed_file} being uploaded to S3 at {s3_path}.") 302 | s3_fs.open(s3_path, "wb").write(compressed_file.read_bytes()) 303 | 304 | return compressed_file 305 | -------------------------------------------------------------------------------- /src/transformer_embeddings/poolers.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, clamp 2 | from torch import sum as torch_sum 3 | from transformers import BatchEncoding 4 | from transformers.file_utils import ModelOutput 5 | 6 | 7 | def mean_pooling(model_output: ModelOutput, model_inputs: BatchEncoding) -> Tensor: 8 | """ 9 | Mean pooling for the model output. 10 | 11 | This is the unweighted average of the token embeddings, while ignoring padding 12 | and padded tokens. 13 | Copied from: https://huggingface.co/sentence-transformers/msmarco-distilroberta-base-v2 14 | Comments ours :). 15 | 16 | Parameters 17 | ---------- 18 | model_output : ModelOutput 19 | Output from the model. 20 | model_inputs : BatchEncoding 21 | Encoded, tokenized input to the model. 22 | 23 | Returns 24 | ------- 25 | Tensor 26 | Mean pooled output. 27 | """ 28 | # last_hidden_state is the per token embedding. 29 | last_hidden_state = model_output.last_hidden_state 30 | attention_mask = model_inputs.attention_mask 31 | # Expand attention_mask (2D) to the shape of the last_hidden_state (3D). 32 | attention_mask_expanded = ( 33 | attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() 34 | ) 35 | # Multiply with the expanded attention mask to make the 36 | # last_hidden_state for padding tokens to 0. 37 | sum_embeddings = torch_sum(last_hidden_state * attention_mask_expanded, dim=1) 38 | sum_mask = clamp(attention_mask_expanded.sum(1), min=1e-9) 39 | return sum_embeddings / sum_mask 40 | 41 | 42 | def get_pooler_output(model_output: ModelOutput, model_inputs: BatchEncoding) -> Tensor: 43 | """ 44 | Return the pooler output. 45 | 46 | Parameters 47 | ---------- 48 | model_output : ModelOutput 49 | Output from the model. 50 | model_inputs : BatchEncoding 51 | Encoded, tokenized input to the model. Not used in this function. 52 | 53 | Returns 54 | ------- 55 | Tensor 56 | Pooler output. 57 | """ 58 | return model_output.pooler_output 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the transformer_embeddings package.""" 2 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tarfile import is_tarfile 3 | from tarfile import open as tarfile_open 4 | from typing import Callable 5 | 6 | from pytest import fixture, raises 7 | from torch import Tensor 8 | from transformers import ( 9 | AutoModel, 10 | AutoTokenizer, 11 | BatchEncoding, 12 | PreTrainedModel, 13 | PreTrainedTokenizer, 14 | PreTrainedTokenizerFast, 15 | ) 16 | from transformers.file_utils import ModelOutput 17 | 18 | from transformer_embeddings import TransformerEmbeddings 19 | from transformer_embeddings.model import ( 20 | TransformerEmbeddingsOutput, 21 | TransformerInputOutput, 22 | ) 23 | from transformer_embeddings.poolers import mean_pooling 24 | 25 | 26 | MESSAGES = [ 27 | "Lorem ipsum dolor sit amet", 28 | "consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", 29 | ] 30 | 31 | 32 | @fixture( 33 | params=[ 34 | "prajjwal1/bert-tiny", 35 | "sshleifer/tiny-distilroberta-base", 36 | # "patrickvonplaten/longformer-random-tiny": We cannot test this since this model doesn't have a tokenizer. 37 | # https://huggingface.co/patrickvonplaten/longformer-random-tiny/ 38 | ] 39 | ) 40 | def model_name(request) -> str: 41 | return request.param 42 | 43 | 44 | @fixture(params=[True, False]) 45 | def return_input(request) -> bool: 46 | return request.param 47 | 48 | 49 | @fixture(params=[True, False]) 50 | def return_output(request) -> bool: 51 | return request.param 52 | 53 | 54 | @fixture(params=[len(MESSAGES), 1]) 55 | def batch_size(request) -> int: 56 | return request.param 57 | 58 | 59 | # 2nd pooling_fn simply returns the pooler_output from the model. 60 | @fixture(params=[mean_pooling, lambda x, y: x.pooler_output]) 61 | def pooling_fn(request) -> Callable: 62 | return request.param 63 | 64 | 65 | def test_transformer_embeddings_model_name(model_name): 66 | # Test transformer object init from model_name. 67 | transformer = TransformerEmbeddings(model_name=model_name) 68 | 69 | assert transformer.model is not None 70 | assert transformer.tokenizer is not None 71 | 72 | assert isinstance(transformer.model, PreTrainedModel) 73 | assert isinstance( 74 | transformer.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) 75 | ) 76 | 77 | 78 | def test_transformer_embeddings_model_tokenizer(model_name): 79 | # Test transformer object init from model and tokenizer. 80 | model = AutoModel.from_pretrained(model_name) 81 | tokenizer = AutoTokenizer.from_pretrained(model_name) 82 | transformer = TransformerEmbeddings(model=model, tokenizer=tokenizer) 83 | 84 | assert transformer.model is not None 85 | assert transformer.tokenizer is not None 86 | 87 | assert isinstance(transformer.model, PreTrainedModel) 88 | assert isinstance( 89 | transformer.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) 90 | ) 91 | 92 | 93 | def test_transformer_embeddings_model(model_name): 94 | # Should raise ValueError if we pass in a model but no model name or tokenizer. 95 | model = AutoModel.from_pretrained(model_name) 96 | with raises(ValueError): 97 | TransformerEmbeddings(model=model) 98 | 99 | 100 | def test_transformer_embeddings_tokenizer(model_name): 101 | # Should raise ValueError if we pass in a tokenizer but no model name or model. 102 | tokenizer = AutoTokenizer.from_pretrained(model_name) 103 | with raises(ValueError): 104 | TransformerEmbeddings(tokenizer=tokenizer) 105 | 106 | 107 | def test_transformer_embeddings_tokenize(model_name): 108 | # Test tokenization. 109 | transformer = TransformerEmbeddings(model_name=model_name) 110 | tokenized_input = transformer.tokenize(MESSAGES) 111 | 112 | assert isinstance(tokenized_input, BatchEncoding) 113 | 114 | for key, value in tokenized_input.items(): 115 | assert isinstance(key, str) 116 | assert isinstance(value, Tensor) 117 | 118 | 119 | def assert_transformer_input_output(transformer_input_output: TransformerInputOutput): 120 | for key, value in transformer_input_output.items(): 121 | assert isinstance(key, str) 122 | isinstance(value, Tensor) 123 | # Forward pass with no_grad() doesn't set requires_grad. 124 | assert not value.requires_grad 125 | assert value.size()[0] == len(MESSAGES) 126 | 127 | 128 | def test_transformer_embeddings_encode( 129 | model_name, batch_size, return_input, return_output, pooling_fn 130 | ): 131 | # Test embedding generation. 132 | transformer = TransformerEmbeddings( 133 | model_name=model_name, 134 | batch_size=batch_size, 135 | return_input=return_input, 136 | return_output=return_output, 137 | pooling_fn=pooling_fn, 138 | ) 139 | embeddings_output = transformer.encode(MESSAGES) 140 | 141 | assert isinstance(embeddings_output, TransformerEmbeddingsOutput) 142 | 143 | output, input, pooled = ( 144 | embeddings_output.output, 145 | embeddings_output.input, 146 | embeddings_output.pooled, 147 | ) 148 | 149 | assert (output is not None) == transformer.return_output 150 | assert (input is not None) == transformer.return_input 151 | assert (pooled is not None) == transformer.return_pooled 152 | 153 | if output is not None: 154 | assert isinstance(output, ModelOutput) 155 | assert_transformer_input_output(output) 156 | 157 | if input is not None: 158 | assert isinstance(input, BatchEncoding) 159 | assert_transformer_input_output(input) 160 | 161 | if pooled is not None: 162 | assert isinstance(pooled, Tensor) 163 | assert pooled.size()[0] == len(MESSAGES) 164 | 165 | 166 | def test_transformer_embeddings_export(model_name, tmp_path): 167 | transformer = TransformerEmbeddings(model_name=model_name) 168 | compressed_file = transformer.export(output_dir=tmp_path) 169 | 170 | assert isinstance(compressed_file, Path) 171 | assert compressed_file.exists() 172 | assert is_tarfile(compressed_file) 173 | 174 | 175 | def test_transformer_embeddings_export_additional_files(model_name, tmp_path): 176 | # Additional file. 177 | single_file = tmp_path.joinpath("latest") 178 | single_file.write_text("this is a random text file.") 179 | 180 | transformer = TransformerEmbeddings(model_name=model_name) 181 | compressed_file = transformer.export( 182 | output_dir=tmp_path.joinpath("compressed"), 183 | additional_files=[single_file.as_posix()], 184 | ) 185 | 186 | assert isinstance(compressed_file, Path) 187 | assert compressed_file.exists() 188 | assert is_tarfile(compressed_file) 189 | 190 | tar_file = tarfile_open(compressed_file) 191 | assert single_file.name in tar_file.getnames() 192 | -------------------------------------------------------------------------------- /tests/test_poolers.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | 3 | from pytest import mark 4 | from torch import Tensor, equal, mean, ones, rand 5 | from transformers import BatchEncoding 6 | from transformers.file_utils import ModelOutput 7 | 8 | from transformer_embeddings.poolers import get_pooler_output, mean_pooling 9 | 10 | 11 | @mark.repeat(10) 12 | def test_get_pooler_output(): 13 | # Use a random vector with a random batch size but 768 dimensions. 14 | batch_size = randint(1, 100) 15 | model_output = ModelOutput(pooler_output=rand(batch_size, 768)) 16 | result = get_pooler_output(model_output=model_output, model_inputs=None) 17 | assert isinstance(result, Tensor) 18 | assert equal(result, model_output.pooler_output) 19 | 20 | 21 | @mark.repeat(10) 22 | def test_mean_pooling(): 23 | tokens = randint(1, 100) 24 | # Single batch. 25 | # `BatchEncoding` objects are created with dicts as the first param. 26 | model_input = BatchEncoding({"attention_mask": ones(1, tokens)}) 27 | model_output = ModelOutput(last_hidden_state=rand(1, tokens, 768)) 28 | 29 | mean_pooled = mean_pooling(model_output, model_input) 30 | assert isinstance(mean_pooled, Tensor) 31 | # We take mean on the sequence dimension (1). 32 | assert equal(mean_pooled, mean(model_output.last_hidden_state, dim=1)) 33 | --------------------------------------------------------------------------------