├── .github
└── workflows
│ ├── integ-test.yml
│ ├── quality.yml
│ └── unit-test.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── NOTICE
├── README.md
├── makefile
├── pyproject.toml
├── setup.cfg
├── setup.py
├── src
└── sagemaker_huggingface_inference_toolkit
│ ├── __init__.py
│ ├── content_types.py
│ ├── decoder_encoder.py
│ ├── diffusers_utils.py
│ ├── handler_service.py
│ ├── mms_model_server.py
│ ├── optimum_utils.py
│ ├── serving.py
│ └── transformers_utils.py
└── tests
├── integ
├── __init__.py
├── config.py
├── test_diffusers.py
├── test_models_from_hub.py
└── utils.py
├── resources
├── audio
│ ├── sample1.flac
│ ├── sample1.mp3
│ ├── sample1.ogg
│ └── sample1.wav
├── image
│ ├── tiger.bmp
│ ├── tiger.gif
│ ├── tiger.jpeg
│ ├── tiger.png
│ ├── tiger.tiff
│ └── tiger.webp
├── model_input_predict_output_fn_with_context
│ └── code
│ │ └── inference.py
├── model_input_predict_output_fn_without_context
│ └── code
│ │ └── inference.py
├── model_transform_fn_with_context
│ └── code
│ │ └── inference_tranform_fn.py
└── model_transform_fn_without_context
│ └── code
│ └── inference_tranform_fn.py
└── unit
├── __init__.py
├── test_decoder_encoder.py
├── test_diffusers_utils.py
├── test_handler_service_with_context.py
├── test_handler_service_without_context.py
├── test_mms_model_server.py
├── test_optimum_utils.py
├── test_serving.py
└── test_transformers_utils.py
/.github/workflows/integ-test.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on:
4 | # pull_request:
5 | workflow_dispatch:
6 |
7 |
8 | jobs:
9 | test:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 | - name: Set up Python 3.8
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: 3.8
17 | - name: Install Python dependencies
18 | run: pip install -e .[test,dev]
19 | - name: Run Integration Tests
20 | env:
21 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
22 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
23 | AWS_DEFAULT_REGION: us-east-1
24 | run: make integ-test
--------------------------------------------------------------------------------
/.github/workflows/quality.yml:
--------------------------------------------------------------------------------
1 | name: Quality Check
2 |
3 | on: [pull_request]
4 |
5 | jobs:
6 | quality:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - name: Set up Python 3.8
11 | uses: actions/setup-python@v2
12 | with:
13 | python-version: 3.8
14 | - name: Install Python dependencies
15 | run: pip install -e .[quality]
16 | - name: Run Quality check
17 | run: make quality
--------------------------------------------------------------------------------
/.github/workflows/unit-test.yml:
--------------------------------------------------------------------------------
1 | name: Run Unit-Tests
2 |
3 | on: [pull_request]
4 |
5 | jobs:
6 | test:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - name: Set up Python 3.8
11 | uses: actions/setup-python@v2
12 | with:
13 | python-version: 3.8
14 | - name: Install Python dependencies
15 | run: pip install -e .[test,dev]
16 | - name: Run Unit Tests
17 | run: make unit-test
18 | # - name: Run Integration Tests
19 | # run: make integ-test
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Docker project generated files to ignore
2 | # if you want to ignore files created by your editor/tools,
3 | # please consider a global .gitignore https://help.github.com/articles/ignoring-files
4 | .vagrant*
5 | bin
6 | docker/docker
7 | .*.swp
8 | a.out
9 | *.orig
10 | build_src
11 | .flymake*
12 | .idea
13 | .DS_Store
14 | docs/_build
15 | docs/_static
16 | docs/_templates
17 | .gopath/
18 | .dotcloud
19 | *.test
20 | bundles/
21 | .hg/
22 | .git/
23 | vendor/pkg/
24 | pyenv
25 | Vagrantfile
26 | # Byte-compiled / optimized / DLL files
27 | __pycache__/
28 | *.py[cod]
29 | *$py.class
30 |
31 | # C extensions
32 | *.so
33 |
34 | # Distribution / packaging
35 | .Python
36 | build/
37 | develop-eggs/
38 | dist/
39 | downloads/
40 | eggs/
41 | .eggs/
42 | lib/
43 | lib64/
44 | parts/
45 | sdist/
46 | var/
47 | wheels/
48 | share/python-wheels/
49 | *.egg-info/
50 | .installed.cfg
51 | *.egg
52 | MANIFEST
53 |
54 | # PyInstaller
55 | # Usually these files are written by a python script from a template
56 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
57 | *.manifest
58 | *.spec
59 |
60 | # Installer logs
61 | pip-log.txt
62 | pip-delete-this-directory.txt
63 |
64 | # Unit test / coverage reports
65 | htmlcov/
66 | .tox/
67 | .nox/
68 | .coverage
69 | .coverage.*
70 | .cache
71 | nosetests.xml
72 | coverage.xml
73 | *.cover
74 | *.py,cover
75 | .hypothesis/
76 | .pytest_cache/
77 | cover/
78 |
79 | # Translations
80 | *.mo
81 | *.pot
82 |
83 | # Django stuff:
84 | *.log
85 | local_settings.py
86 | db.sqlite3
87 | db.sqlite3-journal
88 |
89 | # Flask stuff:
90 | instance/
91 | .webassets-cache
92 |
93 | # Scrapy stuff:
94 | .scrapy
95 |
96 | # Sphinx documentation
97 | docs/_build/
98 |
99 | # PyBuilder
100 | .pybuilder/
101 | target/
102 |
103 | # Jupyter Notebook
104 | .ipynb_checkpoints
105 |
106 | # IPython
107 | profile_default/
108 | ipython_config.py
109 |
110 | # pyenv
111 | # For a library or package, you might want to ignore these files since the code is
112 | # intended to run in multiple environments; otherwise, check them in:
113 | # .python-version
114 |
115 | # pipenv
116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
119 | # install all needed dependencies.
120 | #Pipfile.lock
121 |
122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
123 | __pypackages__/
124 |
125 | # Celery stuff
126 | celerybeat-schedule
127 | celerybeat.pid
128 |
129 | # SageMath parsed files
130 | *.sage.py
131 |
132 | # Environments
133 | .env
134 | .venv
135 | env/
136 | venv/
137 | ENV/
138 | env.bak/
139 | venv.bak/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # pytype static type analyzer
160 | .pytype/
161 |
162 | # Cython debug symbols
163 | cython_debug/
164 |
165 | .vscode/settings.json
166 | .sagemaker
167 | model
168 | tests/tmp
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # how to contribute to huggingface inference toolkit?
2 |
3 | Everyone is welcome to contribute, and we value everybody's contribution. Code
4 | is thus not the only way to help the community. Answering questions, helping
5 | others, reaching out and improving the documentations are immensely valuable to
6 | the community.
7 |
8 | It also helps us if you spread the word: reference the library from blog posts
9 | on the awesome projects it made possible, shout out on Twitter every time it has
10 | helped you, or simply star the repo to say "thank you".
11 |
12 | Whichever way you choose to contribute, please be mindful to respect our
13 | [code of conduct](https://github.com/huggingface/transformers/blob/master/CODE_OF_CONDUCT.md).
14 |
15 |
16 | ## Reporting Bugs/Feature Requests
17 |
18 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
19 |
20 | When filing an issue, please check [existing open](https://github.com/aws/sagemaker_huggingface_inference_toolkit/issues), or [recently closed](https://github.com/aws/sagemaker_huggingface_inference_toolkit/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already
21 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
22 |
23 | * A reproducible test case or series of steps
24 | * The version of our code being used
25 | * Any modifications you've made relevant to the bug
26 | * Anything unusual about your environment or deployment
27 |
28 |
29 | ## Contributing via Pull Requests
30 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
31 |
32 | 1. You are working against the latest source on the *master* branch.
33 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
34 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
35 |
36 | To send us a pull request, please:
37 |
38 | 1. Fork the repository.
39 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
40 | 3. Ensure local tests pass.
41 | 4. Commit to your fork using clear commit messages.
42 | 5. Send us a pull request, answering any default questions in the pull request interface.
43 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
44 |
45 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
46 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
47 |
48 |
49 | ## Finding contributions to work on
50 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/huggingface/sagemaker_huggingface_inference_toolkit/labels/help%20wanted) issues is a great place to start.
51 |
52 |
53 | ## Code of Conduct
54 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
55 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
56 | opensource-codeofconduct@amazon.com with any additional questions or comments.
57 |
58 |
59 | ## Security issue notifications
60 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
61 |
62 |
63 | ## Licensing
64 |
65 | See the [LICENSE](https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/main/LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
66 |
67 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 |
2 | include LICENSE
3 | include README.md
4 |
5 | recursive-exclude * __pycache__
6 | recursive-exclude * *.py[co]
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |

4 |
5 |
6 |
7 |
8 |
9 | # SageMaker Hugging Face Inference Toolkit
10 |
11 | [](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [](https://github.com/python/black)
12 |
13 |
14 | SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers and Diffusers models and tasks. It utilizes the [SageMaker Inference Toolkit](https://github.com/aws/sagemaker-inference-toolkit) for starting up the model server, which is responsible for handling inference requests.
15 |
16 | For Training, see [Run training on Amazon SageMaker](https://huggingface.co/docs/sagemaker/train).
17 |
18 | For the Dockerfiles used for building SageMaker Hugging Face Containers, see [AWS Deep Learning Containers](https://github.com/aws/deep-learning-containers/tree/master/huggingface).
19 |
20 | For information on running Hugging Face jobs on Amazon SageMaker, please refer to the [🤗 Transformers documentation](https://huggingface.co/docs/sagemaker).
21 |
22 | For notebook examples: [SageMaker Notebook Examples](https://github.com/huggingface/notebooks/tree/master/sagemaker).
23 |
24 | ---
25 | ## 💻 Getting Started with 🤗 Inference Toolkit
26 |
27 | _needs to be adjusted -> currently pseudo code_
28 |
29 | **Install Amazon SageMaker Python SDK**
30 |
31 | ```bash
32 | pip install sagemaker --upgrade
33 | ```
34 |
35 | **Create a Amazon SageMaker endpoint with a trained model.**
36 |
37 | ```python
38 | from sagemaker.huggingface import HuggingFaceModel
39 |
40 | # create Hugging Face Model Class
41 | huggingface_model = HuggingFaceModel(
42 | transformers_version='4.6',
43 | pytorch_version='1.7',
44 | py_version='py36',
45 | model_data='s3://my-trained-model/artifacts/model.tar.gz',
46 | role=role,
47 | )
48 | # deploy model to SageMaker Inference
49 | huggingface_model.deploy(initial_instance_count=1,instance_type="ml.m5.xlarge")
50 | ```
51 |
52 |
53 | **Create a Amazon SageMaker endpoint with a model from the [🤗 Hub](https://huggingface.co/models).**
54 | _note: This is an experimental feature, where the model will be loaded after the endpoint is created. Not all sagemaker features are supported, e.g. MME_
55 | ```python
56 | from sagemaker.huggingface import HuggingFaceModel
57 | # Hub Model configuration. https://huggingface.co/models
58 | hub = {
59 | 'HF_MODEL_ID':'distilbert-base-uncased-distilled-squad',
60 | 'HF_TASK':'question-answering'
61 | }
62 | # create Hugging Face Model Class
63 | huggingface_model = HuggingFaceModel(
64 | transformers_version='4.6',
65 | pytorch_version='1.7',
66 | py_version='py36',
67 | env=hub,
68 | role=role,
69 | )
70 | # deploy model to SageMaker Inference
71 | huggingface_model.deploy(initial_instance_count=1,instance_type="ml.m5.xlarge")
72 | ```
73 |
74 | ---
75 |
76 | ## 🛠️ Environment variables
77 |
78 | The SageMaker Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below.
79 |
80 | #### `HF_TASK`
81 |
82 | The `HF_TASK` environment variable defines the task for the used 🤗 Transformers pipeline. A full list of tasks can be find [here](https://huggingface.co/transformers/main_classes/pipelines.html).
83 |
84 | ```bash
85 | HF_TASK="question-answering"
86 | ```
87 |
88 | #### `HF_MODEL_ID`
89 |
90 | The `HF_MODEL_ID` environment variable defines the model id, which will be automatically loaded from [huggingface.co/models](https://huggingface.co/models) when creating or SageMaker Endpoint. The 🤗 Hub provides +10 000 models all available through this environment variable.
91 |
92 | ```bash
93 | HF_MODEL_ID="distilbert-base-uncased-finetuned-sst-2-english"
94 | ```
95 |
96 | #### `HF_MODEL_REVISION`
97 |
98 | The `HF_MODEL_REVISION` is an extension to `HF_MODEL_ID` and allows you to define/pin a revision of the model to make sure you always load the same model on your SageMaker Endpoint.
99 |
100 | ```bash
101 | HF_MODEL_REVISION="03b4d196c19d0a73c7e0322684e97db1ec397613"
102 | ```
103 |
104 | #### `HF_API_TOKEN`
105 |
106 | The `HF_API_TOKEN` environment variable defines the your Hugging Face authorization token. The `HF_API_TOKEN` is used as a HTTP bearer authorization for remote files, like private models. You can find your token at your [settings page](https://huggingface.co/settings/token).
107 |
108 | ```bash
109 | HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
110 | ```
111 |
112 | #### `HF_TRUST_REMOTE_CODE`
113 |
114 | The `HF_TRUST_REMOTE_CODE` environment variable defines wether or not to allow for custom models defined on the Hub in their own modeling files. Allowed values are `"True"` and `"False"`
115 |
116 | ```bash
117 | HF_TRUST_REMOTE_CODE="True"
118 | ```
119 |
120 | #### `HF_OPTIMUM_BATCH_SIZE`
121 |
122 | The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
123 |
124 | ```bash
125 | HF_OPTIMUM_BATCH_SIZE="1"
126 | ```
127 |
128 | #### `HF_OPTIMUM_SEQUENCE_LENGTH`
129 |
130 | The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
131 |
132 | ```bash
133 | HF_OPTIMUM_SEQUENCE_LENGTH="128"
134 | ```
135 |
136 | ---
137 |
138 | ## 🧑🏻💻 User defined code/modules
139 |
140 | The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefore, they need to create a folder named `code/` with an `inference.py` file in it. You can find an example for it in [sagemaker/17_customer_inference_script](https://github.com/huggingface/notebooks/blob/master/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb).
141 | For example:
142 | ```bash
143 | model.tar.gz/
144 | |- pytorch_model.bin
145 | |- ....
146 | |- code/
147 | |- inference.py
148 | |- requirements.txt
149 | ```
150 | In this example, `pytorch_model.bin` is the model file saved from training, `inference.py` is the custom inference module, and `requirements.txt` is a requirements file to add additional dependencies.
151 | The custom module can override the following methods:
152 |
153 | * `model_fn(model_dir, context=None)`: overrides the default method for loading the model, the return value `model` will be used in the `predict()` for predicitions. It receives argument the `model_dir`, the path to your unzipped `model.tar.gz`.
154 | * `transform_fn(model, data, content_type, accept_type)`: overrides the default transform function with a custom implementation. Customers using this would have to implement `preprocess`, `predict` and `postprocess` steps in the `transform_fn`. **NOTE: This method can't be combined with `input_fn`, `predict_fn` or `output_fn` mentioned below.**
155 | * `input_fn(input_data, content_type)`: overrides the default method for preprocessing, the return value `data` will be used in the `predict()` method for predicitions. The input is `input_data`, the raw body of your request and `content_type`, the content type form the request Header.
156 | * `predict_fn(processed_data, model)`: overrides the default method for predictions, the return value `predictions` will be used in the `postprocess()` method. The input is `processed_data`, the result of the `preprocess()` method.
157 | * `output_fn(prediction, accept)`: overrides the default method for postprocessing, the return value `result` will be the respond of your request(e.g.`JSON`). The inputs are `predictions`, the result of the `predict()` method and `accept` the return accept type from the HTTP Request, e.g. `application/json`
158 |
159 |
160 | ## 🏎️ Deploy Models on AWS Inferentia2
161 |
162 | The SageMaker Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options:
163 | * Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format. e.g. `optimum/bge-base-en-v1.5-neuronx`
164 | * Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128`
165 | * Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}`
166 |
167 | The currently supported tasks can be found [here](https://huggingface.co/docs/optimum-neuron/en/package_reference/supported_models). If you plan to deploy an LLM, we recommend taking a look at [Neuronx TGI](https://huggingface.co/blog/text-generation-inference-on-inferentia2), which is purposly build for LLMs
168 |
169 | ---
170 | ## 🤝 Contributing
171 |
172 | Please read [CONTRIBUTING.md](https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/main/CONTRIBUTING.md)
173 | for details on our code of conduct, and the process for submitting pull
174 | requests to us.
175 |
176 | ---
177 | ## 📜 License
178 |
179 | SageMaker Hugging Face Inference Toolkit is licensed under the Apache 2.0 License.
180 |
181 | ---
182 |
183 | ## 🧑🏻💻 Development Environment
184 |
185 | Install all test and development packages with
186 |
187 | ```bash
188 | pip3 install -e ".[test,dev]"
189 | ```
190 | ## Run Model Locally
191 |
192 | 1. manually change `MMS_CONFIG_FILE`
193 | ```
194 | wget -O sagemaker-mms.properties https://raw.githubusercontent.com/aws/deep-learning-containers/master/huggingface/build_artifacts/inference/config.properties
195 | ```
196 |
197 | 2. Run Container, e.g. `text-to-image`
198 | ```
199 | HF_MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0" HF_TASK="text-to-image" python src/sagemaker_huggingface_inference_toolkit/serving.py
200 | ```
201 | 3. Adjust `handler_service.py` and comment out `if content_type in content_types.UTF8_TYPES:` thats needed for SageMaker but cannot be used locally
202 |
203 | 3. Send request
204 | ```
205 | curl --request POST \
206 | --url http://localhost:8080/invocations \
207 | --header 'Accept: image/png' \
208 | --header 'Content-Type: application/json' \
209 | --data '"{\"inputs\": \"Camera\"}" \
210 | --output image.png
211 | ```
212 |
213 |
214 | ## Run Inferentia2 Model Locally
215 |
216 | _Note: You need to run this on an Inferentia2 instance._
217 |
218 | 1. manually change `MMS_CONFIG_FILE`
219 | ```
220 | wget -O sagemaker-mms.properties https://raw.githubusercontent.com/aws/deep-learning-containers/master/huggingface/build_artifacts/inference/config.properties
221 | ```
222 |
223 | 2. Adjust `handler_service.py` and comment out `if content_type in content_types.UTF8_TYPES:` thats needed for SageMaker but cannot be used locally
224 |
225 | 2. Run Container,
226 |
227 | - transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
228 | ```
229 | HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 python src/sagemaker_huggingface_inference_toolkit/serving.py
230 | ```
231 | - sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
232 | ```
233 | HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 python src/sagemaker_huggingface_inference_toolkit/serving.py
234 | ```
235 |
236 | 3. Send request
237 | ```
238 | curl --request POST \
239 | --url http://localhost:8080/invocations \
240 | --header 'Content-Type: application/json' \
241 | --data "{\"inputs\": \"I like you.\"}"
242 | ```
243 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style unit-test integ-test
2 |
3 | check_dirs := src tests
4 |
5 | # run tests
6 |
7 | unit-test:
8 | python -m pytest -v -s ./tests/unit/
9 |
10 | integ-test:
11 | python -m pytest -n 2 -s -v ./tests/integ/
12 | # python -m pytest -n auto -s -v ./tests/integ/
13 |
14 |
15 | # Check that source code meets quality standards
16 |
17 | quality:
18 | black --check --line-length 119 --target-version py36 $(check_dirs)
19 | isort --check-only $(check_dirs)
20 | flake8 $(check_dirs)
21 |
22 | # Format source code automatically
23 |
24 | style:
25 | # black --line-length 119 --target-version py36 tests src benchmarks datasets metrics
26 | black --line-length 119 --target-version py36 $(check_dirs)
27 | isort $(check_dirs)
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 | target-version = ['py36']
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = sagemaker_huggingface_inference_toolkit
7 | known_third_party =
8 | transformers
9 | sagemaker_inference
10 | huggingface_hub
11 | datasets
12 | pytest
13 | sklearn
14 | tensorflow
15 | torch
16 | retrying
17 | numpy
18 |
19 |
20 | line_length = 119
21 | lines_after_imports = 2
22 | multi_line_output = 3
23 | use_parentheses = True
24 |
25 | [flake8]
26 | ignore = E203, E501, E741, W503, W605
27 | max-line-length = 119
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Build release wheels as follows
16 | # $ SM_HF_TOOLKIT_RELEASE=1 python setup.py bdist_wheel build
17 | # $ twine upload --repository-url https://test.pypi.org/legacy/ dist/* # upload to test.pypi
18 | # Test the wheel by downloading from test.pypi
19 | # $ pip install -i https://test.pypi.org/simple/ sagemaker-huggingface-inference-toolkit==
20 | # Once test is complete
21 | # Upload the wheel to pypi
22 | # $ twine upload dist/*
23 |
24 |
25 | from __future__ import absolute_import
26 | import os
27 | from datetime import date
28 | from setuptools import find_packages, setup
29 |
30 | # We don't declare our dependency on transformers here because we build with
31 | # different packages for different variants
32 |
33 | VERSION = "2.6.0"
34 |
35 |
36 | # Ubuntu packages
37 | # libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
38 | # ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
39 | # libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg
40 |
41 | install_requires = [
42 | "sagemaker-inference>=1.8.0",
43 | "huggingface_hub>=0.0.8",
44 | "retrying",
45 | "numpy",
46 | # vision
47 | "Pillow",
48 | # speech + torchaudio
49 | "librosa",
50 | "pyctcdecode>=0.3.0",
51 | "phonemizer",
52 | ]
53 |
54 | extras = {}
55 |
56 | # Hugging Face specific dependencies
57 | extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]
58 | extras["diffusers"] = ["diffusers>=0.23.0"]
59 |
60 | # framework specific dependencies
61 | extras["torch"] = ["torch>=1.8.0", "torchaudio"]
62 |
63 | # TODO: Remove upper bound of TF 2.11 once transformers release contains this fix: https://github.com/huggingface/evaluate/pull/372
64 | extras["tensorflow"] = ["tensorflow>=2.4.0,<2.11"]
65 |
66 | # MMS Server dependencies
67 | extras["mms"] = ["multi-model-server>=1.1.4", "retrying"]
68 |
69 |
70 | extras["test"] = [
71 | "pytest<8",
72 | "pytest-xdist",
73 | "parameterized",
74 | "psutil",
75 | "datasets",
76 | "pytest-sugar",
77 | "black==21.4b0",
78 | "sagemaker",
79 | "boto3",
80 | "mock==2.0.0",
81 | ]
82 |
83 | extras["benchmark"] = ["boto3", "locust"]
84 |
85 | extras["quality"] = [
86 | "black>=21.10",
87 | "isort>=5.5.4",
88 | "flake8>=3.8.3",
89 | ]
90 |
91 | extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"] + extras["diffusers"]
92 | setup(
93 | name="sagemaker-huggingface-inference-toolkit",
94 | version=VERSION,
95 | # if os.getenv("SM_HF_TOOLKIT_RELEASE") is not None
96 | # else VERSION + "b" + str(date.today()).replace("-", ""),
97 | author="HuggingFace and Amazon Web Services",
98 | description="Open source library for running inference workload with Hugging Face Deep Learning Containers on "
99 | "Amazon SageMaker.",
100 | long_description=open("README.md", "r", encoding="utf-8").read(),
101 | long_description_content_type="text/markdown",
102 | keywords="NLP deep-learning transformer pytorch tensorflow BERT GPT GPT-2 AWS Amazon SageMaker Cloud",
103 | url="https://github.com/aws/sagemaker-huggingface-inference-toolkit",
104 | package_dir={"": "src"},
105 | packages=find_packages(where="src"),
106 | install_requires=install_requires,
107 | extras_require=extras,
108 | entry_points={"console_scripts": "serve=sagemaker_huggingface_inference_toolkit.serving:main"},
109 | python_requires=">=3.6.0",
110 | license="Apache License 2.0",
111 | classifiers=[
112 | "Development Status :: 5 - Production/Stable",
113 | "Intended Audience :: Developers",
114 | "Intended Audience :: Education",
115 | "Intended Audience :: Science/Research",
116 | "License :: OSI Approved :: Apache Software License",
117 | "Operating System :: OS Independent",
118 | "Programming Language :: Python :: 3",
119 | "Programming Language :: Python :: 3.6",
120 | "Programming Language :: Python :: 3.7",
121 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
122 | ],
123 | )
124 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import absolute_import
15 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/content_types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """This module contains constants that define MIME content types."""
15 | # Default Mime-Types
16 | JSON = "application/json"
17 | CSV = "text/csv"
18 | OCTET_STREAM = "application/octet-stream"
19 | ANY = "*/*"
20 | NPY = "application/x-npy"
21 | UTF8_TYPES = [JSON, CSV]
22 | # Vision Mime-Types
23 | JPEG = "image/jpeg"
24 | PNG = "image/png"
25 | TIFF = "image/tiff"
26 | BMP = "image/bmp"
27 | GIF = "image/gif"
28 | WEBP = "image/webp"
29 | X_IMAGE = "image/x-image"
30 | VISION_TYPES = [JPEG, PNG, TIFF, BMP, GIF, WEBP, X_IMAGE]
31 | # Speech Mime-Types
32 | FLAC = "audio/x-flac"
33 | MP3 = "audio/mpeg"
34 | WAV = "audio/wave"
35 | OGG = "audio/ogg"
36 | X_AUDIO = "audio/x-audio"
37 | AUDIO_TYPES = [FLAC, MP3, WAV, OGG, X_AUDIO]
38 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import base64
15 | import csv
16 | import datetime
17 | import json
18 | from io import BytesIO, StringIO
19 |
20 | import numpy as np
21 | from sagemaker_inference import errors
22 | from sagemaker_inference.decoder import _npy_to_numpy
23 |
24 | from mms.service import PredictionException
25 | from PIL import Image
26 | from sagemaker_huggingface_inference_toolkit import content_types
27 |
28 |
29 | def decode_json(content):
30 | return json.loads(content)
31 |
32 |
33 | def decode_csv(string_like): # type: (str) -> np.array
34 | """Convert a CSV object to a dictonary with list attributes.
35 |
36 | Args:
37 | string_like (str): CSV string.
38 | Returns:
39 | (dict): dictonatry for input
40 | """
41 | stream = StringIO(string_like)
42 | # detects if the incoming csv has headers
43 | if not any(header in string_like.splitlines()[0].lower() for header in ["question", "context", "inputs"]):
44 | raise PredictionException(
45 | "You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
46 | 400,
47 | )
48 | # reads csv as io
49 | request_list = list(csv.DictReader(stream))
50 | if "inputs" in request_list[0].keys():
51 | return {"inputs": [entry["inputs"] for entry in request_list]}
52 | else:
53 | return {"inputs": request_list}
54 |
55 |
56 | def decode_image(bpayload: bytearray):
57 | """Convert a .jpeg / .png / .tiff... object to a proper inputs dict.
58 | Args:
59 | bpayload (bytes): byte stream.
60 | Returns:
61 | (dict): dictonatry for input
62 | """
63 | image = Image.open(BytesIO(bpayload)).convert("RGB")
64 | return {"inputs": image}
65 |
66 |
67 | def decode_audio(bpayload: bytearray):
68 | """Convert a .wav / .flac / .mp3 object to a proper inputs dict.
69 | Args:
70 | bpayload (bytes): byte stream.
71 | Returns:
72 | (dict): dictonatry for input
73 | """
74 |
75 | return {"inputs": bytes(bpayload)}
76 |
77 |
78 | # https://github.com/automl/SMAC3/issues/453
79 | class _JSONEncoder(json.JSONEncoder):
80 | """
81 | custom `JSONEncoder` to make sure float and int64 ar converted
82 | """
83 |
84 | def default(self, obj):
85 | if isinstance(obj, np.integer):
86 | return int(obj)
87 | elif isinstance(obj, np.floating):
88 | return float(obj)
89 | elif hasattr(obj, "tolist"):
90 | return obj.tolist()
91 | elif isinstance(obj, datetime.datetime):
92 | return obj.__str__()
93 | elif isinstance(obj, Image.Image):
94 | with BytesIO() as out:
95 | obj.save(out, format="PNG")
96 | png_string = out.getvalue()
97 | return base64.b64encode(png_string).decode("utf-8")
98 | else:
99 | return super(_JSONEncoder, self).default(obj)
100 |
101 |
102 | def encode_json(content, accept_type=None):
103 | """
104 | encodes json with custom `JSONEncoder`
105 | """
106 | return json.dumps(
107 | content,
108 | ensure_ascii=False,
109 | allow_nan=False,
110 | indent=None,
111 | cls=_JSONEncoder,
112 | separators=(",", ":"),
113 | )
114 |
115 |
116 | def _array_to_npy(array_like, accept_type=None):
117 | """Convert an array-like object to the NPY format.
118 |
119 | To understand better what an array-like object is see:
120 | https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays
121 |
122 | Args:
123 | array_like (np.array or Iterable or int or float): array-like object
124 | to be converted to NPY.
125 |
126 | Returns:
127 | (obj): NPY array.
128 | """
129 | buffer = BytesIO()
130 | np.save(buffer, array_like)
131 | return buffer.getvalue()
132 |
133 |
134 | def encode_csv(content, accept_type=None):
135 | """Convert the result of a transformers pipeline to CSV.
136 | Args:
137 | content (dict | list): result of transformers pipeline.
138 | Returns:
139 | (str): object serialized to CSV
140 | """
141 | stream = StringIO()
142 | if not isinstance(content, list):
143 | content = list(content)
144 |
145 | column_header = content[0].keys()
146 | writer = csv.DictWriter(stream, column_header)
147 |
148 | writer.writeheader()
149 | writer.writerows(content)
150 | return stream.getvalue()
151 |
152 |
153 | def encode_image(image, accept_type=content_types.PNG):
154 | """Convert a PIL.Image object to a byte stream.
155 | Args:
156 | image (PIL.Image): image to be converted.
157 | accept_type (str): content type of the image.
158 | Returns:
159 | (bytes): byte stream of the image.
160 | """
161 | accept_type = "PNG" if content_types.X_IMAGE == accept_type else accept_type.split("/")[-1].upper()
162 |
163 | with BytesIO() as out:
164 | image.save(out, format=accept_type)
165 | return out.getvalue()
166 |
167 |
168 | _encoder_map = {
169 | content_types.NPY: _array_to_npy,
170 | content_types.CSV: encode_csv,
171 | content_types.JSON: encode_json,
172 | content_types.JPEG: encode_image,
173 | content_types.PNG: encode_image,
174 | content_types.TIFF: encode_image,
175 | content_types.BMP: encode_image,
176 | content_types.GIF: encode_image,
177 | content_types.WEBP: encode_image,
178 | content_types.X_IMAGE: encode_image,
179 | }
180 | _decoder_map = {
181 | content_types.NPY: _npy_to_numpy,
182 | content_types.CSV: decode_csv,
183 | content_types.JSON: decode_json,
184 | # image mime-types
185 | content_types.JPEG: decode_image,
186 | content_types.PNG: decode_image,
187 | content_types.TIFF: decode_image,
188 | content_types.BMP: decode_image,
189 | content_types.GIF: decode_image,
190 | content_types.WEBP: decode_image,
191 | content_types.X_IMAGE: decode_image,
192 | # audio mime-types
193 | content_types.FLAC: decode_audio,
194 | content_types.MP3: decode_audio,
195 | content_types.WAV: decode_audio,
196 | content_types.OGG: decode_audio,
197 | content_types.X_AUDIO: decode_audio,
198 | }
199 |
200 |
201 | def decode(content, content_type=content_types.JSON):
202 | """
203 | Decodes a specific content_type into an 🤗 Transformers object.
204 | """
205 | try:
206 | decoder = _decoder_map[content_type]
207 | return decoder(content)
208 | except KeyError:
209 | raise errors.UnsupportedFormatError(content_type)
210 | except PredictionException as pred_err:
211 | raise pred_err
212 |
213 |
214 | def encode(content, accept_type=content_types.JSON):
215 | """
216 | Encode an 🤗 Transformers object in a specific content_type.
217 | """
218 | try:
219 | encoder = _encoder_map[accept_type]
220 | return encoder(content, accept_type)
221 | except KeyError:
222 | raise errors.UnsupportedFormatError(accept_type)
223 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import importlib.util
15 | import logging
16 |
17 | from transformers.utils.import_utils import is_torch_bf16_gpu_available
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | _diffusers = importlib.util.find_spec("diffusers") is not None
23 |
24 |
25 | def is_diffusers_available():
26 | return _diffusers
27 |
28 |
29 | if is_diffusers_available():
30 | import torch
31 |
32 | from diffusers import DiffusionPipeline
33 |
34 |
35 | class SMDiffusionPipelineForText2Image:
36 |
37 | def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
38 | self.pipeline = None
39 | dtype = torch.float32
40 | if device == "cuda":
41 | dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
42 | if torch.cuda.device_count() > 1:
43 | device_map = "balanced"
44 | self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
45 |
46 | if not self.pipeline:
47 | self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype).to(device)
48 |
49 | def __call__(
50 | self,
51 | prompt,
52 | **kwargs,
53 | ):
54 | # TODO: add support for more images (Reason is correct output)
55 | if "num_images_per_prompt" in kwargs:
56 | kwargs.pop("num_images_per_prompt")
57 | logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")
58 |
59 | # Call pipeline with parameters
60 | out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
61 | return out.images[0]
62 |
63 |
64 | DIFFUSERS_TASKS = {
65 | "text-to-image": SMDiffusionPipelineForText2Image,
66 | }
67 |
68 |
69 | def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs):
70 | """Get a pipeline for Diffusers models."""
71 | device = "cuda" if device == 0 else "cpu"
72 | pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device)
73 | return pipeline
74 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/handler_service.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib
16 | import logging
17 | import os
18 | import sys
19 | import time
20 | from abc import ABC
21 | from inspect import signature
22 |
23 | from sagemaker_inference import environment, utils
24 | from transformers.pipelines import SUPPORTED_TASKS
25 |
26 | from mms.service import PredictionException
27 | from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder
28 | from sagemaker_huggingface_inference_toolkit.transformers_utils import (
29 | _is_gpu_available,
30 | get_pipeline,
31 | infer_task_from_model_architecture,
32 | )
33 |
34 |
35 | ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
36 | PYTHON_PATH_ENV = "PYTHONPATH"
37 | MODEL_FN = "model_fn"
38 | INPUT_FN = "input_fn"
39 | PREDICT_FN = "predict_fn"
40 | OUTPUT_FN = "output_fn"
41 | TRANSFORM_FN = "transform_fn"
42 |
43 | logger = logging.getLogger(__name__)
44 |
45 |
46 | class HuggingFaceHandlerService(ABC):
47 | """Default handler service that is executed by the model server.
48 |
49 | The handler service is responsible for defining our InferenceHandler.
50 | - The ``handle`` method is invoked for all incoming inference requests to the model server.
51 | - The ``initialize`` method is invoked at model server start up.
52 |
53 | Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md
54 | """
55 |
56 | def __init__(self):
57 | self.error = None
58 | self.batch_size = 1
59 | self.model_dir = None
60 | self.model = None
61 | self.device = -1
62 | self.initialized = False
63 | self.attempted_init = False
64 | self.context = None
65 | self.manifest = None
66 | self.environment = environment.Environment()
67 | self.load_extra_arg = []
68 | self.preprocess_extra_arg = []
69 | self.predict_extra_arg = []
70 | self.postprocess_extra_arg = []
71 | self.transform_extra_arg = []
72 |
73 | def initialize(self, context):
74 | """
75 | Initialize model. This will be called during model loading time
76 | :param context: Initial context contains model server system properties.
77 | :return:
78 | """
79 | self.attempted_init = True
80 | self.context = context
81 | properties = context.system_properties
82 | self.model_dir = properties.get("model_dir")
83 | self.batch_size = context.system_properties["batch_size"]
84 |
85 | code_dir_path = os.path.join(self.model_dir, "code")
86 | sys.path.insert(0, code_dir_path)
87 | self.validate_and_initialize_user_module()
88 |
89 | self.device = self.get_device()
90 | self.model = self.load(*([self.model_dir] + self.load_extra_arg))
91 | self.initialized = True
92 | # # Load methods from file
93 | # if (not self._initialized) and ENABLE_MULTI_MODEL:
94 | # code_dir = os.path.join(context.system_properties.get("model_dir"), "code")
95 | # sys.path.append(code_dir)
96 | # self._initialized = True
97 | # # add model_dir/code to python path
98 |
99 | def get_device(self):
100 | """
101 | The get device function will return the device for the DL Framework.
102 | """
103 | if _is_gpu_available():
104 | return int(self.context.system_properties.get("gpu_id"))
105 | else:
106 | return -1
107 |
108 | def load(self, model_dir, context=None):
109 | """
110 | The Load handler is responsible for loading the Hugging Face transformer model.
111 | It can be overridden to load the model from storage.
112 |
113 | Args:
114 | model_dir (str): The directory where model files are stored.
115 | context (obj): metadata on the incoming request data (default: None).
116 |
117 | Returns:
118 | hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
119 | """
120 | # gets pipeline from task tag
121 | if "HF_TASK" in os.environ:
122 | hf_pipeline = get_pipeline(task=os.environ["HF_TASK"], model_dir=model_dir, device=self.device)
123 | elif "config.json" in os.listdir(model_dir):
124 | task = infer_task_from_model_architecture(f"{model_dir}/config.json")
125 | hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
126 | elif "model_index.json" in os.listdir(model_dir):
127 | task = "text-to-image"
128 | hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
129 | else:
130 | raise ValueError(
131 | f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.",
132 | 403,
133 | )
134 | return hf_pipeline
135 |
136 | def preprocess(self, input_data, content_type, context=None):
137 | """
138 | The preprocess handler is responsible for deserializing the input data into
139 | an object for prediction, can handle JSON.
140 | The preprocess handler can be overridden for data or feature transformation.
141 |
142 | Args:
143 | input_data: the request payload serialized in the content_type format.
144 | content_type: the request content_type.
145 | context (obj): metadata on the incoming request data (default: None).
146 |
147 | Returns:
148 | decoded_input_data (dict): deserialized input_data into a Python dictonary.
149 | """
150 | # raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
151 | if (
152 | os.environ.get("HF_TASK", None) == "zero-shot-classification"
153 | or os.environ.get("HF_TASK", None) == "table-question-answering"
154 | ) and content_type == content_types.CSV:
155 | raise PredictionException(
156 | f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
157 | 400,
158 | )
159 |
160 | decoded_input_data = decoder_encoder.decode(input_data, content_type)
161 | return decoded_input_data
162 |
163 | def predict(self, data, model, context=None):
164 | """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
165 | on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
166 | The predict handler can be overridden to implement the model inference.
167 |
168 | Args:
169 | data (dict): deserialized decoded_input_data returned by the input_fn
170 | model : Model returned by the `load` method or if it is a custom module `model_fn`.
171 | context (obj): metadata on the incoming request data (default: None).
172 |
173 | Returns:
174 | obj (dict): prediction result.
175 | """
176 |
177 | # pop inputs for pipeline
178 | inputs = data.pop("inputs", data)
179 | parameters = data.pop("parameters", None)
180 |
181 | # pass inputs with all kwargs in data
182 | if parameters is not None:
183 | prediction = model(inputs, **parameters)
184 | else:
185 | prediction = model(inputs)
186 | return prediction
187 |
188 | def postprocess(self, prediction, accept, context=None):
189 | """
190 | The postprocess handler is responsible for serializing the prediction result to
191 | the desired accept type, can handle JSON.
192 | The postprocess handler can be overridden for inference response transformation.
193 |
194 | Args:
195 | prediction (dict): a prediction result from predict.
196 | accept (str): type which the output data needs to be serialized.
197 | context (obj): metadata on the incoming request data (default: None).
198 | Returns: output data serialized
199 | """
200 | return decoder_encoder.encode(prediction, accept)
201 |
202 | def transform_fn(self, model, input_data, content_type, accept, context=None):
203 | """
204 | Transform function ("transform_fn") can be used to write one function with pre/post-processing steps and predict step in it.
205 | This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn".
206 |
207 | Args:
208 | model: Model returned by the model_fn above
209 | input_data: Data received for inference
210 | content_type: The content type of the inference data
211 | accept: The response accept type.
212 | context (obj): metadata on the incoming request data (default: None).
213 |
214 | Returns: Response in the "accept" format type.
215 |
216 | """
217 | # run pipeline
218 | start_time = time.time()
219 | processed_data = self.preprocess(*([input_data, content_type] + self.preprocess_extra_arg))
220 | preprocess_time = time.time() - start_time
221 | predictions = self.predict(*([processed_data, model] + self.predict_extra_arg))
222 | predict_time = time.time() - preprocess_time - start_time
223 | response = self.postprocess(*([predictions, accept] + self.postprocess_extra_arg))
224 | postprocess_time = time.time() - predict_time - preprocess_time - start_time
225 |
226 | logger.info(
227 | f"Preprocess time - {preprocess_time * 1000} ms\n"
228 | f"Predict time - {predict_time * 1000} ms\n"
229 | f"Postprocess time - {postprocess_time * 1000} ms"
230 | )
231 |
232 | return response
233 |
234 | def handle(self, data, context):
235 | """Handles an inference request with input data and makes a prediction.
236 |
237 | Args:
238 | data (obj): the request data.
239 | context (obj): metadata on the incoming request data.
240 |
241 | Returns:
242 | list[obj]: The return value from the Transformer.transform method,
243 | which is a serialized prediction result wrapped in a list if
244 | inference is successful. Otherwise returns an error message
245 | with the context set appropriately.
246 |
247 | """
248 | try:
249 | if not self.initialized:
250 | if self.attempted_init:
251 | logger.warn(
252 | "Model is not initialized, will try to load model again.\n"
253 | "Please consider increase wait time for model loading.\n"
254 | )
255 | self.initialize(context)
256 |
257 | input_data = data[0].get("body")
258 |
259 | request_property = context.request_processor[0].get_request_properties()
260 | content_type = utils.retrieve_content_type_header(request_property)
261 | accept = request_property.get("Accept") or request_property.get("accept")
262 |
263 | if not accept or accept == content_types.ANY:
264 | accept = content_types.JSON
265 |
266 | if content_type in content_types.UTF8_TYPES:
267 | input_data = input_data.decode("utf-8")
268 |
269 | predict_start = time.time()
270 | response = self.transform_fn(*([self.model, input_data, content_type, accept] + self.transform_extra_arg))
271 | predict_end = time.time()
272 |
273 | context.metrics.add_time("Transform Fn", round((predict_end - predict_start) * 1000, 2))
274 |
275 | context.set_response_content_type(0, accept)
276 | return [response]
277 |
278 | except Exception as e:
279 | raise PredictionException(str(e), 400)
280 |
281 | def validate_and_initialize_user_module(self):
282 | """Retrieves and validates the inference handlers provided within the user module.
283 | Can override load, preprocess, predict and post process function.
284 | """
285 | user_module_name = self.environment.module_name
286 | if importlib.util.find_spec(user_module_name) is not None:
287 | logger.info("Inference script implementation found at `{}`.".format(user_module_name))
288 | user_module = importlib.import_module(user_module_name)
289 |
290 | load_fn = getattr(user_module, MODEL_FN, None)
291 | preprocess_fn = getattr(user_module, INPUT_FN, None)
292 | predict_fn = getattr(user_module, PREDICT_FN, None)
293 | postprocess_fn = getattr(user_module, OUTPUT_FN, None)
294 | transform_fn = getattr(user_module, TRANSFORM_FN, None)
295 |
296 | if transform_fn and (preprocess_fn or predict_fn or postprocess_fn):
297 | raise ValueError(
298 | "Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format(
299 | TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN
300 | )
301 | )
302 | self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
303 | if load_fn is not None:
304 | self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn)
305 | self.load = load_fn
306 | self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
307 | if preprocess_fn is not None:
308 | self.preprocess_extra_arg = self.function_extra_arg(
309 | HuggingFaceHandlerService.preprocess, preprocess_fn
310 | )
311 | self.preprocess = preprocess_fn
312 | self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
313 | if predict_fn is not None:
314 | self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn)
315 | self.predict = predict_fn
316 | self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
317 | if postprocess_fn is not None:
318 | self.postprocess_extra_arg = self.function_extra_arg(
319 | HuggingFaceHandlerService.postprocess, postprocess_fn
320 | )
321 | self.postprocess = postprocess_fn
322 | self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
323 | if transform_fn is not None:
324 | self.transform_extra_arg = self.function_extra_arg(
325 | HuggingFaceHandlerService.transform_fn, transform_fn
326 | )
327 | self.transform_fn = transform_fn
328 | else:
329 | logger.info(
330 | "No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format(
331 | user_module_name
332 | )
333 | )
334 |
335 | @staticmethod
336 | def log_func_implementation_found_or_not(func, func_name):
337 | if func is not None:
338 | logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name))
339 | else:
340 | logger.info(
341 | "No `{}` implementation was found. The default one from the handler service will be used.".format(
342 | func_name
343 | )
344 | )
345 |
346 | def function_extra_arg(self, default_func, func):
347 | """Helper to call the handler function which covers 2 cases:
348 | 1. the handle function takes context
349 | 2. the handle function does not take context
350 | """
351 | default_params = signature(default_func).parameters
352 | func_params = signature(func).parameters
353 |
354 | if "self" in default_params:
355 | num_default_func_input = len(default_params) - 1
356 | else:
357 | num_default_func_input = len(default_params)
358 |
359 | num_func_input = len(func_params)
360 | if num_default_func_input == num_func_input:
361 | # function takes context
362 | extra_args = [self.context]
363 | elif num_default_func_input == num_func_input + 1:
364 | # function does not take context
365 | extra_args = []
366 | else:
367 | raise TypeError(
368 | "{} definition takes {} or {} arguments but {} were given.".format(
369 | func.__name__, num_default_func_input - 1, num_default_func_input, num_func_input
370 | )
371 | )
372 | return extra_args
373 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import absolute_import
15 |
16 | import os
17 | import pathlib
18 | import subprocess
19 |
20 | from sagemaker_inference import environment, logging
21 | from sagemaker_inference.model_server import (
22 | DEFAULT_MMS_LOG_FILE,
23 | DEFAULT_MMS_MODEL_NAME,
24 | ENABLE_MULTI_MODEL,
25 | MMS_CONFIG_FILE,
26 | REQUIREMENTS_PATH,
27 | _add_sigchild_handler,
28 | _add_sigterm_handler,
29 | _create_model_server_config_file,
30 | _install_requirements,
31 | _retry_retrieve_mms_server_process,
32 | _set_python_path,
33 | )
34 |
35 | from sagemaker_huggingface_inference_toolkit import handler_service
36 | from sagemaker_huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available
37 | from sagemaker_huggingface_inference_toolkit.transformers_utils import (
38 | HF_API_TOKEN,
39 | HF_MODEL_REVISION,
40 | _load_model_from_hub,
41 | )
42 |
43 |
44 | logger = logging.get_logger()
45 |
46 | DEFAULT_HANDLER_SERVICE = handler_service.__name__
47 |
48 | DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models")
49 | DEFAULT_MODEL_STORE = "/"
50 |
51 |
52 | def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
53 | """Configure and start the model server.
54 |
55 | Args:
56 | handler_service (str): python path pointing to a module that defines
57 | a class with the following:
58 |
59 | - A ``handle`` method, which is invoked for all incoming inference
60 | requests to the model server.
61 | - A ``initialize`` method, which is invoked at model server start up
62 | for loading the model.
63 |
64 | Defaults to ``sagemaker_huggingface_inference_toolkit.handler_service``.
65 |
66 | """
67 | use_hf_hub = "HF_MODEL_ID" in os.environ
68 | model_store = DEFAULT_MODEL_STORE
69 | if ENABLE_MULTI_MODEL:
70 | if not os.getenv("SAGEMAKER_HANDLER"):
71 | os.environ["SAGEMAKER_HANDLER"] = handler_service
72 | _set_python_path()
73 | elif use_hf_hub:
74 | # Use different model store directory
75 | model_store = DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY
76 | storage_dir = _load_model_from_hub(
77 | model_id=os.environ["HF_MODEL_ID"],
78 | model_dir=DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
79 | revision=HF_MODEL_REVISION,
80 | use_auth_token=HF_API_TOKEN,
81 | )
82 | _adapt_to_mms_format(handler_service, storage_dir)
83 | else:
84 | _set_python_path()
85 |
86 | env = environment.Environment()
87 |
88 | # Set the number of workers to available number if optimum neuron is available and not already set
89 | if is_optimum_neuron_available() and os.environ.get("SAGEMAKER_MODEL_SERVER_WORKERS", None) is None:
90 | from optimum.neuron.utils.cache_utils import get_num_neuron_cores
91 |
92 | try:
93 | env._model_server_workers = str(get_num_neuron_cores())
94 | except Exception:
95 | env._model_server_workers = "1"
96 |
97 | # Note: multi-model default config already sets default_service_handler
98 | handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service
99 | _create_model_server_config_file(env, handler_service_for_config)
100 |
101 | if os.path.exists(REQUIREMENTS_PATH):
102 | _install_requirements()
103 |
104 | multi_model_server_cmd = [
105 | "multi-model-server",
106 | "--start",
107 | "--model-store",
108 | model_store,
109 | "--mms-config",
110 | MMS_CONFIG_FILE,
111 | "--log-config",
112 | DEFAULT_MMS_LOG_FILE,
113 | ]
114 | if not ENABLE_MULTI_MODEL and not use_hf_hub:
115 | multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir]
116 |
117 | logger.info(multi_model_server_cmd)
118 | subprocess.Popen(multi_model_server_cmd)
119 | # retry for configured timeout
120 | mms_process = _retry_retrieve_mms_server_process(env.startup_timeout)
121 |
122 | _add_sigterm_handler(mms_process)
123 | _add_sigchild_handler()
124 |
125 | mms_process.wait()
126 |
127 |
128 | def _adapt_to_mms_format(handler_service, model_path):
129 | os.makedirs(DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY, exist_ok=True)
130 |
131 | # gets the model from the path, default is model/
132 | model = pathlib.PurePath(model_path)
133 |
134 | # This is archiving or cp /opt/ml/model to /opt/ml (MODEL_STORE) into model (MODEL_NAME)
135 | model_archiver_cmd = [
136 | "model-archiver",
137 | "--model-name",
138 | model.name,
139 | "--handler",
140 | handler_service,
141 | "--model-path",
142 | model_path,
143 | "--export-path",
144 | DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
145 | "--archive-format",
146 | "no-archive",
147 | "--f",
148 | ]
149 |
150 | logger.info(model_archiver_cmd)
151 | subprocess.check_call(model_archiver_cmd)
152 |
153 | _set_python_path()
154 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/optimum_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib.util
16 | import logging
17 | import os
18 |
19 |
20 | _optimum_neuron = False
21 | if importlib.util.find_spec("optimum") is not None:
22 | if importlib.util.find_spec("optimum.neuron") is not None:
23 | _optimum_neuron = True
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def is_optimum_neuron_available():
29 | return _optimum_neuron
30 |
31 |
32 | def get_input_shapes(model_dir):
33 | """Method to get input shapes from model config file. If config file is not present, default values are returned."""
34 | from transformers import AutoConfig
35 |
36 | input_shapes = {}
37 | input_shapes_available = False
38 | # try to get input shapes from config file
39 | try:
40 | config = AutoConfig.from_pretrained(model_dir)
41 | if hasattr(config, "neuron"):
42 | # check if static batch size and sequence length are available
43 | if config.neuron.get("static_batch_size", None) and config.neuron.get("static_sequence_length", None):
44 | input_shapes["batch_size"] = config.neuron["static_batch_size"]
45 | input_shapes["sequence_length"] = config.neuron["static_sequence_length"]
46 | input_shapes_available = True
47 | logger.info(
48 | f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}"
49 | )
50 | else:
51 | # Add warning if environment variables are set but will be ignored
52 | if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None:
53 | logger.warning(
54 | "HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
55 | )
56 | if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None:
57 | logger.warning(
58 | "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
59 | )
60 | except Exception:
61 | input_shapes_available = False
62 |
63 | # return input shapes if available
64 | if input_shapes_available:
65 | return input_shapes
66 |
67 | # extract input shapes from environment variables
68 | sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None)
69 | if sequence_length is None:
70 | raise ValueError(
71 | "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is not set. Please set HF_OPTIMUM_SEQUENCE_LENGTH to a positive integer."
72 | )
73 |
74 | if not int(sequence_length) > 0:
75 | raise ValueError(
76 | f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}"
77 | )
78 | batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1)
79 | logger.info(
80 | f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}"
81 | )
82 | return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)}
83 |
84 |
85 | def get_optimum_neuron_pipeline(task, model_dir):
86 | """Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised."""
87 | from optimum.neuron.pipelines.transformers.base import NEURONX_SUPPORTED_TASKS, pipeline
88 | from optimum.neuron.utils import NEURON_FILE_NAME
89 |
90 | # check task support
91 | if task not in NEURONX_SUPPORTED_TASKS:
92 | raise ValueError(
93 | f"Task {task} is not supported by optimum neuron and inf2. Supported tasks are: {list(NEURONX_SUPPORTED_TASKS.keys())}"
94 | )
95 |
96 | # check if model is already converted and has input shapes available
97 | export = True
98 | if NEURON_FILE_NAME in os.listdir(model_dir):
99 | export = False
100 | if export:
101 | logger.info("Model is not converted. Checking if required environment variables are set and converting model.")
102 |
103 | # get static input shapes to run inference
104 | input_shapes = get_input_shapes(model_dir)
105 | # set NEURON_RT_NUM_CORES to 1 to avoid conflicts with multiple HTTP workers
106 | os.environ["NEURON_RT_NUM_CORES"] = "1"
107 | # get optimum neuron pipeline
108 | neuron_pipe = pipeline(task, model=model_dir, export=export, input_shapes=input_shapes)
109 |
110 | return neuron_pipe
111 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/serving.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from subprocess import CalledProcessError
15 |
16 | from retrying import retry
17 |
18 | from sagemaker_huggingface_inference_toolkit import handler_service, mms_model_server
19 |
20 |
21 | HANDLER_SERVICE = handler_service.__name__
22 |
23 |
24 | def _retry_if_error(exception):
25 | return isinstance(exception, CalledProcessError or OSError)
26 |
27 |
28 | @retry(stop_max_delay=1000 * 50, retry_on_exception=_retry_if_error)
29 | def _start_mms():
30 | mms_model_server.start_model_server(handler_service=HANDLER_SERVICE)
31 |
32 |
33 | def main():
34 | _start_mms()
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import importlib.util
15 | import json
16 | import logging
17 | import os
18 | from pathlib import Path
19 | from typing import Optional
20 |
21 | from huggingface_hub import HfApi, login, snapshot_download
22 | from transformers import AutoTokenizer, pipeline
23 | from transformers.file_utils import is_tf_available, is_torch_available
24 | from transformers.pipelines import Pipeline
25 |
26 | from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available
27 | from sagemaker_huggingface_inference_toolkit.optimum_utils import (
28 | get_optimum_neuron_pipeline,
29 | is_optimum_neuron_available,
30 | )
31 |
32 |
33 | if is_tf_available():
34 | import tensorflow as tf
35 |
36 | if is_torch_available():
37 | import torch
38 |
39 | _aws_neuron_available = importlib.util.find_spec("torch_neuron") is not None
40 |
41 |
42 | def is_aws_neuron_available():
43 | return _aws_neuron_available
44 |
45 |
46 | def strtobool(val):
47 | """Convert a string representation of truth to True or False.
48 | True values are 'y', 'yes', 't', 'true', 'on', '1', 'TRUE', or 'True'; false values
49 | are 'n', 'no', 'f', 'false', 'off', '0', 'FALSE' or 'False. Raises ValueError if
50 | 'val' is anything else.
51 | """
52 | val = val.lower()
53 | if val in ("y", "yes", "t", "true", "on", "1", "TRUE", "True"):
54 | return True
55 | elif val in ("n", "no", "f", "false", "off", "0", "FALSE", "False"):
56 | return False
57 | else:
58 | raise ValueError("invalid truth value %r" % (val,))
59 |
60 |
61 | logger = logging.getLogger(__name__)
62 |
63 |
64 | FRAMEWORK_MAPPING = {
65 | "pytorch": "pytorch*",
66 | "tensorflow": "tf*",
67 | "tf": "tf*",
68 | "pt": "pytorch*",
69 | "flax": "flax*",
70 | "rust": "rust*",
71 | "onnx": "*onnx*",
72 | "safetensors": "*safetensors",
73 | "coreml": "*mlmodel",
74 | "tflite": "*tflite",
75 | "savedmodel": "*tar.gz",
76 | "openvino": "*openvino*",
77 | "ckpt": "*ckpt",
78 | "neuronx": "*neuron",
79 | }
80 |
81 |
82 | REPO_ID_SEPARATOR = "__"
83 |
84 | ARCHITECTURES_2_TASK = {
85 | "TapasForQuestionAnswering": "table-question-answering",
86 | "ForQuestionAnswering": "question-answering",
87 | "ForTokenClassification": "token-classification",
88 | "ForSequenceClassification": "text-classification",
89 | "ForMultipleChoice": "multiple-choice",
90 | "ForMaskedLM": "fill-mask",
91 | "ForCausalLM": "text-generation",
92 | "ForConditionalGeneration": "text2text-generation",
93 | "MTModel": "text2text-generation",
94 | "EncoderDecoderModel": "text2text-generation",
95 | # Model specific task for backward comp
96 | "GPT2LMHeadModel": "text-generation",
97 | "T5WithLMHeadModel": "text2text-generation",
98 | }
99 |
100 |
101 | HF_API_TOKEN = os.environ.get("HF_API_TOKEN", None)
102 | HF_MODEL_REVISION = os.environ.get("HF_MODEL_REVISION", None)
103 | TRUST_REMOTE_CODE = strtobool(os.environ.get("HF_TRUST_REMOTE_CODE", "False"))
104 |
105 |
106 | def create_artifact_filter(framework):
107 | """
108 | Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
109 | """
110 | ignore_regex_list = list(set(FRAMEWORK_MAPPING.values()))
111 |
112 | pattern = FRAMEWORK_MAPPING.get(framework, None)
113 | if pattern in ignore_regex_list:
114 | ignore_regex_list.remove(pattern)
115 | return ignore_regex_list
116 | else:
117 | return []
118 |
119 |
120 | def _is_gpu_available():
121 | """
122 | checks if a gpu is available.
123 | """
124 | if is_tf_available():
125 | return True if len(tf.config.list_physical_devices("GPU")) > 0 else False
126 | elif is_torch_available():
127 | return torch.cuda.is_available()
128 | else:
129 | raise RuntimeError(
130 | "At least one of TensorFlow 2.0 or PyTorch should be installed. "
131 | "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
132 | "To install PyTorch, read the instructions at https://pytorch.org/."
133 | )
134 |
135 |
136 | def _get_framework():
137 | """
138 | extracts which DL framework is used for inference, if both are installed use pytorch
139 | """
140 | if is_torch_available():
141 | return "pytorch"
142 | elif is_tf_available():
143 | return "tensorflow"
144 | else:
145 | raise RuntimeError(
146 | "At least one of TensorFlow 2.0 or PyTorch should be installed. "
147 | "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
148 | "To install PyTorch, read the instructions at https://pytorch.org/."
149 | )
150 |
151 |
152 | def _build_storage_path(model_id: str, model_dir: Path, revision: Optional[str] = None):
153 | """
154 | creates storage path for hub model based on model_id and revision
155 | """
156 | if "/" and revision is None:
157 | storage_path = os.path.join(model_dir, model_id.replace("/", REPO_ID_SEPARATOR))
158 | elif "/" and revision is not None:
159 | storage_path = os.path.join(model_dir, model_id.replace("/", REPO_ID_SEPARATOR) + "." + revision)
160 | elif revision is not None:
161 | storage_path = os.path.join(model_dir, model_id + "." + revision)
162 | else:
163 | storage_path = os.path.join(model_dir, model_id)
164 | return storage_path
165 |
166 |
167 | def _load_model_from_hub(
168 | model_id: str, model_dir: Path, revision: Optional[str] = None, use_auth_token: Optional[str] = None
169 | ):
170 | """
171 | Downloads a model repository at the specified revision from the Hugging Face Hub.
172 | All files are nested inside a folder in order to keep their actual filename
173 | relative to that folder. `org__model.revision`
174 | """
175 | logger.warn(
176 | "This is an experimental beta features, which allows downloading model from the Hugging Face Hub on start up. "
177 | "It loads the model defined in the env var `HF_MODEL_ID`"
178 | )
179 | if use_auth_token is not None:
180 | login(token=use_auth_token)
181 | # extracts base framework
182 | framework = _get_framework()
183 |
184 | # creates directory for saved model based on revision and model
185 | storage_folder = _build_storage_path(model_id, model_dir, revision)
186 | os.makedirs(storage_folder, exist_ok=True)
187 |
188 | # check if safetensors weights are available
189 | if framework == "pytorch":
190 | files = HfApi().model_info(model_id).siblings
191 | if is_optimum_neuron_available() and any(f.rfilename.endswith("neuron") for f in files):
192 | framework = "neuronx"
193 | elif any(f.rfilename.endswith("safetensors") for f in files):
194 | framework = "safetensors"
195 |
196 | # create regex to only include the framework specific weights
197 | ignore_regex = create_artifact_filter(framework)
198 |
199 | # Download the repository to the workdir and filter out non-framework specific weights
200 | snapshot_download(
201 | model_id,
202 | revision=revision,
203 | local_dir=str(storage_folder),
204 | local_dir_use_symlinks=False,
205 | ignore_patterns=ignore_regex,
206 | )
207 |
208 | return storage_folder
209 |
210 |
211 | def infer_task_from_model_architecture(model_config_path: str, architecture_index=0) -> str:
212 | """
213 | Infer task from `config.json` of trained model. It is not guaranteed to the detect, e.g. some models implement multiple architectures or
214 | trainend on different tasks https://huggingface.co/facebook/bart-large/blob/main/config.json. Should work for every on Amazon SageMaker fine-tuned model.
215 | It is always recommended to set the task through the env var `TASK`.
216 | """
217 | with open(model_config_path, "r") as config_file:
218 | config = json.loads(config_file.read())
219 | architecture = config.get("architectures", [None])[architecture_index]
220 |
221 | task = None
222 | for arch_options in ARCHITECTURES_2_TASK:
223 | if architecture.endswith(arch_options):
224 | task = ARCHITECTURES_2_TASK[arch_options]
225 |
226 | if task is None:
227 | raise ValueError(
228 | f"Task couldn't be inferenced from {architecture}."
229 | f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
230 | "Use env `HF_TASK` to define your task."
231 | )
232 | # set env to work with
233 | os.environ["HF_TASK"] = task
234 | return task
235 |
236 |
237 | def infer_task_from_hub(model_id: str, revision: Optional[str] = None, use_auth_token: Optional[str] = None) -> str:
238 | """
239 | Infer task from Hub by extracting `pipeline_tag` for model_info.
240 | """
241 | _api = HfApi()
242 | model_info = _api.model_info(repo_id=model_id, revision=revision, token=use_auth_token)
243 | if model_info.pipeline_tag is not None:
244 | # set env to work with
245 | os.environ["HF_TASK"] = model_info.pipeline_tag
246 | return model_info.pipeline_tag
247 | else:
248 | raise ValueError(
249 | f"Task couldn't be inferenced from {model_info.pipeline_tag}." "Use env `HF_TASK` to define your task."
250 | )
251 |
252 |
253 | def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
254 | """
255 | create pipeline class for a specific task based on local saved model
256 | """
257 | if task is None:
258 | raise EnvironmentError(
259 | "The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined"
260 | )
261 | # define tokenizer or feature extractor as kwargs to load it the pipeline correctly
262 | if task in {
263 | "automatic-speech-recognition",
264 | "image-segmentation",
265 | "image-classification",
266 | "audio-classification",
267 | "object-detection",
268 | "zero-shot-image-classification",
269 | }:
270 | kwargs["feature_extractor"] = model_dir
271 | else:
272 | kwargs["tokenizer"] = model_dir
273 | # check if optimum neuron is available and tries to load it
274 | if is_optimum_neuron_available():
275 | hf_pipeline = get_optimum_neuron_pipeline(task=task, model_dir=model_dir)
276 | elif TRUST_REMOTE_CODE and os.environ.get("HF_MODEL_ID", None) is not None and device == 0:
277 | tokenizer = AutoTokenizer.from_pretrained(os.environ["HF_MODEL_ID"])
278 |
279 | hf_pipeline = pipeline(
280 | task=task,
281 | model=os.environ["HF_MODEL_ID"],
282 | tokenizer=tokenizer,
283 | trust_remote_code=TRUST_REMOTE_CODE,
284 | model_kwargs={"device_map": "auto", "torch_dtype": "auto"},
285 | )
286 | elif is_diffusers_available() and task == "text-to-image":
287 | hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
288 | else:
289 | # load pipeline
290 | hf_pipeline = pipeline(
291 | task=task, model=model_dir, device=device, trust_remote_code=TRUST_REMOTE_CODE, **kwargs
292 | )
293 |
294 | return hf_pipeline
295 |
--------------------------------------------------------------------------------
/tests/integ/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/integ/__init__.py
--------------------------------------------------------------------------------
/tests/integ/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from integ.utils import (
4 | validate_automatic_speech_recognition,
5 | validate_classification,
6 | validate_feature_extraction,
7 | validate_fill_mask,
8 | validate_ner,
9 | validate_question_answering,
10 | validate_summarization,
11 | validate_text2text_generation,
12 | validate_text_generation,
13 | validate_translation,
14 | validate_zero_shot_classification,
15 | )
16 |
17 |
18 | task2model = {
19 | "text-classification": {
20 | "pytorch": "distilbert-base-uncased-finetuned-sst-2-english",
21 | "tensorflow": "distilbert-base-uncased-finetuned-sst-2-english",
22 | },
23 | "zero-shot-classification": {
24 | "pytorch": "joeddav/xlm-roberta-large-xnli",
25 | "tensorflow": None,
26 | },
27 | "feature-extraction": {
28 | "pytorch": "bert-base-uncased",
29 | "tensorflow": "bert-base-uncased",
30 | },
31 | "ner": {
32 | "pytorch": "dbmdz/bert-large-cased-finetuned-conll03-english",
33 | "tensorflow": "dbmdz/bert-large-cased-finetuned-conll03-english",
34 | },
35 | "question-answering": {
36 | "pytorch": "distilbert-base-uncased-distilled-squad",
37 | "tensorflow": "distilbert-base-uncased-distilled-squad",
38 | },
39 | "fill-mask": {
40 | "pytorch": "albert-base-v2",
41 | "tensorflow": "albert-base-v2",
42 | },
43 | "summarization": {
44 | "pytorch": "sshleifer/distilbart-xsum-1-1",
45 | "tensorflow": "sshleifer/distilbart-xsum-1-1",
46 | },
47 | "translation_xx_to_yy": {
48 | "pytorch": "Helsinki-NLP/opus-mt-en-de",
49 | "tensorflow": "Helsinki-NLP/opus-mt-en-de",
50 | },
51 | "text2text-generation": {
52 | "pytorch": "t5-small",
53 | "tensorflow": "t5-small",
54 | },
55 | "text-generation": {
56 | "pytorch": "gpt2",
57 | "tensorflow": "gpt2",
58 | },
59 | "image-classification": {
60 | "pytorch": "google/vit-base-patch16-224",
61 | "tensorflow": "google/vit-base-patch16-224",
62 | },
63 | "automatic-speech-recognition": {
64 | "pytorch": "facebook/wav2vec2-base-100h",
65 | "tensorflow": "facebook/wav2vec2-base-960h",
66 | },
67 | }
68 |
69 | task2input = {
70 | "text-classification": {"inputs": "I love you. I like you"},
71 | "zero-shot-classification": {
72 | "inputs": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!",
73 | "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
74 | },
75 | "feature-extraction": {"inputs": "What is the best book."},
76 | "ner": {"inputs": "My name is Wolfgang and I live in Berlin"},
77 | "question-answering": {
78 | "inputs": {
79 | "question": "What is used for inference?",
80 | "context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference.",
81 | }
82 | },
83 | "fill-mask": {"inputs": "Paris is the [MASK] of France."},
84 | "summarization": {
85 | "inputs": "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
86 | },
87 | "translation_xx_to_yy": {"inputs": "My name is Sarah and I live in London"},
88 | "text2text-generation": {
89 | "inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything."
90 | },
91 | "text-generation": {"inputs": "My name is philipp and I am"},
92 | "image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
93 | "automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(),
94 | }
95 |
96 | task2output = {
97 | "text-classification": [{"label": "POSITIVE", "score": 0.99}],
98 | "zero-shot-classification": {
99 | "sequence": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!",
100 | "labels": ["refund", "faq", "legal"],
101 | "scores": [0.96, 0.027, 0.008],
102 | },
103 | "ner": [
104 | {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19},
105 | {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40},
106 | ],
107 | "question-answering": {"score": 0.99, "start": 68, "end": 77, "answer": "sagemaker"},
108 | "summarization": [{"summary_text": " The A The The ANew York City has been installed in the US."}],
109 | "translation_xx_to_yy": [{"translation_text": "Mein Name ist Sarah und ich lebe in London"}],
110 | "text2text-generation": [{"generated_text": "42 is the answer to life, the universe and everything"}],
111 | "feature-extraction": None,
112 | "fill-mask": None,
113 | "text-generation": None,
114 | "image-classification": [
115 | {"score": 0.8858247399330139, "label": "tiger, Panthera tigris"},
116 | {"score": 0.10940514504909515, "label": "tiger cat"},
117 | {"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"},
118 | {"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"},
119 | {"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"},
120 | ],
121 | "automatic-speech-recognition": {
122 | "text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS"
123 | },
124 | }
125 |
126 | task2performance = {
127 | "text-classification": {
128 | "cpu": {
129 | "average_request_time": 4,
130 | },
131 | "gpu": {
132 | "average_request_time": 1,
133 | },
134 | },
135 | "zero-shot-classification": {
136 | "cpu": {
137 | "average_request_time": 4,
138 | },
139 | "gpu": {
140 | "average_request_time": 4,
141 | },
142 | },
143 | "feature-extraction": {
144 | "cpu": {
145 | "average_request_time": 4,
146 | },
147 | "gpu": {
148 | "average_request_time": 1,
149 | },
150 | },
151 | "ner": {
152 | "cpu": {
153 | "average_request_time": 4,
154 | },
155 | "gpu": {
156 | "average_request_time": 1,
157 | },
158 | },
159 | "question-answering": {
160 | "cpu": {
161 | "average_request_time": 4,
162 | },
163 | "gpu": {
164 | "average_request_time": 4,
165 | },
166 | },
167 | "fill-mask": {
168 | "cpu": {
169 | "average_request_time": 4,
170 | },
171 | "gpu": {
172 | "average_request_time": 3,
173 | },
174 | },
175 | "summarization": {
176 | "cpu": {
177 | "average_request_time": 26,
178 | },
179 | "gpu": {
180 | "average_request_time": 3,
181 | },
182 | },
183 | "translation_xx_to_yy": {
184 | "cpu": {
185 | "average_request_time": 8,
186 | },
187 | "gpu": {
188 | "average_request_time": 3,
189 | },
190 | },
191 | "text2text-generation": {
192 | "cpu": {
193 | "average_request_time": 4,
194 | },
195 | "gpu": {
196 | "average_request_time": 3,
197 | },
198 | },
199 | "text-generation": {
200 | "cpu": {
201 | "average_request_time": 15,
202 | },
203 | "gpu": {
204 | "average_request_time": 3,
205 | },
206 | },
207 | "image-classification": {
208 | "cpu": {
209 | "average_request_time": 4,
210 | },
211 | "gpu": {
212 | "average_request_time": 1,
213 | },
214 | },
215 | "automatic-speech-recognition": {
216 | "cpu": {
217 | "average_request_time": 6,
218 | },
219 | "gpu": {
220 | "average_request_time": 6,
221 | },
222 | },
223 | }
224 |
225 | task2validation = {
226 | "text-classification": validate_classification,
227 | "zero-shot-classification": validate_zero_shot_classification,
228 | "feature-extraction": validate_feature_extraction,
229 | "ner": validate_ner,
230 | "question-answering": validate_question_answering,
231 | "fill-mask": validate_fill_mask,
232 | "summarization": validate_summarization,
233 | "translation_xx_to_yy": validate_translation,
234 | "text2text-generation": validate_text2text_generation,
235 | "text-generation": validate_text_generation,
236 | "image-classification": validate_classification,
237 | "automatic-speech-recognition": validate_automatic_speech_recognition,
238 | }
239 |
--------------------------------------------------------------------------------
/tests/integ/test_diffusers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from io import BytesIO
4 |
5 | import boto3
6 | from integ.utils import clean_up, timeout_and_delete_by_name
7 | from PIL import Image
8 | from sagemaker import Session
9 | from sagemaker.model import Model
10 |
11 |
12 | os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
13 | SAGEMAKER_EXECUTION_ROLE = os.environ.get("SAGEMAKER_EXECUTION_ROLE", "sagemaker_execution_role")
14 |
15 |
16 | def get_framework_ecr_image(registry_id="763104351884", repository_name="huggingface-pytorch-inference", device="cpu"):
17 | client = boto3.client("ecr")
18 |
19 | def get_all_ecr_images(registry_id, repository_name, result_key):
20 | response = client.list_images(
21 | registryId=registry_id,
22 | repositoryName=repository_name,
23 | )
24 | results = response[result_key]
25 | while "nextToken" in response:
26 | response = client.list_images(
27 | registryId=registry_id,
28 | nextToken=response["nextToken"],
29 | repositoryName=repository_name,
30 | )
31 | results.extend(response[result_key])
32 | return results
33 |
34 | images = get_all_ecr_images(registry_id=registry_id, repository_name=repository_name, result_key="imageIds")
35 | image_tags = [image["imageTag"] for image in images]
36 | image_regex = re.compile("\d\.\d\.\d-" + device + "-.{4}$")
37 | tag = sorted(list(filter(image_regex.match, image_tags)), reverse=True)[0]
38 | return f"{registry_id}.dkr.ecr.{os.environ.get('AWS_DEFAULT_REGION','us-east-1')}.amazonaws.com/{repository_name}:{tag}"
39 |
40 |
41 | # TODO: needs existing container
42 | def test_text_to_image_model():
43 | image_uri = get_framework_ecr_image(repository_name="huggingface-pytorch-inference", device="gpu")
44 |
45 | name = "hf-test-text-to-image"
46 | task = "text-to-image"
47 | model = "echarlaix/tiny-random-stable-diffusion-xl"
48 | # instance_type = "ml.m5.large" if device == "cpu" else "ml.g4dn.xlarge"
49 | instance_type = "local_gpu"
50 | env = {"HF_MODEL_ID": model, "HF_TASK": task}
51 |
52 | sagemaker_session = Session()
53 | client = boto3.client("sagemaker-runtime")
54 |
55 | hf_model = Model(
56 | image_uri=image_uri, # A Docker image URI.
57 | model_data=None, # The S3 location of a SageMaker model data .tar.gz
58 | env=env, # Environment variables to run with image_uri when hosted in SageMaker (default: None).
59 | role=SAGEMAKER_EXECUTION_ROLE, # An AWS IAM role (either name or full ARN).
60 | name=name, # The model name
61 | sagemaker_session=sagemaker_session,
62 | )
63 |
64 | with timeout_and_delete_by_name(name, sagemaker_session, minutes=59):
65 | # Use accelerator type to differentiate EI vs. CPU and GPU. Don't use processor value
66 | hf_model.deploy(
67 | initial_instance_count=1,
68 | instance_type=instance_type,
69 | endpoint_name=name,
70 | )
71 | response = client.invoke_endpoint(
72 | EndpointName=name,
73 | Body={"inputs": "a yellow lemon tree"},
74 | ContentType="application/json",
75 | Accept="image/png",
76 | )
77 |
78 | # validate response
79 | response_body = response["Body"].read().decode("utf-8")
80 |
81 | img = Image.open(BytesIO(response_body))
82 | assert isinstance(img, Image.Image)
83 |
84 | clean_up(endpoint_name=name, sagemaker_session=sagemaker_session)
85 |
--------------------------------------------------------------------------------
/tests/integ/test_models_from_hub.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | import boto3
9 | from integ.config import task2input, task2model, task2output, task2performance, task2validation
10 | from integ.utils import clean_up, timeout_and_delete_by_name, track_infer_time
11 | from sagemaker import Session
12 | from sagemaker.model import Model
13 |
14 |
15 | os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
16 | SAGEMAKER_EXECUTION_ROLE = os.environ.get("SAGEMAKER_EXECUTION_ROLE", "sagemaker_execution_role")
17 |
18 |
19 | def get_framework_ecr_image(registry_id="763104351884", repository_name="huggingface-pytorch-inference", device="cpu"):
20 | client = boto3.client("ecr")
21 |
22 | def get_all_ecr_images(registry_id, repository_name, result_key):
23 | response = client.list_images(
24 | registryId=registry_id,
25 | repositoryName=repository_name,
26 | )
27 | results = response[result_key]
28 | while "nextToken" in response:
29 | response = client.list_images(
30 | registryId=registry_id,
31 | nextToken=response["nextToken"],
32 | repositoryName=repository_name,
33 | )
34 | results.extend(response[result_key])
35 | return results
36 |
37 | images = get_all_ecr_images(registry_id=registry_id, repository_name=repository_name, result_key="imageIds")
38 | image_tags = [image["imageTag"] for image in images]
39 | image_regex = re.compile("\d\.\d\.\d-" + device + "-.{4}$")
40 | tag = sorted(list(filter(image_regex.match, image_tags)), reverse=True)[0]
41 | return f"{registry_id}.dkr.ecr.{os.environ.get('AWS_DEFAULT_REGION','us-east-1')}.amazonaws.com/{repository_name}:{tag}"
42 |
43 |
44 | @pytest.mark.parametrize(
45 | "task",
46 | [
47 | "text-classification",
48 | "zero-shot-classification",
49 | "ner",
50 | "question-answering",
51 | "fill-mask",
52 | "summarization",
53 | "translation_xx_to_yy",
54 | "text2text-generation",
55 | "text-generation",
56 | "feature-extraction",
57 | "image-classification",
58 | "automatic-speech-recognition",
59 | ],
60 | )
61 | @pytest.mark.parametrize(
62 | "framework",
63 | ["pytorch", "tensorflow"],
64 | )
65 | @pytest.mark.parametrize(
66 | "device",
67 | [
68 | "gpu",
69 | "cpu",
70 | ],
71 | )
72 | def test_deployment_from_hub(task, device, framework):
73 | image_uri = get_framework_ecr_image(repository_name=f"huggingface-{framework}-inference", device=device)
74 | name = f"hf-test-{framework}-{device}-{task}".replace("_", "-")
75 | model = task2model[task][framework]
76 | # instance_type = "ml.m5.large" if device == "cpu" else "ml.g4dn.xlarge"
77 | instance_type = "local" if device == "cpu" else "local_gpu"
78 | number_of_requests = 100
79 | if model is None:
80 | return
81 |
82 | env = {"HF_MODEL_ID": model, "HF_TASK": task}
83 |
84 | sagemaker_session = Session()
85 | client = boto3.client("sagemaker-runtime")
86 |
87 | hf_model = Model(
88 | image_uri=image_uri, # A Docker image URI.
89 | model_data=None, # The S3 location of a SageMaker model data .tar.gz
90 | env=env, # Environment variables to run with image_uri when hosted in SageMaker (default: None).
91 | role=SAGEMAKER_EXECUTION_ROLE, # An AWS IAM role (either name or full ARN).
92 | name=name, # The model name
93 | sagemaker_session=sagemaker_session,
94 | )
95 |
96 | with timeout_and_delete_by_name(name, sagemaker_session, minutes=59):
97 | # Use accelerator type to differentiate EI vs. CPU and GPU. Don't use processor value
98 | hf_model.deploy(
99 | initial_instance_count=1,
100 | instance_type=instance_type,
101 | endpoint_name=name,
102 | )
103 |
104 | # Keep track of the inference time
105 | time_buffer = []
106 |
107 | # Warm up the model
108 | if task == "image-classification":
109 | response = client.invoke_endpoint(
110 | EndpointName=name,
111 | Body=task2input[task],
112 | ContentType="image/jpeg",
113 | Accept="application/json",
114 | )
115 | elif task == "automatic-speech-recognition":
116 | response = client.invoke_endpoint(
117 | EndpointName=name,
118 | Body=task2input[task],
119 | ContentType="audio/x-flac",
120 | Accept="application/json",
121 | )
122 | else:
123 | response = client.invoke_endpoint(
124 | EndpointName=name,
125 | Body=json.dumps(task2input[task]),
126 | ContentType="application/json",
127 | Accept="application/json",
128 | )
129 |
130 | # validate response
131 | response_body = response["Body"].read().decode("utf-8")
132 |
133 | assert True is task2validation[task](result=json.loads(response_body), snapshot=task2output[task])
134 |
135 | for _ in range(number_of_requests):
136 | with track_infer_time(time_buffer):
137 | if task == "image-classification":
138 | response = client.invoke_endpoint(
139 | EndpointName=name,
140 | Body=task2input[task],
141 | ContentType="image/jpeg",
142 | Accept="application/json",
143 | )
144 | elif task == "automatic-speech-recognition":
145 | response = client.invoke_endpoint(
146 | EndpointName=name,
147 | Body=task2input[task],
148 | ContentType="audio/x-flac",
149 | Accept="application/json",
150 | )
151 | else:
152 | response = client.invoke_endpoint(
153 | EndpointName=name,
154 | Body=json.dumps(task2input[task]),
155 | ContentType="application/json",
156 | Accept="application/json",
157 | )
158 | with open(f"{name}.json", "w") as outfile:
159 | data = {
160 | "index": name,
161 | "framework": framework,
162 | "device": device,
163 | "model": model,
164 | "number_of_requests": number_of_requests,
165 | "average_request_time": np.mean(time_buffer),
166 | "max_request_time": max(time_buffer),
167 | "min_request_time": min(time_buffer),
168 | "p95_request_time": np.percentile(time_buffer, 95),
169 | "body": json.loads(response_body),
170 | }
171 | json.dump(data, outfile)
172 |
173 | assert task2performance[task][device]["average_request_time"] >= np.mean(time_buffer)
174 |
175 | clean_up(endpoint_name=name, sagemaker_session=sagemaker_session)
176 |
--------------------------------------------------------------------------------
/tests/integ/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import logging
4 | import re
5 | import signal
6 | from contextlib import contextmanager
7 | from time import time
8 |
9 | from botocore.exceptions import ClientError
10 |
11 |
12 | LOGGER = logging.getLogger("timeout")
13 |
14 |
15 | class TimeoutError(Exception):
16 | pass
17 |
18 |
19 | def clean_up(endpoint_name, sagemaker_session):
20 | try:
21 | sagemaker_session.delete_endpoint(endpoint_name)
22 | sagemaker_session.delete_endpoint_config(endpoint_name)
23 | sagemaker_session.delete_model(endpoint_name)
24 | LOGGER.info("deleted endpoint {}".format(endpoint_name))
25 | except ClientError as ce:
26 | if ce.response["Error"]["Code"] == "ValidationException":
27 | # avoids the inner exception to be overwritten
28 | pass
29 |
30 |
31 | @contextmanager
32 | def timeout(seconds=0, minutes=0, hours=0):
33 | """Add a signal-based timeout to any block of code.
34 | If multiple time units are specified, they will be added together to determine time limit.
35 | Usage:
36 | with timeout(seconds=5):
37 | my_slow_function(...)
38 | Args:
39 | - seconds: The time limit, in seconds.
40 | - minutes: The time limit, in minutes.
41 | - hours: The time limit, in hours.
42 | """
43 |
44 | limit = seconds + 60 * minutes + 3600 * hours
45 |
46 | def handler(signum, frame):
47 | raise TimeoutError("timed out after {} seconds".format(limit))
48 |
49 | try:
50 | signal.signal(signal.SIGALRM, handler)
51 | signal.alarm(limit)
52 |
53 | yield
54 | finally:
55 | signal.alarm(0)
56 |
57 |
58 | @contextmanager
59 | def timeout_and_delete_by_name(endpoint_name, sagemaker_session, seconds=0, minutes=0, hours=0):
60 | with timeout(seconds=seconds, minutes=minutes, hours=hours) as t:
61 | try:
62 | yield [t]
63 | finally:
64 | clean_up(endpoint_name, sagemaker_session)
65 |
66 |
67 | @contextmanager
68 | def track_infer_time(buffer=[]):
69 | start = time()
70 | yield
71 | end = time()
72 |
73 | buffer.append(end - start)
74 |
75 |
76 | _re_word_boundaries = re.compile(r"\b")
77 |
78 |
79 | def count_tokens(inputs: dict, task: str) -> int:
80 | if task == "question-answering":
81 | context_len = len(_re_word_boundaries.findall(inputs["context"])) >> 1
82 | question_len = len(_re_word_boundaries.findall(inputs["question"])) >> 1
83 | return question_len + context_len
84 | else:
85 | return len(_re_word_boundaries.findall(inputs)) >> 1
86 |
87 |
88 | def validate_classification(result=None, snapshot=None):
89 | for idx, _ in enumerate(result):
90 | assert result[idx].keys() == snapshot[idx].keys()
91 | assert result[idx]["score"] >= snapshot[idx]["score"]
92 | return True
93 |
94 |
95 | def validate_zero_shot_classification(result=None, snapshot=None):
96 | assert result.keys() == snapshot.keys()
97 | assert result["labels"] == snapshot["labels"]
98 | assert result["sequence"] == snapshot["sequence"]
99 | for idx in range(len(result["scores"])):
100 | assert result["scores"][idx] >= snapshot["scores"][idx]
101 | return True
102 |
103 |
104 | def validate_ner(result=None, snapshot=None):
105 | for idx, _ in enumerate(result):
106 | assert result[idx].keys() == snapshot[idx].keys()
107 | assert result[idx]["score"] >= snapshot[idx]["score"]
108 | assert result[idx]["entity"] == snapshot[idx]["entity"]
109 | assert result[idx]["entity"] == snapshot[idx]["entity"]
110 | return True
111 |
112 |
113 | def validate_question_answering(result=None, snapshot=None):
114 | assert result.keys() == snapshot.keys()
115 | assert result["answer"] == snapshot["answer"]
116 | assert result["score"] >= snapshot["score"]
117 | return True
118 |
119 |
120 | def validate_summarization(result=None, snapshot=None):
121 | assert result == snapshot
122 | return True
123 |
124 |
125 | def validate_text2text_generation(result=None, snapshot=None):
126 | assert result == snapshot
127 | return True
128 |
129 |
130 | def validate_translation(result=None, snapshot=None):
131 | assert result == snapshot
132 | return True
133 |
134 |
135 | def validate_text_generation(result=None, snapshot=None):
136 | assert result is not None
137 | return True
138 |
139 |
140 | def validate_feature_extraction(result=None, snapshot=None):
141 | assert result is not None
142 | return True
143 |
144 |
145 | def validate_fill_mask(result=None, snapshot=None):
146 | assert result is not None
147 | return True
148 |
149 |
150 | def validate_automatic_speech_recognition(result=None, snapshot=None):
151 | assert result is not None
152 | assert "text" in result
153 | return True
154 |
--------------------------------------------------------------------------------
/tests/resources/audio/sample1.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.flac
--------------------------------------------------------------------------------
/tests/resources/audio/sample1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.mp3
--------------------------------------------------------------------------------
/tests/resources/audio/sample1.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.ogg
--------------------------------------------------------------------------------
/tests/resources/audio/sample1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/audio/sample1.wav
--------------------------------------------------------------------------------
/tests/resources/image/tiger.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.bmp
--------------------------------------------------------------------------------
/tests/resources/image/tiger.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.gif
--------------------------------------------------------------------------------
/tests/resources/image/tiger.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.jpeg
--------------------------------------------------------------------------------
/tests/resources/image/tiger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.png
--------------------------------------------------------------------------------
/tests/resources/image/tiger.tiff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.tiff
--------------------------------------------------------------------------------
/tests/resources/image/tiger.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/resources/image/tiger.webp
--------------------------------------------------------------------------------
/tests/resources/model_input_predict_output_fn_with_context/code/inference.py:
--------------------------------------------------------------------------------
1 | def model_fn(model_dir, context=None):
2 | return "model"
3 |
4 |
5 | def input_fn(data, content_type, context=None):
6 | return "data"
7 |
8 |
9 | def predict_fn(data, model, context=None):
10 | return "output"
11 |
12 |
13 | def output_fn(prediction, accept, context=None):
14 | return prediction
15 |
--------------------------------------------------------------------------------
/tests/resources/model_input_predict_output_fn_without_context/code/inference.py:
--------------------------------------------------------------------------------
1 | def model_fn(model_dir):
2 | return "model"
3 |
4 |
5 | def input_fn(data, content_type):
6 | return "data"
7 |
8 |
9 | def predict_fn(data, model):
10 | return "output"
11 |
12 |
13 | def output_fn(prediction, accept):
14 | return prediction
15 |
--------------------------------------------------------------------------------
/tests/resources/model_transform_fn_with_context/code/inference_tranform_fn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def model_fn(model_dir, context=None):
5 | return f"Loading {os.path.basename(__file__)}"
6 |
7 |
8 | def transform_fn(a, b, c, d, context=None):
9 | return f"output {b}"
10 |
--------------------------------------------------------------------------------
/tests/resources/model_transform_fn_without_context/code/inference_tranform_fn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def model_fn(model_dir):
5 | return f"Loading {os.path.basename(__file__)}"
6 |
7 |
8 | def transform_fn(a, b, c, d):
9 | return f"output {b}"
10 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/sagemaker-huggingface-inference-toolkit/69bc5ef978e1acce63386c10aa2fdb97c92c977e/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests/unit/test_decoder_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | import os
16 |
17 | import numpy as np
18 | import pytest
19 | from transformers.testing_utils import require_torch
20 |
21 | from mms.service import PredictionException
22 | from PIL import Image
23 | from sagemaker_huggingface_inference_toolkit import decoder_encoder
24 |
25 |
26 | ENCODE_JSON_INPUT = {"upper": [1425], "lower": [576], "level": [2], "datetime": ["2012-08-08 15:30"]}
27 | ENCODE_CSV_INPUT = [
28 | {"answer": "Nuremberg", "end": 42, "score": 0.9926825761795044, "start": 33},
29 | {"answer": "Berlin is the capital of Germany", "end": 32, "score": 0.26097726821899414, "start": 0},
30 | ]
31 | ENCODE_TOLOIST_INPUT = [1, 0.5, 5.0]
32 |
33 | DECODE_JSON_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
34 | DECODE_CSV_INPUT = "question,context\r\nwhere do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"
35 |
36 | CONTENT_TYPE = "application/json"
37 |
38 |
39 | def test_decode_json():
40 | decoded_data = decoder_encoder.decode_json(json.dumps(DECODE_JSON_INPUT))
41 | assert decoded_data == DECODE_JSON_INPUT
42 |
43 |
44 | def test_decode_csv():
45 | decoded_data = decoder_encoder.decode_csv(DECODE_CSV_INPUT)
46 | assert decoded_data == {
47 | "inputs": [
48 | {"question": "where do i live?", "context": "My name is Philipp and I live in Nuremberg"},
49 | {"question": "where is Berlin?", "context": "Berlin is the capital of Germany"},
50 | ]
51 | }
52 | text_classification_input = "inputs\r\nI love you\r\nI like you"
53 | decoded_data = decoder_encoder.decode_csv(text_classification_input)
54 | assert decoded_data == {"inputs": ["I love you", "I like you"]}
55 |
56 |
57 | def test_decode_image():
58 | image_files_path = os.path.join(os.getcwd(), "tests/resources/image")
59 |
60 | for image_file in os.listdir(image_files_path):
61 | image_bytes = open(os.path.join(image_files_path, image_file), "rb").read()
62 | decoded_data = decoder_encoder.decode_image(bytearray(image_bytes))
63 |
64 | assert isinstance(decoded_data, dict)
65 | assert isinstance(decoded_data["inputs"], Image.Image)
66 |
67 |
68 | def test_decode_audio():
69 | audio_files_path = os.path.join(os.getcwd(), "tests/resources/audio")
70 |
71 | for audio_file in os.listdir(audio_files_path):
72 | audio_bytes = open(os.path.join(audio_files_path, audio_file), "rb").read()
73 | decoded_data = decoder_encoder.decode_audio(bytearray(audio_bytes))
74 |
75 | assert {"inputs": audio_bytes} == decoded_data
76 |
77 |
78 | def test_decode_csv_without_header():
79 | with pytest.raises(PredictionException):
80 | decoder_encoder.decode_csv(
81 | "where do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"
82 | )
83 |
84 |
85 | def test_encode_json():
86 | encoded_data = decoder_encoder.encode_json(ENCODE_JSON_INPUT)
87 | assert json.loads(encoded_data) == ENCODE_JSON_INPUT
88 |
89 |
90 | @require_torch
91 | def test_encode_json_torch():
92 | import torch
93 |
94 | encoded_data = decoder_encoder.encode_json({"data": torch.tensor(ENCODE_TOLOIST_INPUT)})
95 | assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT}
96 |
97 |
98 | def test_encode_json_numpy():
99 | encoded_data = decoder_encoder.encode_json({"data": np.array(ENCODE_TOLOIST_INPUT)})
100 | assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT}
101 |
102 |
103 | def test_encode_csv():
104 | decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT)
105 | assert (
106 | decoded_data
107 | == "answer,end,score,start\r\nNuremberg,42,0.9926825761795044,33\r\nBerlin is the capital of Germany,32,0.26097726821899414,0\r\n"
108 | )
109 |
110 |
111 | def test_decode_content_type():
112 | decoded_content_type = decoder_encoder._decoder_map[CONTENT_TYPE]
113 | assert decoded_content_type == decoder_encoder.decode_json
114 |
115 |
116 | def test_encode_content_type():
117 | encoded_content_type = decoder_encoder._encoder_map[CONTENT_TYPE]
118 | assert encoded_content_type == decoder_encoder.encode_json
119 |
120 |
121 | def test_decode():
122 | decode = decoder_encoder.decode(json.dumps(DECODE_JSON_INPUT), CONTENT_TYPE)
123 | assert decode == DECODE_JSON_INPUT
124 |
125 |
126 | def test_encode():
127 | encode = decoder_encoder.encode(ENCODE_JSON_INPUT, CONTENT_TYPE)
128 | assert json.loads(encode) == ENCODE_JSON_INPUT
129 |
--------------------------------------------------------------------------------
/tests/unit/test_diffusers_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import tempfile
15 |
16 | from transformers.testing_utils import require_torch, slow
17 |
18 | from PIL import Image
19 | from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMDiffusionPipelineForText2Image
20 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline
21 |
22 |
23 | @require_torch
24 | def test_get_diffusers_pipeline():
25 | with tempfile.TemporaryDirectory() as tmpdirname:
26 | storage_dir = _load_model_from_hub(
27 | "hf-internal-testing/tiny-stable-diffusion-torch",
28 | tmpdirname,
29 | )
30 | pipe = get_pipeline("text-to-image", -1, storage_dir)
31 | assert isinstance(pipe, SMDiffusionPipelineForText2Image)
32 |
33 |
34 | @slow
35 | @require_torch
36 | def test_pipe_on_gpu():
37 | with tempfile.TemporaryDirectory() as tmpdirname:
38 | storage_dir = _load_model_from_hub(
39 | "hf-internal-testing/tiny-stable-diffusion-torch",
40 | tmpdirname,
41 | )
42 | pipe = get_pipeline("text-to-image", 0, storage_dir)
43 | assert pipe.device.type == "cuda"
44 |
45 |
46 | @require_torch
47 | def test_text_to_image_task():
48 | with tempfile.TemporaryDirectory() as tmpdirname:
49 | storage_dir = _load_model_from_hub("hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname)
50 | pipe = get_pipeline("text-to-image", -1, storage_dir)
51 | res = pipe("Lets create an embedding")
52 | assert isinstance(res, Image.Image)
53 |
--------------------------------------------------------------------------------
/tests/unit/test_handler_service_with_context.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | import os
16 | import tempfile
17 |
18 | import pytest
19 | from sagemaker_inference import content_types
20 | from transformers.testing_utils import require_torch, slow
21 |
22 | from mms.context import Context, RequestProcessor
23 | from mms.metrics.metrics_store import MetricsStore
24 | from mock import Mock
25 | from sagemaker_huggingface_inference_toolkit import handler_service
26 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline
27 |
28 |
29 | TASK = "text-classification"
30 | MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
31 | INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
32 | OUTPUT = [
33 | {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19},
34 | {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40},
35 | ]
36 |
37 |
38 | @pytest.fixture()
39 | def inference_handler():
40 | return handler_service.HuggingFaceHandlerService()
41 |
42 |
43 | def test_get_device_cpu(inference_handler):
44 | device = inference_handler.get_device()
45 | assert device == -1
46 |
47 |
48 | @slow
49 | def test_get_device_gpu(inference_handler):
50 | device = inference_handler.get_device()
51 | assert device > -1
52 |
53 |
54 | @require_torch
55 | def test_test_initialize(inference_handler):
56 | with tempfile.TemporaryDirectory() as tmpdirname:
57 | storage_folder = _load_model_from_hub(
58 | model_id=MODEL,
59 | model_dir=tmpdirname,
60 | )
61 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")
62 |
63 | inference_handler.initialize(CONTEXT)
64 | assert inference_handler.initialized is True
65 |
66 |
67 | @require_torch
68 | def test_handle(inference_handler):
69 | with tempfile.TemporaryDirectory() as tmpdirname:
70 | storage_folder = _load_model_from_hub(
71 | model_id=MODEL,
72 | model_dir=tmpdirname,
73 | )
74 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")
75 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
76 | CONTEXT.metrics = MetricsStore(1, MODEL)
77 |
78 | inference_handler.initialize(CONTEXT)
79 | json_data = json.dumps(INPUT)
80 | prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT)
81 | loaded_response = json.loads(prediction[0])
82 | assert "entity" in loaded_response[0]
83 | assert "score" in loaded_response[0]
84 |
85 |
86 | @require_torch
87 | def test_load(inference_handler):
88 | context = Mock()
89 | with tempfile.TemporaryDirectory() as tmpdirname:
90 | storage_folder = _load_model_from_hub(
91 | model_id=MODEL,
92 | model_dir=tmpdirname,
93 | )
94 | # test with automatic infer
95 | hf_pipeline_without_task = inference_handler.load(storage_folder, context)
96 | assert hf_pipeline_without_task.task == "token-classification"
97 |
98 | # test with automatic infer
99 | os.environ["HF_TASK"] = TASK
100 | hf_pipeline_with_task = inference_handler.load(storage_folder, context)
101 | assert hf_pipeline_with_task.task == TASK
102 |
103 |
104 | def test_preprocess(inference_handler):
105 | context = Mock()
106 | json_data = json.dumps(INPUT)
107 | decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON, context)
108 | assert "inputs" in decoded_input_data
109 |
110 |
111 | def test_preprocess_bad_content_type(inference_handler):
112 | context = Mock()
113 | with pytest.raises(json.decoder.JSONDecodeError):
114 | inference_handler.preprocess(b"", content_types.JSON, context)
115 |
116 |
117 | @require_torch
118 | def test_predict(inference_handler):
119 | context = Mock()
120 | with tempfile.TemporaryDirectory() as tmpdirname:
121 | storage_folder = _load_model_from_hub(
122 | model_id=MODEL,
123 | model_dir=tmpdirname,
124 | )
125 | inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder)
126 | prediction = inference_handler.predict(INPUT, inference_handler.model, context)
127 | assert "label" in prediction[0]
128 | assert "score" in prediction[0]
129 |
130 |
131 | def test_postprocess(inference_handler):
132 | context = Mock()
133 | output = inference_handler.postprocess(OUTPUT, content_types.JSON, context)
134 | assert isinstance(output, str)
135 |
136 |
137 | def test_validate_and_initialize_user_module(inference_handler):
138 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_with_context")
139 | CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")
140 |
141 | inference_handler.initialize(CONTEXT)
142 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
143 | CONTEXT.metrics = MetricsStore(1, MODEL)
144 |
145 | prediction = inference_handler.handle([{"body": b""}], CONTEXT)
146 | assert "output" in prediction[0]
147 |
148 | assert inference_handler.load({}, CONTEXT) == "model"
149 | assert inference_handler.preprocess({}, "", CONTEXT) == "data"
150 | assert inference_handler.predict({}, "model", CONTEXT) == "output"
151 | assert inference_handler.postprocess("output", "", CONTEXT) == "output"
152 |
153 |
154 | def test_validate_and_initialize_user_module_transform_fn():
155 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
156 | inference_handler = handler_service.HuggingFaceHandlerService()
157 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context")
158 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
159 |
160 | inference_handler.initialize(CONTEXT)
161 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
162 | CONTEXT.metrics = MetricsStore(1, MODEL)
163 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
164 | assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
165 | assert (
166 | inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
167 | == "output dummy"
168 | )
169 |
170 |
171 | def test_validate_and_initialize_user_module_transform_fn_race_condition():
172 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
173 | inference_handler = handler_service.HuggingFaceHandlerService()
174 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context")
175 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
176 |
177 | # Similuate 2 threads bypassing check in handle() - calling initialize twice
178 | inference_handler.initialize(CONTEXT)
179 | inference_handler.initialize(CONTEXT)
180 |
181 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
182 | CONTEXT.metrics = MetricsStore(1, MODEL)
183 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
184 | assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
185 | assert (
186 | inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
187 | == "output dummy"
188 | )
189 |
--------------------------------------------------------------------------------
/tests/unit/test_handler_service_without_context.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | import os
16 | import tempfile
17 |
18 | import pytest
19 | from sagemaker_inference import content_types
20 | from transformers.testing_utils import require_torch, slow
21 |
22 | from mms.context import Context, RequestProcessor
23 | from mms.metrics.metrics_store import MetricsStore
24 | from sagemaker_huggingface_inference_toolkit import handler_service
25 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline
26 |
27 |
28 | TASK = "text-classification"
29 | MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
30 | INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
31 | OUTPUT = [
32 | {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19},
33 | {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40},
34 | ]
35 |
36 |
37 | @pytest.fixture()
38 | def inference_handler():
39 | return handler_service.HuggingFaceHandlerService()
40 |
41 |
42 | def test_get_device_cpu(inference_handler):
43 | device = inference_handler.get_device()
44 | assert device == -1
45 |
46 |
47 | @slow
48 | def test_get_device_gpu(inference_handler):
49 | device = inference_handler.get_device()
50 | assert device > -1
51 |
52 |
53 | @require_torch
54 | def test_test_initialize(inference_handler):
55 | with tempfile.TemporaryDirectory() as tmpdirname:
56 | storage_folder = _load_model_from_hub(
57 | model_id=MODEL,
58 | model_dir=tmpdirname,
59 | )
60 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")
61 |
62 | inference_handler.initialize(CONTEXT)
63 | assert inference_handler.initialized is True
64 |
65 |
66 | @require_torch
67 | def test_handle(inference_handler):
68 | with tempfile.TemporaryDirectory() as tmpdirname:
69 | storage_folder = _load_model_from_hub(
70 | model_id=MODEL,
71 | model_dir=tmpdirname,
72 | )
73 | CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")
74 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
75 | CONTEXT.metrics = MetricsStore(1, MODEL)
76 |
77 | inference_handler.initialize(CONTEXT)
78 | json_data = json.dumps(INPUT)
79 | prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT)
80 | assert "output" in prediction[0]
81 |
82 |
83 | @require_torch
84 | def test_load(inference_handler):
85 | with tempfile.TemporaryDirectory() as tmpdirname:
86 | storage_folder = _load_model_from_hub(
87 | model_id=MODEL,
88 | model_dir=tmpdirname,
89 | )
90 | # test with automatic infer
91 | if "HF_TASK" in os.environ:
92 | del os.environ["HF_TASK"]
93 | hf_pipeline_without_task = inference_handler.load(storage_folder)
94 | assert hf_pipeline_without_task.task == "token-classification"
95 |
96 | # test with automatic infer
97 | os.environ["HF_TASK"] = "text-classification"
98 | hf_pipeline_with_task = inference_handler.load(storage_folder)
99 | assert hf_pipeline_with_task.task == "text-classification"
100 |
101 |
102 | def test_preprocess(inference_handler):
103 | json_data = json.dumps(INPUT)
104 | decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON)
105 | assert "inputs" in decoded_input_data
106 |
107 |
108 | def test_preprocess_bad_content_type(inference_handler):
109 | with pytest.raises(json.decoder.JSONDecodeError):
110 | inference_handler.preprocess(b"", content_types.JSON)
111 |
112 |
113 | @require_torch
114 | def test_predict(inference_handler):
115 | with tempfile.TemporaryDirectory() as tmpdirname:
116 | storage_folder = _load_model_from_hub(
117 | model_id=MODEL,
118 | model_dir=tmpdirname,
119 | )
120 | inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder)
121 | prediction = inference_handler.predict(INPUT, inference_handler.model)
122 | assert "label" in prediction[0]
123 | assert "score" in prediction[0]
124 |
125 |
126 | def test_postprocess(inference_handler):
127 | output = inference_handler.postprocess(OUTPUT, content_types.JSON)
128 | assert isinstance(output, str)
129 |
130 |
131 | def test_validate_and_initialize_user_module(inference_handler):
132 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_without_context")
133 | CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")
134 |
135 | inference_handler.initialize(CONTEXT)
136 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
137 | CONTEXT.metrics = MetricsStore(1, MODEL)
138 |
139 | prediction = inference_handler.handle([{"body": b""}], CONTEXT)
140 | assert "output" in prediction[0]
141 |
142 | assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
143 |
144 |
145 | def test_validate_and_initialize_user_module_transform_fn():
146 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
147 | inference_handler = handler_service.HuggingFaceHandlerService()
148 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")
149 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
150 |
151 | inference_handler.initialize(CONTEXT)
152 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
153 | CONTEXT.metrics = MetricsStore(1, MODEL)
154 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
155 | assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
156 | assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"
157 |
158 |
159 | def test_validate_and_initialize_user_module_transform_fn_race_condition():
160 | os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
161 | inference_handler = handler_service.HuggingFaceHandlerService()
162 | model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")
163 | CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
164 |
165 | # Similuate 2 threads bypassing check in handle() - calling initialize twice
166 | inference_handler.initialize(CONTEXT)
167 | inference_handler.initialize(CONTEXT)
168 |
169 | CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
170 | CONTEXT.metrics = MetricsStore(1, MODEL)
171 | assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
172 | assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
173 | assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"
174 |
--------------------------------------------------------------------------------
/tests/unit/test_mms_model_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.import os
14 | import os
15 |
16 | from sagemaker_inference.environment import model_dir
17 |
18 | from mock import patch
19 | from sagemaker_huggingface_inference_toolkit import mms_model_server, transformers_utils
20 |
21 |
22 | PYTHON_PATH = "python_path"
23 | DEFAULT_CONFIGURATION = "default_configuration"
24 |
25 |
26 | @patch("subprocess.call")
27 | @patch("subprocess.Popen")
28 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
29 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
30 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
31 | @patch("os.path.exists", return_value=True)
32 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
33 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
34 | @patch("sagemaker_inference.environment.Environment")
35 | def test_start_mms_default_service_handler(
36 | env,
37 | adapt,
38 | create_config,
39 | exists,
40 | install_requirements,
41 | sigterm,
42 | retrieve,
43 | subprocess_popen,
44 | subprocess_call,
45 | ):
46 | env.return_value.startup_timeout = 10000
47 | mms_model_server.start_model_server()
48 |
49 | # In this case, we should not rearchive the model
50 | adapt.assert_not_called()
51 |
52 | create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE)
53 | exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH)
54 | install_requirements.assert_called_once_with()
55 |
56 | multi_model_server_cmd = [
57 | "multi-model-server",
58 | "--start",
59 | "--model-store",
60 | mms_model_server.DEFAULT_MODEL_STORE,
61 | "--mms-config",
62 | mms_model_server.MMS_CONFIG_FILE,
63 | "--log-config",
64 | mms_model_server.DEFAULT_MMS_LOG_FILE,
65 | "--models",
66 | "{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir),
67 | ]
68 |
69 | subprocess_popen.assert_called_once_with(multi_model_server_cmd)
70 | sigterm.assert_called_once_with(retrieve.return_value)
71 |
72 |
73 | @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True)
74 | @patch("subprocess.call")
75 | @patch("subprocess.Popen")
76 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
77 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub")
78 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
79 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
80 | @patch("os.makedirs", return_value=True)
81 | @patch("os.remove", return_value=True)
82 | @patch("os.path.exists", return_value=True)
83 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
84 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
85 | @patch("sagemaker_inference.environment.Environment")
86 | def test_start_mms_neuron(
87 | env,
88 | adapt,
89 | create_config,
90 | exists,
91 | remove,
92 | dir,
93 | install_requirements,
94 | sigterm,
95 | load_model_from_hub,
96 | retrieve,
97 | subprocess_popen,
98 | subprocess_call,
99 | is_aws_neuron_available,
100 | ):
101 | env.return_value.startup_timeout = 10000
102 | mms_model_server.start_model_server()
103 |
104 | # In this case, we should not call model archiver
105 | adapt.assert_not_called()
106 |
107 | create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE)
108 | exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH)
109 | install_requirements.assert_called_once_with()
110 |
111 | multi_model_server_cmd = [
112 | "multi-model-server",
113 | "--start",
114 | "--model-store",
115 | mms_model_server.DEFAULT_MODEL_STORE,
116 | "--mms-config",
117 | mms_model_server.MMS_CONFIG_FILE,
118 | "--log-config",
119 | mms_model_server.DEFAULT_MMS_LOG_FILE,
120 | "--models",
121 | "{}={}".format(mms_model_server.DEFAULT_MMS_MODEL_NAME, model_dir),
122 | ]
123 |
124 | subprocess_popen.assert_called_once_with(multi_model_server_cmd)
125 | sigterm.assert_called_once_with(retrieve.return_value)
126 |
127 |
128 | @patch("subprocess.call")
129 | @patch("subprocess.Popen")
130 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
131 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub")
132 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
133 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
134 | @patch("os.makedirs", return_value=True)
135 | @patch("os.remove", return_value=True)
136 | @patch("os.path.exists", return_value=True)
137 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
138 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
139 | @patch("sagemaker_inference.environment.Environment")
140 | def test_start_mms_with_model_from_hub(
141 | env,
142 | adapt,
143 | create_config,
144 | exists,
145 | remove,
146 | dir,
147 | install_requirements,
148 | sigterm,
149 | load_model_from_hub,
150 | retrieve,
151 | subprocess_popen,
152 | subprocess_call,
153 | ):
154 | env.return_value.startup_timeout = 10000
155 |
156 | os.environ["HF_MODEL_ID"] = "lysandre/tiny-bert-random"
157 |
158 | mms_model_server.start_model_server()
159 |
160 | load_model_from_hub.assert_called_once_with(
161 | model_id=os.environ["HF_MODEL_ID"],
162 | model_dir=mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
163 | revision=transformers_utils.HF_MODEL_REVISION,
164 | use_auth_token=transformers_utils.HF_API_TOKEN,
165 | )
166 |
167 | # When loading model from hub, we do call model archiver
168 | adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub())
169 |
170 | create_config.assert_called_once_with(env.return_value, mms_model_server.DEFAULT_HANDLER_SERVICE)
171 | exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH)
172 | install_requirements.assert_called_once_with()
173 |
174 | multi_model_server_cmd = [
175 | "multi-model-server",
176 | "--start",
177 | "--model-store",
178 | mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
179 | "--mms-config",
180 | mms_model_server.MMS_CONFIG_FILE,
181 | "--log-config",
182 | mms_model_server.DEFAULT_MMS_LOG_FILE,
183 | ]
184 |
185 | subprocess_popen.assert_called_once_with(multi_model_server_cmd)
186 | sigterm.assert_called_once_with(retrieve.return_value)
187 | os.remove(mms_model_server.DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY)
188 |
--------------------------------------------------------------------------------
/tests/unit/test_optimum_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import tempfile
16 |
17 | import pytest
18 | from transformers.testing_utils import require_torch
19 |
20 | from sagemaker_huggingface_inference_toolkit.optimum_utils import (
21 | get_input_shapes,
22 | get_optimum_neuron_pipeline,
23 | is_optimum_neuron_available,
24 | )
25 | from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub
26 |
27 |
28 | require_inferentia = pytest.mark.skipif(
29 | not is_optimum_neuron_available(),
30 | reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.",
31 | )
32 |
33 |
34 | REMOTE_NOT_CONVERTED_MODEL = "hf-internal-testing/tiny-random-BertModel"
35 | REMOTE_CONVERTED_MODEL = "optimum/tiny_random_bert_neuron"
36 | TASK = "text-classification"
37 |
38 |
39 | @require_torch
40 | @require_inferentia
41 | def test_not_supported_task():
42 | os.environ["HF_TASK"] = "not-supported-task"
43 | with pytest.raises(Exception):
44 | get_optimum_neuron_pipeline(task=TASK, model_dir=os.getcwd())
45 |
46 |
47 | @require_torch
48 | @require_inferentia
49 | def test_get_input_shapes_from_file():
50 | with tempfile.TemporaryDirectory() as tmpdirname:
51 | storage_folder = _load_model_from_hub(
52 | model_id=REMOTE_CONVERTED_MODEL,
53 | model_dir=tmpdirname,
54 | )
55 | input_shapes = get_input_shapes(model_dir=storage_folder)
56 | assert input_shapes["batch_size"] == 1
57 | assert input_shapes["sequence_length"] == 32
58 |
59 |
60 | @require_torch
61 | @require_inferentia
62 | def test_get_input_shapes_from_env():
63 | os.environ["HF_OPTIMUM_BATCH_SIZE"] = "4"
64 | os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32"
65 | with tempfile.TemporaryDirectory() as tmpdirname:
66 | storage_folder = _load_model_from_hub(
67 | model_id=REMOTE_NOT_CONVERTED_MODEL,
68 | model_dir=tmpdirname,
69 | )
70 | input_shapes = get_input_shapes(model_dir=storage_folder)
71 | assert input_shapes["batch_size"] == 4
72 | assert input_shapes["sequence_length"] == 32
73 |
74 |
75 | @require_torch
76 | @require_inferentia
77 | def test_get_optimum_neuron_pipeline_from_converted_model():
78 | with tempfile.TemporaryDirectory() as tmpdirname:
79 | os.system(
80 | f"optimum-cli export neuron --model philschmid/tiny-distilbert-classification --sequence_length 32 --batch_size 1 {tmpdirname}"
81 | )
82 | pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=tmpdirname)
83 | r = pipe("This is a test")
84 |
85 | assert r[0]["score"] > 0.0
86 | assert isinstance(r[0]["label"], str)
87 |
88 |
89 | @require_torch
90 | @require_inferentia
91 | def test_get_optimum_neuron_pipeline_from_non_converted_model():
92 | os.environ["OPTIMUM_NEURON_SEQUENCE_LENGTH"] = "32"
93 | with tempfile.TemporaryDirectory() as tmpdirname:
94 | storage_folder = _load_model_from_hub(
95 | model_id=REMOTE_NOT_CONVERTED_MODEL,
96 | model_dir=tmpdirname,
97 | )
98 | pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=storage_folder)
99 | r = pipe("This is a test")
100 |
101 | assert r[0]["score"] > 0.0
102 | assert isinstance(r[0]["label"], str)
103 |
--------------------------------------------------------------------------------
/tests/unit/test_serving.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from mock import patch
16 |
17 |
18 | @patch("sagemaker_huggingface_inference_toolkit.mms_model_server.start_model_server")
19 | def test_hosting_start(start_model_server):
20 | from sagemaker_huggingface_inference_toolkit import serving
21 |
22 | serving.main()
23 | start_model_server.assert_called_with(handler_service="sagemaker_huggingface_inference_toolkit.handler_service")
24 |
25 |
26 | def test_retry_if_error():
27 | from sagemaker_huggingface_inference_toolkit import serving
28 |
29 | serving._retry_if_error(Exception)
30 |
--------------------------------------------------------------------------------
/tests/unit/test_transformers_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import tempfile
16 |
17 | from transformers.file_utils import is_torch_available
18 | from transformers.testing_utils import require_tf, require_torch, slow
19 |
20 | from sagemaker_huggingface_inference_toolkit.transformers_utils import (
21 | _build_storage_path,
22 | _get_framework,
23 | _is_gpu_available,
24 | _load_model_from_hub,
25 | get_pipeline,
26 | infer_task_from_hub,
27 | infer_task_from_model_architecture,
28 | )
29 |
30 |
31 | MODEL = "lysandre/tiny-bert-random"
32 | TASK = "text-classification"
33 | TASK_MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
34 |
35 | MAIN_REVISION = "main"
36 | REVISION = "eb4c77816edd604d0318f8e748a1c606a2888493"
37 |
38 |
39 | @require_torch
40 | def test_loading_model_from_hub():
41 | with tempfile.TemporaryDirectory() as tmpdirname:
42 | storage_folder = _load_model_from_hub(
43 | model_id=MODEL,
44 | model_dir=tmpdirname,
45 | )
46 |
47 | # folder contains all config files and pytorch_model.bin
48 | folder_contents = os.listdir(storage_folder)
49 | assert "config.json" in folder_contents
50 |
51 |
52 | @require_torch
53 | def test_loading_model_from_hub_with_revision():
54 | with tempfile.TemporaryDirectory() as tmpdirname:
55 | storage_folder = _load_model_from_hub(model_id=MODEL, model_dir=tmpdirname, revision=REVISION)
56 |
57 | # folder contains all config files and pytorch_model.bin
58 | assert REVISION in storage_folder
59 | folder_contents = os.listdir(storage_folder)
60 | assert "config.json" in folder_contents
61 | assert "tokenizer_config.json" not in folder_contents
62 |
63 |
64 | @require_torch
65 | def test_loading_model_safetensor_from_hub_with_revision():
66 | with tempfile.TemporaryDirectory() as tmpdirname:
67 | storage_folder = _load_model_from_hub(
68 | model_id="hf-internal-testing/tiny-random-bert-safetensors", model_dir=tmpdirname
69 | )
70 |
71 | folder_contents = os.listdir(storage_folder)
72 | assert "model.safetensors" in folder_contents
73 |
74 |
75 | def test_gpu_is_not_available():
76 | device = _is_gpu_available()
77 | assert device is False
78 |
79 |
80 | def test_build_storage_path():
81 | storage_path = _build_storage_path(model_id=MODEL, model_dir="x")
82 | assert "__" in storage_path
83 |
84 | storage_path = _build_storage_path(model_id="bert-base-uncased", model_dir="x")
85 | assert "__" not in storage_path
86 |
87 | storage_path = _build_storage_path(model_id=MODEL, model_dir="x", revision=REVISION)
88 | assert "__" in storage_path and "." in storage_path
89 |
90 |
91 | @slow
92 | def test_gpu_available():
93 | device = _is_gpu_available()
94 | assert device is True
95 |
96 |
97 | @require_torch
98 | def test_get_framework_pytorch():
99 | framework = _get_framework()
100 | assert framework == "pytorch"
101 |
102 |
103 | @require_tf
104 | def test_get_framework_tensorflow():
105 | framework = _get_framework()
106 | if is_torch_available():
107 | assert framework == "pytorch"
108 | else:
109 | assert framework == "tensorflow"
110 |
111 |
112 | def test_get_pipeline():
113 | with tempfile.TemporaryDirectory() as tmpdirname:
114 | storage_dir = _load_model_from_hub(MODEL, tmpdirname)
115 | pipe = get_pipeline(TASK, -1, storage_dir)
116 | res = pipe("Life is good, Life is bad")
117 | assert "score" in res[0]
118 |
119 |
120 | def test_infer_task_from_hub():
121 | task = infer_task_from_hub(TASK_MODEL)
122 | assert task == "token-classification"
123 |
124 |
125 | def test_infer_task_from_model_architecture():
126 | with tempfile.TemporaryDirectory() as tmpdirname:
127 | storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname)
128 | task = infer_task_from_model_architecture(f"{storage_dir}/config.json")
129 | assert task == "token-classification"
130 |
--------------------------------------------------------------------------------