├── .github └── workflows │ └── main ├── .gitignore ├── .terraform.lock.hcl ├── BUILD.md ├── LICENSE ├── README.md ├── examples ├── async_inference │ ├── .terraform.lock.hcl │ ├── README.md │ └── main.tf ├── autoscaling_example │ ├── .terraform.lock.hcl │ ├── README.md │ └── main.tf ├── deploy_from_hub │ ├── .terraform.lock.hcl │ ├── README.md │ └── main.tf ├── deploy_from_s3 │ ├── .terraform.lock.hcl │ ├── README.md │ └── main.tf ├── deploy_private_model │ ├── .terraform.lock.hcl │ ├── README.md │ └── main.tf ├── serverless_inference │ ├── .terraform.lock.hcl │ ├── README.md │ └── main.tf ├── tensorflow_example │ ├── README.md │ └── main.tf └── use_existing_iam_role │ ├── README.md │ └── main.tf ├── main.tf ├── outputs.tf ├── variables.tf └── versions.tf /.github/workflows/main: -------------------------------------------------------------------------------- 1 | name: workflow 2 | 3 | on: 4 | pull_request: {} 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v2 15 | 16 | - name: Install Go 17 | uses: actions/setup-go@v2 18 | with: { go-version: 1.16.5 } 19 | 20 | - name: Install Terraform 21 | uses: hashicorp/setup-terraform@v1 22 | with: { terraform_version: 1.0.0 } 23 | 24 | - name: Install Taskfile 25 | run: curl -sL https://taskfile.dev/install.sh | sh 26 | 27 | - name: Run tests 28 | run: ./bin/task test 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Local .terraform directories 2 | **/.terraform/* 3 | 4 | # .tfstate files 5 | *.tfstate 6 | *.tfstate.* 7 | 8 | # Crash log files 9 | crash.log 10 | 11 | # Ignore any .tfvars files that are generated automatically for each Terraform run. Most 12 | # .tfvars files are managed as part of configuration and so should be included in 13 | # version control. 14 | # 15 | # example.tfvars 16 | 17 | # Ignore override files as they are usually used to override resources locally and so 18 | # are not checked in 19 | override.tf 20 | override.tf.json 21 | *_override.tf 22 | *_override.tf.json 23 | 24 | # Include override files you do wish to add to version control using negated pattern 25 | # 26 | # !example_override.tf 27 | 28 | # Include tfplan files to ignore the plan output of command: terraform plan -out=tfplan 29 | # example: *tfplan* 30 | -------------------------------------------------------------------------------- /.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "3.74.0" 6 | hashes = [ 7 | "h1:wIP53ozevE0ihhP1Fuoir4N1qI7+TcBs0y4SHlwMxho=", 8 | "zh:00767509c13c0d1c7ad6af702c6942e6572aa6d529b40a00baacc0e73faafea2", 9 | "zh:03aafdc903ad49c2eda03889f927f44212674c50e475a9c6298850381319eec2", 10 | "zh:2de8a6a97b180f909d652f215125aa4683e99db15fcf3b28d62e3d542f875ed6", 11 | "zh:3ac29ebc3af99028f4230a79f56606a0c2954b68767bd749b921a76eb4f3bd30", 12 | "zh:50add2e2d118a15a644360eabc5a34cec59f2560b491f8fabf9c52ab83ca7b09", 13 | "zh:85dd8e81910ab79f841a4a595fdd8ac358fbfe460956144afb0be3d81f91fe10", 14 | "zh:895de83d0f0941fde31bfc53fa6b1ea276901f006bec221bbdee4771a04f3693", 15 | "zh:a15c9724aac52d1ba5001d2d83e42843099b52b1638ea29d84e20be0f45fa4f1", 16 | "zh:c982a64463bd73e9bff2589de214b1de0a571438d9015001f9eae45cfc3a2559", 17 | "zh:e9ef973c18078324e43213ea1252c12b9441e566bf054ddfdbff5dd62f3035d9", 18 | "zh:f297e705b0f339c8baa27ae70db5df9aa6578adfe1ea3d2ba8edc186512464eb", 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /BUILD.md: -------------------------------------------------------------------------------- 1 | # How to publish a new release 2 | 3 | 1. generate new docs 4 | 2. change example + version in `README.md` 5 | 3. create new Github Release 6 | 4. Update package in registry 7 | 8 | # Generate Docs 9 | 10 | ```Bash 11 | docker run --rm --volume "$(pwd):/terraform-docs" -u $(id -u) quay.io/terraform-docs/terraform-docs:0.16.0 markdown /terraform-docs > docs.md 12 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Appsilon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hugging Face Inference SageMaker Module 2 | 3 | Terraform module for easy deployment of a [Hugging Face Transformer models](hf.co/models) to [Amazon SageMaker](https://aws.amazon.com/de/sagemaker/) real-time endpoints. This module will create all the necessary resources to deploy a model to Amazon SageMaker including IAM roles, if not provided, SageMaker Model, SageMaker Endpoint Configuration, SageMaker endpoint. 4 | 5 | With this module you can deploy [Hugging Face Transformer](hf.co/models) directly from the [Model Hub](hf.co/models) or from Amazon S3 to Amazon SageMaker for PyTorch and Tensorflow based models. 6 | 7 | ## Usage 8 | 9 | **basic example** 10 | 11 | ```hcl 12 | module "sagemaker-huggingface" { 13 | source = "philschmid/sagemaker-huggingface/aws" 14 | version = "0.5.0" 15 | name_prefix = "distilbert" 16 | pytorch_version = "1.9.1" 17 | transformers_version = "4.12.3" 18 | instance_type = "ml.g4dn.xlarge" 19 | instance_count = 1 # default is 1 20 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 21 | hf_task = "text-classification" 22 | } 23 | ``` 24 | 25 | **advanced example with autoscaling** 26 | 27 | ```hcl 28 | module "sagemaker-huggingface" { 29 | source = "philschmid/sagemaker-huggingface/aws" 30 | version = "0.5.0" 31 | name_prefix = "distilbert" 32 | pytorch_version = "1.9.1" 33 | transformers_version = "4.12.3" 34 | instance_type = "ml.g4dn.xlarge" 35 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 36 | hf_task = "text-classification" 37 | autoscaling = { 38 | max_capacity = 4 # The max capacity of the scalable target 39 | scaling_target_invocations = 200 # The scaling target invocations (requests/minute) 40 | } 41 | } 42 | ``` 43 | 44 | **examples:** 45 | * [Deploy Model from hf.co/models](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/deploy_from_hub) 46 | * [Deploy Model from Amazon S3](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/deploy_from_s3) 47 | * [Deploy Private Models from hf.co/models](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/deploy_private_model) 48 | * [Autoscaling Endpoint](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/autoscaling_example) 49 | * [Asynchronous Inference](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/async_inference) 50 | * [Serverless Inference](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/serverless_inference) 51 | * [Tensorflow example](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/tensorflow_example) 52 | * [Deploy Model with existing IAM role](https://github.com/philschmid/terraform-aws-sagemaker-huggingface/tree/master/examples/use_existing_iam_role) 53 | 54 | 55 | ## Requirements 56 | 57 | | Name | Version | 58 | |------|---------| 59 | | [terraform](#requirement\_terraform) | >= 1.0.0 | 60 | | [aws](#requirement\_aws) | ~> 4.0 | 61 | 62 | ## Providers 63 | 64 | | Name | Version | 65 | |------|---------| 66 | | [aws](#provider\_aws) | 3.74.0 | 67 | | [random](#provider\_random) | n/a | 68 | 69 | ## Modules 70 | 71 | No modules. 72 | 73 | ## Resources 74 | 75 | | Name | Type | 76 | |------|------| 77 | | [aws_appautoscaling_policy.sagemaker_policy](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/appautoscaling_policy) | resource | 78 | | [aws_appautoscaling_target.sagemaker_target](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/appautoscaling_target) | resource | 79 | | [aws_iam_role.new_role](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/iam_role) | resource | 80 | | [aws_sagemaker_endpoint.huggingface](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sagemaker_endpoint) | resource | 81 | | [aws_sagemaker_endpoint_configuration.huggingface](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sagemaker_endpoint_configuration) | resource | 82 | | [aws_sagemaker_endpoint_configuration.huggingface_async](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sagemaker_endpoint_configuration) | resource | 83 | | [aws_sagemaker_endpoint_configuration.huggingface_serverless](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sagemaker_endpoint_configuration) | resource | 84 | | [aws_sagemaker_model.model_with_hub_model](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sagemaker_model) | resource | 85 | | [aws_sagemaker_model.model_with_model_artifact](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sagemaker_model) | resource | 86 | | [random_string.resource_id](https://registry.terraform.io/providers/hashicorp/random/latest/docs/resources/string) | resource | 87 | | [aws_iam_role.get_role](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/data-sources/iam_role) | data source | 88 | | [aws_sagemaker_prebuilt_ecr_image.deploy_image](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/data-sources/sagemaker_prebuilt_ecr_image) | data source | 89 | 90 | ## Inputs 91 | 92 | | Name | Description | Type | Default | Required | 93 | |------|-------------|------|---------|:--------:| 94 | | [async\_config](#input\_async\_config) | (Optional) Specifies configuration for how an endpoint performs asynchronous inference. Required key is `s3_output_path`, which is the s3 bucket used for async inference. |
object({
s3_output_path = string,
s3_failure_path = optional(string),
kms_key_id = optional(string),
sns_error_topic = optional(string),
sns_success_topic = optional(string),
})
|
{
"kms_key_id": null,
"s3_output_path": null,
"s3_failure_path": null,
"sns_error_topic": null,
"sns_success_topic": null
}
| no | 95 | | [autoscaling](#input\_autoscaling) | A Object which defines the autoscaling target and policy for our SageMaker Endpoint. Required keys are `max_capacity` and `scaling_target_invocations` |
object({
min_capacity = optional(number),
max_capacity = number,
scaling_target_invocations = optional(number),
scale_in_cooldown = optional(number),
scale_out_cooldown = optional(number),
})
|
{
"max_capacity": null,
"min_capacity": 1,
"scale_in_cooldown": 300,
"scale_out_cooldown": 66,
"scaling_target_invocations": null
}
| no | 96 | | [hf\_api\_token](#input\_hf\_api\_token) | The HF\_API\_TOKEN environment variable defines the your Hugging Face authorization token. The HF\_API\_TOKEN is used as a HTTP bearer authorization for remote files, like private models. You can find your token at your settings page. | `string` | `null` | no | 97 | | [hf\_model\_id](#input\_hf\_model\_id) | The HF\_MODEL\_ID environment variable defines the model id, which will be automatically loaded from [hf.co/models](https://huggingface.co/models) when creating or SageMaker Endpoint. | `string` | `null` | no | 98 | | [hf\_model\_revision](#input\_hf\_model\_revision) | The HF\_MODEL\_REVISION is an extension to HF\_MODEL\_ID and allows you to define/pin a revision of the model to make sure you always load the same model on your SageMaker Endpoint. | `string` | `null` | no | 99 | | [hf\_task](#input\_hf\_task) | The HF\_TASK environment variable defines the task for the used 🤗 Transformers pipeline. A full list of tasks can be find [here](https://huggingface.co/transformers/main_classes/pipelines.html). | `string` | n/a | yes | 100 | | [image\_tag](#input\_image\_tag) | The image tag you want to use for the container you want to use. Defaults to `None`. The module tries to derive the `image_tag` from the `pytorch_version`, `tensorflow_version` & `instance_type`. If you want to override this, you can provide the `image_tag` as a variable. | `string` | `null` | no | 101 | | [instance\_count](#input\_instance\_count) | The initial number of instances to run in the Endpoint created from this Model. Defaults to 1. | `number` | `1` | no | 102 | | [instance\_type](#input\_instance\_type) | The EC2 instance type to deploy this Model to. For example, `ml.p2.xlarge`. | `string` | `null` | no | 103 | | [model\_data](#input\_model\_data) | The S3 location of a SageMaker model data .tar.gz file (default: None). Not needed when using `hf_model_id`. | `string` | `null` | no | 104 | | [name\_prefix](#input\_name\_prefix) | A prefix used for naming resources. | `string` | n/a | yes | 105 | | [pytorch\_version](#input\_pytorch\_version) | PyTorch version you want to use for executing your inference code. Defaults to `None`. Required unless `tensorflow_version` is provided. [List of supported versions](https://huggingface.co/docs/sagemaker/reference#inference-dlc-overview) | `string` | `null` | no | 106 | | [sagemaker\_execution\_role](#input\_sagemaker\_execution\_role) | An AWS IAM role Name to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role if it needs to access some AWS resources. If not specified, the role will created with with the `CreateModel` permissions from the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createmodel-perms) | `string` | `null` | no | 107 | | [serverless\_config](#input\_serverless\_config) | (Optional) Specifies configuration for how an endpoint performs serverless inference. Required keys are `max_concurrency` and `memory_size_in_mb` |
object({
max_concurrency = number,
memory_size_in_mb = number
})
|
{
"max_concurrency": null,
"memory_size_in_mb": null
}
| no | 108 | | [tags](#input\_tags) | A map of tags (key-value pairs) passed to resources. | `map(string)` | `{}` | no | 109 | | [tensorflow\_version](#input\_tensorflow\_version) | TensorFlow version you want to use for executing your inference code. Defaults to `None`. Required unless `pytorch_version` is provided. [List of supported versions](https://huggingface.co/docs/sagemaker/reference#inference-dlc-overview) | `string` | `null` | no | 110 | | [transformers\_version](#input\_transformers\_version) | Transformers version you want to use for executing your model training code. Defaults to None. [List of supported versions](https://huggingface.co/docs/sagemaker/reference#inference-dlc-overview) | `string` | n/a | yes | 111 | 112 | ## Outputs 113 | 114 | | Name | Description | 115 | |------|-------------| 116 | | [iam\_role](#output\_iam\_role) | IAM role used in the endpoint | 117 | | [sagemaker\_endpoint](#output\_sagemaker\_endpoint) | created Amazon SageMaker endpoint resource | 118 | | [sagemaker\_endpoint\_configuration](#output\_sagemaker\_endpoint\_configuration) | created Amazon SageMaker endpoint configuration resource | 119 | | [sagemaker\_endpoint\_name](#output\_sagemaker\_endpoint\_name) | Name of the created Amazon SageMaker endpoint, used for invoking the endpoint, with sdks | 120 | | [sagemaker\_model](#output\_sagemaker\_model) | created Amazon SageMaker model resource | 121 | | [tags](#output\_tags) | n/a | 122 | | [used\_container](#output\_used\_container) | Used container for creating the endpoint | 123 | 124 | ## License 125 | 126 | MIT License. See [LICENSE](LICENSE) for full details. 127 | 128 | 129 | -------------------------------------------------------------------------------- /examples/async_inference/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "4.2.0" 6 | constraints = "~> 4.0" 7 | hashes = [ 8 | "h1:N5oVH/WT+1U3/hfpqs2iQ6wkoK+1qrPbYZJ+6Ptx6a0=", 9 | "zh:297d6462055eac8eb5c6735bd1a0fec23574e27d56c4c14a39efd8f3931ce4ed", 10 | "zh:457319839adca3638fd76f49fd65e15756717f97ac99bd1805a1c9387a62a250", 11 | "zh:57377384fa28abc4211a0916fc0fb590af238d096ad0490434ffeb89f568df9b", 12 | "zh:578e1d21bd6d38bdaef0909b30959b884e84e6c464796a50e516822955db162a", 13 | "zh:5e7ff13cc976f609aee4ada3c1967ba1f0ce5d276f3102a0aeaedc586d25ea80", 14 | "zh:5e94f09fe1874a2365bd566fecab8f676cd720da1c0bf70875392679549ebf20", 15 | "zh:93da14d7ffb8550b161cb79fe2cfc0f66848dd5022974399ae2bf88da7b9e9c5", 16 | "zh:c51e4541f3d29627974dcb7f5919012a762391accb574ade9e28bdb3c92bada5", 17 | "zh:eff58c1680e3f29e514919346d937bbe47278434ae03ed62443c77e878e267b1", 18 | "zh:f2b749e6c6b77b26e643bbecc829977270cfefab106d5ea57e5a83e96d49cbdd", 19 | "zh:fcc17e60e55c278535c332469727cf215eaea9ec81d38e2b5f05be127ee39a5b", 20 | ] 21 | } 22 | 23 | provider "registry.terraform.io/hashicorp/random" { 24 | version = "3.1.0" 25 | hashes = [ 26 | "h1:9cCiLO/Cqr6IUvMDSApCkQItooiYNatZpEXmcu0nnng=", 27 | "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", 28 | "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", 29 | "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", 30 | "zh:7011332745ea061e517fe1319bd6c75054a314155cb2c1199a5b01fe1889a7e2", 31 | "zh:738ed82858317ccc246691c8b85995bc125ac3b4143043219bd0437adc56c992", 32 | "zh:7dbe52fac7bb21227acd7529b487511c91f4107db9cc4414f50d04ffc3cab427", 33 | "zh:a3a9251fb15f93e4cfc1789800fc2d7414bbc18944ad4c5c98f466e6477c42bc", 34 | "zh:a543ec1a3a8c20635cf374110bd2f87c07374cf2c50617eee2c669b3ceeeaa9f", 35 | "zh:d9ab41d556a48bd7059f0810cf020500635bfc696c9fc3adab5ea8915c1d886b", 36 | "zh:d9e13427a7d011dbd654e591b0337e6074eef8c3b9bb11b2e39eaaf257044fd7", 37 | "zh:f7605bd1437752114baf601bdf6931debe6dc6bfe3006eb7e9bb9080931dca8a", 38 | ] 39 | } 40 | -------------------------------------------------------------------------------- /examples/async_inference/README.md: -------------------------------------------------------------------------------- 1 | # Example Asynchronous Endpoint 2 | 3 | ```hcl 4 | # create bucket for async inference for inputs & outputs 5 | resource "aws_s3_bucket" "async_inference_bucket" { 6 | bucket = "async-inference-bucket" 7 | } 8 | 9 | 10 | module "huggingface_sagemaker" { 11 | source = "philschmid/sagemaker-huggingface/aws" 12 | version = "0.5.0" 13 | name_prefix = "deploy-hub" 14 | pytorch_version = "1.9.1" 15 | transformers_version = "4.12.3" 16 | instance_type = "ml.g4dn.xlarge" 17 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 18 | hf_task = "text-classification" 19 | async_config = { 20 | # needs to be a s3 uri 21 | s3_output_path = "s3://async-inference-bucket/async-distilbert" 22 | # (Optional) specify Amazon SNS topics, S3 failure path and custom KMS Key 23 | # s3_failure_path = "s3://async-inference-bucket/failed-inference" 24 | # kms_key_id = "string" 25 | # sns_error_topic = "arn:aws:sns:aws-region:account-id:topic-name" 26 | # sns_success_topic = "arn:aws:sns:aws-region:account-id:topic-name" 27 | } 28 | autoscaling = { 29 | min_capacity = 0 30 | max_capacity = 4 31 | scaling_target_invocations = 100 32 | } 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /examples/async_inference/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Asynchronous Endpoint 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | provider "aws" { 6 | region = "us-east-1" 7 | profile = "hf-sm" 8 | } 9 | 10 | # create bucket for async inference for inputs & outputs 11 | resource "aws_s3_bucket" "async_inference_bucket" { 12 | bucket = "async-inference-bucket" 13 | } 14 | 15 | 16 | module "huggingface_sagemaker" { 17 | source = "../../" 18 | name_prefix = "deploy-hub" 19 | pytorch_version = "1.9.1" 20 | transformers_version = "4.12.3" 21 | instance_type = "ml.g4dn.xlarge" 22 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 23 | hf_task = "text-classification" 24 | async_config = { 25 | # needs to be a s3 uri 26 | s3_output_path = "s3://async-inference-bucket/async-distilbert" 27 | # (Optional) specify Amazon SNS topics, S3 failure path and custom KMS Key 28 | # s3_failure_path = "s3://async-inference-bucket/failed-inference" 29 | # kms_key_id = "string" 30 | # sns_error_topic = "arn:aws:sns:aws-region:account-id:topic-name" 31 | # sns_success_topic = "arn:aws:sns:aws-region:account-id:topic-name" 32 | } 33 | autoscaling = { 34 | min_capacity = 0 35 | max_capacity = 4 36 | scaling_target_invocations = 100 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/autoscaling_example/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "3.74.0" 6 | hashes = [ 7 | "h1:wIP53ozevE0ihhP1Fuoir4N1qI7+TcBs0y4SHlwMxho=", 8 | "zh:00767509c13c0d1c7ad6af702c6942e6572aa6d529b40a00baacc0e73faafea2", 9 | "zh:03aafdc903ad49c2eda03889f927f44212674c50e475a9c6298850381319eec2", 10 | "zh:2de8a6a97b180f909d652f215125aa4683e99db15fcf3b28d62e3d542f875ed6", 11 | "zh:3ac29ebc3af99028f4230a79f56606a0c2954b68767bd749b921a76eb4f3bd30", 12 | "zh:50add2e2d118a15a644360eabc5a34cec59f2560b491f8fabf9c52ab83ca7b09", 13 | "zh:85dd8e81910ab79f841a4a595fdd8ac358fbfe460956144afb0be3d81f91fe10", 14 | "zh:895de83d0f0941fde31bfc53fa6b1ea276901f006bec221bbdee4771a04f3693", 15 | "zh:a15c9724aac52d1ba5001d2d83e42843099b52b1638ea29d84e20be0f45fa4f1", 16 | "zh:c982a64463bd73e9bff2589de214b1de0a571438d9015001f9eae45cfc3a2559", 17 | "zh:e9ef973c18078324e43213ea1252c12b9441e566bf054ddfdbff5dd62f3035d9", 18 | "zh:f297e705b0f339c8baa27ae70db5df9aa6578adfe1ea3d2ba8edc186512464eb", 19 | ] 20 | } 21 | 22 | provider "registry.terraform.io/hashicorp/random" { 23 | version = "3.1.0" 24 | hashes = [ 25 | "h1:9cCiLO/Cqr6IUvMDSApCkQItooiYNatZpEXmcu0nnng=", 26 | "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", 27 | "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", 28 | "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", 29 | "zh:7011332745ea061e517fe1319bd6c75054a314155cb2c1199a5b01fe1889a7e2", 30 | "zh:738ed82858317ccc246691c8b85995bc125ac3b4143043219bd0437adc56c992", 31 | "zh:7dbe52fac7bb21227acd7529b487511c91f4107db9cc4414f50d04ffc3cab427", 32 | "zh:a3a9251fb15f93e4cfc1789800fc2d7414bbc18944ad4c5c98f466e6477c42bc", 33 | "zh:a543ec1a3a8c20635cf374110bd2f87c07374cf2c50617eee2c669b3ceeeaa9f", 34 | "zh:d9ab41d556a48bd7059f0810cf020500635bfc696c9fc3adab5ea8915c1d886b", 35 | "zh:d9e13427a7d011dbd654e591b0337e6074eef8c3b9bb11b2e39eaaf257044fd7", 36 | "zh:f7605bd1437752114baf601bdf6931debe6dc6bfe3006eb7e9bb9080931dca8a", 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /examples/autoscaling_example/README.md: -------------------------------------------------------------------------------- 1 | # Example Autoscaling 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "philschmid/sagemaker-huggingface/aws" 6 | version = "0.5.0" 7 | name_prefix = "autoscaling" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 13 | hf_task = "text-classification" 14 | autoscaling = { 15 | min_capacity = 1 # The min capacity of the scalable target, default is 1 16 | max_capacity = 4 # The max capacity of the scalable target 17 | scaling_target_invocations = 200 # The scaling target invocations (requests/minute) 18 | scale_in_cooldown = 300 # The cooldown time after scale-in, default is 300 19 | scale_out_cooldown = 60 # The cooldown time after scale-out, default is 60 20 | } 21 | } 22 | ``` -------------------------------------------------------------------------------- /examples/autoscaling_example/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Deploy from HuggingFace Hub 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | # provider "aws" { 6 | # region = "us-east-1" 7 | # profile = "default" 8 | # } 9 | 10 | module "huggingface_sagemaker" { 11 | source = "../../" 12 | name_prefix = "autoscaling" 13 | pytorch_version = "1.9.1" 14 | transformers_version = "4.12.3" 15 | instance_type = "ml.g4dn.xlarge" 16 | instance_count = 1 # default is 1 17 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 18 | hf_task = "text-classification" 19 | autoscaling = { 20 | min_capacity = 1 # The min capacity of the scalable target, default is 1 21 | max_capacity = 4 # The max capacity of the scalable target 22 | scaling_target_invocations = 200 # The scaling target invocations (requests/minute) 23 | scale_in_cooldown = 300 # The cooldown time after scale-in, default is 300 24 | scale_out_cooldown = 60 # The cooldown time after scale-out, default is 60 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /examples/deploy_from_hub/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "3.74.0" 6 | hashes = [ 7 | "h1:wIP53ozevE0ihhP1Fuoir4N1qI7+TcBs0y4SHlwMxho=", 8 | "zh:00767509c13c0d1c7ad6af702c6942e6572aa6d529b40a00baacc0e73faafea2", 9 | "zh:03aafdc903ad49c2eda03889f927f44212674c50e475a9c6298850381319eec2", 10 | "zh:2de8a6a97b180f909d652f215125aa4683e99db15fcf3b28d62e3d542f875ed6", 11 | "zh:3ac29ebc3af99028f4230a79f56606a0c2954b68767bd749b921a76eb4f3bd30", 12 | "zh:50add2e2d118a15a644360eabc5a34cec59f2560b491f8fabf9c52ab83ca7b09", 13 | "zh:85dd8e81910ab79f841a4a595fdd8ac358fbfe460956144afb0be3d81f91fe10", 14 | "zh:895de83d0f0941fde31bfc53fa6b1ea276901f006bec221bbdee4771a04f3693", 15 | "zh:a15c9724aac52d1ba5001d2d83e42843099b52b1638ea29d84e20be0f45fa4f1", 16 | "zh:c982a64463bd73e9bff2589de214b1de0a571438d9015001f9eae45cfc3a2559", 17 | "zh:e9ef973c18078324e43213ea1252c12b9441e566bf054ddfdbff5dd62f3035d9", 18 | "zh:f297e705b0f339c8baa27ae70db5df9aa6578adfe1ea3d2ba8edc186512464eb", 19 | ] 20 | } 21 | 22 | provider "registry.terraform.io/hashicorp/random" { 23 | version = "3.1.0" 24 | hashes = [ 25 | "h1:9cCiLO/Cqr6IUvMDSApCkQItooiYNatZpEXmcu0nnng=", 26 | "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", 27 | "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", 28 | "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", 29 | "zh:7011332745ea061e517fe1319bd6c75054a314155cb2c1199a5b01fe1889a7e2", 30 | "zh:738ed82858317ccc246691c8b85995bc125ac3b4143043219bd0437adc56c992", 31 | "zh:7dbe52fac7bb21227acd7529b487511c91f4107db9cc4414f50d04ffc3cab427", 32 | "zh:a3a9251fb15f93e4cfc1789800fc2d7414bbc18944ad4c5c98f466e6477c42bc", 33 | "zh:a543ec1a3a8c20635cf374110bd2f87c07374cf2c50617eee2c669b3ceeeaa9f", 34 | "zh:d9ab41d556a48bd7059f0810cf020500635bfc696c9fc3adab5ea8915c1d886b", 35 | "zh:d9e13427a7d011dbd654e591b0337e6074eef8c3b9bb11b2e39eaaf257044fd7", 36 | "zh:f7605bd1437752114baf601bdf6931debe6dc6bfe3006eb7e9bb9080931dca8a", 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /examples/deploy_from_hub/README.md: -------------------------------------------------------------------------------- 1 | # Example: Deploy from Hugging Face Hub (hf.co/models) 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "philschmid/sagemaker-huggingface/aws" 6 | version = "0.5.0" 7 | name_prefix = "deploy-hub" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 13 | hf_task = "text-classification" 14 | } 15 | ``` -------------------------------------------------------------------------------- /examples/deploy_from_hub/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Deploy from HuggingFace Hub 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | # provider "aws" { 6 | # region = "us-east-1" 7 | # profile = "default" 8 | # } 9 | 10 | module "huggingface_sagemaker" { 11 | source = "../../" 12 | name_prefix = "deploy-hub" 13 | pytorch_version = "1.9.1" 14 | transformers_version = "4.12.3" 15 | instance_type = "ml.g4dn.xlarge" 16 | instance_count = 1 # default is 1 17 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 18 | hf_task = "text-classification" 19 | } 20 | -------------------------------------------------------------------------------- /examples/deploy_from_s3/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "3.74.0" 6 | hashes = [ 7 | "h1:wIP53ozevE0ihhP1Fuoir4N1qI7+TcBs0y4SHlwMxho=", 8 | "zh:00767509c13c0d1c7ad6af702c6942e6572aa6d529b40a00baacc0e73faafea2", 9 | "zh:03aafdc903ad49c2eda03889f927f44212674c50e475a9c6298850381319eec2", 10 | "zh:2de8a6a97b180f909d652f215125aa4683e99db15fcf3b28d62e3d542f875ed6", 11 | "zh:3ac29ebc3af99028f4230a79f56606a0c2954b68767bd749b921a76eb4f3bd30", 12 | "zh:50add2e2d118a15a644360eabc5a34cec59f2560b491f8fabf9c52ab83ca7b09", 13 | "zh:85dd8e81910ab79f841a4a595fdd8ac358fbfe460956144afb0be3d81f91fe10", 14 | "zh:895de83d0f0941fde31bfc53fa6b1ea276901f006bec221bbdee4771a04f3693", 15 | "zh:a15c9724aac52d1ba5001d2d83e42843099b52b1638ea29d84e20be0f45fa4f1", 16 | "zh:c982a64463bd73e9bff2589de214b1de0a571438d9015001f9eae45cfc3a2559", 17 | "zh:e9ef973c18078324e43213ea1252c12b9441e566bf054ddfdbff5dd62f3035d9", 18 | "zh:f297e705b0f339c8baa27ae70db5df9aa6578adfe1ea3d2ba8edc186512464eb", 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /examples/deploy_from_s3/README.md: -------------------------------------------------------------------------------- 1 | # Example: Deploy from Amazon S3 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "philschmid/sagemaker-huggingface/aws" 6 | version = "0.5.0" 7 | name_prefix = "deploy-s3" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | model_data = "s3://my-bucket/mypath/model.tar.gz" 13 | hf_task = "text-classification" 14 | } 15 | ``` -------------------------------------------------------------------------------- /examples/deploy_from_s3/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Deploy from Amazon s3 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | module "huggingface_sagemaker" { 6 | source = "../../" 7 | name_prefix = "deploy-s3" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | model_data = "s3://my-bucket/mypath/model.tar.gz" 13 | hf_task = "text-classification" 14 | } 15 | -------------------------------------------------------------------------------- /examples/deploy_private_model/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "3.74.0" 6 | hashes = [ 7 | "h1:wIP53ozevE0ihhP1Fuoir4N1qI7+TcBs0y4SHlwMxho=", 8 | "zh:00767509c13c0d1c7ad6af702c6942e6572aa6d529b40a00baacc0e73faafea2", 9 | "zh:03aafdc903ad49c2eda03889f927f44212674c50e475a9c6298850381319eec2", 10 | "zh:2de8a6a97b180f909d652f215125aa4683e99db15fcf3b28d62e3d542f875ed6", 11 | "zh:3ac29ebc3af99028f4230a79f56606a0c2954b68767bd749b921a76eb4f3bd30", 12 | "zh:50add2e2d118a15a644360eabc5a34cec59f2560b491f8fabf9c52ab83ca7b09", 13 | "zh:85dd8e81910ab79f841a4a595fdd8ac358fbfe460956144afb0be3d81f91fe10", 14 | "zh:895de83d0f0941fde31bfc53fa6b1ea276901f006bec221bbdee4771a04f3693", 15 | "zh:a15c9724aac52d1ba5001d2d83e42843099b52b1638ea29d84e20be0f45fa4f1", 16 | "zh:c982a64463bd73e9bff2589de214b1de0a571438d9015001f9eae45cfc3a2559", 17 | "zh:e9ef973c18078324e43213ea1252c12b9441e566bf054ddfdbff5dd62f3035d9", 18 | "zh:f297e705b0f339c8baa27ae70db5df9aa6578adfe1ea3d2ba8edc186512464eb", 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /examples/deploy_private_model/README.md: -------------------------------------------------------------------------------- 1 | # Example: Deploy private model from Hugging Face Hub (hf.co/models) 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "philschmid/sagemaker-huggingface/aws" 6 | version = "0.5.0" 7 | name_prefix = "deploy-private-hub" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 13 | hf_task = "text-classification" 14 | hf_api_token = "hf_Xxxx" 15 | } 16 | ``` -------------------------------------------------------------------------------- /examples/deploy_private_model/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Deploy from HuggingFace Hub 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | # Configure the AWS Provider 6 | # provider "aws" { 7 | # region = "us-east-1" 8 | # profile = "default" 9 | # } 10 | 11 | 12 | module "huggingface_sagemaker" { 13 | source = "../../" 14 | name_prefix = "deploy-private-hub" 15 | pytorch_version = "1.9.1" 16 | transformers_version = "4.12.3" 17 | instance_type = "ml.g4dn.xlarge" 18 | instance_count = 1 # default is 1 19 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 20 | hf_task = "text-classification" 21 | hf_api_token = "hf_Xxxx" 22 | } 23 | -------------------------------------------------------------------------------- /examples/serverless_inference/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "4.37.0" 6 | constraints = "~> 4.0" 7 | hashes = [ 8 | "h1:fLTymOb7xIdMkjQU1VDzPA5s+d2vNLZ2shpcFPF7KaY=", 9 | "zh:12c2eb60cb1eb0a41d1afbca6fc6f0eed6ca31a12c51858f951a9e71651afbe0", 10 | "zh:1e17482217c39a12e930e71fd2c9af8af577bec6736b184674476ebcaad28477", 11 | "zh:1e8163c3d871bbd54c189bf2fe5e60e556d67fa399e4c88c8e6ee0834525dc33", 12 | "zh:399c41a3e096fd75d487b98b1791f7cea5bd38567ac4e621c930cb67ec45977c", 13 | "zh:40d4329eef2cc130e4cbed7a6345cb053dd258bf6f5f8eb0f8ce777ae42d5a01", 14 | "zh:625db5fa75638d543b418be7d8046c4b76dc753d9d2184daa0faaaaebc02d207", 15 | "zh:7785c8259f12b45d19fa5abdac6268f3b749fe5a35c8be762c27b7a634a4952b", 16 | "zh:8a7611f33cc6422799c217ec2eeb79c779035ef05331d12505a6002bc48582f0", 17 | "zh:9188178235a73c829872d2e82d88ac6d334d8bb01433e9be31615f1c1633e921", 18 | "zh:994895b57bf225232a5fa7422e6ab87d8163a2f0605f54ff6a18cdd71f0aeadf", 19 | "zh:9b12af85486a96aedd8d7984b0ff811a4b42e3d88dad1a3fb4c0b580d04fa425", 20 | "zh:b57de6903ef30c9f22d38d595d64b4f92a89ea717b65782e1f44f57020ce8b1f", 21 | ] 22 | } 23 | 24 | provider "registry.terraform.io/hashicorp/random" { 25 | version = "3.4.3" 26 | hashes = [ 27 | "h1:tL3katm68lX+4lAncjQA9AXL4GR/VM+RPwqYf4D2X8Q=", 28 | "zh:41c53ba47085d8261590990f8633c8906696fa0a3c4b384ff6a7ecbf84339752", 29 | "zh:59d98081c4475f2ad77d881c4412c5129c56214892f490adf11c7e7a5a47de9b", 30 | "zh:686ad1ee40b812b9e016317e7f34c0d63ef837e084dea4a1f578f64a6314ad53", 31 | "zh:78d5eefdd9e494defcb3c68d282b8f96630502cac21d1ea161f53cfe9bb483b3", 32 | "zh:84103eae7251384c0d995f5a257c72b0096605048f757b749b7b62107a5dccb3", 33 | "zh:8ee974b110adb78c7cd18aae82b2729e5124d8f115d484215fd5199451053de5", 34 | "zh:9dd4561e3c847e45de603f17fa0c01ae14cae8c4b7b4e6423c9ef3904b308dda", 35 | "zh:bb07bb3c2c0296beba0beec629ebc6474c70732387477a65966483b5efabdbc6", 36 | "zh:e891339e96c9e5a888727b45b2e1bb3fcbdfe0fd7c5b4396e4695459b38c8cb1", 37 | "zh:ea4739860c24dfeaac6c100b2a2e357106a89d18751f7693f3c31ecf6a996f8d", 38 | "zh:f0c76ac303fd0ab59146c39bc121c5d7d86f878e9a69294e29444d4c653786f8", 39 | "zh:f143a9a5af42b38fed328a161279906759ff39ac428ebcfe55606e05e1518b93", 40 | ] 41 | } 42 | -------------------------------------------------------------------------------- /examples/serverless_inference/README.md: -------------------------------------------------------------------------------- 1 | # Example Serverless Endpoint 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "../../" 6 | name_prefix = "deploy-hub" 7 | pytorch_version = "1.9.1" 8 | transformers_version = "4.12.3" 9 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 10 | hf_task = "text-classification" 11 | serverless_config = { 12 | max_concurrency = 1 13 | memory_size_in_mb = 1024 14 | } 15 | } 16 | ``` -------------------------------------------------------------------------------- /examples/serverless_inference/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Serverless Endpoint 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | module "huggingface_sagemaker" { 6 | source = "../../" 7 | name_prefix = "deploy-hub" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 11 | hf_task = "text-classification" 12 | serverless_config = { 13 | max_concurrency = 1 14 | memory_size_in_mb = 1024 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /examples/tensorflow_example/README.md: -------------------------------------------------------------------------------- 1 | # Example: Deploy from Hugging Face Hub (hf.co/models) 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "philschmid/sagemaker-huggingface/aws" 6 | version = "0.5.0" 7 | name_prefix = "deploy-hub" 8 | tensorflow_version = "2.5.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 13 | hf_task = "text-classification" 14 | } 15 | ``` -------------------------------------------------------------------------------- /examples/tensorflow_example/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Deploy Tensorflow model 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | module "huggingface_sagemaker" { 6 | source = "../../" 7 | name_prefix = "test" 8 | tensorflow_version = "2.5.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 13 | hf_task = "text-classification" 14 | } -------------------------------------------------------------------------------- /examples/use_existing_iam_role/README.md: -------------------------------------------------------------------------------- 1 | # Example: Use existing IAM Role 2 | 3 | ```hcl 4 | module "huggingface_sagemaker" { 5 | source = "philschmid/sagemaker-huggingface/aws" 6 | version = "0.5.0" 7 | name_prefix = "deploy-hub" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 12 | hf_task = "text-classification" 13 | sagemaker_execution_role = "sagemaker_execution_role" 14 | ``` -------------------------------------------------------------------------------- /examples/use_existing_iam_role/main.tf: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------- 2 | # Example Deploy model with existing iam role 3 | # --------------------------------------------------------------------------------------------------------------------- 4 | 5 | module "huggingface_sagemaker" { 6 | source = "../../" 7 | name_prefix = "with-iam-role" 8 | pytorch_version = "1.9.1" 9 | transformers_version = "4.12.3" 10 | instance_type = "ml.g4dn.xlarge" 11 | instance_count = 1 # default is 1 12 | hf_model_id = "distilbert-base-uncased-finetuned-sst-2-english" 13 | hf_task = "text-classification" 14 | sagemaker_execution_role = "sagemaker_execution_role" 15 | } 16 | -------------------------------------------------------------------------------- /main.tf: -------------------------------------------------------------------------------- 1 | 2 | 3 | # ------------------------------------------------------------------------------ 4 | # Local configurations 5 | # ------------------------------------------------------------------------------ 6 | 7 | locals { 8 | framework_version = var.pytorch_version != null ? var.pytorch_version : var.tensorflow_version 9 | repository_name = var.pytorch_version != null ? "huggingface-pytorch-inference" : "huggingface-tensorflow-inference" 10 | device = length(regexall("^ml\\.[g|p{1,3}\\.$]", var.instance_type)) > 0 ? "gpu" : "cpu" 11 | image_key = "${local.framework_version}-${local.device}" 12 | pytorch_image_tag = { 13 | "1.7.1-gpu" = "1.7.1-transformers${var.transformers_version}-gpu-py36-cu110-ubuntu18.04" 14 | "1.7.1-cpu" = "1.7.1-transformers${var.transformers_version}-cpu-py36-ubuntu18.04" 15 | "1.8.1-gpu" = "1.8.1-transformers${var.transformers_version}-gpu-py36-cu111-ubuntu18.04" 16 | "1.8.1-cpu" = "1.8.1-transformers${var.transformers_version}-cpu-py36-ubuntu18.04" 17 | "1.9.1-gpu" = "1.9.1-transformers${var.transformers_version}-gpu-py38-cu111-ubuntu20.04" 18 | "1.9.1-cpu" = "1.9.1-transformers${var.transformers_version}-cpu-py38-ubuntu20.04" 19 | "1.10.2-cpu" = "1.10.2-transformers${var.transformers_version}-cpu-py38-ubuntu20.04" 20 | "1.10.2-gpu" = "1.10.2-transformers${var.transformers_version}-gpu-py38-cu113-ubuntu20.04" 21 | "1.13.1-cpu" = "1.13.1-transformers${var.transformers_version}-cpu-py39-ubuntu20.04" 22 | "1.13.1-gpu" = "1.13.1-transformers${var.transformers_version}-gpu-py39-cu117-ubuntu20.04" 23 | "2.0.0-cpu" = "2.0.0-transformers${var.transformers_version}-cpu-py310-ubuntu20.04" 24 | "2.0.0-gpu" = "2.0.0-transformers${var.transformers_version}-gpu-py310-cu118-ubuntu20.04" 25 | } 26 | tensorflow_image_tag = { 27 | "2.4.1-gpu" = "2.4.1-transformers${var.transformers_version}-gpu-py37-cu110-ubuntu18.04" 28 | "2.4.1-cpu" = "2.4.1-transformers${var.transformers_version}-cpu-py37-ubuntu18.04" 29 | "2.5.1-gpu" = "2.5.1-transformers${var.transformers_version}-gpu-py36-cu111-ubuntu18.04" 30 | "2.5.1-cpu" = "2.5.1-transformers${var.transformers_version}-cpu-py36-ubuntu18.04" 31 | } 32 | sagemaker_endpoint_type = { 33 | real_time = (var.async_config.s3_output_path == null && var.serverless_config.max_concurrency == null) ? true : false 34 | asynchronous = (var.async_config.s3_output_path != null && var.serverless_config.max_concurrency == null) ? true : false 35 | serverless = (var.async_config.s3_output_path == null && var.serverless_config.max_concurrency != null) ? true : false 36 | } 37 | } 38 | 39 | # random lowercase string used for naming 40 | resource "random_string" "resource_id" { 41 | length = 8 42 | lower = true 43 | special = false 44 | upper = false 45 | numeric = false 46 | } 47 | 48 | # ------------------------------------------------------------------------------ 49 | # Container Image 50 | # ------------------------------------------------------------------------------ 51 | 52 | 53 | data "aws_sagemaker_prebuilt_ecr_image" "deploy_image" { 54 | repository_name = local.repository_name 55 | image_tag = var.pytorch_version != null ? local.pytorch_image_tag[local.image_key] : local.tensorflow_image_tag[local.image_key] 56 | } 57 | 58 | # ------------------------------------------------------------------------------ 59 | # Permission 60 | # ------------------------------------------------------------------------------ 61 | 62 | resource "aws_iam_role" "new_role" { 63 | count = var.sagemaker_execution_role == null ? 1 : 0 # Creates IAM role if not provided 64 | name = "${var.name_prefix}-sagemaker-execution-role-${random_string.resource_id.result}" 65 | assume_role_policy = jsonencode({ 66 | Version = "2012-10-17" 67 | Statement = [ 68 | { 69 | Action = "sts:AssumeRole" 70 | Effect = "Allow" 71 | Principal = { 72 | Service = "sagemaker.amazonaws.com" 73 | } 74 | }, 75 | ] 76 | }) 77 | 78 | inline_policy { 79 | name = "terraform-inferences-policy" 80 | policy = jsonencode({ 81 | Version = "2012-10-17" 82 | Statement = [ 83 | { 84 | Effect = "Allow", 85 | Action = [ 86 | "cloudwatch:PutMetricData", 87 | "logs:CreateLogStream", 88 | "logs:PutLogEvents", 89 | "logs:CreateLogGroup", 90 | "logs:DescribeLogStreams", 91 | "s3:GetObject", 92 | "s3:PutObject", 93 | "s3:ListBucket", 94 | "ecr:GetAuthorizationToken", 95 | "ecr:BatchCheckLayerAvailability", 96 | "ecr:GetDownloadUrlForLayer", 97 | "ecr:BatchGetImage" 98 | ], 99 | Resource = "*" 100 | } 101 | ] 102 | }) 103 | 104 | } 105 | 106 | tags = var.tags 107 | } 108 | 109 | data "aws_iam_role" "get_role" { 110 | count = var.sagemaker_execution_role != null ? 1 : 0 # Creates IAM role if not provided 111 | name = var.sagemaker_execution_role 112 | } 113 | 114 | locals { 115 | role_arn = var.sagemaker_execution_role != null ? data.aws_iam_role.get_role[0].arn : aws_iam_role.new_role[0].arn 116 | model_slug = var.model_data != null ? "-${replace(reverse(split("/", replace(var.model_data, ".tar.gz", "")))[0], ".", "-")}" : "" 117 | } 118 | 119 | # ------------------------------------------------------------------------------ 120 | # SageMaker Model 121 | # ------------------------------------------------------------------------------ 122 | 123 | resource "aws_sagemaker_model" "model_with_model_artifact" { 124 | count = var.model_data != null && var.hf_model_id == null ? 1 : 0 125 | name = "${var.name_prefix}-model-${random_string.resource_id.result}${local.model_slug}" 126 | execution_role_arn = local.role_arn 127 | tags = var.tags 128 | 129 | primary_container { 130 | # CPU Image 131 | image = data.aws_sagemaker_prebuilt_ecr_image.deploy_image.registry_path 132 | model_data_url = var.model_data 133 | environment = { 134 | HF_TASK = var.hf_task 135 | } 136 | } 137 | 138 | lifecycle { 139 | create_before_destroy = true 140 | } 141 | } 142 | 143 | 144 | resource "aws_sagemaker_model" "model_with_hub_model" { 145 | count = var.model_data == null && var.hf_model_id != null ? 1 : 0 146 | name = "${var.name_prefix}-model-${random_string.resource_id.result}${local.model_slug}" 147 | execution_role_arn = local.role_arn 148 | tags = var.tags 149 | 150 | primary_container { 151 | image = data.aws_sagemaker_prebuilt_ecr_image.deploy_image.registry_path 152 | environment = { 153 | HF_TASK = var.hf_task 154 | HF_MODEL_ID = var.hf_model_id 155 | HF_API_TOKEN = var.hf_api_token 156 | HF_MODEL_REVISION = var.hf_model_revision 157 | } 158 | } 159 | 160 | lifecycle { 161 | create_before_destroy = true 162 | } 163 | } 164 | 165 | locals { 166 | sagemaker_model = var.model_data != null && var.hf_model_id == null ? aws_sagemaker_model.model_with_model_artifact[0] : aws_sagemaker_model.model_with_hub_model[0] 167 | } 168 | 169 | # ------------------------------------------------------------------------------ 170 | # SageMaker Endpoint configuration 171 | # ------------------------------------------------------------------------------ 172 | 173 | resource "aws_sagemaker_endpoint_configuration" "huggingface" { 174 | count = local.sagemaker_endpoint_type.real_time ? 1 : 0 175 | name = "${var.name_prefix}-ep-config-${random_string.resource_id.result}" 176 | tags = var.tags 177 | 178 | 179 | production_variants { 180 | variant_name = "AllTraffic" 181 | model_name = local.sagemaker_model.name 182 | initial_instance_count = var.instance_count 183 | instance_type = var.instance_type 184 | } 185 | } 186 | 187 | 188 | resource "aws_sagemaker_endpoint_configuration" "huggingface_async" { 189 | count = local.sagemaker_endpoint_type.asynchronous ? 1 : 0 190 | name = "${var.name_prefix}-ep-config-${random_string.resource_id.result}" 191 | tags = var.tags 192 | 193 | 194 | production_variants { 195 | variant_name = "AllTraffic" 196 | model_name = local.sagemaker_model.name 197 | initial_instance_count = var.instance_count 198 | instance_type = var.instance_type 199 | } 200 | async_inference_config { 201 | output_config { 202 | s3_output_path = var.async_config.s3_output_path 203 | s3_failure_path = var.async_config.s3_failure_path 204 | kms_key_id = var.async_config.kms_key_id 205 | notification_config { 206 | error_topic = var.async_config.sns_error_topic 207 | success_topic = var.async_config.sns_success_topic 208 | } 209 | } 210 | } 211 | } 212 | 213 | 214 | resource "aws_sagemaker_endpoint_configuration" "huggingface_serverless" { 215 | count = local.sagemaker_endpoint_type.serverless ? 1 : 0 216 | name = "${var.name_prefix}-ep-config-${random_string.resource_id.result}" 217 | tags = var.tags 218 | 219 | 220 | production_variants { 221 | variant_name = "AllTraffic" 222 | model_name = local.sagemaker_model.name 223 | 224 | serverless_config { 225 | max_concurrency = var.serverless_config.max_concurrency 226 | memory_size_in_mb = var.serverless_config.memory_size_in_mb 227 | } 228 | } 229 | } 230 | 231 | 232 | locals { 233 | sagemaker_endpoint_config = ( 234 | local.sagemaker_endpoint_type.real_time ? 235 | aws_sagemaker_endpoint_configuration.huggingface[0] : ( 236 | local.sagemaker_endpoint_type.asynchronous ? 237 | aws_sagemaker_endpoint_configuration.huggingface_async[0] : ( 238 | local.sagemaker_endpoint_type.serverless ? 239 | aws_sagemaker_endpoint_configuration.huggingface_serverless[0] : null 240 | ) 241 | ) 242 | ) 243 | } 244 | 245 | # ------------------------------------------------------------------------------ 246 | # SageMaker Endpoint 247 | # ------------------------------------------------------------------------------ 248 | 249 | 250 | resource "aws_sagemaker_endpoint" "huggingface" { 251 | name = "${var.name_prefix}-ep-${random_string.resource_id.result}" 252 | tags = var.tags 253 | 254 | endpoint_config_name = local.sagemaker_endpoint_config.name 255 | } 256 | 257 | # ------------------------------------------------------------------------------ 258 | # AutoScaling configuration 259 | # ------------------------------------------------------------------------------ 260 | 261 | 262 | locals { 263 | use_autoscaling = var.autoscaling.max_capacity != null && var.autoscaling.scaling_target_invocations != null && !local.sagemaker_endpoint_type.serverless ? 1 : 0 264 | } 265 | 266 | resource "aws_appautoscaling_target" "sagemaker_target" { 267 | count = local.use_autoscaling 268 | min_capacity = var.autoscaling.min_capacity 269 | max_capacity = var.autoscaling.max_capacity 270 | resource_id = "endpoint/${aws_sagemaker_endpoint.huggingface.name}/variant/AllTraffic" 271 | scalable_dimension = "sagemaker:variant:DesiredInstanceCount" 272 | service_namespace = "sagemaker" 273 | } 274 | 275 | resource "aws_appautoscaling_policy" "sagemaker_policy" { 276 | count = local.use_autoscaling 277 | name = "${var.name_prefix}-scaling-target-${random_string.resource_id.result}" 278 | policy_type = "TargetTrackingScaling" 279 | resource_id = aws_appautoscaling_target.sagemaker_target[0].resource_id 280 | scalable_dimension = aws_appautoscaling_target.sagemaker_target[0].scalable_dimension 281 | service_namespace = aws_appautoscaling_target.sagemaker_target[0].service_namespace 282 | 283 | target_tracking_scaling_policy_configuration { 284 | predefined_metric_specification { 285 | predefined_metric_type = "SageMakerVariantInvocationsPerInstance" 286 | } 287 | target_value = var.autoscaling.scaling_target_invocations 288 | scale_in_cooldown = var.autoscaling.scale_in_cooldown 289 | scale_out_cooldown = var.autoscaling.scale_out_cooldown 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /outputs.tf: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Output 3 | # ------------------------------------------------------------------------------ 4 | output "used_container" { 5 | description = "Used container for creating the endpoint" 6 | value = data.aws_sagemaker_prebuilt_ecr_image.deploy_image.registry_path 7 | } 8 | 9 | output "iam_role" { 10 | description = "IAM role used in the endpoint" 11 | value = local.role_arn 12 | } 13 | 14 | output "sagemaker_model" { 15 | description = "created Amazon SageMaker model resource" 16 | value = local.sagemaker_model 17 | } 18 | 19 | output "sagemaker_endpoint_configuration" { 20 | description = "created Amazon SageMaker endpoint configuration resource" 21 | value = aws_sagemaker_endpoint_configuration.huggingface 22 | } 23 | 24 | output "sagemaker_endpoint" { 25 | description = "created Amazon SageMaker endpoint resource" 26 | value = aws_sagemaker_endpoint.huggingface 27 | } 28 | 29 | output "sagemaker_endpoint_name" { 30 | description = "Name of the created Amazon SageMaker endpoint, used for invoking the endpoint, with sdks" 31 | value = aws_sagemaker_endpoint.huggingface.name 32 | } 33 | 34 | 35 | output "tags" { 36 | value = var.tags 37 | } 38 | -------------------------------------------------------------------------------- /variables.tf: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Variables 3 | # ------------------------------------------------------------------------------ 4 | variable "name_prefix" { 5 | description = "A prefix used for naming resources." 6 | type = string 7 | } 8 | 9 | variable "transformers_version" { 10 | description = "Transformers version you want to use for executing your model training code. Defaults to None. [List of supported versions](https://huggingface.co/docs/sagemaker/reference#inference-dlc-overview)" 11 | type = string 12 | } 13 | 14 | variable "pytorch_version" { 15 | description = "PyTorch version you want to use for executing your inference code. Defaults to `None`. Required unless `tensorflow_version` is provided. [List of supported versions](https://huggingface.co/docs/sagemaker/reference#inference-dlc-overview)" 16 | type = string 17 | default = null 18 | } 19 | 20 | variable "tensorflow_version" { 21 | description = "TensorFlow version you want to use for executing your inference code. Defaults to `None`. Required unless `pytorch_version` is provided. [List of supported versions](https://huggingface.co/docs/sagemaker/reference#inference-dlc-overview)" 22 | type = string 23 | default = null 24 | } 25 | 26 | variable "image_tag" { 27 | description = "The image tag you want to use for the container you want to use. Defaults to `None`. The module tries to derive the `image_tag` from the `pytorch_version`, `tensorflow_version` & `instance_type`. If you want to override this, you can provide the `image_tag` as a variable." 28 | type = string 29 | default = null 30 | } 31 | 32 | variable "instance_type" { 33 | description = "The EC2 instance type to deploy this Model to. For example, `ml.p2.xlarge`." 34 | type = string 35 | default = null 36 | } 37 | 38 | variable "instance_count" { 39 | description = "The initial number of instances to run in the Endpoint created from this Model. Defaults to 1." 40 | type = number 41 | default = 1 42 | } 43 | 44 | 45 | variable "hf_model_id" { 46 | description = "The HF_MODEL_ID environment variable defines the model id, which will be automatically loaded from [hf.co/models](https://huggingface.co/models) when creating or SageMaker Endpoint." 47 | type = string 48 | default = null 49 | } 50 | 51 | variable "hf_task" { 52 | description = "The HF_TASK environment variable defines the task for the used 🤗 Transformers pipeline. A full list of tasks can be find [here](https://huggingface.co/transformers/main_classes/pipelines.html)." 53 | type = string 54 | } 55 | 56 | variable "hf_api_token" { 57 | description = "The HF_API_TOKEN environment variable defines the your Hugging Face authorization token. The HF_API_TOKEN is used as a HTTP bearer authorization for remote files, like private models. You can find your token at your settings page." 58 | type = string 59 | default = null 60 | } 61 | 62 | variable "hf_model_revision" { 63 | description = "The HF_MODEL_REVISION is an extension to HF_MODEL_ID and allows you to define/pin a revision of the model to make sure you always load the same model on your SageMaker Endpoint." 64 | type = string 65 | default = null 66 | } 67 | 68 | 69 | variable "model_data" { 70 | description = "The S3 location of a SageMaker model data .tar.gz file (default: None). Not needed when using `hf_model_id`." 71 | type = string 72 | default = null 73 | 74 | } 75 | 76 | variable "sagemaker_execution_role" { 77 | description = "An AWS IAM role Name to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role if it needs to access some AWS resources. If not specified, the role will created with with the `CreateModel` permissions from the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createmodel-perms)" 78 | type = string 79 | default = null 80 | } 81 | 82 | variable "autoscaling" { 83 | description = "A Object which defines the autoscaling target and policy for our SageMaker Endpoint. Required keys are `max_capacity` and `scaling_target_invocations` " 84 | type = object({ 85 | min_capacity = optional(number), 86 | max_capacity = number, 87 | scaling_target_invocations = optional(number), 88 | scale_in_cooldown = optional(number), 89 | scale_out_cooldown = optional(number), 90 | }) 91 | 92 | default = { 93 | min_capacity = 1 94 | max_capacity = null 95 | scaling_target_invocations = null 96 | scale_in_cooldown = 300 97 | scale_out_cooldown = 66 98 | } 99 | } 100 | 101 | variable "async_config" { 102 | description = "(Optional) Specifies configuration for how an endpoint performs asynchronous inference. Required key is `s3_output_path`, which is the s3 bucket used for async inference." 103 | type = object({ 104 | s3_output_path = string, 105 | s3_failure_path = optional(string), 106 | kms_key_id = optional(string), 107 | sns_error_topic = optional(string), 108 | sns_success_topic = optional(string), 109 | }) 110 | 111 | default = { 112 | s3_output_path = null 113 | kms_key_id = null 114 | sns_error_topic = null 115 | sns_success_topic = null 116 | } 117 | } 118 | 119 | variable "serverless_config" { 120 | description = "(Optional) Specifies configuration for how an endpoint performs serverless inference. Required keys are `max_concurrency` and `memory_size_in_mb`" 121 | type = object({ 122 | max_concurrency = number, 123 | memory_size_in_mb = number 124 | }) 125 | 126 | default = { 127 | max_concurrency = null 128 | memory_size_in_mb = null 129 | } 130 | } 131 | 132 | 133 | variable "tags" { 134 | description = "A map of tags (key-value pairs) passed to resources." 135 | type = map(string) 136 | default = {} 137 | } 138 | -------------------------------------------------------------------------------- /versions.tf: -------------------------------------------------------------------------------- 1 | terraform { 2 | required_providers { 3 | aws = { 4 | source = "hashicorp/aws" 5 | version = "~> 5.0" 6 | } 7 | } 8 | required_version = ">= 1.3.0" 9 | } 10 | --------------------------------------------------------------------------------