├── .github └── workflows │ ├── integ-test.yml │ ├── quality.yml │ └── unit-test.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── makefile ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src └── sagemaker_huggingface_inference_toolkit │ ├── __init__.py │ ├── content_types.py │ ├── decoder_encoder.py │ ├── diffusers_utils.py │ ├── handler_service.py │ ├── mms_model_server.py │ ├── optimum_utils.py │ ├── serving.py │ └── transformers_utils.py └── tests ├── integ ├── __init__.py ├── config.py ├── test_diffusers.py ├── test_models_from_hub.py └── utils.py ├── resources ├── audio │ ├── sample1.flac │ ├── sample1.mp3 │ ├── sample1.ogg │ └── sample1.wav ├── image │ ├── tiger.bmp │ ├── tiger.gif │ ├── tiger.jpeg │ ├── tiger.png │ ├── tiger.tiff │ └── tiger.webp ├── model_input_predict_output_fn_with_context │ └── code │ │ └── inference.py ├── model_input_predict_output_fn_without_context │ └── code │ │ └── inference.py ├── model_transform_fn_with_context │ └── code │ │ └── inference_tranform_fn.py └── model_transform_fn_without_context │ └── code │ └── inference_tranform_fn.py └── unit ├── __init__.py ├── test_decoder_encoder.py ├── test_diffusers_utils.py ├── test_handler_service_with_context.py ├── test_handler_service_without_context.py ├── test_mms_model_server.py ├── test_optimum_utils.py ├── test_serving.py └── test_transformers_utils.py /.github/workflows/integ-test.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | # pull_request: 5 | workflow_dispatch: 6 | 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.8 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | - name: Install Python dependencies 18 | run: pip install -e .[test,dev] 19 | - name: Run Integration Tests 20 | env: 21 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 22 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 23 | AWS_DEFAULT_REGION: us-east-1 24 | run: make integ-test -------------------------------------------------------------------------------- /.github/workflows/quality.yml: -------------------------------------------------------------------------------- 1 | name: Quality Check 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | quality: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.8 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.8 14 | - name: Install Python dependencies 15 | run: pip install -e .[quality] 16 | - name: Run Quality check 17 | run: make quality -------------------------------------------------------------------------------- /.github/workflows/unit-test.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit-Tests 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.8 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.8 14 | - name: Install Python dependencies 15 | run: pip install -e .[test,dev] 16 | - name: Run Unit Tests 17 | run: make unit-test 18 | # - name: Run Integration Tests 19 | # run: make integ-test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Docker project generated files to ignore 2 | # if you want to ignore files created by your editor/tools, 3 | # please consider a global .gitignore https://help.github.com/articles/ignoring-files 4 | .vagrant* 5 | bin 6 | docker/docker 7 | .*.swp 8 | a.out 9 | *.orig 10 | build_src 11 | .flymake* 12 | .idea 13 | .DS_Store 14 | docs/_build 15 | docs/_static 16 | docs/_templates 17 | .gopath/ 18 | .dotcloud 19 | *.test 20 | bundles/ 21 | .hg/ 22 | .git/ 23 | vendor/pkg/ 24 | pyenv 25 | Vagrantfile 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | *$py.class 30 | 31 | # C extensions 32 | *.so 33 | 34 | # Distribution / packaging 35 | .Python 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | share/python-wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | *.py,cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | cover/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | db.sqlite3-journal 88 | 89 | # Flask stuff: 90 | instance/ 91 | .webassets-cache 92 | 93 | # Scrapy stuff: 94 | .scrapy 95 | 96 | # Sphinx documentation 97 | docs/_build/ 98 | 99 | # PyBuilder 100 | .pybuilder/ 101 | target/ 102 | 103 | # Jupyter Notebook 104 | .ipynb_checkpoints 105 | 106 | # IPython 107 | profile_default/ 108 | ipython_config.py 109 | 110 | # pyenv 111 | # For a library or package, you might want to ignore these files since the code is 112 | # intended to run in multiple environments; otherwise, check them in: 113 | # .python-version 114 | 115 | # pipenv 116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 119 | # install all needed dependencies. 120 | #Pipfile.lock 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | .vscode/settings.json 166 | .sagemaker 167 | model 168 | tests/tmp -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # how to contribute to huggingface inference toolkit? 2 | 3 | Everyone is welcome to contribute, and we value everybody's contribution. Code 4 | is thus not the only way to help the community. Answering questions, helping 5 | others, reaching out and improving the documentations are immensely valuable to 6 | the community. 7 | 8 | It also helps us if you spread the word: reference the library from blog posts 9 | on the awesome projects it made possible, shout out on Twitter every time it has 10 | helped you, or simply star the repo to say "thank you". 11 | 12 | Whichever way you choose to contribute, please be mindful to respect our 13 | [code of conduct](https://github.com/huggingface/transformers/blob/master/CODE_OF_CONDUCT.md). 14 | 15 | 16 | ## Reporting Bugs/Feature Requests 17 | 18 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 19 | 20 | When filing an issue, please check [existing open](https://github.com/aws/sagemaker_huggingface_inference_toolkit/issues), or [recently closed](https://github.com/aws/sagemaker_huggingface_inference_toolkit/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already 21 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 22 | 23 | * A reproducible test case or series of steps 24 | * The version of our code being used 25 | * Any modifications you've made relevant to the bug 26 | * Anything unusual about your environment or deployment 27 | 28 | 29 | ## Contributing via Pull Requests 30 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 31 | 32 | 1. You are working against the latest source on the *master* branch. 33 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 34 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 35 | 36 | To send us a pull request, please: 37 | 38 | 1. Fork the repository. 39 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 40 | 3. Ensure local tests pass. 41 | 4. Commit to your fork using clear commit messages. 42 | 5. Send us a pull request, answering any default questions in the pull request interface. 43 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 44 | 45 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 46 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 47 | 48 | 49 | ## Finding contributions to work on 50 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/huggingface/sagemaker_huggingface_inference_toolkit/labels/help%20wanted) issues is a great place to start. 51 | 52 | 53 | ## Code of Conduct 54 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 55 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 56 | opensource-codeofconduct@amazon.com with any additional questions or comments. 57 | 58 | 59 | ## Security issue notifications 60 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 61 | 62 | 63 | ## Licensing 64 | 65 | See the [LICENSE](https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/main/LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 66 | 67 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | 2 | include LICENSE 3 | include README.md 4 | 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |
5 | 6 | 7 | 8 | 9 | # SageMaker Hugging Face Inference Toolkit 10 | 11 | [![Latest Version](https://img.shields.io/pypi/v/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Code Style: Black](https://img.shields.io/badge/code_style-black-000000.svg)](https://github.com/python/black) 12 | 13 | 14 | SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers and Diffusers models and tasks. It utilizes the [SageMaker Inference Toolkit](https://github.com/aws/sagemaker-inference-toolkit) for starting up the model server, which is responsible for handling inference requests. 15 | 16 | For Training, see [Run training on Amazon SageMaker](https://huggingface.co/docs/sagemaker/train). 17 | 18 | For the Dockerfiles used for building SageMaker Hugging Face Containers, see [AWS Deep Learning Containers](https://github.com/aws/deep-learning-containers/tree/master/huggingface). 19 | 20 | For information on running Hugging Face jobs on Amazon SageMaker, please refer to the [🤗 Transformers documentation](https://huggingface.co/docs/sagemaker). 21 | 22 | For notebook examples: [SageMaker Notebook Examples](https://github.com/huggingface/notebooks/tree/master/sagemaker). 23 | 24 | --- 25 | ## 💻 Getting Started with 🤗 Inference Toolkit 26 | 27 | _needs to be adjusted -> currently pseudo code_ 28 | 29 | **Install Amazon SageMaker Python SDK** 30 | 31 | ```bash 32 | pip install sagemaker --upgrade 33 | ``` 34 | 35 | **Create a Amazon SageMaker endpoint with a trained model.** 36 | 37 | ```python 38 | from sagemaker.huggingface import HuggingFaceModel 39 | 40 | # create Hugging Face Model Class 41 | huggingface_model = HuggingFaceModel( 42 | transformers_version='4.6', 43 | pytorch_version='1.7', 44 | py_version='py36', 45 | model_data='s3://my-trained-model/artifacts/model.tar.gz', 46 | role=role, 47 | ) 48 | # deploy model to SageMaker Inference 49 | huggingface_model.deploy(initial_instance_count=1,instance_type="ml.m5.xlarge") 50 | ``` 51 | 52 | 53 | **Create a Amazon SageMaker endpoint with a model from the [🤗 Hub](https://huggingface.co/models).** 54 | _note: This is an experimental feature, where the model will be loaded after the endpoint is created. Not all sagemaker features are supported, e.g. MME_ 55 | ```python 56 | from sagemaker.huggingface import HuggingFaceModel 57 | # Hub Model configuration. https://huggingface.co/models 58 | hub = { 59 | 'HF_MODEL_ID':'distilbert-base-uncased-distilled-squad', 60 | 'HF_TASK':'question-answering' 61 | } 62 | # create Hugging Face Model Class 63 | huggingface_model = HuggingFaceModel( 64 | transformers_version='4.6', 65 | pytorch_version='1.7', 66 | py_version='py36', 67 | env=hub, 68 | role=role, 69 | ) 70 | # deploy model to SageMaker Inference 71 | huggingface_model.deploy(initial_instance_count=1,instance_type="ml.m5.xlarge") 72 | ``` 73 | 74 | --- 75 | 76 | ## 🛠️ Environment variables 77 | 78 | The SageMaker Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below. 79 | 80 | #### `HF_TASK` 81 | 82 | The `HF_TASK` environment variable defines the task for the used 🤗 Transformers pipeline. A full list of tasks can be find [here](https://huggingface.co/transformers/main_classes/pipelines.html). 83 | 84 | ```bash 85 | HF_TASK="question-answering" 86 | ``` 87 | 88 | #### `HF_MODEL_ID` 89 | 90 | The `HF_MODEL_ID` environment variable defines the model id, which will be automatically loaded from [huggingface.co/models](https://huggingface.co/models) when creating or SageMaker Endpoint. The 🤗 Hub provides +10 000 models all available through this environment variable. 91 | 92 | ```bash 93 | HF_MODEL_ID="distilbert-base-uncased-finetuned-sst-2-english" 94 | ``` 95 | 96 | #### `HF_MODEL_REVISION` 97 | 98 | The `HF_MODEL_REVISION` is an extension to `HF_MODEL_ID` and allows you to define/pin a revision of the model to make sure you always load the same model on your SageMaker Endpoint. 99 | 100 | ```bash 101 | HF_MODEL_REVISION="03b4d196c19d0a73c7e0322684e97db1ec397613" 102 | ``` 103 | 104 | #### `HF_API_TOKEN` 105 | 106 | The `HF_API_TOKEN` environment variable defines the your Hugging Face authorization token. The `HF_API_TOKEN` is used as a HTTP bearer authorization for remote files, like private models. You can find your token at your [settings page](https://huggingface.co/settings/token). 107 | 108 | ```bash 109 | HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX" 110 | ``` 111 | 112 | #### `HF_TRUST_REMOTE_CODE` 113 | 114 | The `HF_TRUST_REMOTE_CODE` environment variable defines wether or not to allow for custom models defined on the Hub in their own modeling files. Allowed values are `"True"` and `"False"` 115 | 116 | ```bash 117 | HF_TRUST_REMOTE_CODE="True" 118 | ``` 119 | 120 | #### `HF_OPTIMUM_BATCH_SIZE` 121 | 122 | The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted. 123 | 124 | ```bash 125 | HF_OPTIMUM_BATCH_SIZE="1" 126 | ``` 127 | 128 | #### `HF_OPTIMUM_SEQUENCE_LENGTH` 129 | 130 | The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted. 131 | 132 | ```bash 133 | HF_OPTIMUM_SEQUENCE_LENGTH="128" 134 | ``` 135 | 136 | --- 137 | 138 | ## 🧑🏻‍💻 User defined code/modules 139 | 140 | The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefore, they need to create a folder named `code/` with an `inference.py` file in it. You can find an example for it in [sagemaker/17_customer_inference_script](https://github.com/huggingface/notebooks/blob/master/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb). 141 | For example: 142 | ```bash 143 | model.tar.gz/ 144 | |- pytorch_model.bin 145 | |- .... 146 | |- code/ 147 | |- inference.py 148 | |- requirements.txt 149 | ``` 150 | In this example, `pytorch_model.bin` is the model file saved from training, `inference.py` is the custom inference module, and `requirements.txt` is a requirements file to add additional dependencies. 151 | The custom module can override the following methods: 152 | 153 | * `model_fn(model_dir, context=None)`: overrides the default method for loading the model, the return value `model` will be used in the `predict()` for predicitions. It receives argument the `model_dir`, the path to your unzipped `model.tar.gz`. 154 | * `transform_fn(model, data, content_type, accept_type)`: overrides the default transform function with a custom implementation. Customers using this would have to implement `preprocess`, `predict` and `postprocess` steps in the `transform_fn`. **NOTE: This method can't be combined with `input_fn`, `predict_fn` or `output_fn` mentioned below.** 155 | * `input_fn(input_data, content_type)`: overrides the default method for preprocessing, the return value `data` will be used in the `predict()` method for predicitions. The input is `input_data`, the raw body of your request and `content_type`, the content type form the request Header. 156 | * `predict_fn(processed_data, model)`: overrides the default method for predictions, the return value `predictions` will be used in the `postprocess()` method. The input is `processed_data`, the result of the `preprocess()` method. 157 | * `output_fn(prediction, accept)`: overrides the default method for postprocessing, the return value `result` will be the respond of your request(e.g.`JSON`). The inputs are `predictions`, the result of the `predict()` method and `accept` the return accept type from the HTTP Request, e.g. `application/json` 158 | 159 | 160 | ## 🏎️ Deploy Models on AWS Inferentia2 161 | 162 | The SageMaker Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options: 163 | * Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format. e.g. `optimum/bge-base-en-v1.5-neuronx` 164 | * Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128` 165 | * Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}` 166 | 167 | The currently supported tasks can be found [here](https://huggingface.co/docs/optimum-neuron/en/package_reference/supported_models). If you plan to deploy an LLM, we recommend taking a look at [Neuronx TGI](https://huggingface.co/blog/text-generation-inference-on-inferentia2), which is purposly build for LLMs 168 | 169 | --- 170 | ## 🤝 Contributing 171 | 172 | Please read [CONTRIBUTING.md](https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/main/CONTRIBUTING.md) 173 | for details on our code of conduct, and the process for submitting pull 174 | requests to us. 175 | 176 | --- 177 | ## 📜 License 178 | 179 | SageMaker Hugging Face Inference Toolkit is licensed under the Apache 2.0 License. 180 | 181 | --- 182 | 183 | ## 🧑🏻‍💻 Development Environment 184 | 185 | Install all test and development packages with 186 | 187 | ```bash 188 | pip3 install -e ".[test,dev]" 189 | ``` 190 | ## Run Model Locally 191 | 192 | 1. manually change `MMS_CONFIG_FILE` 193 | ``` 194 | wget -O sagemaker-mms.properties https://raw.githubusercontent.com/aws/deep-learning-containers/master/huggingface/build_artifacts/inference/config.properties 195 | ``` 196 | 197 | 2. Run Container, e.g. `text-to-image` 198 | ``` 199 | HF_MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0" HF_TASK="text-to-image" python src/sagemaker_huggingface_inference_toolkit/serving.py 200 | ``` 201 | 3. Adjust `handler_service.py` and comment out `if content_type in content_types.UTF8_TYPES:` thats needed for SageMaker but cannot be used locally 202 | 203 | 3. Send request 204 | ``` 205 | curl --request POST \ 206 | --url http://localhost:8080/invocations \ 207 | --header 'Accept: image/png' \ 208 | --header 'Content-Type: application/json' \ 209 | --data '"{\"inputs\": \"Camera\"}" \ 210 | --output image.png 211 | ``` 212 | 213 | 214 | ## Run Inferentia2 Model Locally 215 | 216 | _Note: You need to run this on an Inferentia2 instance._ 217 | 218 | 1. manually change `MMS_CONFIG_FILE` 219 | ``` 220 | wget -O sagemaker-mms.properties https://raw.githubusercontent.com/aws/deep-learning-containers/master/huggingface/build_artifacts/inference/config.properties 221 | ``` 222 | 223 | 2. Adjust `handler_service.py` and comment out `if content_type in content_types.UTF8_TYPES:` thats needed for SageMaker but cannot be used locally 224 | 225 | 2. Run Container, 226 | 227 | - transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` 228 | ``` 229 | HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 python src/sagemaker_huggingface_inference_toolkit/serving.py 230 | ``` 231 | - sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` 232 | ``` 233 | HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 python src/sagemaker_huggingface_inference_toolkit/serving.py 234 | ``` 235 | 236 | 3. Send request 237 | ``` 238 | curl --request POST \ 239 | --url http://localhost:8080/invocations \ 240 | --header 'Content-Type: application/json' \ 241 | --data "{\"inputs\": \"I like you.\"}" 242 | ``` 243 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style unit-test integ-test 2 | 3 | check_dirs := src tests 4 | 5 | # run tests 6 | 7 | unit-test: 8 | python -m pytest -v -s ./tests/unit/ 9 | 10 | integ-test: 11 | python -m pytest -n 2 -s -v ./tests/integ/ 12 | # python -m pytest -n auto -s -v ./tests/integ/ 13 | 14 | 15 | # Check that source code meets quality standards 16 | 17 | quality: 18 | black --check --line-length 119 --target-version py36 $(check_dirs) 19 | isort --check-only $(check_dirs) 20 | flake8 $(check_dirs) 21 | 22 | # Format source code automatically 23 | 24 | style: 25 | # black --line-length 119 --target-version py36 tests src benchmarks datasets metrics 26 | black --line-length 119 --target-version py36 $(check_dirs) 27 | isort $(check_dirs) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py36'] 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = sagemaker_huggingface_inference_toolkit 7 | known_third_party = 8 | transformers 9 | sagemaker_inference 10 | huggingface_hub 11 | datasets 12 | pytest 13 | sklearn 14 | tensorflow 15 | torch 16 | retrying 17 | numpy 18 | 19 | 20 | line_length = 119 21 | lines_after_imports = 2 22 | multi_line_output = 3 23 | use_parentheses = True 24 | 25 | [flake8] 26 | ignore = E203, E501, E741, W503, W605 27 | max-line-length = 119 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Build release wheels as follows 16 | # $ SM_HF_TOOLKIT_RELEASE=1 python setup.py bdist_wheel build 17 | # $ twine upload --repository-url https://test.pypi.org/legacy/ dist/* # upload to test.pypi 18 | # Test the wheel by downloading from test.pypi 19 | # $ pip install -i https://test.pypi.org/simple/ sagemaker-huggingface-inference-toolkit== 20 | # Once test is complete 21 | # Upload the wheel to pypi 22 | # $ twine upload dist/* 23 | 24 | 25 | from __future__ import absolute_import 26 | import os 27 | from datetime import date 28 | from setuptools import find_packages, setup 29 | 30 | # We don't declare our dependency on transformers here because we build with 31 | # different packages for different variants 32 | 33 | VERSION = "2.6.0" 34 | 35 | 36 | # Ubuntu packages 37 | # libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev 38 | # ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg 39 | # libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg 40 | 41 | install_requires = [ 42 | "sagemaker-inference>=1.8.0", 43 | "huggingface_hub>=0.0.8", 44 | "retrying", 45 | "numpy", 46 | # vision 47 | "Pillow", 48 | # speech + torchaudio 49 | "librosa", 50 | "pyctcdecode>=0.3.0", 51 | "phonemizer", 52 | ] 53 | 54 | extras = {} 55 | 56 | # Hugging Face specific dependencies 57 | extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"] 58 | extras["diffusers"] = ["diffusers>=0.23.0"] 59 | 60 | # framework specific dependencies 61 | extras["torch"] = ["torch>=1.8.0", "torchaudio"] 62 | 63 | # TODO: Remove upper bound of TF 2.11 once transformers release contains this fix: https://github.com/huggingface/evaluate/pull/372 64 | extras["tensorflow"] = ["tensorflow>=2.4.0,<2.11"] 65 | 66 | # MMS Server dependencies 67 | extras["mms"] = ["multi-model-server>=1.1.4", "retrying"] 68 | 69 | 70 | extras["test"] = [ 71 | "pytest<8", 72 | "pytest-xdist", 73 | "parameterized", 74 | "psutil", 75 | "datasets", 76 | "pytest-sugar", 77 | "black==21.4b0", 78 | "sagemaker", 79 | "boto3", 80 | "mock==2.0.0", 81 | ] 82 | 83 | extras["benchmark"] = ["boto3", "locust"] 84 | 85 | extras["quality"] = [ 86 | "black>=21.10", 87 | "isort>=5.5.4", 88 | "flake8>=3.8.3", 89 | ] 90 | 91 | extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"] + extras["diffusers"] 92 | setup( 93 | name="sagemaker-huggingface-inference-toolkit", 94 | version=VERSION, 95 | # if os.getenv("SM_HF_TOOLKIT_RELEASE") is not None 96 | # else VERSION + "b" + str(date.today()).replace("-", ""), 97 | author="HuggingFace and Amazon Web Services", 98 | description="Open source library for running inference workload with Hugging Face Deep Learning Containers on " 99 | "Amazon SageMaker.", 100 | long_description=open("README.md", "r", encoding="utf-8").read(), 101 | long_description_content_type="text/markdown", 102 | keywords="NLP deep-learning transformer pytorch tensorflow BERT GPT GPT-2 AWS Amazon SageMaker Cloud", 103 | url="https://github.com/aws/sagemaker-huggingface-inference-toolkit", 104 | package_dir={"": "src"}, 105 | packages=find_packages(where="src"), 106 | install_requires=install_requires, 107 | extras_require=extras, 108 | entry_points={"console_scripts": "serve=sagemaker_huggingface_inference_toolkit.serving:main"}, 109 | python_requires=">=3.6.0", 110 | license="Apache License 2.0", 111 | classifiers=[ 112 | "Development Status :: 5 - Production/Stable", 113 | "Intended Audience :: Developers", 114 | "Intended Audience :: Education", 115 | "Intended Audience :: Science/Research", 116 | "License :: OSI Approved :: Apache Software License", 117 | "Operating System :: OS Independent", 118 | "Programming Language :: Python :: 3", 119 | "Programming Language :: Python :: 3.6", 120 | "Programming Language :: Python :: 3.7", 121 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 122 | ], 123 | ) 124 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import 15 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/content_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module contains constants that define MIME content types.""" 15 | # Default Mime-Types 16 | JSON = "application/json" 17 | CSV = "text/csv" 18 | OCTET_STREAM = "application/octet-stream" 19 | ANY = "*/*" 20 | NPY = "application/x-npy" 21 | UTF8_TYPES = [JSON, CSV] 22 | # Vision Mime-Types 23 | JPEG = "image/jpeg" 24 | PNG = "image/png" 25 | TIFF = "image/tiff" 26 | BMP = "image/bmp" 27 | GIF = "image/gif" 28 | WEBP = "image/webp" 29 | X_IMAGE = "image/x-image" 30 | VISION_TYPES = [JPEG, PNG, TIFF, BMP, GIF, WEBP, X_IMAGE] 31 | # Speech Mime-Types 32 | FLAC = "audio/x-flac" 33 | MP3 = "audio/mpeg" 34 | WAV = "audio/wave" 35 | OGG = "audio/ogg" 36 | X_AUDIO = "audio/x-audio" 37 | AUDIO_TYPES = [FLAC, MP3, WAV, OGG, X_AUDIO] 38 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import base64 15 | import csv 16 | import datetime 17 | import json 18 | from io import BytesIO, StringIO 19 | 20 | import numpy as np 21 | from sagemaker_inference import errors 22 | from sagemaker_inference.decoder import _npy_to_numpy 23 | 24 | from mms.service import PredictionException 25 | from PIL import Image 26 | from sagemaker_huggingface_inference_toolkit import content_types 27 | 28 | 29 | def decode_json(content): 30 | return json.loads(content) 31 | 32 | 33 | def decode_csv(string_like): # type: (str) -> np.array 34 | """Convert a CSV object to a dictonary with list attributes. 35 | 36 | Args: 37 | string_like (str): CSV string. 38 | Returns: 39 | (dict): dictonatry for input 40 | """ 41 | stream = StringIO(string_like) 42 | # detects if the incoming csv has headers 43 | if not any(header in string_like.splitlines()[0].lower() for header in ["question", "context", "inputs"]): 44 | raise PredictionException( 45 | "You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.", 46 | 400, 47 | ) 48 | # reads csv as io 49 | request_list = list(csv.DictReader(stream)) 50 | if "inputs" in request_list[0].keys(): 51 | return {"inputs": [entry["inputs"] for entry in request_list]} 52 | else: 53 | return {"inputs": request_list} 54 | 55 | 56 | def decode_image(bpayload: bytearray): 57 | """Convert a .jpeg / .png / .tiff... object to a proper inputs dict. 58 | Args: 59 | bpayload (bytes): byte stream. 60 | Returns: 61 | (dict): dictonatry for input 62 | """ 63 | image = Image.open(BytesIO(bpayload)).convert("RGB") 64 | return {"inputs": image} 65 | 66 | 67 | def decode_audio(bpayload: bytearray): 68 | """Convert a .wav / .flac / .mp3 object to a proper inputs dict. 69 | Args: 70 | bpayload (bytes): byte stream. 71 | Returns: 72 | (dict): dictonatry for input 73 | """ 74 | 75 | return {"inputs": bytes(bpayload)} 76 | 77 | 78 | # https://github.com/automl/SMAC3/issues/453 79 | class _JSONEncoder(json.JSONEncoder): 80 | """ 81 | custom `JSONEncoder` to make sure float and int64 ar converted 82 | """ 83 | 84 | def default(self, obj): 85 | if isinstance(obj, np.integer): 86 | return int(obj) 87 | elif isinstance(obj, np.floating): 88 | return float(obj) 89 | elif hasattr(obj, "tolist"): 90 | return obj.tolist() 91 | elif isinstance(obj, datetime.datetime): 92 | return obj.__str__() 93 | elif isinstance(obj, Image.Image): 94 | with BytesIO() as out: 95 | obj.save(out, format="PNG") 96 | png_string = out.getvalue() 97 | return base64.b64encode(png_string).decode("utf-8") 98 | else: 99 | return super(_JSONEncoder, self).default(obj) 100 | 101 | 102 | def encode_json(content, accept_type=None): 103 | """ 104 | encodes json with custom `JSONEncoder` 105 | """ 106 | return json.dumps( 107 | content, 108 | ensure_ascii=False, 109 | allow_nan=False, 110 | indent=None, 111 | cls=_JSONEncoder, 112 | separators=(",", ":"), 113 | ) 114 | 115 | 116 | def _array_to_npy(array_like, accept_type=None): 117 | """Convert an array-like object to the NPY format. 118 | 119 | To understand better what an array-like object is see: 120 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays 121 | 122 | Args: 123 | array_like (np.array or Iterable or int or float): array-like object 124 | to be converted to NPY. 125 | 126 | Returns: 127 | (obj): NPY array. 128 | """ 129 | buffer = BytesIO() 130 | np.save(buffer, array_like) 131 | return buffer.getvalue() 132 | 133 | 134 | def encode_csv(content, accept_type=None): 135 | """Convert the result of a transformers pipeline to CSV. 136 | Args: 137 | content (dict | list): result of transformers pipeline. 138 | Returns: 139 | (str): object serialized to CSV 140 | """ 141 | stream = StringIO() 142 | if not isinstance(content, list): 143 | content = list(content) 144 | 145 | column_header = content[0].keys() 146 | writer = csv.DictWriter(stream, column_header) 147 | 148 | writer.writeheader() 149 | writer.writerows(content) 150 | return stream.getvalue() 151 | 152 | 153 | def encode_image(image, accept_type=content_types.PNG): 154 | """Convert a PIL.Image object to a byte stream. 155 | Args: 156 | image (PIL.Image): image to be converted. 157 | accept_type (str): content type of the image. 158 | Returns: 159 | (bytes): byte stream of the image. 160 | """ 161 | accept_type = "PNG" if content_types.X_IMAGE == accept_type else accept_type.split("/")[-1].upper() 162 | 163 | with BytesIO() as out: 164 | image.save(out, format=accept_type) 165 | return out.getvalue() 166 | 167 | 168 | _encoder_map = { 169 | content_types.NPY: _array_to_npy, 170 | content_types.CSV: encode_csv, 171 | content_types.JSON: encode_json, 172 | content_types.JPEG: encode_image, 173 | content_types.PNG: encode_image, 174 | content_types.TIFF: encode_image, 175 | content_types.BMP: encode_image, 176 | content_types.GIF: encode_image, 177 | content_types.WEBP: encode_image, 178 | content_types.X_IMAGE: encode_image, 179 | } 180 | _decoder_map = { 181 | content_types.NPY: _npy_to_numpy, 182 | content_types.CSV: decode_csv, 183 | content_types.JSON: decode_json, 184 | # image mime-types 185 | content_types.JPEG: decode_image, 186 | content_types.PNG: decode_image, 187 | content_types.TIFF: decode_image, 188 | content_types.BMP: decode_image, 189 | content_types.GIF: decode_image, 190 | content_types.WEBP: decode_image, 191 | content_types.X_IMAGE: decode_image, 192 | # audio mime-types 193 | content_types.FLAC: decode_audio, 194 | content_types.MP3: decode_audio, 195 | content_types.WAV: decode_audio, 196 | content_types.OGG: decode_audio, 197 | content_types.X_AUDIO: decode_audio, 198 | } 199 | 200 | 201 | def decode(content, content_type=content_types.JSON): 202 | """ 203 | Decodes a specific content_type into an 🤗 Transformers object. 204 | """ 205 | try: 206 | decoder = _decoder_map[content_type] 207 | return decoder(content) 208 | except KeyError: 209 | raise errors.UnsupportedFormatError(content_type) 210 | except PredictionException as pred_err: 211 | raise pred_err 212 | 213 | 214 | def encode(content, accept_type=content_types.JSON): 215 | """ 216 | Encode an 🤗 Transformers object in a specific content_type. 217 | """ 218 | try: 219 | encoder = _encoder_map[accept_type] 220 | return encoder(content, accept_type) 221 | except KeyError: 222 | raise errors.UnsupportedFormatError(accept_type) 223 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib.util 15 | import logging 16 | 17 | from transformers.utils.import_utils import is_torch_bf16_gpu_available 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | _diffusers = importlib.util.find_spec("diffusers") is not None 23 | 24 | 25 | def is_diffusers_available(): 26 | return _diffusers 27 | 28 | 29 | if is_diffusers_available(): 30 | import torch 31 | 32 | from diffusers import DiffusionPipeline 33 | 34 | 35 | class SMDiffusionPipelineForText2Image: 36 | 37 | def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU 38 | self.pipeline = None 39 | dtype = torch.float32 40 | if device == "cuda": 41 | dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 42 | if torch.cuda.device_count() > 1: 43 | device_map = "balanced" 44 | self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map) 45 | 46 | if not self.pipeline: 47 | self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype).to(device) 48 | 49 | def __call__( 50 | self, 51 | prompt, 52 | **kwargs, 53 | ): 54 | # TODO: add support for more images (Reason is correct output) 55 | if "num_images_per_prompt" in kwargs: 56 | kwargs.pop("num_images_per_prompt") 57 | logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.") 58 | 59 | # Call pipeline with parameters 60 | out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs) 61 | return out.images[0] 62 | 63 | 64 | DIFFUSERS_TASKS = { 65 | "text-to-image": SMDiffusionPipelineForText2Image, 66 | } 67 | 68 | 69 | def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs): 70 | """Get a pipeline for Diffusers models.""" 71 | device = "cuda" if device == 0 else "cpu" 72 | pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device) 73 | return pipeline 74 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/handler_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | import logging 17 | import os 18 | import sys 19 | import time 20 | from abc import ABC 21 | from inspect import signature 22 | 23 | from sagemaker_inference import environment, utils 24 | from transformers.pipelines import SUPPORTED_TASKS 25 | 26 | from mms.service import PredictionException 27 | from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder 28 | from sagemaker_huggingface_inference_toolkit.transformers_utils import ( 29 | _is_gpu_available, 30 | get_pipeline, 31 | infer_task_from_model_architecture, 32 | ) 33 | 34 | 35 | ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" 36 | PYTHON_PATH_ENV = "PYTHONPATH" 37 | MODEL_FN = "model_fn" 38 | INPUT_FN = "input_fn" 39 | PREDICT_FN = "predict_fn" 40 | OUTPUT_FN = "output_fn" 41 | TRANSFORM_FN = "transform_fn" 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | 46 | class HuggingFaceHandlerService(ABC): 47 | """Default handler service that is executed by the model server. 48 | 49 | The handler service is responsible for defining our InferenceHandler. 50 | - The ``handle`` method is invoked for all incoming inference requests to the model server. 51 | - The ``initialize`` method is invoked at model server start up. 52 | 53 | Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md 54 | """ 55 | 56 | def __init__(self): 57 | self.error = None 58 | self.batch_size = 1 59 | self.model_dir = None 60 | self.model = None 61 | self.device = -1 62 | self.initialized = False 63 | self.attempted_init = False 64 | self.context = None 65 | self.manifest = None 66 | self.environment = environment.Environment() 67 | self.load_extra_arg = [] 68 | self.preprocess_extra_arg = [] 69 | self.predict_extra_arg = [] 70 | self.postprocess_extra_arg = [] 71 | self.transform_extra_arg = [] 72 | 73 | def initialize(self, context): 74 | """ 75 | Initialize model. This will be called during model loading time 76 | :param context: Initial context contains model server system properties. 77 | :return: 78 | """ 79 | self.attempted_init = True 80 | self.context = context 81 | properties = context.system_properties 82 | self.model_dir = properties.get("model_dir") 83 | self.batch_size = context.system_properties["batch_size"] 84 | 85 | code_dir_path = os.path.join(self.model_dir, "code") 86 | sys.path.insert(0, code_dir_path) 87 | self.validate_and_initialize_user_module() 88 | 89 | self.device = self.get_device() 90 | self.model = self.load(*([self.model_dir] + self.load_extra_arg)) 91 | self.initialized = True 92 | # # Load methods from file 93 | # if (not self._initialized) and ENABLE_MULTI_MODEL: 94 | # code_dir = os.path.join(context.system_properties.get("model_dir"), "code") 95 | # sys.path.append(code_dir) 96 | # self._initialized = True 97 | # # add model_dir/code to python path 98 | 99 | def get_device(self): 100 | """ 101 | The get device function will return the device for the DL Framework. 102 | """ 103 | if _is_gpu_available(): 104 | return int(self.context.system_properties.get("gpu_id")) 105 | else: 106 | return -1 107 | 108 | def load(self, model_dir, context=None): 109 | """ 110 | The Load handler is responsible for loading the Hugging Face transformer model. 111 | It can be overridden to load the model from storage. 112 | 113 | Args: 114 | model_dir (str): The directory where model files are stored. 115 | context (obj): metadata on the incoming request data (default: None). 116 | 117 | Returns: 118 | hf_pipeline (Pipeline): A Hugging Face Transformer pipeline. 119 | """ 120 | # gets pipeline from task tag 121 | if "HF_TASK" in os.environ: 122 | hf_pipeline = get_pipeline(task=os.environ["HF_TASK"], model_dir=model_dir, device=self.device) 123 | elif "config.json" in os.listdir(model_dir): 124 | task = infer_task_from_model_architecture(f"{model_dir}/config.json") 125 | hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device) 126 | elif "model_index.json" in os.listdir(model_dir): 127 | task = "text-to-image" 128 | hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device) 129 | else: 130 | raise ValueError( 131 | f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.", 132 | 403, 133 | ) 134 | return hf_pipeline 135 | 136 | def preprocess(self, input_data, content_type, context=None): 137 | """ 138 | The preprocess handler is responsible for deserializing the input data into 139 | an object for prediction, can handle JSON. 140 | The preprocess handler can be overridden for data or feature transformation. 141 | 142 | Args: 143 | input_data: the request payload serialized in the content_type format. 144 | content_type: the request content_type. 145 | context (obj): metadata on the incoming request data (default: None). 146 | 147 | Returns: 148 | decoded_input_data (dict): deserialized input_data into a Python dictonary. 149 | """ 150 | # raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties 151 | if ( 152 | os.environ.get("HF_TASK", None) == "zero-shot-classification" 153 | or os.environ.get("HF_TASK", None) == "table-question-answering" 154 | ) and content_type == content_types.CSV: 155 | raise PredictionException( 156 | f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type", 157 | 400, 158 | ) 159 | 160 | decoded_input_data = decoder_encoder.decode(input_data, content_type) 161 | return decoded_input_data 162 | 163 | def predict(self, data, model, context=None): 164 | """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline` 165 | on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available. 166 | The predict handler can be overridden to implement the model inference. 167 | 168 | Args: 169 | data (dict): deserialized decoded_input_data returned by the input_fn 170 | model : Model returned by the `load` method or if it is a custom module `model_fn`. 171 | context (obj): metadata on the incoming request data (default: None). 172 | 173 | Returns: 174 | obj (dict): prediction result. 175 | """ 176 | 177 | # pop inputs for pipeline 178 | inputs = data.pop("inputs", data) 179 | parameters = data.pop("parameters", None) 180 | 181 | # pass inputs with all kwargs in data 182 | if parameters is not None: 183 | prediction = model(inputs, **parameters) 184 | else: 185 | prediction = model(inputs) 186 | return prediction 187 | 188 | def postprocess(self, prediction, accept, context=None): 189 | """ 190 | The postprocess handler is responsible for serializing the prediction result to 191 | the desired accept type, can handle JSON. 192 | The postprocess handler can be overridden for inference response transformation. 193 | 194 | Args: 195 | prediction (dict): a prediction result from predict. 196 | accept (str): type which the output data needs to be serialized. 197 | context (obj): metadata on the incoming request data (default: None). 198 | Returns: output data serialized 199 | """ 200 | return decoder_encoder.encode(prediction, accept) 201 | 202 | def transform_fn(self, model, input_data, content_type, accept, context=None): 203 | """ 204 | Transform function ("transform_fn") can be used to write one function with pre/post-processing steps and predict step in it. 205 | This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn". 206 | 207 | Args: 208 | model: Model returned by the model_fn above 209 | input_data: Data received for inference 210 | content_type: The content type of the inference data 211 | accept: The response accept type. 212 | context (obj): metadata on the incoming request data (default: None). 213 | 214 | Returns: Response in the "accept" format type. 215 | 216 | """ 217 | # run pipeline 218 | start_time = time.time() 219 | processed_data = self.preprocess(*([input_data, content_type] + self.preprocess_extra_arg)) 220 | preprocess_time = time.time() - start_time 221 | predictions = self.predict(*([processed_data, model] + self.predict_extra_arg)) 222 | predict_time = time.time() - preprocess_time - start_time 223 | response = self.postprocess(*([predictions, accept] + self.postprocess_extra_arg)) 224 | postprocess_time = time.time() - predict_time - preprocess_time - start_time 225 | 226 | logger.info( 227 | f"Preprocess time - {preprocess_time * 1000} ms\n" 228 | f"Predict time - {predict_time * 1000} ms\n" 229 | f"Postprocess time - {postprocess_time * 1000} ms" 230 | ) 231 | 232 | return response 233 | 234 | def handle(self, data, context): 235 | """Handles an inference request with input data and makes a prediction. 236 | 237 | Args: 238 | data (obj): the request data. 239 | context (obj): metadata on the incoming request data. 240 | 241 | Returns: 242 | list[obj]: The return value from the Transformer.transform method, 243 | which is a serialized prediction result wrapped in a list if 244 | inference is successful. Otherwise returns an error message 245 | with the context set appropriately. 246 | 247 | """ 248 | try: 249 | if not self.initialized: 250 | if self.attempted_init: 251 | logger.warn( 252 | "Model is not initialized, will try to load model again.\n" 253 | "Please consider increase wait time for model loading.\n" 254 | ) 255 | self.initialize(context) 256 | 257 | input_data = data[0].get("body") 258 | 259 | request_property = context.request_processor[0].get_request_properties() 260 | content_type = utils.retrieve_content_type_header(request_property) 261 | accept = request_property.get("Accept") or request_property.get("accept") 262 | 263 | if not accept or accept == content_types.ANY: 264 | accept = content_types.JSON 265 | 266 | if content_type in content_types.UTF8_TYPES: 267 | input_data = input_data.decode("utf-8") 268 | 269 | predict_start = time.time() 270 | response = self.transform_fn(*([self.model, input_data, content_type, accept] + self.transform_extra_arg)) 271 | predict_end = time.time() 272 | 273 | context.metrics.add_time("Transform Fn", round((predict_end - predict_start) * 1000, 2)) 274 | 275 | context.set_response_content_type(0, accept) 276 | return [response] 277 | 278 | except Exception as e: 279 | raise PredictionException(str(e), 400) 280 | 281 | def validate_and_initialize_user_module(self): 282 | """Retrieves and validates the inference handlers provided within the user module. 283 | Can override load, preprocess, predict and post process function. 284 | """ 285 | user_module_name = self.environment.module_name 286 | if importlib.util.find_spec(user_module_name) is not None: 287 | logger.info("Inference script implementation found at `{}`.".format(user_module_name)) 288 | user_module = importlib.import_module(user_module_name) 289 | 290 | load_fn = getattr(user_module, MODEL_FN, None) 291 | preprocess_fn = getattr(user_module, INPUT_FN, None) 292 | predict_fn = getattr(user_module, PREDICT_FN, None) 293 | postprocess_fn = getattr(user_module, OUTPUT_FN, None) 294 | transform_fn = getattr(user_module, TRANSFORM_FN, None) 295 | 296 | if transform_fn and (preprocess_fn or predict_fn or postprocess_fn): 297 | raise ValueError( 298 | "Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format( 299 | TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN 300 | ) 301 | ) 302 | self.log_func_implementation_found_or_not(load_fn, MODEL_FN) 303 | if load_fn is not None: 304 | self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn) 305 | self.load = load_fn 306 | self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN) 307 | if preprocess_fn is not None: 308 | self.preprocess_extra_arg = self.function_extra_arg( 309 | HuggingFaceHandlerService.preprocess, preprocess_fn 310 | ) 311 | self.preprocess = preprocess_fn 312 | self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN) 313 | if predict_fn is not None: 314 | self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn) 315 | self.predict = predict_fn 316 | self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN) 317 | if postprocess_fn is not None: 318 | self.postprocess_extra_arg = self.function_extra_arg( 319 | HuggingFaceHandlerService.postprocess, postprocess_fn 320 | ) 321 | self.postprocess = postprocess_fn 322 | self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN) 323 | if transform_fn is not None: 324 | self.transform_extra_arg = self.function_extra_arg( 325 | HuggingFaceHandlerService.transform_fn, transform_fn 326 | ) 327 | self.transform_fn = transform_fn 328 | else: 329 | logger.info( 330 | "No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format( 331 | user_module_name 332 | ) 333 | ) 334 | 335 | @staticmethod 336 | def log_func_implementation_found_or_not(func, func_name): 337 | if func is not None: 338 | logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name)) 339 | else: 340 | logger.info( 341 | "No `{}` implementation was found. The default one from the handler service will be used.".format( 342 | func_name 343 | ) 344 | ) 345 | 346 | def function_extra_arg(self, default_func, func): 347 | """Helper to call the handler function which covers 2 cases: 348 | 1. the handle function takes context 349 | 2. the handle function does not take context 350 | """ 351 | default_params = signature(default_func).parameters 352 | func_params = signature(func).parameters 353 | 354 | if "self" in default_params: 355 | num_default_func_input = len(default_params) - 1 356 | else: 357 | num_default_func_input = len(default_params) 358 | 359 | num_func_input = len(func_params) 360 | if num_default_func_input == num_func_input: 361 | # function takes context 362 | extra_args = [self.context] 363 | elif num_default_func_input == num_func_input + 1: 364 | # function does not take context 365 | extra_args = [] 366 | else: 367 | raise TypeError( 368 | "{} definition takes {} or {} arguments but {} were given.".format( 369 | func.__name__, num_default_func_input - 1, num_default_func_input, num_func_input 370 | ) 371 | ) 372 | return extra_args 373 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/mms_model_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import 15 | 16 | import os 17 | import pathlib 18 | import subprocess 19 | 20 | from sagemaker_inference import environment, logging 21 | from sagemaker_inference.model_server import ( 22 | DEFAULT_MMS_LOG_FILE, 23 | DEFAULT_MMS_MODEL_NAME, 24 | ENABLE_MULTI_MODEL, 25 | MMS_CONFIG_FILE, 26 | REQUIREMENTS_PATH, 27 | _add_sigchild_handler, 28 | _add_sigterm_handler, 29 | _create_model_server_config_file, 30 | _install_requirements, 31 | _retry_retrieve_mms_server_process, 32 | _set_python_path, 33 | ) 34 | 35 | from sagemaker_huggingface_inference_toolkit import handler_service 36 | from sagemaker_huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available 37 | from sagemaker_huggingface_inference_toolkit.transformers_utils import ( 38 | HF_API_TOKEN, 39 | HF_MODEL_REVISION, 40 | _load_model_from_hub, 41 | ) 42 | 43 | 44 | logger = logging.get_logger() 45 | 46 | DEFAULT_HANDLER_SERVICE = handler_service.__name__ 47 | 48 | DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models") 49 | DEFAULT_MODEL_STORE = "/" 50 | 51 | 52 | def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): 53 | """Configure and start the model server. 54 | 55 | Args: 56 | handler_service (str): python path pointing to a module that defines 57 | a class with the following: 58 | 59 | - A ``handle`` method, which is invoked for all incoming inference 60 | requests to the model server. 61 | - A ``initialize`` method, which is invoked at model server start up 62 | for loading the model. 63 | 64 | Defaults to ``sagemaker_huggingface_inference_toolkit.handler_service``. 65 | 66 | """ 67 | use_hf_hub = "HF_MODEL_ID" in os.environ 68 | model_store = DEFAULT_MODEL_STORE 69 | if ENABLE_MULTI_MODEL: 70 | if not os.getenv("SAGEMAKER_HANDLER"): 71 | os.environ["SAGEMAKER_HANDLER"] = handler_service 72 | _set_python_path() 73 | elif use_hf_hub: 74 | # Use different model store directory 75 | model_store = DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY 76 | storage_dir = _load_model_from_hub( 77 | model_id=os.environ["HF_MODEL_ID"], 78 | model_dir=DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, 79 | revision=HF_MODEL_REVISION, 80 | use_auth_token=HF_API_TOKEN, 81 | ) 82 | _adapt_to_mms_format(handler_service, storage_dir) 83 | else: 84 | _set_python_path() 85 | 86 | env = environment.Environment() 87 | 88 | # Set the number of workers to available number if optimum neuron is available and not already set 89 | if is_optimum_neuron_available() and os.environ.get("SAGEMAKER_MODEL_SERVER_WORKERS", None) is None: 90 | from optimum.neuron.utils.cache_utils import get_num_neuron_cores 91 | 92 | try: 93 | env._model_server_workers = str(get_num_neuron_cores()) 94 | except Exception: 95 | env._model_server_workers = "1" 96 | 97 | # Note: multi-model default config already sets default_service_handler 98 | handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service 99 | _create_model_server_config_file(env, handler_service_for_config) 100 | 101 | if os.path.exists(REQUIREMENTS_PATH): 102 | _install_requirements() 103 | 104 | multi_model_server_cmd = [ 105 | "multi-model-server", 106 | "--start", 107 | "--model-store", 108 | model_store, 109 | "--mms-config", 110 | MMS_CONFIG_FILE, 111 | "--log-config", 112 | DEFAULT_MMS_LOG_FILE, 113 | ] 114 | if not ENABLE_MULTI_MODEL and not use_hf_hub: 115 | multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir] 116 | 117 | logger.info(multi_model_server_cmd) 118 | subprocess.Popen(multi_model_server_cmd) 119 | # retry for configured timeout 120 | mms_process = _retry_retrieve_mms_server_process(env.startup_timeout) 121 | 122 | _add_sigterm_handler(mms_process) 123 | _add_sigchild_handler() 124 | 125 | mms_process.wait() 126 | 127 | 128 | def _adapt_to_mms_format(handler_service, model_path): 129 | os.makedirs(DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, exist_ok=True) 130 | 131 | # gets the model from the path, default is model/ 132 | model = pathlib.PurePath(model_path) 133 | 134 | # This is archiving or cp /opt/ml/model to /opt/ml (MODEL_STORE) into model (MODEL_NAME) 135 | model_archiver_cmd = [ 136 | "model-archiver", 137 | "--model-name", 138 | model.name, 139 | "--handler", 140 | handler_service, 141 | "--model-path", 142 | model_path, 143 | "--export-path", 144 | DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, 145 | "--archive-format", 146 | "no-archive", 147 | "--f", 148 | ] 149 | 150 | logger.info(model_archiver_cmd) 151 | subprocess.check_call(model_archiver_cmd) 152 | 153 | _set_python_path() 154 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/optimum_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib.util 16 | import logging 17 | import os 18 | 19 | 20 | _optimum_neuron = False 21 | if importlib.util.find_spec("optimum") is not None: 22 | if importlib.util.find_spec("optimum.neuron") is not None: 23 | _optimum_neuron = True 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def is_optimum_neuron_available(): 29 | return _optimum_neuron 30 | 31 | 32 | def get_input_shapes(model_dir): 33 | """Method to get input shapes from model config file. If config file is not present, default values are returned.""" 34 | from transformers import AutoConfig 35 | 36 | input_shapes = {} 37 | input_shapes_available = False 38 | # try to get input shapes from config file 39 | try: 40 | config = AutoConfig.from_pretrained(model_dir) 41 | if hasattr(config, "neuron"): 42 | # check if static batch size and sequence length are available 43 | if config.neuron.get("static_batch_size", None) and config.neuron.get("static_sequence_length", None): 44 | input_shapes["batch_size"] = config.neuron["static_batch_size"] 45 | input_shapes["sequence_length"] = config.neuron["static_sequence_length"] 46 | input_shapes_available = True 47 | logger.info( 48 | f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}" 49 | ) 50 | else: 51 | # Add warning if environment variables are set but will be ignored 52 | if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None: 53 | logger.warning( 54 | "HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used." 55 | ) 56 | if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None: 57 | logger.warning( 58 | "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used." 59 | ) 60 | except Exception: 61 | input_shapes_available = False 62 | 63 | # return input shapes if available 64 | if input_shapes_available: 65 | return input_shapes 66 | 67 | # extract input shapes from environment variables 68 | sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) 69 | if sequence_length is None: 70 | raise ValueError( 71 | "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is not set. Please set HF_OPTIMUM_SEQUENCE_LENGTH to a positive integer." 72 | ) 73 | 74 | if not int(sequence_length) > 0: 75 | raise ValueError( 76 | f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}" 77 | ) 78 | batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1) 79 | logger.info( 80 | f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}" 81 | ) 82 | return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)} 83 | 84 | 85 | def get_optimum_neuron_pipeline(task, model_dir): 86 | """Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised.""" 87 | from optimum.neuron.pipelines.transformers.base import NEURONX_SUPPORTED_TASKS, pipeline 88 | from optimum.neuron.utils import NEURON_FILE_NAME 89 | 90 | # check task support 91 | if task not in NEURONX_SUPPORTED_TASKS: 92 | raise ValueError( 93 | f"Task {task} is not supported by optimum neuron and inf2. Supported tasks are: {list(NEURONX_SUPPORTED_TASKS.keys())}" 94 | ) 95 | 96 | # check if model is already converted and has input shapes available 97 | export = True 98 | if NEURON_FILE_NAME in os.listdir(model_dir): 99 | export = False 100 | if export: 101 | logger.info("Model is not converted. Checking if required environment variables are set and converting model.") 102 | 103 | # get static input shapes to run inference 104 | input_shapes = get_input_shapes(model_dir) 105 | # set NEURON_RT_NUM_CORES to 1 to avoid conflicts with multiple HTTP workers 106 | os.environ["NEURON_RT_NUM_CORES"] = "1" 107 | # get optimum neuron pipeline 108 | neuron_pipe = pipeline(task, model=model_dir, export=export, input_shapes=input_shapes) 109 | 110 | return neuron_pipe 111 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/serving.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from subprocess import CalledProcessError 15 | 16 | from retrying import retry 17 | 18 | from sagemaker_huggingface_inference_toolkit import handler_service, mms_model_server 19 | 20 | 21 | HANDLER_SERVICE = handler_service.__name__ 22 | 23 | 24 | def _retry_if_error(exception): 25 | return isinstance(exception, CalledProcessError or OSError) 26 | 27 | 28 | @retry(stop_max_delay=1000 * 50, retry_on_exception=_retry_if_error) 29 | def _start_mms(): 30 | mms_model_server.start_model_server(handler_service=HANDLER_SERVICE) 31 | 32 | 33 | def main(): 34 | _start_mms() 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /src/sagemaker_huggingface_inference_toolkit/transformers_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib.util 15 | import json 16 | import logging 17 | import os 18 | from pathlib import Path 19 | from typing import Optional 20 | 21 | from huggingface_hub import HfApi, login, snapshot_download 22 | from transformers import AutoTokenizer, pipeline 23 | from transformers.file_utils import is_tf_available, is_torch_available 24 | from transformers.pipelines import Pipeline 25 | 26 | from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available 27 | from sagemaker_huggingface_inference_toolkit.optimum_utils import ( 28 | get_optimum_neuron_pipeline, 29 | is_optimum_neuron_available, 30 | ) 31 | 32 | 33 | if is_tf_available(): 34 | import tensorflow as tf 35 | 36 | if is_torch_available(): 37 | import torch 38 | 39 | _aws_neuron_available = importlib.util.find_spec("torch_neuron") is not None 40 | 41 | 42 | def is_aws_neuron_available(): 43 | return _aws_neuron_available 44 | 45 | 46 | def strtobool(val): 47 | """Convert a string representation of truth to True or False. 48 | True values are 'y', 'yes', 't', 'true', 'on', '1', 'TRUE', or 'True'; false values 49 | are 'n', 'no', 'f', 'false', 'off', '0', 'FALSE' or 'False. Raises ValueError if 50 | 'val' is anything else. 51 | """ 52 | val = val.lower() 53 | if val in ("y", "yes", "t", "true", "on", "1", "TRUE", "True"): 54 | return True 55 | elif val in ("n", "no", "f", "false", "off", "0", "FALSE", "False"): 56 | return False 57 | else: 58 | raise ValueError("invalid truth value %r" % (val,)) 59 | 60 | 61 | logger = logging.getLogger(__name__) 62 | 63 | 64 | FRAMEWORK_MAPPING = { 65 | "pytorch": "pytorch*", 66 | "tensorflow": "tf*", 67 | "tf": "tf*", 68 | "pt": "pytorch*", 69 | "flax": "flax*", 70 | "rust": "rust*", 71 | "onnx": "*onnx*", 72 | "safetensors": "*safetensors", 73 | "coreml": "*mlmodel", 74 | "tflite": "*tflite", 75 | "savedmodel": "*tar.gz", 76 | "openvino": "*openvino*", 77 | "ckpt": "*ckpt", 78 | "neuronx": "*neuron", 79 | } 80 | 81 | 82 | REPO_ID_SEPARATOR = "__" 83 | 84 | ARCHITECTURES_2_TASK = { 85 | "TapasForQuestionAnswering": "table-question-answering", 86 | "ForQuestionAnswering": "question-answering", 87 | "ForTokenClassification": "token-classification", 88 | "ForSequenceClassification": "text-classification", 89 | "ForMultipleChoice": "multiple-choice", 90 | "ForMaskedLM": "fill-mask", 91 | "ForCausalLM": "text-generation", 92 | "ForConditionalGeneration": "text2text-generation", 93 | "MTModel": "text2text-generation", 94 | "EncoderDecoderModel": "text2text-generation", 95 | # Model specific task for backward comp 96 | "GPT2LMHeadModel": "text-generation", 97 | "T5WithLMHeadModel": "text2text-generation", 98 | } 99 | 100 | 101 | HF_API_TOKEN = os.environ.get("HF_API_TOKEN", None) 102 | HF_MODEL_REVISION = os.environ.get("HF_MODEL_REVISION", None) 103 | TRUST_REMOTE_CODE = strtobool(os.environ.get("HF_TRUST_REMOTE_CODE", "False")) 104 | 105 | 106 | def create_artifact_filter(framework): 107 | """ 108 | Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading 109 | """ 110 | ignore_regex_list = list(set(FRAMEWORK_MAPPING.values())) 111 | 112 | pattern = FRAMEWORK_MAPPING.get(framework, None) 113 | if pattern in ignore_regex_list: 114 | ignore_regex_list.remove(pattern) 115 | return ignore_regex_list 116 | else: 117 | return [] 118 | 119 | 120 | def _is_gpu_available(): 121 | """ 122 | checks if a gpu is available. 123 | """ 124 | if is_tf_available(): 125 | return True if len(tf.config.list_physical_devices("GPU")) > 0 else False 126 | elif is_torch_available(): 127 | return torch.cuda.is_available() 128 | else: 129 | raise RuntimeError( 130 | "At least one of TensorFlow 2.0 or PyTorch should be installed. " 131 | "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " 132 | "To install PyTorch, read the instructions at https://pytorch.org/." 133 | ) 134 | 135 | 136 | def _get_framework(): 137 | """ 138 | extracts which DL framework is used for inference, if both are installed use pytorch 139 | """ 140 | if is_torch_available(): 141 | return "pytorch" 142 | elif is_tf_available(): 143 | return "tensorflow" 144 | else: 145 | raise RuntimeError( 146 | "At least one of TensorFlow 2.0 or PyTorch should be installed. " 147 | "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " 148 | "To install PyTorch, read the instructions at https://pytorch.org/." 149 | ) 150 | 151 | 152 | def _build_storage_path(model_id: str, model_dir: Path, revision: Optional[str] = None): 153 | """ 154 | creates storage path for hub model based on model_id and revision 155 | """ 156 | if "/" and revision is None: 157 | storage_path = os.path.join(model_dir, model_id.replace("/", REPO_ID_SEPARATOR)) 158 | elif "/" and revision is not None: 159 | storage_path = os.path.join(model_dir, model_id.replace("/", REPO_ID_SEPARATOR) + "." + revision) 160 | elif revision is not None: 161 | storage_path = os.path.join(model_dir, model_id + "." + revision) 162 | else: 163 | storage_path = os.path.join(model_dir, model_id) 164 | return storage_path 165 | 166 | 167 | def _load_model_from_hub( 168 | model_id: str, model_dir: Path, revision: Optional[str] = None, use_auth_token: Optional[str] = None 169 | ): 170 | """ 171 | Downloads a model repository at the specified revision from the Hugging Face Hub. 172 | All files are nested inside a folder in order to keep their actual filename 173 | relative to that folder. `org__model.revision` 174 | """ 175 | logger.warn( 176 | "This is an experimental beta features, which allows downloading model from the Hugging Face Hub on start up. " 177 | "It loads the model defined in the env var `HF_MODEL_ID`" 178 | ) 179 | if use_auth_token is not None: 180 | login(token=use_auth_token) 181 | # extracts base framework 182 | framework = _get_framework() 183 | 184 | # creates directory for saved model based on revision and model 185 | storage_folder = _build_storage_path(model_id, model_dir, revision) 186 | os.makedirs(storage_folder, exist_ok=True) 187 | 188 | # check if safetensors weights are available 189 | if framework == "pytorch": 190 | files = HfApi().model_info(model_id).siblings 191 | if is_optimum_neuron_available() and any(f.rfilename.endswith("neuron") for f in files): 192 | framework = "neuronx" 193 | elif any(f.rfilename.endswith("safetensors") for f in files): 194 | framework = "safetensors" 195 | 196 | # create regex to only include the framework specific weights 197 | ignore_regex = create_artifact_filter(framework) 198 | 199 | # Download the repository to the workdir and filter out non-framework specific weights 200 | snapshot_download( 201 | model_id, 202 | revision=revision, 203 | local_dir=str(storage_folder), 204 | local_dir_use_symlinks=False, 205 | ignore_patterns=ignore_regex, 206 | ) 207 | 208 | return storage_folder 209 | 210 | 211 | def infer_task_from_model_architecture(model_config_path: str, architecture_index=0) -> str: 212 | """ 213 | Infer task from `config.json` of trained model. It is not guaranteed to the detect, e.g. some models implement multiple architectures or 214 | trainend on different tasks https://huggingface.co/facebook/bart-large/blob/main/config.json. Should work for every on Amazon SageMaker fine-tuned model. 215 | It is always recommended to set the task through the env var `TASK`. 216 | """ 217 | with open(model_config_path, "r") as config_file: 218 | config = json.loads(config_file.read()) 219 | architecture = config.get("architectures", [None])[architecture_index] 220 | 221 | task = None 222 | for arch_options in ARCHITECTURES_2_TASK: 223 | if architecture.endswith(arch_options): 224 | task = ARCHITECTURES_2_TASK[arch_options] 225 | 226 | if task is None: 227 | raise ValueError( 228 | f"Task couldn't be inferenced from {architecture}." 229 | f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}." 230 | "Use env `HF_TASK` to define your task." 231 | ) 232 | # set env to work with 233 | os.environ["HF_TASK"] = task 234 | return task 235 | 236 | 237 | def infer_task_from_hub(model_id: str, revision: Optional[str] = None, use_auth_token: Optional[str] = None) -> str: 238 | """ 239 | Infer task from Hub by extracting `pipeline_tag` for model_info. 240 | """ 241 | _api = HfApi() 242 | model_info = _api.model_info(repo_id=model_id, revision=revision, token=use_auth_token) 243 | if model_info.pipeline_tag is not None: 244 | # set env to work with 245 | os.environ["HF_TASK"] = model_info.pipeline_tag 246 | return model_info.pipeline_tag 247 | else: 248 | raise ValueError( 249 | f"Task couldn't be inferenced from {model_info.pipeline_tag}." "Use env `HF_TASK` to define your task." 250 | ) 251 | 252 | 253 | def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline: 254 | """ 255 | create pipeline class for a specific task based on local saved model 256 | """ 257 | if task is None: 258 | raise EnvironmentError( 259 | "The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined" 260 | ) 261 | # define tokenizer or feature extractor as kwargs to load it the pipeline correctly 262 | if task in { 263 | "automatic-speech-recognition", 264 | "image-segmentation", 265 | "image-classification", 266 | "audio-classification", 267 | "object-detection", 268 | "zero-shot-image-classification", 269 | }: 270 | kwargs["feature_extractor"] = model_dir 271 | else: 272 | kwargs["tokenizer"] = model_dir 273 | # check if optimum neuron is available and tries to load it 274 | if is_optimum_neuron_available(): 275 | hf_pipeline = get_optimum_neuron_pipeline(task=task, model_dir=model_dir) 276 | elif TRUST_REMOTE_CODE and os.environ.get("HF_MODEL_ID", None) is not None and device == 0: 277 | tokenizer = AutoTokenizer.from_pretrained(os.environ["HF_MODEL_ID"]) 278 | 279 | hf_pipeline = pipeline( 280 | task=task, 281 | model=os.environ["HF_MODEL_ID"], 282 | tokenizer=tokenizer, 283 | trust_remote_code=TRUST_REMOTE_CODE, 284 | model_kwargs={"device_map": "auto", "torch_dtype": "auto"}, 285 | ) 286 | elif is_diffusers_available() and task == "text-to-image": 287 | hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) 288 | else: 289 | # load pipeline 290 | hf_pipeline = pipeline( 291 | task=task, model=model_dir, device=device, trust_remote_code=TRUST_REMOTE_CODE, **kwargs 292 | ) 293 | 294 | return hf_pipeline 295 | -------------------------------------------------------------------------------- /tests/integ/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/integ/__init__.py -------------------------------------------------------------------------------- /tests/integ/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from integ.utils import ( 4 | validate_automatic_speech_recognition, 5 | validate_classification, 6 | validate_feature_extraction, 7 | validate_fill_mask, 8 | validate_ner, 9 | validate_question_answering, 10 | validate_summarization, 11 | validate_text2text_generation, 12 | validate_text_generation, 13 | validate_translation, 14 | validate_zero_shot_classification, 15 | ) 16 | 17 | 18 | task2model = { 19 | "text-classification": { 20 | "pytorch": "distilbert-base-uncased-finetuned-sst-2-english", 21 | "tensorflow": "distilbert-base-uncased-finetuned-sst-2-english", 22 | }, 23 | "zero-shot-classification": { 24 | "pytorch": "joeddav/xlm-roberta-large-xnli", 25 | "tensorflow": None, 26 | }, 27 | "feature-extraction": { 28 | "pytorch": "bert-base-uncased", 29 | "tensorflow": "bert-base-uncased", 30 | }, 31 | "ner": { 32 | "pytorch": "dbmdz/bert-large-cased-finetuned-conll03-english", 33 | "tensorflow": "dbmdz/bert-large-cased-finetuned-conll03-english", 34 | }, 35 | "question-answering": { 36 | "pytorch": "distilbert-base-uncased-distilled-squad", 37 | "tensorflow": "distilbert-base-uncased-distilled-squad", 38 | }, 39 | "fill-mask": { 40 | "pytorch": "albert-base-v2", 41 | "tensorflow": "albert-base-v2", 42 | }, 43 | "summarization": { 44 | "pytorch": "sshleifer/distilbart-xsum-1-1", 45 | "tensorflow": "sshleifer/distilbart-xsum-1-1", 46 | }, 47 | "translation_xx_to_yy": { 48 | "pytorch": "Helsinki-NLP/opus-mt-en-de", 49 | "tensorflow": "Helsinki-NLP/opus-mt-en-de", 50 | }, 51 | "text2text-generation": { 52 | "pytorch": "t5-small", 53 | "tensorflow": "t5-small", 54 | }, 55 | "text-generation": { 56 | "pytorch": "gpt2", 57 | "tensorflow": "gpt2", 58 | }, 59 | "image-classification": { 60 | "pytorch": "google/vit-base-patch16-224", 61 | "tensorflow": "google/vit-base-patch16-224", 62 | }, 63 | "automatic-speech-recognition": { 64 | "pytorch": "facebook/wav2vec2-base-100h", 65 | "tensorflow": "facebook/wav2vec2-base-960h", 66 | }, 67 | } 68 | 69 | task2input = { 70 | "text-classification": {"inputs": "I love you. I like you"}, 71 | "zero-shot-classification": { 72 | "inputs": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", 73 | "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, 74 | }, 75 | "feature-extraction": {"inputs": "What is the best book."}, 76 | "ner": {"inputs": "My name is Wolfgang and I live in Berlin"}, 77 | "question-answering": { 78 | "inputs": { 79 | "question": "What is used for inference?", 80 | "context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference.", 81 | } 82 | }, 83 | "fill-mask": {"inputs": "Paris is the [MASK] of France."}, 84 | "summarization": { 85 | "inputs": "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct." 86 | }, 87 | "translation_xx_to_yy": {"inputs": "My name is Sarah and I live in London"}, 88 | "text2text-generation": { 89 | "inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything." 90 | }, 91 | "text-generation": {"inputs": "My name is philipp and I am"}, 92 | "image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), 93 | "automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(), 94 | } 95 | 96 | task2output = { 97 | "text-classification": [{"label": "POSITIVE", "score": 0.99}], 98 | "zero-shot-classification": { 99 | "sequence": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", 100 | "labels": ["refund", "faq", "legal"], 101 | "scores": [0.96, 0.027, 0.008], 102 | }, 103 | "ner": [ 104 | {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19}, 105 | {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40}, 106 | ], 107 | "question-answering": {"score": 0.99, "start": 68, "end": 77, "answer": "sagemaker"}, 108 | "summarization": [{"summary_text": " The A The The ANew York City has been installed in the US."}], 109 | "translation_xx_to_yy": [{"translation_text": "Mein Name ist Sarah und ich lebe in London"}], 110 | "text2text-generation": [{"generated_text": "42 is the answer to life, the universe and everything"}], 111 | "feature-extraction": None, 112 | "fill-mask": None, 113 | "text-generation": None, 114 | "image-classification": [ 115 | {"score": 0.8858247399330139, "label": "tiger, Panthera tigris"}, 116 | {"score": 0.10940514504909515, "label": "tiger cat"}, 117 | {"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"}, 118 | {"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"}, 119 | {"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"}, 120 | ], 121 | "automatic-speech-recognition": { 122 | "text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS" 123 | }, 124 | } 125 | 126 | task2performance = { 127 | "text-classification": { 128 | "cpu": { 129 | "average_request_time": 4, 130 | }, 131 | "gpu": { 132 | "average_request_time": 1, 133 | }, 134 | }, 135 | "zero-shot-classification": { 136 | "cpu": { 137 | "average_request_time": 4, 138 | }, 139 | "gpu": { 140 | "average_request_time": 4, 141 | }, 142 | }, 143 | "feature-extraction": { 144 | "cpu": { 145 | "average_request_time": 4, 146 | }, 147 | "gpu": { 148 | "average_request_time": 1, 149 | }, 150 | }, 151 | "ner": { 152 | "cpu": { 153 | "average_request_time": 4, 154 | }, 155 | "gpu": { 156 | "average_request_time": 1, 157 | }, 158 | }, 159 | "question-answering": { 160 | "cpu": { 161 | "average_request_time": 4, 162 | }, 163 | "gpu": { 164 | "average_request_time": 4, 165 | }, 166 | }, 167 | "fill-mask": { 168 | "cpu": { 169 | "average_request_time": 4, 170 | }, 171 | "gpu": { 172 | "average_request_time": 3, 173 | }, 174 | }, 175 | "summarization": { 176 | "cpu": { 177 | "average_request_time": 26, 178 | }, 179 | "gpu": { 180 | "average_request_time": 3, 181 | }, 182 | }, 183 | "translation_xx_to_yy": { 184 | "cpu": { 185 | "average_request_time": 8, 186 | }, 187 | "gpu": { 188 | "average_request_time": 3, 189 | }, 190 | }, 191 | "text2text-generation": { 192 | "cpu": { 193 | "average_request_time": 4, 194 | }, 195 | "gpu": { 196 | "average_request_time": 3, 197 | }, 198 | }, 199 | "text-generation": { 200 | "cpu": { 201 | "average_request_time": 15, 202 | }, 203 | "gpu": { 204 | "average_request_time": 3, 205 | }, 206 | }, 207 | "image-classification": { 208 | "cpu": { 209 | "average_request_time": 4, 210 | }, 211 | "gpu": { 212 | "average_request_time": 1, 213 | }, 214 | }, 215 | "automatic-speech-recognition": { 216 | "cpu": { 217 | "average_request_time": 6, 218 | }, 219 | "gpu": { 220 | "average_request_time": 6, 221 | }, 222 | }, 223 | } 224 | 225 | task2validation = { 226 | "text-classification": validate_classification, 227 | "zero-shot-classification": validate_zero_shot_classification, 228 | "feature-extraction": validate_feature_extraction, 229 | "ner": validate_ner, 230 | "question-answering": validate_question_answering, 231 | "fill-mask": validate_fill_mask, 232 | "summarization": validate_summarization, 233 | "translation_xx_to_yy": validate_translation, 234 | "text2text-generation": validate_text2text_generation, 235 | "text-generation": validate_text_generation, 236 | "image-classification": validate_classification, 237 | "automatic-speech-recognition": validate_automatic_speech_recognition, 238 | } 239 | -------------------------------------------------------------------------------- /tests/integ/test_diffusers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from io import BytesIO 4 | 5 | import boto3 6 | from integ.utils import clean_up, timeout_and_delete_by_name 7 | from PIL import Image 8 | from sagemaker import Session 9 | from sagemaker.model import Model 10 | 11 | 12 | os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") 13 | SAGEMAKER_EXECUTION_ROLE = os.environ.get("SAGEMAKER_EXECUTION_ROLE", "sagemaker_execution_role") 14 | 15 | 16 | def get_framework_ecr_image(registry_id="763104351884", repository_name="huggingface-pytorch-inference", device="cpu"): 17 | client = boto3.client("ecr") 18 | 19 | def get_all_ecr_images(registry_id, repository_name, result_key): 20 | response = client.list_images( 21 | registryId=registry_id, 22 | repositoryName=repository_name, 23 | ) 24 | results = response[result_key] 25 | while "nextToken" in response: 26 | response = client.list_images( 27 | registryId=registry_id, 28 | nextToken=response["nextToken"], 29 | repositoryName=repository_name, 30 | ) 31 | results.extend(response[result_key]) 32 | return results 33 | 34 | images = get_all_ecr_images(registry_id=registry_id, repository_name=repository_name, result_key="imageIds") 35 | image_tags = [image["imageTag"] for image in images] 36 | image_regex = re.compile("\d\.\d\.\d-" + device + "-.{4}$") 37 | tag = sorted(list(filter(image_regex.match, image_tags)), reverse=True)[0] 38 | return f"{registry_id}.dkr.ecr.{os.environ.get('AWS_DEFAULT_REGION','us-east-1')}.amazonaws.com/{repository_name}:{tag}" 39 | 40 | 41 | # TODO: needs existing container 42 | def test_text_to_image_model(): 43 | image_uri = get_framework_ecr_image(repository_name="huggingface-pytorch-inference", device="gpu") 44 | 45 | name = "hf-test-text-to-image" 46 | task = "text-to-image" 47 | model = "echarlaix/tiny-random-stable-diffusion-xl" 48 | # instance_type = "ml.m5.large" if device == "cpu" else "ml.g4dn.xlarge" 49 | instance_type = "local_gpu" 50 | env = {"HF_MODEL_ID": model, "HF_TASK": task} 51 | 52 | sagemaker_session = Session() 53 | client = boto3.client("sagemaker-runtime") 54 | 55 | hf_model = Model( 56 | image_uri=image_uri, # A Docker image URI. 57 | model_data=None, # The S3 location of a SageMaker model data .tar.gz 58 | env=env, # Environment variables to run with image_uri when hosted in SageMaker (default: None). 59 | role=SAGEMAKER_EXECUTION_ROLE, # An AWS IAM role (either name or full ARN). 60 | name=name, # The model name 61 | sagemaker_session=sagemaker_session, 62 | ) 63 | 64 | with timeout_and_delete_by_name(name, sagemaker_session, minutes=59): 65 | # Use accelerator type to differentiate EI vs. CPU and GPU. Don't use processor value 66 | hf_model.deploy( 67 | initial_instance_count=1, 68 | instance_type=instance_type, 69 | endpoint_name=name, 70 | ) 71 | response = client.invoke_endpoint( 72 | EndpointName=name, 73 | Body={"inputs": "a yellow lemon tree"}, 74 | ContentType="application/json", 75 | Accept="image/png", 76 | ) 77 | 78 | # validate response 79 | response_body = response["Body"].read().decode("utf-8") 80 | 81 | img = Image.open(BytesIO(response_body)) 82 | assert isinstance(img, Image.Image) 83 | 84 | clean_up(endpoint_name=name, sagemaker_session=sagemaker_session) 85 | -------------------------------------------------------------------------------- /tests/integ/test_models_from_hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | import boto3 9 | from integ.config import task2input, task2model, task2output, task2performance, task2validation 10 | from integ.utils import clean_up, timeout_and_delete_by_name, track_infer_time 11 | from sagemaker import Session 12 | from sagemaker.model import Model 13 | 14 | 15 | os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") 16 | SAGEMAKER_EXECUTION_ROLE = os.environ.get("SAGEMAKER_EXECUTION_ROLE", "sagemaker_execution_role") 17 | 18 | 19 | def get_framework_ecr_image(registry_id="763104351884", repository_name="huggingface-pytorch-inference", device="cpu"): 20 | client = boto3.client("ecr") 21 | 22 | def get_all_ecr_images(registry_id, repository_name, result_key): 23 | response = client.list_images( 24 | registryId=registry_id, 25 | repositoryName=repository_name, 26 | ) 27 | results = response[result_key] 28 | while "nextToken" in response: 29 | response = client.list_images( 30 | registryId=registry_id, 31 | nextToken=response["nextToken"], 32 | repositoryName=repository_name, 33 | ) 34 | results.extend(response[result_key]) 35 | return results 36 | 37 | images = get_all_ecr_images(registry_id=registry_id, repository_name=repository_name, result_key="imageIds") 38 | image_tags = [image["imageTag"] for image in images] 39 | image_regex = re.compile("\d\.\d\.\d-" + device + "-.{4}$") 40 | tag = sorted(list(filter(image_regex.match, image_tags)), reverse=True)[0] 41 | return f"{registry_id}.dkr.ecr.{os.environ.get('AWS_DEFAULT_REGION','us-east-1')}.amazonaws.com/{repository_name}:{tag}" 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "task", 46 | [ 47 | "text-classification", 48 | "zero-shot-classification", 49 | "ner", 50 | "question-answering", 51 | "fill-mask", 52 | "summarization", 53 | "translation_xx_to_yy", 54 | "text2text-generation", 55 | "text-generation", 56 | "feature-extraction", 57 | "image-classification", 58 | "automatic-speech-recognition", 59 | ], 60 | ) 61 | @pytest.mark.parametrize( 62 | "framework", 63 | ["pytorch", "tensorflow"], 64 | ) 65 | @pytest.mark.parametrize( 66 | "device", 67 | [ 68 | "gpu", 69 | "cpu", 70 | ], 71 | ) 72 | def test_deployment_from_hub(task, device, framework): 73 | image_uri = get_framework_ecr_image(repository_name=f"huggingface-{framework}-inference", device=device) 74 | name = f"hf-test-{framework}-{device}-{task}".replace("_", "-") 75 | model = task2model[task][framework] 76 | # instance_type = "ml.m5.large" if device == "cpu" else "ml.g4dn.xlarge" 77 | instance_type = "local" if device == "cpu" else "local_gpu" 78 | number_of_requests = 100 79 | if model is None: 80 | return 81 | 82 | env = {"HF_MODEL_ID": model, "HF_TASK": task} 83 | 84 | sagemaker_session = Session() 85 | client = boto3.client("sagemaker-runtime") 86 | 87 | hf_model = Model( 88 | image_uri=image_uri, # A Docker image URI. 89 | model_data=None, # The S3 location of a SageMaker model data .tar.gz 90 | env=env, # Environment variables to run with image_uri when hosted in SageMaker (default: None). 91 | role=SAGEMAKER_EXECUTION_ROLE, # An AWS IAM role (either name or full ARN). 92 | name=name, # The model name 93 | sagemaker_session=sagemaker_session, 94 | ) 95 | 96 | with timeout_and_delete_by_name(name, sagemaker_session, minutes=59): 97 | # Use accelerator type to differentiate EI vs. CPU and GPU. Don't use processor value 98 | hf_model.deploy( 99 | initial_instance_count=1, 100 | instance_type=instance_type, 101 | endpoint_name=name, 102 | ) 103 | 104 | # Keep track of the inference time 105 | time_buffer = [] 106 | 107 | # Warm up the model 108 | if task == "image-classification": 109 | response = client.invoke_endpoint( 110 | EndpointName=name, 111 | Body=task2input[task], 112 | ContentType="image/jpeg", 113 | Accept="application/json", 114 | ) 115 | elif task == "automatic-speech-recognition": 116 | response = client.invoke_endpoint( 117 | EndpointName=name, 118 | Body=task2input[task], 119 | ContentType="audio/x-flac", 120 | Accept="application/json", 121 | ) 122 | else: 123 | response = client.invoke_endpoint( 124 | EndpointName=name, 125 | Body=json.dumps(task2input[task]), 126 | ContentType="application/json", 127 | Accept="application/json", 128 | ) 129 | 130 | # validate response 131 | response_body = response["Body"].read().decode("utf-8") 132 | 133 | assert True is task2validation[task](result=json.loads(response_body), snapshot=task2output[task]) 134 | 135 | for _ in range(number_of_requests): 136 | with track_infer_time(time_buffer): 137 | if task == "image-classification": 138 | response = client.invoke_endpoint( 139 | EndpointName=name, 140 | Body=task2input[task], 141 | ContentType="image/jpeg", 142 | Accept="application/json", 143 | ) 144 | elif task == "automatic-speech-recognition": 145 | response = client.invoke_endpoint( 146 | EndpointName=name, 147 | Body=task2input[task], 148 | ContentType="audio/x-flac", 149 | Accept="application/json", 150 | ) 151 | else: 152 | response = client.invoke_endpoint( 153 | EndpointName=name, 154 | Body=json.dumps(task2input[task]), 155 | ContentType="application/json", 156 | Accept="application/json", 157 | ) 158 | with open(f"{name}.json", "w") as outfile: 159 | data = { 160 | "index": name, 161 | "framework": framework, 162 | "device": device, 163 | "model": model, 164 | "number_of_requests": number_of_requests, 165 | "average_request_time": np.mean(time_buffer), 166 | "max_request_time": max(time_buffer), 167 | "min_request_time": min(time_buffer), 168 | "p95_request_time": np.percentile(time_buffer, 95), 169 | "body": json.loads(response_body), 170 | } 171 | json.dump(data, outfile) 172 | 173 | assert task2performance[task][device]["average_request_time"] >= np.mean(time_buffer) 174 | 175 | clean_up(endpoint_name=name, sagemaker_session=sagemaker_session) 176 | -------------------------------------------------------------------------------- /tests/integ/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import logging 4 | import re 5 | import signal 6 | from contextlib import contextmanager 7 | from time import time 8 | 9 | from botocore.exceptions import ClientError 10 | 11 | 12 | LOGGER = logging.getLogger("timeout") 13 | 14 | 15 | class TimeoutError(Exception): 16 | pass 17 | 18 | 19 | def clean_up(endpoint_name, sagemaker_session): 20 | try: 21 | sagemaker_session.delete_endpoint(endpoint_name) 22 | sagemaker_session.delete_endpoint_config(endpoint_name) 23 | sagemaker_session.delete_model(endpoint_name) 24 | LOGGER.info("deleted endpoint {}".format(endpoint_name)) 25 | except ClientError as ce: 26 | if ce.response["Error"]["Code"] == "ValidationException": 27 | # avoids the inner exception to be overwritten 28 | pass 29 | 30 | 31 | @contextmanager 32 | def timeout(seconds=0, minutes=0, hours=0): 33 | """Add a signal-based timeout to any block of code. 34 | If multiple time units are specified, they will be added together to determine time limit. 35 | Usage: 36 | with timeout(seconds=5): 37 | my_slow_function(...) 38 | Args: 39 | - seconds: The time limit, in seconds. 40 | - minutes: The time limit, in minutes. 41 | - hours: The time limit, in hours. 42 | """ 43 | 44 | limit = seconds + 60 * minutes + 3600 * hours 45 | 46 | def handler(signum, frame): 47 | raise TimeoutError("timed out after {} seconds".format(limit)) 48 | 49 | try: 50 | signal.signal(signal.SIGALRM, handler) 51 | signal.alarm(limit) 52 | 53 | yield 54 | finally: 55 | signal.alarm(0) 56 | 57 | 58 | @contextmanager 59 | def timeout_and_delete_by_name(endpoint_name, sagemaker_session, seconds=0, minutes=0, hours=0): 60 | with timeout(seconds=seconds, minutes=minutes, hours=hours) as t: 61 | try: 62 | yield [t] 63 | finally: 64 | clean_up(endpoint_name, sagemaker_session) 65 | 66 | 67 | @contextmanager 68 | def track_infer_time(buffer=[]): 69 | start = time() 70 | yield 71 | end = time() 72 | 73 | buffer.append(end - start) 74 | 75 | 76 | _re_word_boundaries = re.compile(r"\b") 77 | 78 | 79 | def count_tokens(inputs: dict, task: str) -> int: 80 | if task == "question-answering": 81 | context_len = len(_re_word_boundaries.findall(inputs["context"])) >> 1 82 | question_len = len(_re_word_boundaries.findall(inputs["question"])) >> 1 83 | return question_len + context_len 84 | else: 85 | return len(_re_word_boundaries.findall(inputs)) >> 1 86 | 87 | 88 | def validate_classification(result=None, snapshot=None): 89 | for idx, _ in enumerate(result): 90 | assert result[idx].keys() == snapshot[idx].keys() 91 | assert result[idx]["score"] >= snapshot[idx]["score"] 92 | return True 93 | 94 | 95 | def validate_zero_shot_classification(result=None, snapshot=None): 96 | assert result.keys() == snapshot.keys() 97 | assert result["labels"] == snapshot["labels"] 98 | assert result["sequence"] == snapshot["sequence"] 99 | for idx in range(len(result["scores"])): 100 | assert result["scores"][idx] >= snapshot["scores"][idx] 101 | return True 102 | 103 | 104 | def validate_ner(result=None, snapshot=None): 105 | for idx, _ in enumerate(result): 106 | assert result[idx].keys() == snapshot[idx].keys() 107 | assert result[idx]["score"] >= snapshot[idx]["score"] 108 | assert result[idx]["entity"] == snapshot[idx]["entity"] 109 | assert result[idx]["entity"] == snapshot[idx]["entity"] 110 | return True 111 | 112 | 113 | def validate_question_answering(result=None, snapshot=None): 114 | assert result.keys() == snapshot.keys() 115 | assert result["answer"] == snapshot["answer"] 116 | assert result["score"] >= snapshot["score"] 117 | return True 118 | 119 | 120 | def validate_summarization(result=None, snapshot=None): 121 | assert result == snapshot 122 | return True 123 | 124 | 125 | def validate_text2text_generation(result=None, snapshot=None): 126 | assert result == snapshot 127 | return True 128 | 129 | 130 | def validate_translation(result=None, snapshot=None): 131 | assert result == snapshot 132 | return True 133 | 134 | 135 | def validate_text_generation(result=None, snapshot=None): 136 | assert result is not None 137 | return True 138 | 139 | 140 | def validate_feature_extraction(result=None, snapshot=None): 141 | assert result is not None 142 | return True 143 | 144 | 145 | def validate_fill_mask(result=None, snapshot=None): 146 | assert result is not None 147 | return True 148 | 149 | 150 | def validate_automatic_speech_recognition(result=None, snapshot=None): 151 | assert result is not None 152 | assert "text" in result 153 | return True 154 | -------------------------------------------------------------------------------- /tests/resources/audio/sample1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.flac -------------------------------------------------------------------------------- /tests/resources/audio/sample1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.mp3 -------------------------------------------------------------------------------- /tests/resources/audio/sample1.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.ogg -------------------------------------------------------------------------------- /tests/resources/audio/sample1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.wav -------------------------------------------------------------------------------- /tests/resources/image/tiger.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.bmp -------------------------------------------------------------------------------- /tests/resources/image/tiger.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.gif -------------------------------------------------------------------------------- /tests/resources/image/tiger.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.jpeg -------------------------------------------------------------------------------- /tests/resources/image/tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.png -------------------------------------------------------------------------------- /tests/resources/image/tiger.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.tiff -------------------------------------------------------------------------------- /tests/resources/image/tiger.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.webp -------------------------------------------------------------------------------- /tests/resources/model_input_predict_output_fn_with_context/code/inference.py: -------------------------------------------------------------------------------- 1 | def model_fn(model_dir, context=None): 2 | return "model" 3 | 4 | 5 | def input_fn(data, content_type, context=None): 6 | return "data" 7 | 8 | 9 | def predict_fn(data, model, context=None): 10 | return "output" 11 | 12 | 13 | def output_fn(prediction, accept, context=None): 14 | return prediction 15 | -------------------------------------------------------------------------------- /tests/resources/model_input_predict_output_fn_without_context/code/inference.py: -------------------------------------------------------------------------------- 1 | def model_fn(model_dir): 2 | return "model" 3 | 4 | 5 | def input_fn(data, content_type): 6 | return "data" 7 | 8 | 9 | def predict_fn(data, model): 10 | return "output" 11 | 12 | 13 | def output_fn(prediction, accept): 14 | return prediction 15 | -------------------------------------------------------------------------------- /tests/resources/model_transform_fn_with_context/code/inference_tranform_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def model_fn(model_dir, context=None): 5 | return f"Loading {os.path.basename(__file__)}" 6 | 7 | 8 | def transform_fn(a, b, c, d, context=None): 9 | return f"output {b}" 10 | -------------------------------------------------------------------------------- /tests/resources/model_transform_fn_without_context/code/inference_tranform_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def model_fn(model_dir): 5 | return f"Loading {os.path.basename(__file__)}" 6 | 7 | 8 | def transform_fn(a, b, c, d): 9 | return f"output {b}" 10 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_decoder_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | import os 16 | 17 | import numpy as np 18 | import pytest 19 | from transformers.testing_utils import require_torch 20 | 21 | from mms.service import PredictionException 22 | from PIL import Image 23 | from sagemaker_huggingface_inference_toolkit import decoder_encoder 24 | 25 | 26 | ENCODE_JSON_INPUT = {"upper": [1425], "lower": [576], "level": [2], "datetime": ["2012-08-08 15:30"]} 27 | ENCODE_CSV_INPUT = [ 28 | {"answer": "Nuremberg", "end": 42, "score": 0.9926825761795044, "start": 33}, 29 | {"answer": "Berlin is the capital of Germany", "end": 32, "score": 0.26097726821899414, "start": 0}, 30 | ] 31 | ENCODE_TOLOIST_INPUT = [1, 0.5, 5.0] 32 | 33 | DECODE_JSON_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"} 34 | DECODE_CSV_INPUT = "question,context\r\nwhere do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany" 35 | 36 | CONTENT_TYPE = "application/json" 37 | 38 | 39 | def test_decode_json(): 40 | decoded_data = decoder_encoder.decode_json(json.dumps(DECODE_JSON_INPUT)) 41 | assert decoded_data == DECODE_JSON_INPUT 42 | 43 | 44 | def test_decode_csv(): 45 | decoded_data = decoder_encoder.decode_csv(DECODE_CSV_INPUT) 46 | assert decoded_data == { 47 | "inputs": [ 48 | {"question": "where do i live?", "context": "My name is Philipp and I live in Nuremberg"}, 49 | {"question": "where is Berlin?", "context": "Berlin is the capital of Germany"}, 50 | ] 51 | } 52 | text_classification_input = "inputs\r\nI love you\r\nI like you" 53 | decoded_data = decoder_encoder.decode_csv(text_classification_input) 54 | assert decoded_data == {"inputs": ["I love you", "I like you"]} 55 | 56 | 57 | def test_decode_image(): 58 | image_files_path = os.path.join(os.getcwd(), "tests/resources/image") 59 | 60 | for image_file in os.listdir(image_files_path): 61 | image_bytes = open(os.path.join(image_files_path, image_file), "rb").read() 62 | decoded_data = decoder_encoder.decode_image(bytearray(image_bytes)) 63 | 64 | assert isinstance(decoded_data, dict) 65 | assert isinstance(decoded_data["inputs"], Image.Image) 66 | 67 | 68 | def test_decode_audio(): 69 | audio_files_path = os.path.join(os.getcwd(), "tests/resources/audio") 70 | 71 | for audio_file in os.listdir(audio_files_path): 72 | audio_bytes = open(os.path.join(audio_files_path, audio_file), "rb").read() 73 | decoded_data = decoder_encoder.decode_audio(bytearray(audio_bytes)) 74 | 75 | assert {"inputs": audio_bytes} == decoded_data 76 | 77 | 78 | def test_decode_csv_without_header(): 79 | with pytest.raises(PredictionException): 80 | decoder_encoder.decode_csv( 81 | "where do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany" 82 | ) 83 | 84 | 85 | def test_encode_json(): 86 | encoded_data = decoder_encoder.encode_json(ENCODE_JSON_INPUT) 87 | assert json.loads(encoded_data) == ENCODE_JSON_INPUT 88 | 89 | 90 | @require_torch 91 | def test_encode_json_torch(): 92 | import torch 93 | 94 | encoded_data = decoder_encoder.encode_json({"data": torch.tensor(ENCODE_TOLOIST_INPUT)}) 95 | assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT} 96 | 97 | 98 | def test_encode_json_numpy(): 99 | encoded_data = decoder_encoder.encode_json({"data": np.array(ENCODE_TOLOIST_INPUT)}) 100 | assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT} 101 | 102 | 103 | def test_encode_csv(): 104 | decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT) 105 | assert ( 106 | decoded_data 107 | == "answer,end,score,start\r\nNuremberg,42,0.9926825761795044,33\r\nBerlin is the capital of Germany,32,0.26097726821899414,0\r\n" 108 | ) 109 | 110 | 111 | def test_decode_content_type(): 112 | decoded_content_type = decoder_encoder._decoder_map[CONTENT_TYPE] 113 | assert decoded_content_type == decoder_encoder.decode_json 114 | 115 | 116 | def test_encode_content_type(): 117 | encoded_content_type = decoder_encoder._encoder_map[CONTENT_TYPE] 118 | assert encoded_content_type == decoder_encoder.encode_json 119 | 120 | 121 | def test_decode(): 122 | decode = decoder_encoder.decode(json.dumps(DECODE_JSON_INPUT), CONTENT_TYPE) 123 | assert decode == DECODE_JSON_INPUT 124 | 125 | 126 | def test_encode(): 127 | encode = decoder_encoder.encode(ENCODE_JSON_INPUT, CONTENT_TYPE) 128 | assert json.loads(encode) == ENCODE_JSON_INPUT 129 | -------------------------------------------------------------------------------- /tests/unit/test_diffusers_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tempfile 15 | 16 | from transformers.testing_utils import require_torch, slow 17 | 18 | from PIL import Image 19 | from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMDiffusionPipelineForText2Image 20 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline 21 | 22 | 23 | @require_torch 24 | def test_get_diffusers_pipeline(): 25 | with tempfile.TemporaryDirectory() as tmpdirname: 26 | storage_dir = _load_model_from_hub( 27 | "hf-internal-testing/tiny-stable-diffusion-torch", 28 | tmpdirname, 29 | ) 30 | pipe = get_pipeline("text-to-image", -1, storage_dir) 31 | assert isinstance(pipe, SMDiffusionPipelineForText2Image) 32 | 33 | 34 | @slow 35 | @require_torch 36 | def test_pipe_on_gpu(): 37 | with tempfile.TemporaryDirectory() as tmpdirname: 38 | storage_dir = _load_model_from_hub( 39 | "hf-internal-testing/tiny-stable-diffusion-torch", 40 | tmpdirname, 41 | ) 42 | pipe = get_pipeline("text-to-image", 0, storage_dir) 43 | assert pipe.device.type == "cuda" 44 | 45 | 46 | @require_torch 47 | def test_text_to_image_task(): 48 | with tempfile.TemporaryDirectory() as tmpdirname: 49 | storage_dir = _load_model_from_hub("hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname) 50 | pipe = get_pipeline("text-to-image", -1, storage_dir) 51 | res = pipe("Lets create an embedding") 52 | assert isinstance(res, Image.Image) 53 | -------------------------------------------------------------------------------- /tests/unit/test_handler_service_with_context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | import os 16 | import tempfile 17 | 18 | import pytest 19 | from sagemaker_inference import content_types 20 | from transformers.testing_utils import require_torch, slow 21 | 22 | from mms.context import Context, RequestProcessor 23 | from mms.metrics.metrics_store import MetricsStore 24 | from mock import Mock 25 | from sagemaker_huggingface_inference_toolkit import handler_service 26 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline 27 | 28 | 29 | TASK = "text-classification" 30 | MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" 31 | INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"} 32 | OUTPUT = [ 33 | {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19}, 34 | {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40}, 35 | ] 36 | 37 | 38 | @pytest.fixture() 39 | def inference_handler(): 40 | return handler_service.HuggingFaceHandlerService() 41 | 42 | 43 | def test_get_device_cpu(inference_handler): 44 | device = inference_handler.get_device() 45 | assert device == -1 46 | 47 | 48 | @slow 49 | def test_get_device_gpu(inference_handler): 50 | device = inference_handler.get_device() 51 | assert device > -1 52 | 53 | 54 | @require_torch 55 | def test_test_initialize(inference_handler): 56 | with tempfile.TemporaryDirectory() as tmpdirname: 57 | storage_folder = _load_model_from_hub( 58 | model_id=MODEL, 59 | model_dir=tmpdirname, 60 | ) 61 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") 62 | 63 | inference_handler.initialize(CONTEXT) 64 | assert inference_handler.initialized is True 65 | 66 | 67 | @require_torch 68 | def test_handle(inference_handler): 69 | with tempfile.TemporaryDirectory() as tmpdirname: 70 | storage_folder = _load_model_from_hub( 71 | model_id=MODEL, 72 | model_dir=tmpdirname, 73 | ) 74 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") 75 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 76 | CONTEXT.metrics = MetricsStore(1, MODEL) 77 | 78 | inference_handler.initialize(CONTEXT) 79 | json_data = json.dumps(INPUT) 80 | prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT) 81 | loaded_response = json.loads(prediction[0]) 82 | assert "entity" in loaded_response[0] 83 | assert "score" in loaded_response[0] 84 | 85 | 86 | @require_torch 87 | def test_load(inference_handler): 88 | context = Mock() 89 | with tempfile.TemporaryDirectory() as tmpdirname: 90 | storage_folder = _load_model_from_hub( 91 | model_id=MODEL, 92 | model_dir=tmpdirname, 93 | ) 94 | # test with automatic infer 95 | hf_pipeline_without_task = inference_handler.load(storage_folder, context) 96 | assert hf_pipeline_without_task.task == "token-classification" 97 | 98 | # test with automatic infer 99 | os.environ["HF_TASK"] = TASK 100 | hf_pipeline_with_task = inference_handler.load(storage_folder, context) 101 | assert hf_pipeline_with_task.task == TASK 102 | 103 | 104 | def test_preprocess(inference_handler): 105 | context = Mock() 106 | json_data = json.dumps(INPUT) 107 | decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON, context) 108 | assert "inputs" in decoded_input_data 109 | 110 | 111 | def test_preprocess_bad_content_type(inference_handler): 112 | context = Mock() 113 | with pytest.raises(json.decoder.JSONDecodeError): 114 | inference_handler.preprocess(b"", content_types.JSON, context) 115 | 116 | 117 | @require_torch 118 | def test_predict(inference_handler): 119 | context = Mock() 120 | with tempfile.TemporaryDirectory() as tmpdirname: 121 | storage_folder = _load_model_from_hub( 122 | model_id=MODEL, 123 | model_dir=tmpdirname, 124 | ) 125 | inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder) 126 | prediction = inference_handler.predict(INPUT, inference_handler.model, context) 127 | assert "label" in prediction[0] 128 | assert "score" in prediction[0] 129 | 130 | 131 | def test_postprocess(inference_handler): 132 | context = Mock() 133 | output = inference_handler.postprocess(OUTPUT, content_types.JSON, context) 134 | assert isinstance(output, str) 135 | 136 | 137 | def test_validate_and_initialize_user_module(inference_handler): 138 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_with_context") 139 | CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4") 140 | 141 | inference_handler.initialize(CONTEXT) 142 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 143 | CONTEXT.metrics = MetricsStore(1, MODEL) 144 | 145 | prediction = inference_handler.handle([{"body": b""}], CONTEXT) 146 | assert "output" in prediction[0] 147 | 148 | assert inference_handler.load({}, CONTEXT) == "model" 149 | assert inference_handler.preprocess({}, "", CONTEXT) == "data" 150 | assert inference_handler.predict({}, "model", CONTEXT) == "output" 151 | assert inference_handler.postprocess("output", "", CONTEXT) == "output" 152 | 153 | 154 | def test_validate_and_initialize_user_module_transform_fn(): 155 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" 156 | inference_handler = handler_service.HuggingFaceHandlerService() 157 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") 158 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") 159 | 160 | inference_handler.initialize(CONTEXT) 161 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 162 | CONTEXT.metrics = MetricsStore(1, MODEL) 163 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] 164 | assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py" 165 | assert ( 166 | inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) 167 | == "output dummy" 168 | ) 169 | 170 | 171 | def test_validate_and_initialize_user_module_transform_fn_race_condition(): 172 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" 173 | inference_handler = handler_service.HuggingFaceHandlerService() 174 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") 175 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") 176 | 177 | # Similuate 2 threads bypassing check in handle() - calling initialize twice 178 | inference_handler.initialize(CONTEXT) 179 | inference_handler.initialize(CONTEXT) 180 | 181 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 182 | CONTEXT.metrics = MetricsStore(1, MODEL) 183 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] 184 | assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py" 185 | assert ( 186 | inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) 187 | == "output dummy" 188 | ) 189 | -------------------------------------------------------------------------------- /tests/unit/test_handler_service_without_context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | import os 16 | import tempfile 17 | 18 | import pytest 19 | from sagemaker_inference import content_types 20 | from transformers.testing_utils import require_torch, slow 21 | 22 | from mms.context import Context, RequestProcessor 23 | from mms.metrics.metrics_store import MetricsStore 24 | from sagemaker_huggingface_inference_toolkit import handler_service 25 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline 26 | 27 | 28 | TASK = "text-classification" 29 | MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" 30 | INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"} 31 | OUTPUT = [ 32 | {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19}, 33 | {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40}, 34 | ] 35 | 36 | 37 | @pytest.fixture() 38 | def inference_handler(): 39 | return handler_service.HuggingFaceHandlerService() 40 | 41 | 42 | def test_get_device_cpu(inference_handler): 43 | device = inference_handler.get_device() 44 | assert device == -1 45 | 46 | 47 | @slow 48 | def test_get_device_gpu(inference_handler): 49 | device = inference_handler.get_device() 50 | assert device > -1 51 | 52 | 53 | @require_torch 54 | def test_test_initialize(inference_handler): 55 | with tempfile.TemporaryDirectory() as tmpdirname: 56 | storage_folder = _load_model_from_hub( 57 | model_id=MODEL, 58 | model_dir=tmpdirname, 59 | ) 60 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") 61 | 62 | inference_handler.initialize(CONTEXT) 63 | assert inference_handler.initialized is True 64 | 65 | 66 | @require_torch 67 | def test_handle(inference_handler): 68 | with tempfile.TemporaryDirectory() as tmpdirname: 69 | storage_folder = _load_model_from_hub( 70 | model_id=MODEL, 71 | model_dir=tmpdirname, 72 | ) 73 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4") 74 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 75 | CONTEXT.metrics = MetricsStore(1, MODEL) 76 | 77 | inference_handler.initialize(CONTEXT) 78 | json_data = json.dumps(INPUT) 79 | prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT) 80 | assert "output" in prediction[0] 81 | 82 | 83 | @require_torch 84 | def test_load(inference_handler): 85 | with tempfile.TemporaryDirectory() as tmpdirname: 86 | storage_folder = _load_model_from_hub( 87 | model_id=MODEL, 88 | model_dir=tmpdirname, 89 | ) 90 | # test with automatic infer 91 | if "HF_TASK" in os.environ: 92 | del os.environ["HF_TASK"] 93 | hf_pipeline_without_task = inference_handler.load(storage_folder) 94 | assert hf_pipeline_without_task.task == "token-classification" 95 | 96 | # test with automatic infer 97 | os.environ["HF_TASK"] = "text-classification" 98 | hf_pipeline_with_task = inference_handler.load(storage_folder) 99 | assert hf_pipeline_with_task.task == "text-classification" 100 | 101 | 102 | def test_preprocess(inference_handler): 103 | json_data = json.dumps(INPUT) 104 | decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON) 105 | assert "inputs" in decoded_input_data 106 | 107 | 108 | def test_preprocess_bad_content_type(inference_handler): 109 | with pytest.raises(json.decoder.JSONDecodeError): 110 | inference_handler.preprocess(b"", content_types.JSON) 111 | 112 | 113 | @require_torch 114 | def test_predict(inference_handler): 115 | with tempfile.TemporaryDirectory() as tmpdirname: 116 | storage_folder = _load_model_from_hub( 117 | model_id=MODEL, 118 | model_dir=tmpdirname, 119 | ) 120 | inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder) 121 | prediction = inference_handler.predict(INPUT, inference_handler.model) 122 | assert "label" in prediction[0] 123 | assert "score" in prediction[0] 124 | 125 | 126 | def test_postprocess(inference_handler): 127 | output = inference_handler.postprocess(OUTPUT, content_types.JSON) 128 | assert isinstance(output, str) 129 | 130 | 131 | def test_validate_and_initialize_user_module(inference_handler): 132 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_without_context") 133 | CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4") 134 | 135 | inference_handler.initialize(CONTEXT) 136 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 137 | CONTEXT.metrics = MetricsStore(1, MODEL) 138 | 139 | prediction = inference_handler.handle([{"body": b""}], CONTEXT) 140 | assert "output" in prediction[0] 141 | 142 | assert inference_handler.load({}) == "Loading inference_tranform_fn.py" 143 | 144 | 145 | def test_validate_and_initialize_user_module_transform_fn(): 146 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" 147 | inference_handler = handler_service.HuggingFaceHandlerService() 148 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context") 149 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") 150 | 151 | inference_handler.initialize(CONTEXT) 152 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 153 | CONTEXT.metrics = MetricsStore(1, MODEL) 154 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] 155 | assert inference_handler.load({}) == "Loading inference_tranform_fn.py" 156 | assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy" 157 | 158 | 159 | def test_validate_and_initialize_user_module_transform_fn_race_condition(): 160 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" 161 | inference_handler = handler_service.HuggingFaceHandlerService() 162 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context") 163 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") 164 | 165 | # Similuate 2 threads bypassing check in handle() - calling initialize twice 166 | inference_handler.initialize(CONTEXT) 167 | inference_handler.initialize(CONTEXT) 168 | 169 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] 170 | CONTEXT.metrics = MetricsStore(1, MODEL) 171 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] 172 | assert inference_handler.load({}) == "Loading inference_tranform_fn.py" 173 | assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy" 174 | -------------------------------------------------------------------------------- /tests/unit/test_mms_model_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License.import os 14 | import os 15 | 16 | from sagemaker_inference.environment import model_dir 17 | 18 | from mock import patch 19 | from sagemaker_huggingface_inference_toolkit import mms_model_server, transformers_utils 20 | 21 | 22 | PYTHON_PATH = "python_path" 23 | DEFAULT_CONFIGURATION = "default_configuration" 24 | 25 | 26 | @patch("subprocess.call") 27 | @patch("subprocess.Popen") 28 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") 29 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") 30 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") 31 | @patch("os.path.exists", return_value=True) 32 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") 33 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") 34 | @patch("sagemaker_inference.environment.Environment") 35 | def test_start_mms_default_service_handler( 36 | env, 37 | adapt, 38 | create_config, 39 | exists, 40 | install_requirements, 41 | sigterm, 42 | retrieve, 43 | subprocess_popen, 44 | subprocess_call, 45 | ): 46 | env.return_value.startup_timeout = 10000 47 | mms_model_server.start_model_server() 48 | 49 | # In this case, we should not rearchive the model 50 | adapt.assert_not_called() 51 | 52 | create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE) 53 | exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) 54 | install_requirements.assert_called_once_with() 55 | 56 | multi_model_server_cmd = [ 57 | "multi-model-server", 58 | "--start", 59 | "--model-store", 60 | mms_model_server.DEFAULT_MODEL_STORE, 61 | "--mms-config", 62 | mms_model_server.MMS_CONFIG_FILE, 63 | "--log-config", 64 | mms_model_server.DEFAULT_MMS_LOG_FILE, 65 | "--models", 66 | "{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir), 67 | ] 68 | 69 | subprocess_popen.assert_called_once_with(multi_model_server_cmd) 70 | sigterm.assert_called_once_with(retrieve.return_value) 71 | 72 | 73 | @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True) 74 | @patch("subprocess.call") 75 | @patch("subprocess.Popen") 76 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") 77 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") 78 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") 79 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") 80 | @patch("os.makedirs", return_value=True) 81 | @patch("os.remove", return_value=True) 82 | @patch("os.path.exists", return_value=True) 83 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") 84 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") 85 | @patch("sagemaker_inference.environment.Environment") 86 | def test_start_mms_neuron( 87 | env, 88 | adapt, 89 | create_config, 90 | exists, 91 | remove, 92 | dir, 93 | install_requirements, 94 | sigterm, 95 | load_model_from_hub, 96 | retrieve, 97 | subprocess_popen, 98 | subprocess_call, 99 | is_aws_neuron_available, 100 | ): 101 | env.return_value.startup_timeout = 10000 102 | mms_model_server.start_model_server() 103 | 104 | # In this case, we should not call model archiver 105 | adapt.assert_not_called() 106 | 107 | create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE) 108 | exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) 109 | install_requirements.assert_called_once_with() 110 | 111 | multi_model_server_cmd = [ 112 | "multi-model-server", 113 | "--start", 114 | "--model-store", 115 | mms_model_server.DEFAULT_MODEL_STORE, 116 | "--mms-config", 117 | mms_model_server.MMS_CONFIG_FILE, 118 | "--log-config", 119 | mms_model_server.DEFAULT_MMS_LOG_FILE, 120 | "--models", 121 | "{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir), 122 | ] 123 | 124 | subprocess_popen.assert_called_once_with(multi_model_server_cmd) 125 | sigterm.assert_called_once_with(retrieve.return_value) 126 | 127 | 128 | @patch("subprocess.call") 129 | @patch("subprocess.Popen") 130 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") 131 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") 132 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") 133 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") 134 | @patch("os.makedirs", return_value=True) 135 | @patch("os.remove", return_value=True) 136 | @patch("os.path.exists", return_value=True) 137 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") 138 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") 139 | @patch("sagemaker_inference.environment.Environment") 140 | def test_start_mms_with_model_from_hub( 141 | env, 142 | adapt, 143 | create_config, 144 | exists, 145 | remove, 146 | dir, 147 | install_requirements, 148 | sigterm, 149 | load_model_from_hub, 150 | retrieve, 151 | subprocess_popen, 152 | subprocess_call, 153 | ): 154 | env.return_value.startup_timeout = 10000 155 | 156 | os.environ["HF_MODEL_ID"] = "lysandre/tiny-bert-random" 157 | 158 | mms_model_server.start_model_server() 159 | 160 | load_model_from_hub.assert_called_once_with( 161 | model_id=os.environ["HF_MODEL_ID"], 162 | model_dir=mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, 163 | revision=transformers_utils.HF_MODEL_REVISION, 164 | use_auth_token=transformers_utils.HF_API_TOKEN, 165 | ) 166 | 167 | # When loading model from hub, we do call model archiver 168 | adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub()) 169 | 170 | create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE) 171 | exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH) 172 | install_requirements.assert_called_once_with() 173 | 174 | multi_model_server_cmd = [ 175 | "multi-model-server", 176 | "--start", 177 | "--model-store", 178 | mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, 179 | "--mms-config", 180 | mms_model_server.MMS_CONFIG_FILE, 181 | "--log-config", 182 | mms_model_server.DEFAULT_MMS_LOG_FILE, 183 | ] 184 | 185 | subprocess_popen.assert_called_once_with(multi_model_server_cmd) 186 | sigterm.assert_called_once_with(retrieve.return_value) 187 | os.remove(mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY) 188 | -------------------------------------------------------------------------------- /tests/unit/test_optimum_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import tempfile 16 | 17 | import pytest 18 | from transformers.testing_utils import require_torch 19 | 20 | from sagemaker_huggingface_inference_toolkit.optimum_utils import ( 21 | get_input_shapes, 22 | get_optimum_neuron_pipeline, 23 | is_optimum_neuron_available, 24 | ) 25 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub 26 | 27 | 28 | require_inferentia = pytest.mark.skipif( 29 | not is_optimum_neuron_available(), 30 | reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.", 31 | ) 32 | 33 | 34 | REMOTE_NOT_CONVERTED_MODEL = "hf-internal-testing/tiny-random-BertModel" 35 | REMOTE_CONVERTED_MODEL = "optimum/tiny_random_bert_neuron" 36 | TASK = "text-classification" 37 | 38 | 39 | @require_torch 40 | @require_inferentia 41 | def test_not_supported_task(): 42 | os.environ["HF_TASK"] = "not-supported-task" 43 | with pytest.raises(Exception): 44 | get_optimum_neuron_pipeline(task=TASK, model_dir=os.getcwd()) 45 | 46 | 47 | @require_torch 48 | @require_inferentia 49 | def test_get_input_shapes_from_file(): 50 | with tempfile.TemporaryDirectory() as tmpdirname: 51 | storage_folder = _load_model_from_hub( 52 | model_id=REMOTE_CONVERTED_MODEL, 53 | model_dir=tmpdirname, 54 | ) 55 | input_shapes = get_input_shapes(model_dir=storage_folder) 56 | assert input_shapes["batch_size"] == 1 57 | assert input_shapes["sequence_length"] == 32 58 | 59 | 60 | @require_torch 61 | @require_inferentia 62 | def test_get_input_shapes_from_env(): 63 | os.environ["HF_OPTIMUM_BATCH_SIZE"] = "4" 64 | os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" 65 | with tempfile.TemporaryDirectory() as tmpdirname: 66 | storage_folder = _load_model_from_hub( 67 | model_id=REMOTE_NOT_CONVERTED_MODEL, 68 | model_dir=tmpdirname, 69 | ) 70 | input_shapes = get_input_shapes(model_dir=storage_folder) 71 | assert input_shapes["batch_size"] == 4 72 | assert input_shapes["sequence_length"] == 32 73 | 74 | 75 | @require_torch 76 | @require_inferentia 77 | def test_get_optimum_neuron_pipeline_from_converted_model(): 78 | with tempfile.TemporaryDirectory() as tmpdirname: 79 | os.system( 80 | f"optimum-cli export neuron --model philschmid/tiny-distilbert-classification --sequence_length 32 --batch_size 1 {tmpdirname}" 81 | ) 82 | pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=tmpdirname) 83 | r = pipe("This is a test") 84 | 85 | assert r[0]["score"] > 0.0 86 | assert isinstance(r[0]["label"], str) 87 | 88 | 89 | @require_torch 90 | @require_inferentia 91 | def test_get_optimum_neuron_pipeline_from_non_converted_model(): 92 | os.environ["OPTIMUM_NEURON_SEQUENCE_LENGTH"] = "32" 93 | with tempfile.TemporaryDirectory() as tmpdirname: 94 | storage_folder = _load_model_from_hub( 95 | model_id=REMOTE_NOT_CONVERTED_MODEL, 96 | model_dir=tmpdirname, 97 | ) 98 | pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=storage_folder) 99 | r = pipe("This is a test") 100 | 101 | assert r[0]["score"] > 0.0 102 | assert isinstance(r[0]["label"], str) 103 | -------------------------------------------------------------------------------- /tests/unit/test_serving.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mock import patch 16 | 17 | 18 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server.start_model_server") 19 | def test_hosting_start(start_model_server): 20 | from sagemaker_huggingface_inference_toolkit import serving 21 | 22 | serving.main() 23 | start_model_server.assert_called_with(handler_service="sagemaker_huggingface_inference_toolkit.handler_service") 24 | 25 | 26 | def test_retry_if_error(): 27 | from sagemaker_huggingface_inference_toolkit import serving 28 | 29 | serving._retry_if_error(Exception) 30 | -------------------------------------------------------------------------------- /tests/unit/test_transformers_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import tempfile 16 | 17 | from transformers.file_utils import is_torch_available 18 | from transformers.testing_utils import require_tf, require_torch, slow 19 | 20 | from sagemaker_huggingface_inference_toolkit.transformers_utils import ( 21 | _build_storage_path, 22 | _get_framework, 23 | _is_gpu_available, 24 | _load_model_from_hub, 25 | get_pipeline, 26 | infer_task_from_hub, 27 | infer_task_from_model_architecture, 28 | ) 29 | 30 | 31 | MODEL = "lysandre/tiny-bert-random" 32 | TASK = "text-classification" 33 | TASK_MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" 34 | 35 | MAIN_REVISION = "main" 36 | REVISION = "eb4c77816edd604d0318f8e748a1c606a2888493" 37 | 38 | 39 | @require_torch 40 | def test_loading_model_from_hub(): 41 | with tempfile.TemporaryDirectory() as tmpdirname: 42 | storage_folder = _load_model_from_hub( 43 | model_id=MODEL, 44 | model_dir=tmpdirname, 45 | ) 46 | 47 | # folder contains all config files and pytorch_model.bin 48 | folder_contents = os.listdir(storage_folder) 49 | assert "config.json" in folder_contents 50 | 51 | 52 | @require_torch 53 | def test_loading_model_from_hub_with_revision(): 54 | with tempfile.TemporaryDirectory() as tmpdirname: 55 | storage_folder = _load_model_from_hub(model_id=MODEL, model_dir=tmpdirname, revision=REVISION) 56 | 57 | # folder contains all config files and pytorch_model.bin 58 | assert REVISION in storage_folder 59 | folder_contents = os.listdir(storage_folder) 60 | assert "config.json" in folder_contents 61 | assert "tokenizer_config.json" not in folder_contents 62 | 63 | 64 | @require_torch 65 | def test_loading_model_safetensor_from_hub_with_revision(): 66 | with tempfile.TemporaryDirectory() as tmpdirname: 67 | storage_folder = _load_model_from_hub( 68 | model_id="hf-internal-testing/tiny-random-bert-safetensors", model_dir=tmpdirname 69 | ) 70 | 71 | folder_contents = os.listdir(storage_folder) 72 | assert "model.safetensors" in folder_contents 73 | 74 | 75 | def test_gpu_is_not_available(): 76 | device = _is_gpu_available() 77 | assert device is False 78 | 79 | 80 | def test_build_storage_path(): 81 | storage_path = _build_storage_path(model_id=MODEL, model_dir="x") 82 | assert "__" in storage_path 83 | 84 | storage_path = _build_storage_path(model_id="bert-base-uncased", model_dir="x") 85 | assert "__" not in storage_path 86 | 87 | storage_path = _build_storage_path(model_id=MODEL, model_dir="x", revision=REVISION) 88 | assert "__" in storage_path and "." in storage_path 89 | 90 | 91 | @slow 92 | def test_gpu_available(): 93 | device = _is_gpu_available() 94 | assert device is True 95 | 96 | 97 | @require_torch 98 | def test_get_framework_pytorch(): 99 | framework = _get_framework() 100 | assert framework == "pytorch" 101 | 102 | 103 | @require_tf 104 | def test_get_framework_tensorflow(): 105 | framework = _get_framework() 106 | if is_torch_available(): 107 | assert framework == "pytorch" 108 | else: 109 | assert framework == "tensorflow" 110 | 111 | 112 | def test_get_pipeline(): 113 | with tempfile.TemporaryDirectory() as tmpdirname: 114 | storage_dir = _load_model_from_hub(MODEL, tmpdirname) 115 | pipe = get_pipeline(TASK, -1, storage_dir) 116 | res = pipe("Life is good, Life is bad") 117 | assert "score" in res[0] 118 | 119 | 120 | def test_infer_task_from_hub(): 121 | task = infer_task_from_hub(TASK_MODEL) 122 | assert task == "token-classification" 123 | 124 | 125 | def test_infer_task_from_model_architecture(): 126 | with tempfile.TemporaryDirectory() as tmpdirname: 127 | storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname) 128 | task = infer_task_from_model_architecture(f"{storage_dir}/config.json") 129 | assert task == "token-classification" 130 | --------------------------------------------------------------------------------