├── .github ├── banner.jpg └── workflows │ ├── classification.yml │ ├── sentence_encoding.yml │ ├── text-generation.yml │ ├── token_classification.yml │ └── zero-shot-classification.yml ├── .gitignore ├── LICENSE ├── README.md ├── classification ├── .gitignore ├── Dockerfile ├── README.MD ├── main.py ├── requirements.txt ├── src │ ├── __init__.py │ ├── classifier.py │ ├── config.py │ └── utils.py └── tests │ ├── __init__.py │ ├── conftest.py │ └── test_classifier.py ├── sentence-encoding ├── .gitignore ├── Dockerfile ├── README.MD ├── main.py ├── requirements.txt ├── src │ ├── __init__.py │ ├── config.py │ ├── sentence_encoder.py │ └── utils.py └── tests │ ├── __init__.py │ ├── conftest.py │ └── test_sentence_encoder.py ├── text-generation ├── .gitignore ├── Dockerfile ├── README.MD ├── main.py ├── requirements.txt ├── src │ ├── __init__.py │ ├── config.py │ ├── text_generator.py │ └── utils.py └── tests │ ├── __init__.py │ ├── conftest.py │ └── test_text_generator.py ├── token-classification ├── .gitignore ├── Dockerfile ├── README.MD ├── main.py ├── requirements.txt ├── src │ ├── __init__.py │ ├── config.py │ ├── token_classifier.py │ └── utils.py └── tests │ ├── __init__.py │ ├── conftest.py │ └── test_token_classifier.py └── zero-shot-classification ├── .gitignore ├── Dockerfile ├── README.MD ├── main.py ├── requirements.txt ├── src ├── __init__.py ├── classifier.py ├── config.py └── utils.py └── tests ├── __init__.py ├── conftest.py └── test_classifier.py /.github/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/.github/banner.jpg -------------------------------------------------------------------------------- /.github/workflows/classification.yml: -------------------------------------------------------------------------------- 1 | name: classification 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master # this can be main 7 | paths: 8 | - "classification/**" 9 | 10 | jobs: 11 | classification: 12 | runs-on: ubuntu-latest 13 | defaults: 14 | run: 15 | working-directory: ./classification 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v2 19 | with: 20 | ref: ${{ github.ref }} 21 | - name: Build container 22 | run: | 23 | docker build --tag classification:latest . 24 | - name: Configure AWS Credentials 25 | uses: aws-actions/configure-aws-credentials@v1 26 | with: 27 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 28 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 29 | aws-region: us-east-1 30 | - name: Push2ECR 31 | id: ecr 32 | uses: jwalton/gh-ecr-push@v1 33 | with: 34 | access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 35 | secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 36 | region: us-east-1 37 | image: classification:latest 38 | - name: Update lambda with image 39 | run: aws lambda update-function-code --function-name classification --image-uri 968911158010.dkr.ecr.us-east-1.amazonaws.com/classification:latest 40 | -------------------------------------------------------------------------------- /.github/workflows/sentence_encoding.yml: -------------------------------------------------------------------------------- 1 | name: sentence-encoding 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master # this can be main 7 | paths: 8 | - "sentence-encoding/**" 9 | 10 | jobs: 11 | sentence-encoding: 12 | runs-on: ubuntu-latest 13 | defaults: 14 | run: 15 | working-directory: ./sentence-encoding 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v2 19 | with: 20 | ref: ${{ github.ref }} 21 | - name: Build container 22 | run: | 23 | docker build --tag sentence-encoding:latest . 24 | - name: Configure AWS Credentials 25 | uses: aws-actions/configure-aws-credentials@v1 26 | with: 27 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 28 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 29 | aws-region: us-east-1 30 | - name: Push2ECR 31 | id: ecr 32 | uses: jwalton/gh-ecr-push@v1 33 | with: 34 | access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 35 | secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 36 | region: us-east-1 37 | image: sentence-encoding:latest 38 | - name: Update lambda with image 39 | run: aws lambda update-function-code --function-name sentence-encoding --image-uri 968911158010.dkr.ecr.us-east-1.amazonaws.com/sentence-encoding:latest 40 | -------------------------------------------------------------------------------- /.github/workflows/text-generation.yml: -------------------------------------------------------------------------------- 1 | name: text-generation 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master # this can be main 7 | paths: 8 | - "text-generation/**" 9 | 10 | jobs: 11 | text-generation: 12 | runs-on: ubuntu-latest 13 | defaults: 14 | run: 15 | working-directory: ./text-generation 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v2 19 | with: 20 | ref: ${{ github.ref }} 21 | - name: Build container 22 | run: | 23 | docker build --tag text-generation:latest . 24 | - name: Configure AWS Credentials 25 | uses: aws-actions/configure-aws-credentials@v1 26 | with: 27 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 28 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 29 | aws-region: us-east-1 30 | - name: Push2ECR 31 | id: ecr 32 | uses: jwalton/gh-ecr-push@v1 33 | with: 34 | access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 35 | secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 36 | region: us-east-1 37 | image: text-generation:latest 38 | - name: Update lambda with image 39 | run: aws lambda update-function-code --function-name text-generation --image-uri 968911158010.dkr.ecr.us-east-1.amazonaws.com/text-generation:latest 40 | -------------------------------------------------------------------------------- /.github/workflows/token_classification.yml: -------------------------------------------------------------------------------- 1 | name: token-classification 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master # this can be main 7 | paths: 8 | - "token-classification/**" 9 | 10 | jobs: 11 | token-classification: 12 | runs-on: ubuntu-latest 13 | defaults: 14 | run: 15 | working-directory: ./token-classification 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v2 19 | with: 20 | ref: ${{ github.ref }} 21 | - name: Build container 22 | run: | 23 | docker build --tag token-classification:latest . 24 | - name: Configure AWS Credentials 25 | uses: aws-actions/configure-aws-credentials@v1 26 | with: 27 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 28 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 29 | aws-region: us-east-1 30 | - name: Push2ECR 31 | id: ecr 32 | uses: jwalton/gh-ecr-push@v1 33 | with: 34 | access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 35 | secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 36 | region: us-east-1 37 | image: token-classification:latest 38 | - name: Update lambda with image 39 | run: aws lambda update-function-code --function-name token-classification --image-uri 968911158010.dkr.ecr.us-east-1.amazonaws.com/token-classification:latest 40 | -------------------------------------------------------------------------------- /.github/workflows/zero-shot-classification.yml: -------------------------------------------------------------------------------- 1 | name: zero-shot-classification 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master # this can be main 7 | paths: 8 | - "zero-shot-classification/**" 9 | 10 | jobs: 11 | zero-shot-classification: 12 | runs-on: ubuntu-latest 13 | defaults: 14 | run: 15 | working-directory: ./zero-shot-classification 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v2 19 | with: 20 | ref: ${{ github.ref }} 21 | - name: Build container 22 | run: | 23 | docker build --tag zero-shot-classification:latest . 24 | - name: Configure AWS Credentials 25 | uses: aws-actions/configure-aws-credentials@v1 26 | with: 27 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 28 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 29 | aws-region: us-east-1 30 | - name: Push2ECR 31 | id: ecr 32 | uses: jwalton/gh-ecr-push@v1 33 | with: 34 | access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 35 | secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 36 | region: us-east-1 37 | image: zero-shot-classification:latest 38 | - name: Update lambda with image 39 | run: aws lambda update-function-code --function-name zero-shot-classification --image-uri 968911158010.dkr.ecr.us-east-1.amazonaws.com/zero-shot-classification:latest 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End2End Serverless Transformers On AWS Lambda for NLP 🚀 2 | 3 |
4 | 5 |
You need no servers
6 |
7 | 8 | ## Deploy transformers with ease 💆‍♂️ 9 | 10 | Go through this [video](https://www.youtube.com/watch?v=EoazSUJyGbs) and [slide deck](https://bit.ly/serverless-transformers) for full info. 11 | 12 | Current available pipelines 13 | 1. classification 14 | 2. sentence encoding 15 | 3. translation **(coming soon)** 16 | 4. token classification 17 | 5. text generation 18 | 6. zero shot classification 19 | 20 | ## What you get with this? 21 | - ability to run transformers without servers 22 | - complete CI/CD 23 | - concurrency upto 1000 (default AWS limit) 24 | ## How to use this? 25 | - clone the repo 26 | - keep the pipeline folder you want to use 27 | - modify the source and tests 28 | - keep the corresponding github action in `.github/workflows` 29 | - modify directory, registry and lambda function name in workflow 30 | - create repository in AWS ECR 31 | - update ECR path in the workflow 32 | - set up secrets in repo (needed for access to AWS; this creds should have access to ECR and Lambda) 33 | - AWS_ACCESS_KEY_ID 34 | - AWS_SECRET_ACCESS_KEY 35 | - push the code 36 | - create PR 37 | - this will build the container 38 | - run all the tests 39 | - push container to ECR registry 40 | - update lambda with the new container (this will not happen when you push the first time) 41 | - create lambda function if it does not exist 42 | - give appropriate IAM role 43 | - set timeout and RAM 44 | - create API in API gateway and link to lambda 45 | 46 | Done! Now you can call the lambda using the API 47 | -------------------------------------------------------------------------------- /classification/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .aws-sam 132 | *.pyc 133 | .vscode 134 | .DS_store 135 | **.bin 136 | **.ipynb_checkpoints -------------------------------------------------------------------------------- /classification/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM amazon/aws-lambda-python 2 | 3 | ARG MODEL_DIR=./models 4 | 5 | ENV TRANSFORMERS_CACHE=$MODEL_DIR 6 | ENV TRANSFORMERS_VERBOSITY=error 7 | 8 | RUN yum -y install gcc-c++ 9 | 10 | COPY requirements.txt requirements.txt 11 | RUN pip install torch==1.8+cpu -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir 12 | RUN pip install -r requirements.txt --no-cache-dir 13 | 14 | COPY ./ ./ 15 | 16 | # Run test cases and this saves the transformer model in the container 17 | RUN pip install pytest --no-cache-dir && pytest tests -s -vv 18 | 19 | RUN chmod -R 0777 $MODEL_DIR 20 | 21 | CMD [ "main.lambda_handler"] -------------------------------------------------------------------------------- /classification/README.MD: -------------------------------------------------------------------------------- 1 | ## Classification service 2 | Classification using Transformers on AWS Lambda. Check root readme for complete setup info. 3 | 4 | ## Request format 5 | ``` 6 | { 7 | "texts": ["food was great", "food was bad", "i am going out for food"], 8 | "model_name": "cardiffnlp/twitter-roberta-base-sentiment", # optional 9 | "tokenizer_name": "roberta-base" # optional 10 | } 11 | ``` 12 | 13 | ## Response format 14 | ``` 15 | { 16 | 'predictions': [{ 17 | 'label': 'POSITIVE', 18 | 'score': 0.97 19 | }, { 20 | 'label': 'NEGATIVE', 21 | 'score': 0.95 22 | }, { 23 | 'label': 'NEUTRAL', 24 | 'score': 0.69 25 | }] 26 | } 27 | ``` -------------------------------------------------------------------------------- /classification/main.py: -------------------------------------------------------------------------------- 1 | from sklearn import pipeline 2 | from src.classifier import Classifier 3 | 4 | pipeline = Classifier() 5 | 6 | 7 | def lambda_handler(event, context): 8 | try: 9 | return pipeline(event) 10 | except Exception as e: 11 | raise 12 | -------------------------------------------------------------------------------- /classification/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.* 2 | scikit-learn==0.24.* -------------------------------------------------------------------------------- /classification/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/classification/src/__init__.py -------------------------------------------------------------------------------- /classification/src/classifier.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import lru_cache 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | from transformers import (AutoConfig, AutoModelForSequenceClassification, 7 | AutoTokenizer, pipeline) 8 | 9 | from src import config, utils 10 | 11 | logger = utils.create_logger(project_name=config.PREDICTION_TYPE, level="INFO") 12 | 13 | class Classifier: 14 | def __init__(self): 15 | _ = self.get_sentiment_pipeline(model_name=config.DEFAULT_MODEL_NAME, tokenizer_name=config.DEFAULT_TOKENIZER_NAME) #warm up 16 | 17 | @staticmethod 18 | @lru_cache(maxsize=config.CACHE_MAXSIZE) 19 | def get_sentiment_pipeline(model_name: str, tokenizer_name: str) -> pipeline: 20 | """Sentiment pipeline for the given model and tokenizer 21 | 22 | Args: 23 | model_name (str): Indicating the name of the model 24 | tokenizer_name (str): Indicating the name of the tokenizer 25 | 26 | Returns: 27 | pipeline: sentiment pipeline 28 | """ 29 | logger.info(f"Loading model: {model_name}") 30 | id2label = config.ID_SENTIMENT_MAPPING[model_name] 31 | label2id = {label: idx for idx, label in id2label.items()} 32 | 33 | model_config = AutoConfig.from_pretrained(model_name) 34 | model_config.label2id = label2id 35 | model_config.id2label = id2label 36 | model = AutoModelForSequenceClassification.from_pretrained( 37 | model_name, config=model_config 38 | ) 39 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 40 | classification_pipeline = pipeline( 41 | "sentiment-analysis", model=model, tokenizer=tokenizer 42 | ) 43 | return classification_pipeline 44 | 45 | def get_clean_text(self, text: str) -> str: 46 | """Clean the text 47 | 48 | Args: 49 | text (str): text 50 | 51 | Returns: 52 | str: clean text 53 | """ 54 | return text.strip().lower() 55 | 56 | def __call__(self, request: dict)-> dict: 57 | """Predict the sentiment of the given texts 58 | 59 | Args: 60 | request (dict): request containing the list of text to predict the sentiment 61 | 62 | Returns: 63 | dict: classes of the given text 64 | """ 65 | texts = [self.get_clean_text(text) for text in request["texts"]] 66 | 67 | model_name = request.get("model_name", config.DEFAULT_MODEL_NAME) 68 | tokenizer_name = request.get("tokenizer_name", config.DEFAULT_TOKENIZER_NAME) 69 | 70 | logger.info(f"Predicting sentiment for {len(texts)} texts") 71 | classification_pipeline = self.get_sentiment_pipeline(model_name, tokenizer_name) 72 | 73 | predictions = classification_pipeline(texts) 74 | for i, pred in enumerate(predictions): 75 | predictions[i]["score"] = round(pred["score"], 2) 76 | 77 | return { 78 | "predictions": predictions 79 | } 80 | 81 | -------------------------------------------------------------------------------- /classification/src/config.py: -------------------------------------------------------------------------------- 1 | PREDICTION_TYPE = 'classification' 2 | 3 | DEFAULT_MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment" 4 | DEFAULT_TOKENIZER_NAME = "roberta-base" 5 | ID_SENTIMENT_MAPPING = { # add for all models to be supported 6 | "cardiffnlp/twitter-roberta-base-sentiment": { 7 | 0: "NEGATIVE", 8 | 1: "NEUTRAL", 9 | 2: "POSITIVE" 10 | }, 11 | "cardiffnlp/twitter-roberta-base-emotion": { 12 | 0: "ANGER", 13 | 1: "JOY", 14 | 2: "OPTIMISM", 15 | 3: "SADNESS" 16 | } 17 | } 18 | 19 | # cache 20 | CACHE_MAXSIZE = 4 21 | -------------------------------------------------------------------------------- /classification/src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import ntpath 4 | import os 5 | from typing import Optional 6 | 7 | 8 | def create_folder(directory): 9 | if not os.path.exists(directory): 10 | os.makedirs(directory) 11 | print("Directory created: " + directory) 12 | else: 13 | print("Directory exists: " + directory) 14 | 15 | 16 | def create_logger( 17 | project_name: str, 18 | level: str = "INFO", 19 | log_dir: str = "/tmp/logs", 20 | file_name: Optional[str] = None, 21 | do_print: bool = True, 22 | simple_logging: bool = False, 23 | log_to_file: bool = False, 24 | rich_logging: bool = False, 25 | time_zone: Optional[str] = None, 26 | ): 27 | """Creates a logger of given level and saves logs to a file 28 | 29 | :param project_name: project name for which we are logging 30 | :param level: logging level 31 | LEVELS available 32 | DEBUG: Detailed information, typically of interest only when diagnosing problems. 33 | INFO: Confirmation that things are working as expected. 34 | WARNING: An indication that something unexpected happened, or indicative of some problem in the near future (e.g. 'disk space low'). The software is still working as expected. 35 | ERROR: Due to a more serious problem, the software has not been able to perform some function. 36 | CRITICAL: A serious error, indicating that the program itself may be unable to continue running. 37 | :param log_dir: directory when log files are created 38 | :param file_name: name of the log file 39 | :param do_print: whether to print the logs 40 | :param simple_logging: sets formatter to only message 41 | :param log_to_file: whether to save logs on disk 42 | :param rich_logging: colorful logging using rich 43 | :param time_zone: timezone to be used for time in logging such as Asia/Kolkata 44 | https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568 45 | """ 46 | import __main__ 47 | 48 | if file_name is None: 49 | try: 50 | file_name = ntpath.basename(__main__.__file__).split(".")[0] 51 | except: 52 | file_name = "logs" 53 | 54 | logger = logging.getLogger(file_name) 55 | logger.handlers.clear() 56 | logger.setLevel(getattr(logging, level)) 57 | 58 | if time_zone: 59 | from pytz import timezone, utc 60 | def time_formatter(*args): 61 | # TODO: Doesnt work with rich formatter 62 | utc_dt = utc.localize(datetime.datetime.utcnow()) 63 | my_tz = timezone(time_zone) 64 | converted = utc_dt.astimezone(my_tz) 65 | return converted.timetuple() 66 | 67 | logging.Formatter.converter = time_formatter 68 | 69 | if rich_logging: 70 | from rich.logging import RichHandler 71 | stream_format = f"{project_name}:%(module)s:%(funcName)s: %(message)s" 72 | stream_handler = RichHandler(omit_repeated_times=False) 73 | else: 74 | stream_format = f"%(asctime)s:%(levelname)s:{project_name}:%(module)s:%(funcName)s: %(message)s" 75 | stream_handler = logging.StreamHandler() 76 | 77 | file_formatter = stream_formatter = logging.Formatter( 78 | stream_format, "%Y-%m-%d %H:%M:%S" 79 | ) 80 | 81 | if simple_logging: 82 | file_formatter = logging.Formatter("%(message)s") 83 | stream_formatter = logging.Formatter("%(message)s") 84 | 85 | if log_to_file: 86 | date = datetime.date.today() 87 | date = "%s-%s-%s" % (date.day, date.month, date.year) 88 | log_file_path = os.path.join(log_dir, "%s-%s.log" % (file_name, date)) 89 | 90 | create_folder(log_dir) 91 | file_handler = logging.FileHandler(log_file_path) 92 | file_handler.setFormatter(file_formatter) 93 | logger.addHandler(file_handler) 94 | 95 | if do_print: 96 | stream_handler.setFormatter(stream_formatter) 97 | logger.addHandler(stream_handler) 98 | 99 | logger.propagate = False 100 | 101 | return logger 102 | -------------------------------------------------------------------------------- /classification/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/classification/tests/__init__.py -------------------------------------------------------------------------------- /classification/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src import config 3 | 4 | 5 | @pytest.fixture 6 | def requests(): 7 | return { 8 | "texts": ["food was great", "food was bad", "i am going out for food"], 9 | "model_name": config.DEFAULT_MODEL_NAME, 10 | "tokenizer_name": config.DEFAULT_TOKENIZER_NAME 11 | } 12 | 13 | 14 | @pytest.fixture 15 | def response(): 16 | return { 17 | 'predictions': [{ 18 | 'label': 'POSITIVE', 19 | 'score': 0.97 20 | }, { 21 | 'label': 'NEGATIVE', 22 | 'score': 0.95 23 | }, { 24 | 'label': 'NEUTRAL', 25 | 'score': 0.69 26 | }] 27 | } 28 | -------------------------------------------------------------------------------- /classification/tests/test_classifier.py: -------------------------------------------------------------------------------- 1 | from src.classifier import Classifier 2 | 3 | pipeline = Classifier() 4 | 5 | def test_response(requests, response): 6 | assert response == pipeline(requests) 7 | -------------------------------------------------------------------------------- /sentence-encoding/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .aws-sam 132 | *.pyc 133 | .vscode 134 | .DS_store 135 | **.bin 136 | **.ipynb_checkpoints -------------------------------------------------------------------------------- /sentence-encoding/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM amazon/aws-lambda-python 2 | 3 | ARG MODEL_DIR=./models 4 | 5 | ENV SENTENCE_TRANSFORMERS_HOME=$MODEL_DIR 6 | ENV TRANSFORMERS_CACHE=$MODEL_DIR 7 | ENV TRANSFORMERS_VERBOSITY=error 8 | 9 | RUN yum -y install gcc-c++ 10 | 11 | COPY requirements.txt requirements.txt 12 | RUN pip install torch==1.8+cpu -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir 13 | RUN pip install -r requirements.txt --no-cache-dir 14 | 15 | COPY ./ ./ 16 | 17 | # Run test cases and this saves the transformer model in the container 18 | RUN pip install pytest --no-cache-dir && pytest tests -s -vv 19 | 20 | RUN chmod -R 0777 $MODEL_DIR 21 | 22 | CMD [ "main.lambda_handler"] -------------------------------------------------------------------------------- /sentence-encoding/README.MD: -------------------------------------------------------------------------------- 1 | ## Sentence Encoding service 2 | 3 | Sentence Encoding using Sentence Transformers on AWS Lambda. Check root readme for complete setup info. 4 | 5 | ## Request format 6 | 7 | ``` 8 | { 9 | "texts": ['This framework generates embeddings for each input sentence', 10 | 'Sentences are passed as a list of string.', 11 | 'The quick brown fox jumps over the lazy dog.'], 12 | "model_name":'paraphrase-MiniLM-L6-v2' (optional) 13 | } 14 | ``` 15 | 16 | ## Response format 17 | 18 | ``` 19 | # returns a list of the shape (348,) for each sentence 20 | 21 | { 22 | "vectors": [[-1.76214233e-01 1.20601095e-01 -2.93623984e-01 -2.29858086e-01 ... -0.3186724 0.41656044 -0.05431654 0.14036156 1.0559164 0.53018135]] 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /sentence-encoding/main.py: -------------------------------------------------------------------------------- 1 | from sklearn import pipeline 2 | from src.sentence_encoder import SentenceEncoder 3 | 4 | pipeline = SentenceEncoder() 5 | 6 | 7 | def lambda_handler(event, context): 8 | try: 9 | return pipeline(event) 10 | except Exception as e: 11 | raise 12 | -------------------------------------------------------------------------------- /sentence-encoding/requirements.txt: -------------------------------------------------------------------------------- 1 | sentence-transformers==2.* 2 | tqdm==4.* 3 | scikit-learn==0.24.* -------------------------------------------------------------------------------- /sentence-encoding/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/sentence-encoding/src/__init__.py -------------------------------------------------------------------------------- /sentence-encoding/src/config.py: -------------------------------------------------------------------------------- 1 | PREDICTION_TYPE = "sentence-encoding" 2 | 3 | DEFAULT_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2" 4 | 5 | # cache 6 | CACHE_MAXSIZE = 4 7 | -------------------------------------------------------------------------------- /sentence-encoding/src/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | from functools import lru_cache 6 | 7 | from sentence_transformers import SentenceTransformer 8 | 9 | from src import config, utils 10 | 11 | logger = utils.create_logger(project_name=config.PREDICTION_TYPE, level="INFO") 12 | 13 | class SentenceEncoder: 14 | def __init__(self): 15 | _ = self.get_sent_encoder(model_name=config.DEFAULT_MODEL_NAME) #warm up 16 | 17 | @staticmethod 18 | @lru_cache(maxsize=config.CACHE_MAXSIZE) 19 | def get_sent_encoder(model_name: str) -> SentenceTransformer: 20 | """loads and returns a SentenceTransformer model specified by model_name argument 21 | Args: 22 | model_name (str): Indicating the name of the model 23 | 24 | Returns: 25 | SentenceTransformer model 26 | """ 27 | logger.info(f"Loading model: {model_name}") 28 | 29 | model = SentenceTransformer(model_name) 30 | return model 31 | 32 | def get_clean_text(self, text: str) -> str: 33 | """Clean the text 34 | 35 | Args: 36 | text (str): text 37 | 38 | Returns: 39 | str: clean text 40 | """ 41 | return text.strip().lower() 42 | 43 | def __call__(self, request: dict)-> dict: 44 | """ embeddings of the given list of sentences 45 | 46 | Args: 47 | request (dict): request containing the list of snetences for encoding 48 | 49 | Returns: 50 | dict: list of embeddings for each sentence embedding dimension = (384,) 51 | """ 52 | texts = [self.get_clean_text(text) for text in request["texts"]] 53 | 54 | logger.info(f"Generating embeddings for {len(texts)} sentences") 55 | 56 | model_name = request.get('model_name', config.DEFAULT_MODEL_NAME) 57 | 58 | sentence_encoder = self.get_sent_encoder(model_name) 59 | 60 | embeddings = sentence_encoder.encode(texts) 61 | 62 | return { 63 | "vectors": embeddings 64 | } 65 | 66 | -------------------------------------------------------------------------------- /sentence-encoding/src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import ntpath 4 | import os 5 | from typing import Optional 6 | 7 | 8 | def create_folder(directory): 9 | if not os.path.exists(directory): 10 | os.makedirs(directory) 11 | print("Directory created: " + directory) 12 | else: 13 | print("Directory exists: " + directory) 14 | 15 | 16 | def create_logger( 17 | project_name: str, 18 | level: str = "INFO", 19 | log_dir: str = "/tmp/logs", 20 | file_name: Optional[str] = None, 21 | do_print: bool = True, 22 | simple_logging: bool = False, 23 | log_to_file: bool = False, 24 | rich_logging: bool = False, 25 | time_zone: Optional[str] = None, 26 | ): 27 | """Creates a logger of given level and saves logs to a file 28 | 29 | :param project_name: project name for which we are logging 30 | :param level: logging level 31 | LEVELS available 32 | DEBUG: Detailed information, typically of interest only when diagnosing problems. 33 | INFO: Confirmation that things are working as expected. 34 | WARNING: An indication that something unexpected happened, or indicative of some problem in the near future (e.g. 'disk space low'). The software is still working as expected. 35 | ERROR: Due to a more serious problem, the software has not been able to perform some function. 36 | CRITICAL: A serious error, indicating that the program itself may be unable to continue running. 37 | :param log_dir: directory when log files are created 38 | :param file_name: name of the log file 39 | :param do_print: whether to print the logs 40 | :param simple_logging: sets formatter to only message 41 | :param log_to_file: whether to save logs on disk 42 | :param rich_logging: colorful logging using rich 43 | :param time_zone: timezone to be used for time in logging such as Asia/Kolkata 44 | https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568 45 | """ 46 | import __main__ 47 | 48 | if file_name is None: 49 | try: 50 | file_name = ntpath.basename(__main__.__file__).split(".")[0] 51 | except: 52 | file_name = "logs" 53 | 54 | logger = logging.getLogger(file_name) 55 | logger.handlers.clear() 56 | logger.setLevel(getattr(logging, level)) 57 | 58 | if time_zone: 59 | from pytz import timezone, utc 60 | def time_formatter(*args): 61 | # TODO: Doesnt work with rich formatter 62 | utc_dt = utc.localize(datetime.datetime.utcnow()) 63 | my_tz = timezone(time_zone) 64 | converted = utc_dt.astimezone(my_tz) 65 | return converted.timetuple() 66 | 67 | logging.Formatter.converter = time_formatter 68 | 69 | if rich_logging: 70 | from rich.logging import RichHandler 71 | stream_format = f"{project_name}:%(module)s:%(funcName)s: %(message)s" 72 | stream_handler = RichHandler(omit_repeated_times=False) 73 | else: 74 | stream_format = f"%(asctime)s:%(levelname)s:{project_name}:%(module)s:%(funcName)s: %(message)s" 75 | stream_handler = logging.StreamHandler() 76 | 77 | file_formatter = stream_formatter = logging.Formatter( 78 | stream_format, "%Y-%m-%d %H:%M:%S" 79 | ) 80 | 81 | if simple_logging: 82 | file_formatter = logging.Formatter("%(message)s") 83 | stream_formatter = logging.Formatter("%(message)s") 84 | 85 | if log_to_file: 86 | date = datetime.date.today() 87 | date = "%s-%s-%s" % (date.day, date.month, date.year) 88 | log_file_path = os.path.join(log_dir, "%s-%s.log" % (file_name, date)) 89 | 90 | create_folder(log_dir) 91 | file_handler = logging.FileHandler(log_file_path) 92 | file_handler.setFormatter(file_formatter) 93 | logger.addHandler(file_handler) 94 | 95 | if do_print: 96 | stream_handler.setFormatter(stream_formatter) 97 | logger.addHandler(stream_handler) 98 | 99 | logger.propagate = False 100 | 101 | return logger 102 | -------------------------------------------------------------------------------- /sentence-encoding/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/sentence-encoding/tests/__init__.py -------------------------------------------------------------------------------- /sentence-encoding/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src import config 3 | 4 | 5 | @pytest.fixture 6 | def requests(): 7 | return { 8 | "texts": ['This framework generates embeddings for each input sentence', 9 | 'Sentences are passed as a list of string.', 10 | 'The quick brown fox jumps over the lazy dog.'] 11 | } 12 | 13 | 14 | @pytest.fixture 15 | def response(): 16 | return { 17 | 'vectors': [[-1.762142330408096313e-01, 1.206010952591896057e-01, -2.936239838600158691e-01, -2.298580855131149292e-01, -8.229275792837142944e-02, 2.377093583345413208e-01, 3.399852216243743896e-01, -7.809641361236572266e-01, 1.181278675794601440e-01, 1.633740067481994629e-01, -1.377150118350982666e-01, 2.402825355529785156e-01, 4.251254796981811523e-01, 1.724179536104202271e-01, 1.052796989679336548e-01, 5.181643962860107422e-01, 6.222178414463996887e-02, 3.992859423160552979e-01, -1.816524118185043335e-01, -5.855786800384521484e-01, 4.497172683477401733e-02, -1.727504432201385498e-01, -2.684433758258819580e-01, -1.473859846591949463e-01, -1.892180442810058594e-01, 1.921506971120834351e-01, -3.838426470756530762e-01, -3.960070312023162842e-01, 4.306489229202270508e-01, -3.153194189071655273e-01, 3.659493327140808105e-01, 6.051584333181381226e-02, 3.573259115219116211e-01, 1.597363203763961792e-01, -3.009840846061706543e-01, 2.632503807544708252e-01, -3.943109810352325439e-01, 1.848553270101547241e-01, -3.995491266250610352e-01, -2.678897678852081299e-01, -5.451174378395080566e-01, -3.134060651063919067e-02, -4.306440055370330811e-01, 1.332782506942749023e-01, -1.747937947511672974e-01, -4.354656040668487549e-01, -4.773788452148437500e-01, 7.125558704137802124e-02, -7.370020449161529541e-02, 5.691370368003845215e-01, -2.825797796249389648e-01, 5.249750241637229919e-02, -8.200080394744873047e-01, 1.982968002557754517e-01, 1.695117801427841187e-01, 2.717802226543426514e-01, 2.646110653877258301e-01, -2.557394839823246002e-02, -1.740963160991668701e-01, 1.633144766092300415e-01, -3.952611088752746582e-01, -3.175597637891769409e-02, -2.625561356544494629e-01, 3.527543842792510986e-01, 3.014345467090606689e-01, -1.471974998712539673e-01, 2.100759595632553101e-01, -1.840105354785919189e-01, -4.128958582878112793e-01, 4.147755205631256104e-01, -1.897691041231155396e-01, -1.354821026325225830e-01, -3.792723119258880615e-01, -4.680226370692253113e-02, -3.336004167795181274e-02, 9.003944694995880127e-02, -3.301327526569366455e-01, -3.873167559504508972e-02, 3.750822842121124268e-01, -1.469965875148773193e-01, 4.349598586559295654e-01, 5.383257269859313965e-01, -2.654454112052917480e-01, 1.644457727670669556e-01, 4.170784354209899902e-01, -4.725078865885734558e-02, -7.487314194440841675e-02, -4.262605905532836914e-01, -1.969945430755615234e-01, 6.103179603815078735e-02, -4.742631316184997559e-01, -6.483347415924072266e-01, 3.714625239372253418e-01, 2.509568929672241211e-01, 1.225298494100570679e-01, 8.887653797864913940e-02, -1.067240089178085327e-01, 5.339851230382919312e-02, 9.745053946971893311e-02, -3.466571122407913208e-02, -1.028828322887420654e-01, 2.322889268398284912e-01, -2.537398934364318848e-01, -5.131123065948486328e-01, 1.852160990238189697e-01, -3.043577075004577637e-01, -3.552089631557464600e-02, -1.269750446081161499e-01, -7.716330140829086304e-02, -5.153301954269409180e-01, -2.280719429254531860e-01, 2.033434994518756866e-02, 7.381757348775863647e-02, -1.525584012269973755e-01, -4.008376300334930420e-01, -2.477492988109588623e-01, 3.974704742431640625e-01, -2.602606117725372314e-01, 2.509060800075531006e-01, 1.682287901639938354e-01, 1.339003145694732666e-01, -2.108344808220863342e-02, -4.700354337692260742e-01, 4.788500964641571045e-01, 2.803455591201782227e-01, -4.645468294620513916e-01, 3.217469453811645508e-01, 2.342072576284408569e-01, 2.457724511623382568e-01, -4.714821279048919678e-01, 5.004014968872070312e-01, 4.101898968219757080e-01, 5.152167677879333496e-01, 2.625494301319122314e-01, 2.115933038294315338e-02, -3.896875083446502686e-01, -2.417427897453308105e-01, -2.148346602916717529e-01, -8.626504242420196533e-02, -1.653234362602233887e-01, -5.218933522701263428e-02, 3.418747484683990479e-01, 4.503143727779388428e-01, -3.069733977317810059e-01, -2.022944092750549316e-01, 6.855217218399047852e-01, -5.338923335075378418e-01, 3.584715127944946289e-01, 1.452866494655609131e-01, -7.070577144622802734e-02, -1.505292356014251709e-01, -8.562794327735900879e-02, -7.678513973951339722e-02, 1.895446479320526123e-01, -1.040674969553947449e-01, 5.335438251495361328e-01, -5.278869867324829102e-01, 2.423308603465557098e-02, -2.643479108810424805e-01, -2.231865972280502319e-01, -3.812087774276733398e-01, 7.599139958620071411e-02, -4.644850194454193115e-01, -3.365494608879089355e-01, 4.212298393249511719e-01, 1.074794009327888489e-01, 1.904578208923339844e-01, 2.894973149523139000e-03, -1.085136681795120239e-01, 1.535454690456390381e-01, 3.160231411457061768e-01, -2.708382718265056610e-02, -5.405944585800170898e-01, 8.972890675067901611e-02, -1.155498400330543518e-01, 3.978038132190704346e-01, -4.976834058761596680e-01, -2.848933637142181396e-01, 4.998634755611419678e-02, 3.612796962261199951e-01, 6.905353069305419922e-01, 1.468216478824615479e-01, 1.733965277671813965e-01, -1.745821088552474976e-01, -3.157024085521697998e-01, 6.730007380247116089e-02, 2.172500193119049072e-01, 9.785352647304534912e-02, -1.294724792242050171e-01, -1.869295537471771240e-01, 1.348781287670135498e-01, 18 | -1.538850963115692139e-01, 7.447167485952377319e-02, -1.855361759662628174e-01, -2.806282639503479004e-01, -1.141442880034446716e-01, 4.122495949268341064e-01, 6.394940614700317383e-02, -1.457153260707855225e-01, -9.820619970560073853e-02, -1.330819427967071533e-01, -1.884108334779739380e-01, -2.848415635526180267e-02, -3.495096787810325623e-02, 3.342579305171966553e-02, 6.988976150751113892e-02, 1.903543323278427124e-01, -2.967241406440734863e-01, 2.646947046741843224e-03, 1.091409251093864441e-01, 1.708933338522911072e-02, 2.605892121791839600e-01, 3.290384411811828613e-01, -6.615598499774932861e-02, 2.396654188632965088e-01, -2.261948734521865845e-01, -3.368677198886871338e-02, 1.494003832340240479e-01, -3.212654292583465576e-01, -2.685779333114624023e-01, 5.726315975189208984e-01, -4.923083186149597168e-01, 2.006668746471405029e-01, -3.492617011070251465e-01, -2.898878604173660278e-02, 6.090103387832641602e-01, -5.723330974578857422e-01, 2.350005209445953369e-01, 6.471671629697084427e-03, -3.149505332112312317e-02, 2.781099453568458557e-02, -3.903406262397766113e-01, -2.089497298002243042e-01, -3.044527471065521240e-01, -7.201950252056121826e-02, -8.298408985137939453e-02, 3.737927079200744629e-01, 7.389393448829650879e-02, -2.210726961493492126e-02, 9.881400316953659058e-02, -1.514267623424530029e-01, -1.404307633638381958e-01, 2.260178029537200928e-01, 2.760902643203735352e-01, -8.877505362033843994e-02, -1.128159686923027039e-01, -2.662860751152038574e-01, 2.778344750404357910e-01, -4.756131395697593689e-02, 6.710069626569747925e-02, -2.785863541066646576e-02, -2.399928681552410126e-02, 2.517086863517761230e-01, 4.687937498092651367e-01, -5.393254756927490234e-01, 1.105985268950462341e-01, -3.449472188949584961e-01, 4.159896969795227051e-01, 7.284827530384063721e-02, -3.196474611759185791e-01, 4.903741478919982910e-01, -7.303585298359394073e-03, -2.642561215907335281e-03, 9.637111425399780273e-01, 3.238849043846130371e-01, -7.796172052621841431e-02, -2.375894188880920410e-01, 2.340381741523742676e-01, -3.160540759563446045e-01, -1.656696782447397709e-03, -1.090706586837768555e+00, 3.384092152118682861e-01, 4.706051200628280640e-02, 1.074354574084281921e-01, -2.066721320152282715e-01, 4.264194052666425705e-03, -1.384751172736287117e-03, -5.314556956291198730e-01, -2.756482958793640137e-01, -1.646486371755599976e-01, -3.429162204265594482e-01, -4.261189103126525879e-01, 6.018121242523193359e-01, 4.559717774391174316e-01, -2.727017998695373535e-01, -3.458075225353240967e-02, 2.627523541450500488e-01, -6.341892760246992111e-03, 2.796310782432556152e-01, -2.535589039325714111e-01, -1.686263829469680786e-01, 3.829357400536537170e-02, 2.077634185552597046e-01, -4.315258562564849854e-01, -7.239980250597000122e-02, -1.268542557954788208e-01, 2.070293202996253967e-02, 5.744414329528808594e-01, 3.546725511550903320e-01, 9.283009171485900879e-02, 6.705061346292495728e-02, 1.115204244852066040e-01, -1.865123026072978973e-02, 4.623519778251647949e-01, 2.725045979022979736e-01, -3.604742288589477539e-01, 5.294152498245239258e-01, -1.003212062641978264e-03, -8.813624083995819092e-02, 1.499755829572677612e-01, 5.258591473102569580e-02, 4.635176062583923340e-01, -3.968313336372375488e-01, 2.426403164863586426e-01, -2.089123725891113281e-01, 3.656721115112304688e-01, -4.735508700832724571e-04, 5.339631438255310059e-01, -1.978794932365417480e-01, 3.115830123424530029e-01, -6.967152953147888184e-01, -4.295006096363067627e-01, -4.493593871593475342e-01, -2.713678032159805298e-02, -6.987117230892181396e-02, 2.061746418476104736e-01, -1.571076959371566772e-01, 4.435211718082427979e-01, -6.742671132087707520e-02, -3.009242713451385498e-01, 5.148594379425048828e-01, 3.360294699668884277e-01, 6.633765250444412231e-02, -1.152352169156074524e-01, -2.959796600043773651e-02, 2.794718444347381592e-01, -3.481988236308097839e-02, -7.293223589658737183e-02, -4.584729671478271484e-02, 1.542629599571228027e-01, 8.093562722206115723e-01, 5.203280448913574219e-01, -4.021146893501281738e-01, -3.231527656316757202e-02, -1.103639751672744751e-01, 7.505026459693908691e-02, -1.510986685752868652e-01, 8.457400202751159668e-01, -1.808441281318664551e-01, 3.225733637809753418e-01, 1.047080457210540771e-01, 3.196638524532318115e-01, -1.550855338573455811e-01, 1.692366600036621094e-01, -2.569966018199920654e-01, 2.012090384960174561e-01, 1.773930937051773071e-01, -2.743332386016845703e-01, -3.369442522525787354e-01, 5.023568868637084961e-01, -1.183573007583618164e-01, -2.011669576168060303e-01, -5.364859104156494141e-01, -7.698090374469757080e-02, 1.153794955462217331e-02, -2.364642918109893799e-01, -2.987704798579216003e-02, 1.313665360212326050e-01, 2.941844761371612549e-01, 9.909154474735260010e-02, -5.438975691795349121e-01, 1.408130675554275513e-01, 3.669986426830291748e-01, 5.048625171184539795e-02, 1.991224288940429688e-01, -2.806745469570159912e-01, 4.341922402381896973e-01, -1.402750909328460693e-01, 5.780487656593322754e-01, 1.777157634496688843e-01, 8.983647078275680542e-02, 3.296516537666320801e-01, 6.130084767937660217e-02, -3.249333500862121582e-01], 19 | [3.220873475074768066e-01, -1.239313278347253799e-03, 1.793740540742874146e-01, -3.691916763782501221e-01, -6.460254639387130737e-02, 9.153671562671661377e-02, 2.411911040544509888e-01, -2.949422895908355713e-01, 7.728967815637588501e-02, 1.157702207565307617e-01, -4.479986801743507385e-02, 1.792827546596527100e-01, 1.475359350442886353e-01, 2.151165902614593506e-01, 3.681079149246215820e-01, 2.091092020273208618e-01, 2.719422578811645508e-01, 3.488005101680755615e-01, -5.725189447402954102e-01, -1.825320869684219360e-01, 4.448955357074737549e-01, 2.745294272899627686e-01, 4.266280680894851685e-02, -7.683566957712173462e-02, 1.868916153907775879e-01, 4.496503174304962158e-01, -1.693259030580520630e-01, -2.489633411169052124e-01, -2.047924995422363281e-01, 4.028501510620117188e-01, -2.101926207542419434e-01, 3.775680437684059143e-02, 7.848539203405380249e-02, 1.284843087196350098e-01, 2.593061327934265137e-02, 4.715599715709686279e-01, 1.785378158092498779e-01, -7.379743456840515137e-02, 8.130741119384765625e-02, -2.332873344421386719e-01, -4.980126619338989258e-01, -4.135714471340179443e-02, -1.209461018443107605e-01, 1.702897548675537109e-01, -1.915408521890640259e-01, -3.845985531806945801e-01, -7.747915387153625488e-01, -1.062273234128952026e-01, -2.304487377405166626e-01, 4.024147689342498779e-01, -8.745088577270507812e-01, 2.385370880365371704e-01, -4.712987244129180908e-01, 2.126221209764480591e-01, 3.340933322906494141e-01, -2.415401339530944824e-01, -1.483509540557861328e-01, -1.451357752084732056e-01, -3.483093380928039551e-01, -8.349202573299407959e-02, -6.909726858139038086e-01, -2.984526455402374268e-01, -1.223049387335777283e-01, 7.482658326625823975e-02, -1.877558678388595581e-01, -3.754654824733734131e-01, 2.136952728033065796e-01, -1.009642779827117920e-01, -1.223443225026130676e-01, 3.143149614334106445e-01, -2.398996949195861816e-01, 2.246076464653015137e-01, 3.996007889509201050e-02, 3.603481054306030273e-01, -5.663804411888122559e-01, 2.188350707292556763e-01, 1.102031245827674866e-01, -1.087081879377365112e-01, 7.084076106548309326e-02, -2.608161605894565582e-02, 1.837034225463867188e-01, 8.465936034917831421e-02, -2.047823071479797363e-01, -2.443560957908630371e-01, -8.180584013462066650e-02, -1.903086900711059570e-02, -3.591400757431983948e-02, 2.398450858891010284e-02, -2.855857014656066895e-01, 7.374799251556396484e-02, -2.974424064159393311e-01, -8.771786093711853027e-01, 4.710197150707244873e-01, -4.940467700362205505e-02, 3.639449179172515869e-01, 4.826440811157226562e-01, 1.564634218811988831e-02, 3.558915480971336365e-02, -2.620298564434051514e-01, -1.121847182512283325e-01, 2.411031164228916168e-02, 3.747780621051788330e-01, -9.897301346063613892e-02, -9.851840138435363770e-02, 1.500086337327957153e-01, 6.895607803016901016e-03, -1.265247017145156860e-01, -3.159892857074737549e-01, 3.144952654838562012e-01, -2.942563295364379883e-01, -2.694102823734283447e-01, 2.022118419408798218e-01, 1.432989835739135742e-01, -1.958461403846740723e-01, -3.410446345806121826e-01, -3.172732889652252197e-02, 7.365031838417053223e-01, 3.192348778247833252e-01, 2.438131272792816162e-01, 3.073262274265289307e-01, 9.933231025934219360e-02, 1.901090294122695923e-01, -1.069451048970222473e-01, 5.178655311465263367e-02, 3.233431279659271240e-02, -1.031463071703910828e-01, 2.649920284748077393e-01, 3.120644390583038330e-01, 4.315259754657745361e-01, -6.426120996475219727e-01, 8.409559726715087891e-02, -4.327352344989776611e-02, -4.991196468472480774e-02, -1.271858215332031250e-01, 1.378916352987289429e-01, 1.306246872991323471e-02, 3.438322246074676514e-01, 9.234275668859481812e-02, -9.922754019498825073e-02, -5.215995907783508301e-01, 2.584226131439208984e-01, -1.057167630642652512e-02, -4.781659226864576340e-03, 3.938866034150123596e-02, 1.908608973026275635e-01, 3.293388485908508301e-01, -2.434513717889785767e-01, -7.328316569328308105e-02, -3.928004205226898193e-01, 1.454180181026458740e-01, 3.283953368663787842e-01, -4.184615612030029297e-02, 7.407122105360031128e-02, -7.386051416397094727e-01, -9.075982123613357544e-02, 1.580233424901962280e-01, -9.780049324035644531e-02, -2.160598188638687134e-01, -3.002744615077972412e-01, 2.323656082153320312e-01, 1.072475314140319824e-02, 4.957046508789062500e-01, 4.974841699004173279e-02, 2.993140518665313721e-01, -5.382250621914863586e-02, 3.532811999320983887e-01, 3.419177234172821045e-01, 4.966728389263153076e-01, -4.860525131225585938e-01, -1.909886747598648071e-01, 8.154573440551757812e-01, 2.296263128519058228e-01, -3.207779824733734131e-01, -3.272672891616821289e-01, -3.677168786525726318e-01, 3.452117443084716797e-01, -2.620116993784904480e-02, -1.431507021188735962e-01, 1.064844354987144470e-01, -2.463801801204681396e-01, -9.366600215435028076e-02, 1.719863712787628174e-01, -8.508807420730590820e-02, 2.012029290199279785e-01, -5.879214033484458923e-02, -3.402101099491119385e-01, -1.956532895565032959e-01, 2.828088104724884033e-01, 2.012428641319274902e-01, -8.207251131534576416e-02, 9.779152274131774902e-02, -2.637499868869781494e-01, 20 | 1.217653453350067139e-01, -1.041472330689430237e-02, -4.385982751846313477e-01, 1.105825379490852356e-01, 4.801034629344940186e-01, -1.098195090889930725e-01, -6.375459432601928711e-01, 2.933678328990936279e-01, -1.920764893293380737e-01, 4.653698205947875977e-01, 2.704200744628906250e-01, 1.938849985599517822e-01, 1.737902462482452393e-01, -3.007701635360717773e-01, -2.751181833446025848e-02, -2.291271276772022247e-02, 3.678463995456695557e-01, 2.492175623774528503e-02, 5.370550751686096191e-01, 1.885121315717697144e-01, -1.334442049264907837e-01, 8.917348831892013550e-02, 5.542919784784317017e-02, -2.481835782527923584e-01, -4.199777916073799133e-02, 5.767389386892318726e-02, -1.827879101037979126e-01, -4.168645739555358887e-01, 1.607060581445693970e-01, -4.636249840259552002e-01, 1.176923438906669617e-01, -3.770689666271209717e-01, 2.960372157394886017e-02, 6.925609111785888672e-01, -4.830894172191619873e-01, 2.112836986780166626e-01, 1.821453720331192017e-01, -1.842960864305496216e-01, 6.817662715911865234e-02, -2.460883930325508118e-02, -1.907362788915634155e-01, -6.736993044614791870e-02, -5.670071244239807129e-01, -2.392930537462234497e-01, -8.497231453657150269e-02, 3.093983046710491180e-02, 3.107990920543670654e-01, 1.291628479957580566e-01, 5.248226970434188843e-02, -3.344979882240295410e-01, 1.881012320518493652e-01, 2.354713529348373413e-01, -1.834763679653406143e-03, 4.536155760288238525e-01, 2.488507777452468872e-01, -5.641095340251922607e-02, -2.977458834648132324e-01, -4.351175129413604736e-01, -7.969439774751663208e-02, -1.767016202211380005e-01, -1.334709078073501587e-01, 1.938273310661315918e-01, 2.200260609388351440e-01, -1.105752289295196533e-01, 2.647372484207153320e-01, -2.717908024787902832e-01, 3.410907089710235596e-02, -4.771438837051391602e-01, 4.471905529499053955e-01, -5.570406094193458557e-02, 3.964374661445617676e-01, 2.748323678970336914e-01, 3.330560624599456787e-01, -1.089023798704147339e-01, 2.788817584514617920e-01, 2.159695178270339966e-01, -5.252279341220855713e-02, -3.586752712726593018e-01, -6.906290650367736816e-01, 3.960205614566802979e-02, 6.527704186737537384e-03, -1.095331553369760513e-02, -1.002768501639366150e-01, 4.770006984472274780e-02, -3.414691686630249023e-01, -1.671415418386459351e-01, 7.136443257331848145e-02, -1.807848662137985229e-01, -3.024845421314239502e-01, -6.842871308326721191e-01, -9.592869877815246582e-02, -2.141110748052597046e-01, -6.552440524101257324e-01, 5.675646066665649414e-01, 2.694673240184783936e-01, -1.900700037367641926e-03, 8.618065118789672852e-01, 1.677159219980239868e-01, 3.102787584066390991e-02, -2.677303850650787354e-01, -7.830285280942916870e-02, -4.851088523864746094e-01, -2.673721611499786377e-01, -3.335425853729248047e-01, -5.738252401351928711e-01, 3.567825555801391602e-01, 8.993595838546752930e-02, -1.305717229843139648e-01, -1.513648182153701782e-01, -6.124149262905120850e-02, -1.303707212209701538e-01, 5.585607290267944336e-01, 6.141750812530517578e-01, -4.804052039980888367e-02, -6.388589739799499512e-02, 8.390594273805618286e-02, -2.514369189739227295e-01, -4.359832778573036194e-02, -1.852578520774841309e-01, 4.693385586142539978e-02, -3.438087105751037598e-01, -9.738498181104660034e-02, 1.683363318443298340e-01, 7.526867836713790894e-02, 1.769449561834335327e-01, 1.772719025611877441e-01, -3.423422947525978088e-02, 1.499355733394622803e-01, -1.377317905426025391e-01, -2.094969302415847778e-01, -6.127283573150634766e-01, 3.781396746635437012e-01, 3.901829719543457031e-01, -8.359315991401672363e-02, 3.152117878198623657e-02, 1.312240660190582275e-01, 3.882605433464050293e-01, 2.184422910213470459e-01, 9.724286198616027832e-02, 4.208936393260955811e-01, -3.264123499393463135e-01, -2.693341970443725586e-01, -3.909511864185333252e-01, -2.264865487813949585e-01, -3.202073872089385986e-01, -1.628742963075637817e-01, -3.581614047288894653e-02, 3.637387156486511230e-01, 1.858329921960830688e-01, -2.914021164178848267e-02, -4.657791256904602051e-01, 2.916888296604156494e-01, 3.725127875804901123e-01, -2.372660040855407715e-01, 3.386467695236206055e-03, 4.154096841812133789e-01, 3.300432115793228149e-02, 4.500397443771362305e-01, -8.159277588129043579e-02, 3.399033248424530029e-01, 2.449786067008972168e-01, 2.352412417531013489e-02, -1.464306265115737915e-01, -1.264452338218688965e-01, 3.112861812114715576e-01, -1.518262922763824463e-01, 1.009387709200382233e-02, 4.910853505134582520e-01, 1.436241269111633301e-01, 1.158903315663337708e-01, -2.323697209358215332e-01, 2.475174367427825928e-01, 1.836448311805725098e-01, -2.483685612678527832e-01, -1.122094392776489258e-01, -2.311335057020187378e-01, 8.428957313299179077e-02, -2.437865734100341797e-01, 1.330728679895401001e-01, 4.235569238662719727e-01, 3.334836065769195557e-01, -3.437012135982513428e-01, 3.443647548556327820e-02, 1.879549026489257812e-01, 2.003718763589859009e-01, -5.355946719646453857e-02, 2.848529219627380371e-01, 7.176562398672103882e-02, 5.487144365906715393e-02, -8.103788644075393677e-02, 2.707686424255371094e-01, 1.170023232698440552e-01], 21 | [5.897930860519409180e-01, -2.359832376241683960e-01, -2.541170120239257812e-01, 3.116633975878357887e-03, -8.485706895589828491e-02, -2.679976820945739746e-01, -7.506710290908813477e-02, -3.002136647701263428e-01, 5.151664093136787415e-02, 1.658532172441482544e-01, 2.607674896717071533e-01, 3.825633525848388672e-01, 4.373287260532379150e-01, -9.301974624395370483e-02, -2.656879723072052002e-01, -9.716296195983886719e-02, -4.809605777263641357e-01, 1.187827885150909424e-01, 1.367550343275070190e-01, 4.712080582976341248e-02, -2.369651645421981812e-01, -5.233234763145446777e-01, -1.631881482899188995e-02, 6.127280369400978088e-02, -7.433295845985412598e-01, -1.189892366528511047e-01, -7.886531949043273926e-01, -4.810884296894073486e-01, 1.031494066119194031e-01, -3.237242400646209717e-01, 8.144375681877136230e-01, -3.977453410625457764e-01, -5.031558275222778320e-01, -7.972458004951477051e-01, -6.324826478958129883e-01, 3.232096731662750244e-01, -3.841938972473144531e-01, -1.118668839335441589e-01, -1.324360221624374390e-01, 2.069724537432193756e-02, -1.430956125259399414e-01, -3.701156005263328552e-02, 6.116622313857078552e-02, 1.633288562297821045e-01, -1.117428839206695557e-01, 2.523421943187713623e-01, -1.046407103538513184e+00, -3.725237548351287842e-01, 1.560198068618774414e-01, -2.999159693717956543e-01, 1.988387256860733032e-01, 2.343344241380691528e-01, -3.702580928802490234e-01, 3.173359930515289307e-01, 8.442859649658203125e-01, 6.977686285972595215e-02, 3.273648396134376526e-02, 9.948327392339706421e-02, -3.114135563373565674e-01, 5.051773190498352051e-01, 3.092621685937047005e-03, 3.801366388797760010e-01, 4.582764208316802979e-02, 6.333882454782724380e-03, -1.429482013918459415e-03, -1.356867849826812744e-01, -7.611397653818130493e-02, -2.584426701068878174e-01, -8.022130131721496582e-01, 5.508586764335632324e-01, -9.124376624822616577e-02, -2.178205102682113647e-01, -7.881092429161071777e-01, -5.118381381034851074e-01, 4.667254686355590820e-01, 5.527477860450744629e-01, -3.712476193904876709e-01, -1.864536553621292114e-01, 3.585702180862426758e-01, -1.958634704351425171e-01, 1.804253458976745605e-01, -4.254889488220214844e-01, -9.681382030248641968e-02, -5.536802113056182861e-02, 5.248928666114807129e-01, 2.448114603757858276e-01, 1.934617571532726288e-02, -2.963794171810150146e-01, -1.277786940336227417e-01, -3.053497672080993652e-01, 4.534939825534820557e-01, 7.469131797552108765e-02, -7.061695307493209839e-02, 2.624299228191375732e-01, 3.738394975662231445e-01, 1.430637687444686890e-01, 1.278782845474779606e-03, -4.177604615688323975e-01, -2.401406615972518921e-01, -2.509351670742034912e-01, 3.484380245208740234e-01, 3.114404380321502686e-01, 8.087333291769027710e-02, -5.764053463935852051e-01, 5.408530831336975098e-01, -1.802206784486770630e-02, -1.295981407165527344e-01, -7.399659603834152222e-02, 3.936978280544281006e-01, 6.488384604454040527e-01, -2.029980346560478210e-02, -5.665556192398071289e-01, 2.967598140239715576e-01, 5.200023055076599121e-01, 2.153875231742858887e-01, 1.036966815590858459e-01, 6.199258565902709961e-02, 1.896279491484165192e-02, -1.526914983987808228e-01, -1.064266324043273926e+00, 7.614958882331848145e-01, 2.073440998792648315e-01, 4.471894502639770508e-01, 1.449393630027770996e-01, 6.580228805541992188e-01, -9.440919756889343262e-02, -2.331637144088745117e-01, 4.215706884860992432e-01, 1.195765361189842224e-01, -3.257109224796295166e-01, 1.642551422119140625e-01, -4.950869977474212646e-01, -1.951612085103988647e-01, -5.618322491645812988e-01, -1.493323594331741333e-01, 6.109409928321838379e-01, -1.789793968200683594e-01, -1.805550791323184967e-02, -5.964048504829406738e-01, 4.918630048632621765e-02, 1.534783542156219482e-01, -4.282945692539215088e-01, 7.329528927803039551e-01, -3.529110848903656006e-01, -1.115965247154235840e-01, 6.127825006842613220e-02, -2.970442771911621094e-01, 4.396659433841705322e-01, -9.660355001688003540e-02, 6.557945609092712402e-01, -6.140334606170654297e-01, 2.576654590666294098e-02, 4.382747411727905273e-01, 1.733195222914218903e-02, -4.000226557254791260e-01, -8.178367465734481812e-02, -3.712696731090545654e-01, 8.230254054069519043e-02, -1.310442984104156494e-01, -5.326109528541564941e-01, -2.992832958698272705e-01, 6.993656754493713379e-01, -4.398765042424201965e-02, -1.570302397012710571e-01, 9.794142097234725952e-02, -3.017469309270381927e-02, -1.000270023941993713e-01, 1.999655365943908691e-01, -4.818853437900543213e-01, 1.794915199279785156e-01, 5.656598806381225586e-01, -1.195481792092323303e-01, -6.963727474212646484e-01, 5.259707570075988770e-02, -5.496182013303041458e-03, 1.673939079046249390e-01, -3.169286251068115234e-01, -9.747564792633056641e-02, 3.319365680217742920e-01, 4.719963073730468750e-01, 1.265397518873214722e-01, 1.913098245859146118e-01, 4.294907152652740479e-01, 5.529124140739440918e-01, 3.146334886550903320e-01, -3.143309056758880615e-01, -4.150866568088531494e-01, 3.289771378040313721e-01, 3.570270836353302002e-01, -1.920964270830154419e-01, 2.223942279815673828e-01, -4.871788918972015381e-01, 22 | 3.409155905246734619e-01, -2.213744521141052246e-01, -1.266756504774093628e-01, 2.112082839012145996e-01, -3.134789764881134033e-01, 8.468940854072570801e-01, 2.011267691850662231e-01, -4.259879291057586670e-01, 5.131570100784301758e-01, -1.235141396522521973e+00, 7.697179317474365234e-01, -1.741422861814498901e-01, -2.181101031601428986e-02, -3.568635880947113037e-02, -1.105949044227600098e+00, -5.720656514167785645e-01, 5.585205927491188049e-02, 1.246149912476539612e-01, -4.506591260433197021e-01, 6.429003924131393433e-02, -1.603386402130126953e-01, 3.993293344974517822e-01, -1.032289788126945496e-01, -2.025506459176540375e-02, -1.801048517227172852e-01, 6.234775856137275696e-02, -2.188893593847751617e-02, -1.579539030790328979e-01, 2.831696271896362305e-01, 2.385283075273036957e-02, 3.098120354115962982e-02, -7.853289693593978882e-02, 2.989652752876281738e-01, -6.237306818366050720e-02, 5.498673915863037109e-01, 1.786233633756637573e-01, 2.116472125053405762e-01, 4.448334872722625732e-01, 4.890750348567962646e-02, -1.623808294534683228e-01, -2.266986370086669922e-01, 1.887196898460388184e-01, 7.943605631589889526e-02, 1.359758377075195312e-01, -1.848446130752563477e-01, 1.113551020622253418e+00, 8.280951976776123047e-01, -3.120273053646087646e-01, 9.506001323461532593e-02, 5.096073076128959656e-02, 3.880488872528076172e-01, 2.500048279762268066e-01, 5.584861040115356445e-01, 3.108873963356018066e-01, -5.318577215075492859e-02, -7.675344496965408325e-02, 1.528232544660568237e-01, 9.189971536397933960e-02, -1.429156679660081863e-02, 6.657540202140808105e-01, -3.346028923988342285e-02, -4.470352232456207275e-01, 8.006748557090759277e-01, -4.799281060695648193e-01, 1.747820973396301270e-01, -3.056386411190032959e-01, 5.536521077156066895e-01, 4.238092899322509766e-01, 4.867432117462158203e-01, -4.967796802520751953e-01, -4.519478082656860352e-01, -9.556307792663574219e-01, -2.070993185043334961e-01, -2.260573953390121460e-01, -9.991831146180629730e-03, 9.879770278930664062e-01, 5.880774259567260742e-01, 8.305435627698898315e-02, -5.578136444091796875e-01, 2.113683819770812988e-01, -3.607222735881805420e-01, 5.266848206520080566e-01, 3.398357927799224854e-01, -1.575620472431182861e-01, 4.237726330757141113e-03, -5.354529991745948792e-02, -5.777673125267028809e-01, 5.595101118087768555e-01, -5.747127532958984375e-02, 1.683770418167114258e-01, 3.794685304164886475e-01, -2.577640712261199951e-01, 8.421474695205688477e-02, -1.522992998361587524e-01, -3.280776739120483398e-02, 1.008386313915252686e-01, -4.185833036899566650e-01, -4.449901580810546875e-01, -2.930993139743804932e-01, 6.144204735755920410e-01, 8.548156172037124634e-02, -6.349527090787887573e-02, -6.152555346488952637e-01, 7.954411506652832031e-01, -2.405838221311569214e-01, 2.063887566328048706e-01, -5.125258564949035645e-01, 6.312013268470764160e-01, 3.674431145191192627e-01, -4.400988817214965820e-01, 4.691397249698638916e-01, 2.308772653341293335e-01, -1.373799592256546021e-01, 2.169690132141113281e-01, 4.004325866699218750e-01, -2.490628510713577271e-02, -1.139675617218017578e+00, 2.653856761753559113e-02, -3.273023366928100586e-01, 9.984182566404342651e-02, 5.725682899355888367e-02, -8.472219109535217285e-01, 6.451983749866485596e-02, 4.569802582263946533e-01, 6.356298327445983887e-01, 4.518562257289886475e-01, -2.751905620098114014e-01, 2.134616822004318237e-01, 1.737425774335861206e-01, 4.282205998897552490e-01, -6.584535241127014160e-01, 4.000254571437835693e-01, -2.035577781498432159e-02, -6.730788350105285645e-01, -1.026923418045043945e+00, 1.687724739313125610e-01, -9.248701483011245728e-02, -7.997762560844421387e-01, 3.809339106082916260e-01, 5.171234011650085449e-01, 4.200939834117889404e-02, -4.867553710937500000e-02, -1.877224296331405640e-01, 1.633953303098678589e-01, -2.197492122650146484e-01, 2.193927615880966187e-01, 3.676529601216316223e-02, -2.975027859210968018e-01, -3.740966320037841797e-01, -5.209508538246154785e-01, -4.131466150283813477e-01, -4.894771575927734375e-01, -8.189660906791687012e-01, 8.531483262777328491e-02, 3.457696437835693359e-01, 1.250599324703216553e-01, 2.494525909423828125e-01, -2.525470256805419922e-01, -3.156108409166336060e-02, 2.757310569286346436e-01, -6.085720658302307129e-01, 3.357000648975372314e-01, 2.291317582130432129e-01, 6.607081294059753418e-01, -3.021580278873443604e-01, -5.315314233303070068e-02, 2.224752455949783325e-01, 6.138700246810913086e-02, 3.355513811111450195e-01, -8.485180139541625977e-02, 8.764573931694030762e-02, 1.087202429771423340e-01, -4.038927853107452393e-01, -1.494978666305541992e-01, 1.945850253105163574e-01, -8.106063008308410645e-01, 7.973096370697021484e-01, -4.116255939006805420e-01, 1.364154089242219925e-02, 2.347293049097061157e-01, -9.732253104448318481e-02, -2.904408872127532959e-01, 3.843209147453308105e-02, -7.090485841035842896e-02, -1.740449666976928711e-01, -4.485938549041748047e-01, -3.186723887920379639e-01, 4.165604412555694580e-01, -5.431653931736946106e-02, 1.403615623712539673e-01, 1.055916428565979004e+00, 5.301813483238220215e-01] 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /sentence-encoding/tests/test_sentence_encoder.py: -------------------------------------------------------------------------------- 1 | from src.sentence_encoder import SentenceEncoder 2 | import numpy as np 3 | 4 | pipeline = SentenceEncoder() 5 | 6 | def test_response(requests, response): 7 | assert np.allclose(response['predictions'], pipeline(requests)['predictions'], atol=1e-3) 8 | -------------------------------------------------------------------------------- /text-generation/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .aws-sam 132 | *.pyc 133 | .vscode 134 | .DS_store 135 | **.bin 136 | **.ipynb_checkpoints -------------------------------------------------------------------------------- /text-generation/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM amazon/aws-lambda-python 2 | 3 | ARG MODEL_DIR=./models 4 | 5 | ENV TRANSFORMERS_CACHE=$MODEL_DIR 6 | ENV TRANSFORMERS_VERBOSITY=error 7 | 8 | RUN yum -y install gcc-c++ 9 | 10 | COPY requirements.txt requirements.txt 11 | RUN pip install torch==1.8+cpu -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir 12 | RUN pip install -r requirements.txt --no-cache-dir 13 | 14 | COPY ./ ./ 15 | 16 | # Run test cases and this saves the transformer model in the container 17 | RUN pip install pytest --no-cache-dir && pytest tests -s -vv 18 | 19 | RUN chmod -R 0777 $MODEL_DIR 20 | 21 | CMD [ "main.lambda_handler"] -------------------------------------------------------------------------------- /text-generation/README.MD: -------------------------------------------------------------------------------- 1 | ## Text Generation service 2 | 3 | Text Generation using Transformers on AWS Lambda. Check root readme for complete setup info. 4 | 5 | ## Request format 6 | 7 | ``` 8 | { 9 | "texts": ["India is", "AI will"], 10 | "model_name": "distilgpt2", (optional argument) 11 | "tokenizer_name": "distilgpt2", (optional argument) 12 | "max_len": 10, # maximum no. of token(words) to be genrated using the given context (optional argument) 13 | "num_return_sequences": 1 # no. of sequences(sentences) to be genrated using the given context (optional argument) 14 | } 15 | ``` 16 | 17 | ## Response format 18 | 19 | ``` 20 | { 21 | 'predictions': [ 22 | [{'generated_text': 'India is a great country for international investors. It also has the support of'}], 23 | [{'generated_text': 'AI will rule out that she cannot be allowed to wear a hijab but will'}] 24 | ] 25 | } 26 | ``` -------------------------------------------------------------------------------- /text-generation/main.py: -------------------------------------------------------------------------------- 1 | from sklearn import pipeline 2 | from src.text_generator import TextGenerator 3 | 4 | pipeline = TextGenerator() 5 | 6 | 7 | def lambda_handler(event, context): 8 | try: 9 | return pipeline(event) 10 | except Exception as e: 11 | raise 12 | -------------------------------------------------------------------------------- /text-generation/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.* 2 | tqdm==4.* 3 | scikit-learn==0.24.* -------------------------------------------------------------------------------- /text-generation/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/text-generation/src/__init__.py -------------------------------------------------------------------------------- /text-generation/src/config.py: -------------------------------------------------------------------------------- 1 | PREDICTION_TYPE = 'text_generation' 2 | 3 | DEFAULT_MODEL_NAME = "distilgpt2" 4 | DEFAULT_TOKENIZER_NAME = "distilgpt2" 5 | DEFAULT_MAX_LEN = 10 6 | DEFAULT_NUM_SEQ = 1 7 | # cache 8 | CACHE_MAXSIZE = 4 9 | -------------------------------------------------------------------------------- /text-generation/src/text_generator.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | from functools import lru_cache 5 | 6 | from transformers import (AutoConfig, AutoModelForCausalLM, 7 | AutoTokenizer, pipeline, set_seed) 8 | 9 | from src import config, utils 10 | 11 | set_seed(10) 12 | logger = utils.create_logger(project_name=config.PREDICTION_TYPE, level="INFO") 13 | 14 | class TextGenerator: 15 | def __init__(self): 16 | _ = self.get_text_generator(model_name=config.DEFAULT_MODEL_NAME, tokenizer_name=config.DEFAULT_TOKENIZER_NAME) #warm up 17 | 18 | @staticmethod 19 | @lru_cache(maxsize=config.CACHE_MAXSIZE) 20 | def get_text_generator(model_name: str, tokenizer_name: str) -> pipeline: 21 | """text generation pipeline for the given model and tokenizer 22 | 23 | Args: 24 | model_name (str): Indicating the name of the model 25 | tokenizer_name (str): Indicating the name of the tokenizer 26 | 27 | Returns: 28 | pipeline: text generation pipeline 29 | """ 30 | logger.info(f"Loading model: {model_name}") 31 | 32 | model_config = AutoConfig.from_pretrained(model_name) 33 | model = AutoModelForCausalLM.from_pretrained( 34 | model_name, config=model_config 35 | ) 36 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 37 | 38 | set_seed(10) # for reproducibility of result 39 | 40 | text_generator = pipeline( 41 | "text-generation", model=model, tokenizer=tokenizer 42 | ) 43 | return text_generator 44 | 45 | def get_clean_text(self, text: str) -> str: 46 | """Clean the text 47 | 48 | Args: 49 | text (str): text 50 | 51 | Returns: 52 | str: clean text 53 | """ 54 | return text.strip() 55 | 56 | def __call__(self, request: dict)-> dict: 57 | """ text generation of the given sentences 58 | 59 | Args: 60 | request (dict): request containing the list of snetence for text generation 61 | 62 | Returns: 63 | dict: classes of the given text 64 | """ 65 | texts = [self.get_clean_text(text) for text in request["texts"]] 66 | 67 | model_name = request.get("model_name", config.DEFAULT_MODEL_NAME) 68 | tokenizer_name = request.get("tokenizer_name", config.DEFAULT_TOKENIZER_NAME) 69 | 70 | logger.info(f"Generating text for {len(texts)} sentences") 71 | 72 | text_generator = self.get_text_generator(model_name, tokenizer_name) 73 | 74 | max_len = request.get("max_len", config.DEFAULT_MAX_LEN) 75 | num_seq = request.get("num_return_sequences", config.DEFAULT_NUM_SEQ) 76 | generated_text = text_generator(texts, max_length=max_len, num_return_sequences=num_seq) 77 | 78 | return { 79 | "predictions": generated_text 80 | } 81 | 82 | 83 | -------------------------------------------------------------------------------- /text-generation/src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import ntpath 4 | import os 5 | from typing import Optional 6 | 7 | 8 | def create_folder(directory): 9 | if not os.path.exists(directory): 10 | os.makedirs(directory) 11 | print("Directory created: " + directory) 12 | else: 13 | print("Directory exists: " + directory) 14 | 15 | 16 | def create_logger( 17 | project_name: str, 18 | level: str = "INFO", 19 | log_dir: str = "/tmp/logs", 20 | file_name: Optional[str] = None, 21 | do_print: bool = True, 22 | simple_logging: bool = False, 23 | log_to_file: bool = False, 24 | rich_logging: bool = False, 25 | time_zone: Optional[str] = None, 26 | ): 27 | """Creates a logger of given level and saves logs to a file 28 | 29 | :param project_name: project name for which we are logging 30 | :param level: logging level 31 | LEVELS available 32 | DEBUG: Detailed information, typically of interest only when diagnosing problems. 33 | INFO: Confirmation that things are working as expected. 34 | WARNING: An indication that something unexpected happened, or indicative of some problem in the near future (e.g. 'disk space low'). The software is still working as expected. 35 | ERROR: Due to a more serious problem, the software has not been able to perform some function. 36 | CRITICAL: A serious error, indicating that the program itself may be unable to continue running. 37 | :param log_dir: directory when log files are created 38 | :param file_name: name of the log file 39 | :param do_print: whether to print the logs 40 | :param simple_logging: sets formatter to only message 41 | :param log_to_file: whether to save logs on disk 42 | :param rich_logging: colorful logging using rich 43 | :param time_zone: timezone to be used for time in logging such as Asia/Kolkata 44 | https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568 45 | """ 46 | import __main__ 47 | 48 | if file_name is None: 49 | try: 50 | file_name = ntpath.basename(__main__.__file__).split(".")[0] 51 | except: 52 | file_name = "logs" 53 | 54 | logger = logging.getLogger(file_name) 55 | logger.handlers.clear() 56 | logger.setLevel(getattr(logging, level)) 57 | 58 | if time_zone: 59 | from pytz import timezone, utc 60 | def time_formatter(*args): 61 | # TODO: Doesnt work with rich formatter 62 | utc_dt = utc.localize(datetime.datetime.utcnow()) 63 | my_tz = timezone(time_zone) 64 | converted = utc_dt.astimezone(my_tz) 65 | return converted.timetuple() 66 | 67 | logging.Formatter.converter = time_formatter 68 | 69 | if rich_logging: 70 | from rich.logging import RichHandler 71 | stream_format = f"{project_name}:%(module)s:%(funcName)s: %(message)s" 72 | stream_handler = RichHandler(omit_repeated_times=False) 73 | else: 74 | stream_format = f"%(asctime)s:%(levelname)s:{project_name}:%(module)s:%(funcName)s: %(message)s" 75 | stream_handler = logging.StreamHandler() 76 | 77 | file_formatter = stream_formatter = logging.Formatter( 78 | stream_format, "%Y-%m-%d %H:%M:%S" 79 | ) 80 | 81 | if simple_logging: 82 | file_formatter = logging.Formatter("%(message)s") 83 | stream_formatter = logging.Formatter("%(message)s") 84 | 85 | if log_to_file: 86 | date = datetime.date.today() 87 | date = "%s-%s-%s" % (date.day, date.month, date.year) 88 | log_file_path = os.path.join(log_dir, "%s-%s.log" % (file_name, date)) 89 | 90 | create_folder(log_dir) 91 | file_handler = logging.FileHandler(log_file_path) 92 | file_handler.setFormatter(file_formatter) 93 | logger.addHandler(file_handler) 94 | 95 | if do_print: 96 | stream_handler.setFormatter(stream_formatter) 97 | logger.addHandler(stream_handler) 98 | 99 | logger.propagate = False 100 | 101 | return logger 102 | -------------------------------------------------------------------------------- /text-generation/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/text-generation/tests/__init__.py -------------------------------------------------------------------------------- /text-generation/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src import config 3 | 4 | 5 | @pytest.fixture 6 | def requests(): 7 | return { 8 | "texts": ["India is a great", "AI will rule"], 9 | "model_name": "distilgpt2", # optional 10 | "tokenizer_name": "distilgpt2", # optional 11 | # maximum no. of token(words) to be genrated using the given context optional 12 | "max_len": 15, 13 | # no. of sequences(sentences) to be genrated using the given context optional 14 | "num_return_sequences": 1 15 | } 16 | 17 | 18 | @pytest.fixture 19 | def response(): 20 | return { 21 | 'predictions': [ 22 | [{'generated_text': 'India is a great country for international investors. It also has the support of'}], 23 | [{'generated_text': 'AI will rule out that she cannot be allowed to wear a hijab but will'}] 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /text-generation/tests/test_text_generator.py: -------------------------------------------------------------------------------- 1 | from src.text_generator import TextGenerator 2 | 3 | pipeline = TextGenerator() 4 | 5 | def test_response(requests, response): 6 | assert response == pipeline(requests) 7 | -------------------------------------------------------------------------------- /token-classification/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .aws-sam 132 | *.pyc 133 | .vscode 134 | .DS_store 135 | **.bin 136 | **.ipynb_checkpoints -------------------------------------------------------------------------------- /token-classification/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM amazon/aws-lambda-python 2 | 3 | ARG MODEL_DIR=./models 4 | 5 | ENV TRANSFORMERS_CACHE=$MODEL_DIR 6 | ENV TRANSFORMERS_VERBOSITY=error 7 | 8 | RUN yum -y install gcc-c++ 9 | 10 | COPY requirements.txt requirements.txt 11 | RUN pip install torch==1.8+cpu -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir 12 | RUN pip install -r requirements.txt --no-cache-dir 13 | 14 | COPY ./ ./ 15 | 16 | # Run test cases and this saves the transformer model in the container 17 | RUN pip install pytest --no-cache-dir && pytest tests -s -vv 18 | 19 | RUN chmod -R 0777 $MODEL_DIR 20 | 21 | CMD [ "main.lambda_handler"] -------------------------------------------------------------------------------- /token-classification/README.MD: -------------------------------------------------------------------------------- 1 | ## Token Classification service 2 | 3 | Token Classification using Transformers on AWS Lambda. Check root readme for complete setup info. 4 | 5 | ## Request format 6 | 7 | ``` 8 | { 9 | "texts": ["Mark is going back to Germany from South Africa", "John Adams is performing live in Venezuela"], 10 | "model_name": "dslim/bert-base-NER", # optional 11 | "tokenizer_name": "dslim/bert-base-NER" # optional 12 | } 13 | ``` 14 | 15 | ## Response format 16 | 17 | ``` 18 | { 19 | 'predictions':[ 20 | [ 21 | {'word': 'Mark', 'score': 1.0, 'entity': 'B-PER', 'index': 1, 'start': 0, 'end': 4 22 | }, 23 | {'word': 'Germany', 'score': 1.0, 'entity': 'B-LOC', 'index': 6, 'start': 22, 'end': 29 24 | }, 25 | {'word': 'South', 'score': 1.0, 'entity': 'B-LOC', 'index': 8, 'start': 35, 'end': 40 26 | }, 27 | {'word': 'Africa', 'score': 1.0, 'entity': 'I-LOC', 'index': 9, 'start': 41, 'end': 47 28 | } 29 | ], 30 | [ 31 | {'word': 'John', 'score': 1.0, 'entity': 'B-PER', 'index': 1, 'start': 0, 'end': 4 32 | }, 33 | {'word': 'Adams', 'score': 1.0, 'entity': 'I-PER', 'index': 2, 'start': 5, 'end': 10 34 | }, 35 | {'word': 'Venezuela', 'score': 1.0, 'entity': 'B-LOC', 'index': 7, 'start': 33, 'end': 42 36 | } 37 | ] 38 | ] 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /token-classification/main.py: -------------------------------------------------------------------------------- 1 | from sklearn import pipeline 2 | from src.token_classifier import TokenClassifier 3 | 4 | pipeline = TokenClassifier() 5 | 6 | 7 | def lambda_handler(event, context): 8 | try: 9 | return pipeline(event) 10 | except Exception as e: 11 | raise 12 | -------------------------------------------------------------------------------- /token-classification/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.* 2 | tqdm==4.* 3 | scikit-learn==0.24.* -------------------------------------------------------------------------------- /token-classification/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/token-classification/src/__init__.py -------------------------------------------------------------------------------- /token-classification/src/config.py: -------------------------------------------------------------------------------- 1 | PREDICTION_TYPE = 'token_classification' 2 | 3 | DEFAULT_MODEL_NAME = "dslim/bert-base-NER" 4 | DEFAULT_TOKENIZER_NAME = "dslim/bert-base-NER" 5 | ID_TAG_MAPPING = { # add for all models to be supported 6 | "dslim/bert-base-NER": { 7 | 0: "O", 8 | 1: "B-MISC", 9 | 2: "I-MISC", 10 | 3: "B-PER", 11 | 4: "I-PER", 12 | 5: "B-ORG", 13 | 6: "I-ORG", 14 | 7: "B-LOC", 15 | 8: "I-LOC" 16 | }, 17 | } 18 | # cache 19 | CACHE_MAXSIZE = 4 20 | -------------------------------------------------------------------------------- /token-classification/src/token_classifier.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | from functools import lru_cache 6 | 7 | from transformers import (AutoConfig, AutoModelForTokenClassification, 8 | AutoTokenizer, pipeline) 9 | 10 | from src import config, utils 11 | 12 | logger = utils.create_logger(project_name=config.PREDICTION_TYPE, level="INFO") 13 | 14 | 15 | class TokenClassifier: 16 | def __init__(self): 17 | _ = self.get_ner_pipeline(model_name=config.DEFAULT_MODEL_NAME, 18 | tokenizer_name=config.DEFAULT_TOKENIZER_NAME) # warm up 19 | 20 | @staticmethod 21 | @lru_cache(maxsize=config.CACHE_MAXSIZE) 22 | def get_ner_pipeline(model_name: str, tokenizer_name: str) -> pipeline: 23 | """NER pipeline for the given model and tokenizer 24 | 25 | Args: 26 | model_name (str): Indicating the name of the model 27 | tokenizer_name (str): Indicating the name of the tokenizer 28 | 29 | Returns: 30 | pipeline: ner pipeline 31 | """ 32 | logger.info(f"Loading model: {model_name}") 33 | id2label = config.ID_TAG_MAPPING[model_name] 34 | label2id = {label: idx for idx, label in id2label.items()} 35 | 36 | model_config = AutoConfig.from_pretrained(model_name) 37 | model_config.label2id = label2id 38 | model_config.id2label = id2label 39 | model = AutoModelForTokenClassification.from_pretrained( 40 | model_name, config=model_config 41 | ) 42 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 43 | ner_pipeline = pipeline( 44 | "ner", model=model, tokenizer=tokenizer 45 | ) 46 | return ner_pipeline 47 | 48 | def get_clean_text(self, text: str) -> str: 49 | """Clean the text 50 | 51 | Args: 52 | text (str): text 53 | 54 | Returns: 55 | str: clean text 56 | """ 57 | return text.strip() 58 | 59 | def __call__(self, request: dict) -> dict: 60 | """Predict tags of the given tokens 61 | 62 | Args: 63 | request (dict): request containing the list of text to predict entities 64 | 65 | Returns: 66 | dict: classes of the given text 67 | """ 68 | texts = [self.get_clean_text(text) for text in request["texts"]] 69 | model_name = request.get("model_name", config.DEFAULT_MODEL_NAME) 70 | tokenizer_name = request.get("tokenizer_name", config.DEFAULT_TOKENIZER_NAME) 71 | 72 | logger.info(f"Predicting tags for {len(texts)} texts") 73 | ner_pipeline = self.get_ner_pipeline(model_name, tokenizer_name) 74 | 75 | predictions = ner_pipeline(texts) 76 | 77 | for i, pred in enumerate(predictions): 78 | for dct in pred: 79 | dct["score"] = round(dct["score"], 2) 80 | predictions[i] = pred 81 | 82 | return { 83 | "predictions": predictions 84 | } 85 | -------------------------------------------------------------------------------- /token-classification/src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import ntpath 4 | import os 5 | from typing import Optional 6 | 7 | 8 | def create_folder(directory): 9 | if not os.path.exists(directory): 10 | os.makedirs(directory) 11 | print("Directory created: " + directory) 12 | else: 13 | print("Directory exists: " + directory) 14 | 15 | 16 | def create_logger( 17 | project_name: str, 18 | level: str = "INFO", 19 | log_dir: str = "/tmp/logs", 20 | file_name: Optional[str] = None, 21 | do_print: bool = True, 22 | simple_logging: bool = False, 23 | log_to_file: bool = False, 24 | rich_logging: bool = False, 25 | time_zone: Optional[str] = None, 26 | ): 27 | """Creates a logger of given level and saves logs to a file 28 | 29 | :param project_name: project name for which we are logging 30 | :param level: logging level 31 | LEVELS available 32 | DEBUG: Detailed information, typically of interest only when diagnosing problems. 33 | INFO: Confirmation that things are working as expected. 34 | WARNING: An indication that something unexpected happened, or indicative of some problem in the near future (e.g. 'disk space low'). The software is still working as expected. 35 | ERROR: Due to a more serious problem, the software has not been able to perform some function. 36 | CRITICAL: A serious error, indicating that the program itself may be unable to continue running. 37 | :param log_dir: directory when log files are created 38 | :param file_name: name of the log file 39 | :param do_print: whether to print the logs 40 | :param simple_logging: sets formatter to only message 41 | :param log_to_file: whether to save logs on disk 42 | :param rich_logging: colorful logging using rich 43 | :param time_zone: timezone to be used for time in logging such as Asia/Kolkata 44 | https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568 45 | """ 46 | import __main__ 47 | 48 | if file_name is None: 49 | try: 50 | file_name = ntpath.basename(__main__.__file__).split(".")[0] 51 | except: 52 | file_name = "logs" 53 | 54 | logger = logging.getLogger(file_name) 55 | logger.handlers.clear() 56 | logger.setLevel(getattr(logging, level)) 57 | 58 | if time_zone: 59 | from pytz import timezone, utc 60 | def time_formatter(*args): 61 | # TODO: Doesnt work with rich formatter 62 | utc_dt = utc.localize(datetime.datetime.utcnow()) 63 | my_tz = timezone(time_zone) 64 | converted = utc_dt.astimezone(my_tz) 65 | return converted.timetuple() 66 | 67 | logging.Formatter.converter = time_formatter 68 | 69 | if rich_logging: 70 | from rich.logging import RichHandler 71 | stream_format = f"{project_name}:%(module)s:%(funcName)s: %(message)s" 72 | stream_handler = RichHandler(omit_repeated_times=False) 73 | else: 74 | stream_format = f"%(asctime)s:%(levelname)s:{project_name}:%(module)s:%(funcName)s: %(message)s" 75 | stream_handler = logging.StreamHandler() 76 | 77 | file_formatter = stream_formatter = logging.Formatter( 78 | stream_format, "%Y-%m-%d %H:%M:%S" 79 | ) 80 | 81 | if simple_logging: 82 | file_formatter = logging.Formatter("%(message)s") 83 | stream_formatter = logging.Formatter("%(message)s") 84 | 85 | if log_to_file: 86 | date = datetime.date.today() 87 | date = "%s-%s-%s" % (date.day, date.month, date.year) 88 | log_file_path = os.path.join(log_dir, "%s-%s.log" % (file_name, date)) 89 | 90 | create_folder(log_dir) 91 | file_handler = logging.FileHandler(log_file_path) 92 | file_handler.setFormatter(file_formatter) 93 | logger.addHandler(file_handler) 94 | 95 | if do_print: 96 | stream_handler.setFormatter(stream_formatter) 97 | logger.addHandler(stream_handler) 98 | 99 | logger.propagate = False 100 | 101 | return logger 102 | -------------------------------------------------------------------------------- /token-classification/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/token-classification/tests/__init__.py -------------------------------------------------------------------------------- /token-classification/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src import config 3 | 4 | 5 | @pytest.fixture 6 | def requests(): 7 | return { 8 | "texts": ["Mark is going back to Germany from South Africa", "John Adams is performing live in Venezuela"], 9 | "model_name": config.DEFAULT_MODEL_NAME, 10 | "tokenizer_name": config.DEFAULT_TOKENIZER_NAME 11 | } 12 | 13 | @pytest.fixture 14 | def requests_default(): 15 | return { 16 | "texts": ["Mark is going back to Germany from South Africa", "John Adams is performing live in Venezuela"], 17 | } 18 | 19 | 20 | @pytest.fixture 21 | def response(): 22 | return { 23 | 'predictions': [ 24 | [ 25 | {'word': 'Mark', 'score': 1.0, 'entity': 'B-PER', 'index': 1, 'start': 0, 'end': 4 26 | }, 27 | {'word': 'Germany', 'score': 1.0, 'entity': 'B-LOC', 'index': 6, 'start': 22, 'end': 29 28 | }, 29 | {'word': 'South', 'score': 1.0, 'entity': 'B-LOC', 'index': 8, 'start': 35, 'end': 40 30 | }, 31 | {'word': 'Africa', 'score': 1.0, 'entity': 'I-LOC', 'index': 9, 'start': 41, 'end': 47 32 | } 33 | ], 34 | [ 35 | {'word': 'John', 'score': 1.0, 'entity': 'B-PER', 'index': 1, 'start': 0, 'end': 4 36 | }, 37 | {'word': 'Adams', 'score': 1.0, 'entity': 'I-PER', 'index': 2, 'start': 5, 'end': 10 38 | }, 39 | {'word': 'Venezuela', 'score': 1.0, 'entity': 'B-LOC', 'index': 7, 'start': 33, 'end': 42 40 | } 41 | ] 42 | ] 43 | } 44 | -------------------------------------------------------------------------------- /token-classification/tests/test_token_classifier.py: -------------------------------------------------------------------------------- 1 | from src.token_classifier import TokenClassifier 2 | 3 | pipeline = TokenClassifier() 4 | 5 | def test_response(requests, response): 6 | assert response == pipeline(requests) 7 | 8 | def test_response_default(requests_default, response): 9 | assert response == pipeline(requests_default) 10 | -------------------------------------------------------------------------------- /zero-shot-classification/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .aws-sam 132 | *.pyc 133 | .vscode 134 | .DS_store 135 | **.bin 136 | **.ipynb_checkpoints -------------------------------------------------------------------------------- /zero-shot-classification/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM amazon/aws-lambda-python 2 | 3 | ARG MODEL_DIR=./models 4 | 5 | ENV TRANSFORMERS_CACHE=$MODEL_DIR 6 | ENV TRANSFORMERS_VERBOSITY=error 7 | 8 | RUN yum -y install gcc-c++ 9 | 10 | COPY requirements.txt requirements.txt 11 | RUN pip install torch==1.8+cpu -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir 12 | RUN pip install -r requirements.txt --no-cache-dir 13 | 14 | COPY ./ ./ 15 | 16 | # Run test cases and this saves the transformer model in the container 17 | RUN pip install pytest --no-cache-dir && pytest tests -s -vv 18 | 19 | RUN chmod -R 0777 $MODEL_DIR 20 | 21 | CMD [ "main.lambda_handler"] -------------------------------------------------------------------------------- /zero-shot-classification/README.MD: -------------------------------------------------------------------------------- 1 | ## Zero-Shot Classification service 2 | Zero-Shot using Transformers on AWS Lambda. Check root readme for complete setup info. 3 | 4 | ## Request format for Multi-Class 5 | ``` 6 | { 7 | "texts": ["food was great", "food was bad", "i am going out for food"], 8 | "labels": ["negative", "postive", "neutral"], # optional 9 | "hypothesis": "The sentiment of the review is {}.", # optional 10 | "model_name": "typeform/mobilebert-uncased-mnli", # optional 11 | "tokenizer_name": "typeform/mobilebert-uncased-mnli", # optional 12 | "multi_label": False # optional 13 | } 14 | ``` 15 | 16 | ## Response format for Multi-Class 17 | ``` 18 | {'predictions':[ 19 | {'label': 'postive', 'score': 0.8}, 20 | {'label': 'negative', 'score': 0.87}, 21 | {'label': 'postive', 'score': 0.58} 22 | ] 23 | 24 | } 25 | ``` 26 | 27 | ## Request format for Multi-Label 28 | ``` 29 | { 30 | "texts": ["food was great", "food was bad", "i am going out for food"], 31 | "multi_label": True, 32 | } 33 | ``` 34 | 35 | ## Response format for Multi-Label 36 | ``` 37 | {'predictions': [ 38 | {'label': ['postive', 'neutral', 'negative'], 'score': [0.87, 0.33, 0.0]}, 39 | {'label': ['negative', 'neutral', 'postive'], 'score': [1.0, 0.85, 0.83]}, 40 | {'label': ['postive', 'negative', 'neutral'], 'score': [0.67, 0.34, 0.14]} 41 | ] 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /zero-shot-classification/main.py: -------------------------------------------------------------------------------- 1 | from src.classifier import Classifier 2 | 3 | pipeline = Classifier() 4 | 5 | 6 | def lambda_handler(event, context): 7 | try: 8 | return pipeline(event) 9 | except Exception as e: 10 | raise 11 | -------------------------------------------------------------------------------- /zero-shot-classification/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.* 2 | scikit-learn==0.24.* -------------------------------------------------------------------------------- /zero-shot-classification/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/zero-shot-classification/src/__init__.py -------------------------------------------------------------------------------- /zero-shot-classification/src/classifier.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import lru_cache 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | from transformers import pipeline 7 | 8 | from src import config, utils 9 | 10 | logger = utils.create_logger(project_name=config.PREDICTION_TYPE, level="INFO") 11 | 12 | class Classifier: 13 | def __init__(self): 14 | _ = self.get_zero_shot_classification_pipeline(model_name=config.DEFAULT_MODEL_NAME, tokenizer_name=config.DEFAULT_TOKENIZER_NAME) #warm up 15 | 16 | @staticmethod 17 | @lru_cache(maxsize=config.CACHE_MAXSIZE) 18 | def get_zero_shot_classification_pipeline(model_name: str, tokenizer_name: str) -> pipeline: 19 | """Zero Shot pipeline for the given model and tokenizer 20 | 21 | Args: 22 | model_name (str): Indicating the name of the model 23 | tokenizer_name (str): Indicating the name of the tokenizer 24 | 25 | Returns: 26 | pipeline: Zero-Shot Classification Pipeline 27 | """ 28 | logger.info(f"Loading model: {model_name}") 29 | classification_pipeline = pipeline("zero-shot-classification", model=model_name, tokenizer=tokenizer_name) 30 | return classification_pipeline 31 | 32 | def get_clean_text(self, text: str) -> str: 33 | """Clean the text 34 | 35 | Args: 36 | text (str): text 37 | 38 | Returns: 39 | str: clean text 40 | """ 41 | return text.strip().lower() 42 | 43 | def __call__(self, request: dict)-> dict: 44 | """Predict the sentiment of the given texts 45 | 46 | Args: 47 | request (dict): request containing the list of text to predict the sentiment 48 | 49 | Returns: 50 | dict: classes of the given text 51 | """ 52 | texts = [self.get_clean_text(text) for text in request["texts"]] 53 | 54 | labels = request.get("labels", config.DEFAULT_CANDIDATE_LABELS) 55 | hypothesis = request.get("hypothesis", config.DEFAULT_HYPOTHESIS_TEMPLATE) 56 | model_name = request.get("model_name", config.DEFAULT_MODEL_NAME) 57 | tokenizer_name = request.get("tokenizer_name", config.DEFAULT_TOKENIZER_NAME) 58 | multi_label = request.get("multi_label", config.DEFAULT_MULTI_LABEL) 59 | 60 | logger.info(f"Classifying {len(texts)} texts") 61 | classification_pipeline = self.get_zero_shot_classification_pipeline(model_name, tokenizer_name) 62 | 63 | predictions = classification_pipeline(texts, labels, hypothesis, multi_label=multi_label) 64 | 65 | if not multi_label: 66 | output = [] 67 | for i, pred in enumerate(predictions): 68 | output.append({"label": pred["labels"][0], "score": round(pred["scores"][0], 2)}) 69 | 70 | return {"predictions": output} 71 | 72 | else: 73 | output = [] 74 | for pred in predictions: 75 | output.append({"label": pred["labels"], "score": [round(score, 2) for score in pred["scores"]]}) 76 | 77 | return {"predictions": output} 78 | -------------------------------------------------------------------------------- /zero-shot-classification/src/config.py: -------------------------------------------------------------------------------- 1 | PREDICTION_TYPE = 'zero-shot-classification' 2 | 3 | DEFAULT_MODEL_NAME = "typeform/mobilebert-uncased-mnli" 4 | DEFAULT_TOKENIZER_NAME = "typeform/mobilebert-uncased-mnli" 5 | DEFAULT_HYPOTHESIS_TEMPLATE = "The sentiment of the review is {}." 6 | DEFAULT_CANDIDATE_LABELS = ["negative", "postive", "neutral"] 7 | DEFAULT_MULTI_LABEL = False 8 | 9 | # cache 10 | CACHE_MAXSIZE = 4 11 | -------------------------------------------------------------------------------- /zero-shot-classification/src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import ntpath 4 | import os 5 | from typing import Optional 6 | 7 | 8 | def create_folder(directory): 9 | if not os.path.exists(directory): 10 | os.makedirs(directory) 11 | print("Directory created: " + directory) 12 | else: 13 | print("Directory exists: " + directory) 14 | 15 | 16 | def create_logger( 17 | project_name: str, 18 | level: str = "INFO", 19 | log_dir: str = "/tmp/logs", 20 | file_name: Optional[str] = None, 21 | do_print: bool = True, 22 | simple_logging: bool = False, 23 | log_to_file: bool = False, 24 | rich_logging: bool = False, 25 | time_zone: Optional[str] = None, 26 | ): 27 | """Creates a logger of given level and saves logs to a file 28 | 29 | :param project_name: project name for which we are logging 30 | :param level: logging level 31 | LEVELS available 32 | DEBUG: Detailed information, typically of interest only when diagnosing problems. 33 | INFO: Confirmation that things are working as expected. 34 | WARNING: An indication that something unexpected happened, or indicative of some problem in the near future (e.g. 'disk space low'). The software is still working as expected. 35 | ERROR: Due to a more serious problem, the software has not been able to perform some function. 36 | CRITICAL: A serious error, indicating that the program itself may be unable to continue running. 37 | :param log_dir: directory when log files are created 38 | :param file_name: name of the log file 39 | :param do_print: whether to print the logs 40 | :param simple_logging: sets formatter to only message 41 | :param log_to_file: whether to save logs on disk 42 | :param rich_logging: colorful logging using rich 43 | :param time_zone: timezone to be used for time in logging such as Asia/Kolkata 44 | https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568 45 | """ 46 | import __main__ 47 | 48 | if file_name is None: 49 | try: 50 | file_name = ntpath.basename(__main__.__file__).split(".")[0] 51 | except: 52 | file_name = "logs" 53 | 54 | logger = logging.getLogger(file_name) 55 | logger.handlers.clear() 56 | logger.setLevel(getattr(logging, level)) 57 | 58 | if time_zone: 59 | from pytz import timezone, utc 60 | def time_formatter(*args): 61 | # TODO: Doesnt work with rich formatter 62 | utc_dt = utc.localize(datetime.datetime.utcnow()) 63 | my_tz = timezone(time_zone) 64 | converted = utc_dt.astimezone(my_tz) 65 | return converted.timetuple() 66 | 67 | logging.Formatter.converter = time_formatter 68 | 69 | if rich_logging: 70 | from rich.logging import RichHandler 71 | stream_format = f"{project_name}:%(module)s:%(funcName)s: %(message)s" 72 | stream_handler = RichHandler(omit_repeated_times=False) 73 | else: 74 | stream_format = f"%(asctime)s:%(levelname)s:{project_name}:%(module)s:%(funcName)s: %(message)s" 75 | stream_handler = logging.StreamHandler() 76 | 77 | file_formatter = stream_formatter = logging.Formatter( 78 | stream_format, "%Y-%m-%d %H:%M:%S" 79 | ) 80 | 81 | if simple_logging: 82 | file_formatter = logging.Formatter("%(message)s") 83 | stream_formatter = logging.Formatter("%(message)s") 84 | 85 | if log_to_file: 86 | date = datetime.date.today() 87 | date = "%s-%s-%s" % (date.day, date.month, date.year) 88 | log_file_path = os.path.join(log_dir, "%s-%s.log" % (file_name, date)) 89 | 90 | create_folder(log_dir) 91 | file_handler = logging.FileHandler(log_file_path) 92 | file_handler.setFormatter(file_formatter) 93 | logger.addHandler(file_handler) 94 | 95 | if do_print: 96 | stream_handler.setFormatter(stream_formatter) 97 | logger.addHandler(stream_handler) 98 | 99 | logger.propagate = False 100 | 101 | return logger 102 | -------------------------------------------------------------------------------- /zero-shot-classification/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhavsarpratik/serverless-transformers-on-aws-lambda/d48caab0e07ae8326d4b37ab730faf2cffd02f7d/zero-shot-classification/tests/__init__.py -------------------------------------------------------------------------------- /zero-shot-classification/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src import config 3 | 4 | 5 | @pytest.fixture 6 | def requests(): 7 | request_complete = { 8 | "texts": ["food was great", "food was bad", "i am going out for food"], 9 | "labels": config.DEFAULT_CANDIDATE_LABELS, 10 | "hypothesis": config.DEFAULT_HYPOTHESIS_TEMPLATE, 11 | "model_name": config.DEFAULT_MODEL_NAME, 12 | "tokenizer_name": config.DEFAULT_TOKENIZER_NAME, 13 | "multi_label": config.DEFAULT_MULTI_LABEL 14 | } 15 | request_default = { 16 | "texts": ["food was great", "food was bad", "i am going out for food"], 17 | } 18 | 19 | request_multi_label = { 20 | "texts": ["food was great", "food was bad", "i am going out for food"], 21 | "multi_label": True, 22 | } 23 | 24 | return (request_complete, request_default, request_multi_label) 25 | 26 | 27 | @pytest.fixture 28 | def response(): 29 | response_complete = {'predictions':[ 30 | {'label': 'postive', 'score': 0.8}, 31 | {'label': 'negative', 'score': 0.87}, 32 | {'label': 'postive', 'score': 0.58} 33 | ] 34 | } 35 | 36 | response_default = {'predictions':[ 37 | {'label': 'postive', 'score': 0.8}, 38 | {'label': 'negative', 'score': 0.87}, 39 | {'label': 'postive', 'score': 0.58} 40 | ] 41 | } 42 | 43 | response_multi_label = {'predictions': [ 44 | {'label': ['postive', 'neutral', 'negative'], 'score': [0.87, 0.33, 0.0]}, 45 | {'label': ['negative', 'neutral', 'postive'], 'score': [1.0, 0.85, 0.83]}, 46 | {'label': ['postive', 'negative', 'neutral'], 'score': [0.67, 0.34, 0.14]} 47 | ] 48 | } 49 | 50 | return (response_complete, response_default, response_multi_label) 51 | -------------------------------------------------------------------------------- /zero-shot-classification/tests/test_classifier.py: -------------------------------------------------------------------------------- 1 | from src.classifier import Classifier 2 | 3 | pipeline = Classifier() 4 | 5 | def test_complete_response(requests, response): 6 | assert response[0] == pipeline(requests[0]) 7 | 8 | def test_default_response(requests, response): 9 | assert response[1] == pipeline(requests[1]) 10 | 11 | def test_multi_label_response(requests, response): 12 | assert response[2] == pipeline(requests[2]) 13 | --------------------------------------------------------------------------------