├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── app.py ├── cdk.json ├── code ├── lambda_txt2img │ └── txt2img.py └── lambda_txt2nlu │ └── txt2nlu.py ├── construct └── sagemaker_endpoint_construct.py ├── images ├── architecture.png ├── cdk-deploy.png ├── cdk-stacks.png ├── console-apigw.png ├── console-cloudformation.png ├── console-ecs.png ├── console-lambda.png ├── console-sagemaker.png ├── console-ssm-parameter-store.png ├── foundation-models.png ├── streamlit-image-gen-01.png ├── streamlit-image-gen-02.png ├── streamlit-landing-page.png ├── streamlit-text-gen-01.png ├── streamlit-text-gen-02.png └── streamlit-text-gen-03.png ├── requirements-dev.txt ├── requirements.txt ├── script └── sagemaker_uri.py ├── source.bat ├── stack ├── __init__.py ├── generative_ai_demo_web_stack.py ├── generative_ai_txt2img_sagemaker_stack.py ├── generative_ai_txt2nlu_sagemaker_stack.py └── generative_ai_vpc_network_stack.py ├── tests ├── __init__.py └── unit │ ├── __init__.py │ └── test_generative_ai_sagemaker_cdk_demo_stack.py └── web-app ├── Dockerfile ├── Home.py ├── configs.py ├── img └── sagemaker.png ├── pages ├── 2_Image_Generation.py └── 3_Text_Generation.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | package-lock.json 3 | __pycache__ 4 | .pytest_cache 5 | .venv 6 | *.egg-info 7 | 8 | # CDK asset staging directory 9 | .cdk.staging 10 | cdk.out 11 | 12 | .DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deploy generative AI models from Amazon SageMaker JumpStart using the AWS CDK 2 | 3 | The seeds of a machine learning (ML) paradigm shift have existed for decades, but with the ready availability of virtually infinite compute capacity, a massive proliferation of data, and the rapid advancement of ML technologies, customers across industries are rapidly adopting and using ML technologies to transform their businesses. 4 | 5 | Just recently, generative AI applications have captured everyone's attention and imagination. We are truly at an exciting inflection point in the widespread adoption of ML, and we believe every customer experience and application will be reinvented with generative AI. 6 | 7 | Generative AI is a type of AI that can create new content and ideas, including conversations, stories, images, videos, and music. Like all AI, generative AI is powered by ML models—very large models that are pre-trained on vast corpora of data and commonly referred to as foundation models (FMs). 8 | 9 | The size and general-purpose nature of FMs make them different from traditional ML models, which typically perform specific tasks, like analyzing text for sentiment, classifying images, and forecasting trends. 10 | 11 | ![foundation-models](./images/foundation-models.png) 12 | 13 | With tradition ML models, in order to achieve each specific task, you need to gather labeled data, train a model, and deploy that model. With foundation models, instead of gathering labeled data for each model and training multiple models, you can use the same pre-trained FM to adapt various tasks. You can also customize FMs to perform domain-specific functions that are differentiating to your businesses, using only a small fraction of the data and compute required to train a model from scratch. 14 | 15 | 16 | 17 | Generative AI has the potential to disrupt many industries by revolutionizing the way content is created and consumed. Original content production, code generation, customer service enhancement, and document summarization are typical use cases of generative AI. 18 | 19 | 20 | 21 | [Amazon SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html) provides pre-trained, open-source models for a wide range of problem types to help you get started with ML. You can incrementally train and tune these models before deployment. JumpStart also provides solution templates that set up infrastructure for common use cases, and executable example notebooks for ML with [Amazon SageMaker](https://aws.amazon.com/sagemaker/). 22 | 23 | 24 | 25 | With over 600 pre-trained models available and growing every day, JumpStart enables developers to quickly and easily incorporate cutting-edge ML techniques into their production workflows. You can access the pre-trained models, solution templates, and examples through the JumpStart landing page in [Amazon SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/studio.html). You can also access JumpStart models using the SageMaker Python SDK. For information about how to use JumpStart models programmatically, see [Use SageMaker JumpStart Algorithms with Pretrained Models](https://sagemaker.readthedocs.io/en/stable/overview.html#use-sagemaker-jumpstart-algorithms-with-pretrained-models). 26 | 27 | 28 | 29 | In April 2023, AWS unveiled [Amazon Bedrock](https://aws.amazon.com/bedrock/), which provides a way to build generative AI-powered apps via pre-trained models from startups including [AI21 Labs](https://www.ai21.com/), [Anthropic](https://techcrunch.com/2023/02/27/anthropic-begins-supplying-its-text-generating-ai-models-to-startups/), and [Stability AI](https://techcrunch.com/2022/10/17/stability-ai-the-startup-behind-stable-diffusion-raises-101m/). Amazon Bedrock also offers access to Titan foundation models, a family of models trained in-house by AWS. With the serverless experience of Amazon Bedrock, you can easily find the right model for your needs, get started quickly, privately customize FMs with your own data, and easily integrate and deploy them into your applications using the AWS tools and capabilities you're familiar with (including integrations with SageMaker ML features like [Amazon SageMaker Experiments](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments.html) to test different models and [Amazon SageMaker Pipelines](https://aws.amazon.com/sagemaker/pipelines/) to manage your FMs at scale) without having to manage any infrastructure. 30 | 31 | 32 | 33 | In this post, we show how to deploy image and text generative AI models from JumpStart using the [AWS Cloud Development Kit](https://aws.amazon.com/cdk/) (AWS CDK). The AWS CDK is an open-source software development framework to define your cloud application resources using familiar programming languages like Python. 34 | 35 | 36 | 37 | We use the Stable Diffusion model for image generation and the FLAN-T5-XL model for [natural language understanding (NLU)](https://en.wikipedia.org/wiki/Natural-language_understanding) and text generation from [Hugging Face](https://huggingface.co/) in JumpStart. 38 | 39 | 40 | 41 | ## Solution overview 42 | 43 | The web application is built on [Streamlit](https://streamlit.io/), an open-source Python library that makes it easy to create and share beautiful, custom web apps for ML and data science. We host the web application using [Amazon Elastic Container Service](https://aws.amazon.com/ecs) (Amazon ECS) with [AWS Fargate](https://docs.aws.amazon.com/AmazonECS/latest/userguide/what-is-fargate.html) and it is accessed via an Application Load Balancer. Fargate is a technology that you can use with Amazon ECS to run [containers](https://aws.amazon.com/what-are-containers) without having to manage servers or clusters or virtual machines. The generative AI model endpoints are launched from JumpStart images in [Amazon Elastic Container Registry](https://aws.amazon.com/ecr/) (Amazon ECR). Model data is stored on [Amazon Simple Storage Service](https://aws.amazon.com/s3/) (Amazon S3) in the JumpStart account. The web application interacts with the models via [Amazon API Gateway](https://aws.amazon.com/api-gateway) and [AWS Lambda](http://aws.amazon.com/lambda) functions as shown in the following diagram. 44 | 45 | ![architecture](./images/architecture.png) 46 | 47 | API Gateway provides the web application and other clients a standard RESTful interface, while shielding the Lambda functions that interface with the model. This simplifies the client application code that consumes the models. The API Gateway endpoints are publicly accessible in this example, allowing for the possibility to extend this architecture to implement different [API access controls](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-control-access-to-api.html) and integrate with other applications. 48 | 49 | 50 | 51 | In this post, we walk you through the following steps: 52 | 53 | 1. Install the [AWS Command Line Interface](http://aws.amazon.com/cli) (AWS CLI) and [AWS CDK v2](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html) on your local machine. 54 | 2. Clone and set up the AWS CDK application. 55 | 3. Deploy the AWS CDK application. 56 | 4. Use the image generation AI model. 57 | 5. Use the text generation AI model. 58 | 6. View the deployed resources on the [AWS Management Console](http://aws.amazon.com/console). 59 | 60 | We provide an overview of the code in this project in the appendix at the end of this post. 61 | 62 | 63 | 64 | ## Prerequisites 65 | 66 | You must have the following prerequisites: 67 | 68 | - An [AWS account](https://signin.aws.amazon.com/signin) 69 | - The [AWS CLI v2](https://docs.aws.amazon.com/cli/latest/userguide/install-cliv2.html) 70 | - Python 3.6 or later 71 | - node.js 14.x or later 72 | - The [AWS CDK v2](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html) 73 | - Docker v20.10 or later 74 | 75 | You can deploy the infrastructure in this tutorial from your local computer or you can use [AWS Cloud9](https://aws.amazon.com/cloud9/) as your deployment workstation. AWS Cloud9 comes pre-loaded with AWS CLI, AWS CDK and Docker. If you opt for AWS Cloud9, [create the environment](https://docs.aws.amazon.com/cloud9/latest/user-guide/tutorial-create-environment.html) from the [AWS console](https://console.aws.amazon.com/cloud9). 76 | 77 | The estimated cost to complete this post is $50, assuming you leave the resources running for 8 hours. Make sure you delete the resources you create in this post to avoid ongoing charges. 78 | 79 | 80 | 81 | ## Install the AWS CLI and AWS CDK on your local machine 82 | 83 | If you don't already have the AWS CLI on your local machine, refer to [Installing or updating the latest version of the AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html) and [Configuring the AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html). 84 | 85 | Install the AWS CDK Toolkit globally using the following node package manager command: 86 | 87 | ``` 88 | npm install -g aws-cdk-lib@latest 89 | ``` 90 | 91 | Run the following command to verify the correct installation and print the version number of the AWS CDK: 92 | 93 | ``` 94 | cdk --version 95 | ``` 96 | 97 | Make sure you have Docker installed on your local machine. Issue the following command to verify the version: 98 | 99 | ``` 100 | docker --version 101 | ``` 102 | 103 | 104 | 105 | 106 | 107 | ## Clone and set up the AWS CDK application 108 | 109 | On your local machine, clone the AWS CDK application with the following command: 110 | 111 | ``` 112 | git clone https://github.com/aws-samples/generative-ai-sagemaker-cdk-demo.git 113 | ``` 114 | 115 | Navigate to the project folder: 116 | 117 | ``` 118 | cd generative-ai-sagemaker-cdk-demo 119 | ``` 120 | 121 | Before we deploy the application, let's review the directory structure: 122 | 123 | ```shell 124 | . 125 | ├── LICENSE 126 | ├── README.md 127 | ├── app.py 128 | ├── cdk.json 129 | ├── code 130 | │ ├── lambda_txt2img 131 | │ │ └── txt2img.py 132 | │ └── lambda_txt2nlu 133 | │ └── txt2nlu.py 134 | ├── construct 135 | │ └── sagemaker_endpoint_construct.py 136 | ├── images 137 | │ ├── architecture.png 138 | │ ├── ... 139 | ├── requirements-dev.txt 140 | ├── requirements.txt 141 | ├── source.bat 142 | ├── stack 143 | │ ├── __init__.py 144 | │ ├── generative_ai_demo_web_stack.py 145 | │ ├── generative_ai_txt2img_sagemaker_stack.py 146 | │ ├── generative_ai_txt2nlu_sagemaker_stack.py 147 | │ └── generative_ai_vpc_network_stack.py 148 | ├── tests 149 | │ ├── __init__.py 150 | │ └── ... 151 | └── web-app 152 | ├── Dockerfile 153 | ├── Home.py 154 | ├── configs.py 155 | ├── img 156 | │ └── sagemaker.png 157 | ├── pages 158 | │ ├── 2_Image_Generation.py 159 | │ └── 3_Text_Generation.py 160 | └── requirements.txt 161 | ``` 162 | 163 | 164 | 165 | The `stack` folder contains the code for each stack in the AWS CDK application. The `code` folder contains the code for the Amazon Lambda functions. The repository also contains the web application located under the folder `web-app`. 166 | 167 | The `cdk.json` file tells the AWS CDK Toolkit how to run your application. 168 | 169 | This application was tested in the `us-east-1` Region but it should work in any Region that has the required services and inference instance type `ml.g4dn.4xlarge` specified in [app.py](app.py). 170 | 171 | 172 | 173 | #### Setup a virtual environment 174 | 175 | This project is set up like a standard Python project. Create a Python virtual environment using the following code: 176 | 177 | ``` 178 | python3 -m venv .venv 179 | ``` 180 | 181 | Use the following command to activate the virtual environment: 182 | 183 | ``` 184 | source .venv/bin/activate 185 | ``` 186 | 187 | If you're on a Windows platform, activate the virtual environment as follows: 188 | 189 | ``` 190 | .venv\Scripts\activate.bat 191 | ``` 192 | 193 | After the virtual environment is activated, upgrade pip to the latest version: 194 | 195 | ``` 196 | python3 -m pip install --upgrade pip 197 | ``` 198 | 199 | Install the required dependencies: 200 | 201 | ``` 202 | pip install -r requirements.txt 203 | ``` 204 | 205 | Before you deploy any AWS CDK application, you need to bootstrap a space in your account and the Region you're deploying into. To bootstrap in your default Region, issue the following command: 206 | 207 | ``` 208 | cdk bootstrap 209 | ``` 210 | 211 | If you want to deploy into a specific account and Region, issue the following command: 212 | 213 | ``` 214 | cdk bootstrap aws://ACCOUNT-NUMBER/REGION 215 | ``` 216 | 217 | For more information about this setup, visit [Getting started with the AWS CDK](https://docs.aws.amazon.com/cdk/latest/guide/getting_started.html). 218 | 219 | 220 | 221 | #### AWS CDK application stack structure 222 | 223 | The AWS CDK application contains multiple stacks as shown in the following diagram. 224 | 225 | ![cdk-stacks](./images/cdk-stacks.png) 226 | 227 | You can list stacks in your CDK application with the following command: 228 | 229 | ```bash 230 | cdk list 231 | ``` 232 | You should get the following output: 233 | 234 | ``` 235 | GenerativeAiTxt2imgSagemakerStack 236 | GenerativeAiTxt2nluSagemakerStack 237 | GenerativeAiVpcNetworkStack 238 | GenerativeAiDemoWebStack 239 | ``` 240 | 241 | 242 | 243 | Other useful AWS CDK commands: 244 | 245 | * `cdk ls` - Lists all stacks in the app 246 | * `cdk synth` - Emits the synthesized AWS CloudFormation template 247 | * `cdk deploy` - Deploys this stack to your default AWS account and Region 248 | * `cdk diff` - Compares the deployed stack with current state 249 | * `cdk docs` - Opens the AWS CDK documentation 250 | 251 | The next section shows you how to deploy the AWS CDK application. 252 | 253 | 254 | 255 | ## Deploy the AWS CDK application 256 | 257 | The AWS CDK application will be deployed to the default Region based on your workstation configuration. If you want to force the deployment in a specific Region, set your `AWS_DEFAULT_REGION` environment variable accordingly. 258 | 259 | 260 | 261 | At this point, you can deploy the AWS CDK application. First you launch the VPC network stack: 262 | 263 | ``` 264 | cdk deploy GenerativeAiVpcNetworkStack 265 | ``` 266 | 267 | If you are prompted, enter `y` to proceed with the deployment. You should see a list of AWS resources that are being provisioned in the stack. This step takes around 3 minutes to complete. 268 | 269 | 270 | 271 | Then you launch the web application stack: 272 | 273 | ``` 274 | cdk deploy GenerativeAiDemoWebStack 275 | ``` 276 | 277 | After analyzing the stack, the AWS CDK will display the resource list in the stack. Enter y to proceed with the deployment. This step takes around 5 minutes. 278 | 279 | ![04](./images/cdk-deploy.png) 280 | 281 | Note down the `WebApplicationServiceURL` from the output as you will use it later. You can also retrieve it later in the CloudFormation console, under the `GenerativeAiDemoWebStack` stack outputs. 282 | 283 | 284 | 285 | Now, launch the image generation AI model endpoint stack: 286 | 287 | ``` 288 | cdk deploy GenerativeAiTxt2imgSagemakerStack 289 | ``` 290 | 291 | This step takes around 8 minutes. The image generation model endpoint is deployed, we can now use it. 292 | 293 | 294 | 295 | ## Use the image generation AI model 296 | 297 | The first example demonstrates how to utilize Stable Diffusion, a powerful generative modeling technique that enables the creation of high-quality images from text prompts. 298 | 299 | 1. Access the web application using the `WebApplicationServiceURL` from the output of the `GenerativeAiDemoWebStack` in your browser. 300 | 301 | ![streamlit-01](./images/streamlit-landing-page.png) 302 | 303 | 2. In the navigation pane, choose **Image Generation**. 304 | 305 | 3. The **SageMaker Endpoint Name** and **API GW Url** fields will be pre-populated, but you can change the prompt for the image description if you'd like. 306 | 4. Choose **Generate image**. 307 | 308 | ![streamlit-03](./images/streamlit-image-gen-01.png) 309 | 310 | The application will make a call to the SageMaker endpoint. It takes a few seconds. A picture with the charasteristics in your image description will be displayed. 311 | 312 | ![streamlit-04](./images/streamlit-image-gen-02.png) 313 | 314 | 315 | 316 | ## Use the text generation AI model 317 | 318 | The second example centers around using the FLAN-T5-XL model, which is a foundation or large language model (LLM), to achieve in-context learning for text generation while also addressing a broad range of natural language understanding (NLU) and natural language generation (NLG) tasks. 319 | 320 | Some environments might limit the number of endpoints you can launch at a time. If this is the case, you can launch one SageMaker endpoint at a time. To stop a SageMaker endpoint in the AWS CDK app, you have to destroy the deployed endpoint stack and before launching the other endpoint stack. To turn down the image generation AI model endpoint, issue the following command: 321 | 322 | ``` 323 | cdk destroy GenerativeAiTxt2imgSagemakerStack 324 | ``` 325 | 326 | 327 | 328 | Then launch the text generation AI model endpoint stack: 329 | 330 | ``` 331 | cdk deploy GenerativeAiTxt2nluSagemakerStack 332 | ``` 333 | 334 | Enter `y` at the prompts. 335 | 336 | 337 | 338 | After the text generation model endpoint stack is launched, complete the following steps: 339 | 340 | 1. Go back to the web application and choose **Text Generation** in the navigation pane. 341 | 2. The **Input Context** field is pre-populated with a conversation between a customer and an agent regarding an issue with the customers phone, but you can enter your own context if you'd like. 342 | 343 | ![streamlit-05](./images/streamlit-text-gen-01.png) 344 | 345 | Below the context, you will find some prepopulated queries in the dropdown menu options. 346 | 347 | 3. Choose a query and choose **Generate Response**. 348 | 349 | ![streamlit-06](./images/streamlit-text-gen-02.png) 350 | 351 | You can also enter your own query in the **Input Query** field and choose **Generate Response**. 352 | 353 | ![streamlit-07](./images/streamlit-text-gen-03.png) 354 | 355 | 356 | 357 | ## View the deployed resources on the console 358 | 359 | On the AWS CloudFormation console, choose **Stacks** in the navigation pane to view the stacks deployed. 360 | 361 | ![console-cloudformation](./images/console-cloudformation.png) 362 | 363 | 364 | 365 | On the Amazon ECS console, you can see the clusters on the **Clusters** page. 366 | 367 | ![console-ec2](./images/console-ecs.png) 368 | 369 | 370 | 371 | On the AWS Lambda console, you can see the functions on the **Functions** page. 372 | 373 | ![console-lambda](./images/console-lambda.png) 374 | 375 | 376 | 377 | On the API Gateway console, you can see the API Gateway endpoints on the **APIs** page. 378 | 379 | ![console-apigw](./images/console-apigw.png) 380 | 381 | 382 | 383 | On the SageMaker console, you can see the deployed model endpoints on the **Endpoints** page. 384 | 385 | ![console-sagemaker](./images/console-sagemaker.png) 386 | 387 | 388 | 389 | When the stacks are launched, some parameters are generated. These are stored in the [AWS Systems Manager Parameter Store](https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html). To view them, choose **Parameter Store** in the navigation pane on the [AWS Systems Manager](https://aws.amazon.com/systems-manager/) console. 390 | 391 | ![console-ssm-parameter-store](./images/console-ssm-parameter-store.png) 392 | 393 | 394 | 395 | ## Clean up 396 | 397 | To avoid unnecessary cost, clean up all the infrastructure created with the following command on your workstation: 398 | 399 | ``` 400 | cdk destroy --all 401 | ``` 402 | 403 | Enter `y` at the prompt. This step takes around 10 minutes. Check if all resources are deleted on the console. Also delete the assets S3 buckets created by the AWS CDK on the Amazon S3 console as well as the assets repositories on Amazon ECR. 404 | 405 | 406 | 407 | ## Conclusion 408 | 409 | As demonstrated in this post, you can use the AWS CDK to deploy generative AI models in JumpStart. We showed an image generation example and a text generation example using a user interface powered by Streamlit, Lambda, and API Gateway. 410 | 411 | You can now build your generative AI projects using pre-trained AI models in JumpStart. You can also extend this project to fine-tune the foundation models for your use case and control access to API Gateway endpoints. 412 | 413 | We invite you to test the solution and contribute to the project on [GitHub](https://github.com/aws-samples/generative-ai-sagemaker-cdk-demo). 414 | 415 | 416 | 417 | ## License summary 418 | 419 | This sample code is made available under a modified MIT license. See the [LICENSE](https://github.com/Hantzley/generative-ai-sagemaker-cdk-demo/blob/main/LICENSE) file for more information. Also, review the respective licenses for the [stable diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) and [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) models on Hugging Face. 420 | 421 | 422 | 423 | ## About Authors 424 | 425 | **Hantzley Tauckoor** is an APJ Partner Solutions Architecture Leader based in Singapore. He has 20 years' experience in the ICT industry spanning multiple functional areas, including solutions architecture, business development, sales strategy, consulting, and leadership. He leads a team of Senior Solutions Architects that enable partners to develop joint solutions, build technical capabilities, and steer them through the implementation phase as customers migrate and modernize their applications to AWS. Outside work, he enjoys spending time with his family, watching movies, and hiking. 426 | 427 | 428 | 429 | **Kwonyul Choi** is a CTO at BABITALK, a Korean beauty care platform startup, based in Seoul. Prior to this role, Kownyul worked as Software Development Engineer at AWS with a focus on AWS CDK and Amazon SageMaker. 430 | 431 | 432 | 433 | **Arunprasath Shankar** is a Senior AI/ML Specialist Solutions Architect with AWS, helping global customers scale their AI solutions effectively and efficiently in the cloud. In his spare time, Arun enjoys watching sci-fi movies and listening to classical music. 434 | 435 | 436 | 437 | **Satish Upreti** is a Migration Lead PSA and Security SME in the partner organization in APJ. Satish has 20 years of experience spanning on-premises private cloud and public cloud technologies. Since joining AWS in August 2020 as a migration specialist, he provides extensive technical advice and support to AWS partners to plan and implement complex migrations. 438 | 439 | 440 | 441 | ## Appendix: Code walk-through 442 | 443 | In this section, we provide an overview of the code in this project. 444 | 445 | ### AWS CDK Application 446 | 447 | The main AWS CDK application is contained in the `app.py` file in the root directory. The project consists of multiple stacks, and we import them with proper logging configuration: 448 | 449 | ```python 450 | #!/usr/bin/env python3 451 | import aws_cdk as cdk 452 | import logging 453 | import warnings 454 | 455 | # Configure logging levels to suppress unnecessary messages 456 | logging.getLogger('sagemaker.config').setLevel(logging.ERROR) 457 | logging.getLogger('botocore.credentials').setLevel(logging.ERROR) 458 | warnings.filterwarnings("ignore", message="Field name \"json\" in \"MonitoringDatasetFormat\" shadows an attribute in parent \"Base\"") 459 | 460 | from stack.generative_ai_vpc_network_stack import GenerativeAiVpcNetworkStack 461 | from stack.generative_ai_demo_web_stack import GenerativeAiDemoWebStack 462 | from stack.generative_ai_txt2nlu_sagemaker_stack import GenerativeAiTxt2nluSagemakerStack 463 | from stack.generative_ai_txt2img_sagemaker_stack import GenerativeAiTxt2imgSagemakerStack 464 | ``` 465 | 466 | We define our generative AI models with specific versions and get the related URIs from SageMaker: 467 | 468 | ```python 469 | #Text to Image model parameters 470 | TXT2IMG_MODEL_ID = "model-txt2img-stabilityai-stable-diffusion-v2-1-base" 471 | TXT2IMG_INFERENCE_INSTANCE_TYPE = "ml.p3.2xlarge" # Fallback to ml.g4dn.4xlarge if not supported 472 | TXT2IMG_MODEL_TASK_TYPE = "txt2img" 473 | TXT2IMG_MODEL_VERSION = "2.0.9" 474 | 475 | #Text to NLU image model parameters 476 | TXT2NLU_MODEL_ID = "huggingface-text2text-flan-t5-xl" 477 | TXT2NLU_INFERENCE_INSTANCE_TYPE = "ml.g4dn.4xlarge" 478 | TXT2NLU_MODEL_TASK_TYPE = "text2text" 479 | TXT2NLU_MODEL_VERSION = "2.2.2" 480 | ``` 481 | 482 | Then, we instantiate the stacks with proper dependencies: 483 | 484 | ```python 485 | app = cdk.App() 486 | 487 | network_stack = GenerativeAiVpcNetworkStack(app, "GenerativeAiVpcNetworkStack", env=env) 488 | GenerativeAiDemoWebStack(app, "GenerativeAiDemoWebStack", vpc=network_stack.vpc, env=env) 489 | 490 | GenerativeAiTxt2nluSagemakerStack(app, "GenerativeAiTxt2nluSagemakerStack", env=env, model_info=TXT2NLU_MODEL_INFO) 491 | GenerativeAiTxt2imgSagemakerStack(app, "GenerativeAiTxt2imgSagemakerStack", env=env, model_info=TXT2IMG_MODEL_INFO) 492 | 493 | app.synth() 494 | ``` 495 | 496 | ### VPC Network Stack 497 | 498 | In the `GenerativeAiVpcNetworkStack` stack, we create a VPC with public and private subnets across two Availability Zones (AZs): 499 | 500 | ```python 501 | self.output_vpc = ec2.Vpc(self, "VPC", 502 | nat_gateways=1, 503 | ip_addresses=ec2.IpAddresses.cidr("10.0.0.0/16"), 504 | max_azs=2, 505 | subnet_configuration=[ 506 | ec2.SubnetConfiguration(name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24), 507 | ec2.SubnetConfiguration(name="private", subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, cidr_mask=24) 508 | ] 509 | ) 510 | ``` 511 | 512 | ### Demo Web Application Stack 513 | 514 | The `GenerativeAiDemoWebStack` stack sets up Lambda functions, API Gateway endpoints, and ECS infrastructure: 515 | 516 | 1. **Lambda Functions and API Gateway**: 517 | - Two Lambda functions for image generation and text generation services 518 | - Each Lambda has specific IAM roles and VPC configurations 519 | - API Gateway endpoints for both services 520 | 521 | 2. **ECS Infrastructure**: 522 | - ECS cluster with auto-scaling capabilities 523 | - Spot instance configuration for cost optimization 524 | - Launch template with specific instance type (c5.xlarge) 525 | - Auto Scaling Group with capacity provider 526 | 527 | 3. **Fargate Service**: 528 | - Application Load Balancer configuration 529 | - Task auto-scaling based on CPU utilization 530 | - IAM permissions for SSM and API Gateway access 531 | 532 | ### SageMaker Endpoint Stacks 533 | 534 | Both SageMaker endpoint stacks (`GenerativeAiTxt2imgSagemakerStack` and `GenerativeAiTxt2nluSagemakerStack`) follow a similar pattern: 535 | 536 | 1. **IAM Configuration**: 537 | - SageMaker service role with necessary permissions 538 | - STS, CloudWatch Logs, and ECR policies 539 | - S3 full access for model artifacts 540 | 541 | 2. **Endpoint Configuration**: 542 | - Model-specific environment variables 543 | - Instance type and count configuration 544 | - Model artifact location and container image settings 545 | 546 | 3. **Parameter Store**: 547 | - Endpoint names stored in SSM Parameter Store for web application access 548 | 549 | ### Web Application 550 | 551 | The web application is containerized and hosted on Amazon ECS with Fargate. The Dockerfile in the `web-app` directory contains the necessary configuration: 552 | 553 | ```dockerfile 554 | FROM --platform=linux/x86_64 python:3.9 555 | EXPOSE 8501 556 | WORKDIR /app 557 | COPY requirements.txt ./requirements.txt 558 | RUN pip3 install -r requirements.txt 559 | COPY . . 560 | CMD streamlit run Home.py \ 561 | --server.headless true \ 562 | --browser.serverAddress="0.0.0.0" \ 563 | --server.enableCORS false \ 564 | --browser.gatherUsageStats false 565 | ``` 566 | 567 | The application uses Streamlit for the user interface and interacts with the SageMaker endpoints through API Gateway. 568 | 569 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import aws_cdk as cdk 3 | import logging 4 | import warnings 5 | 6 | # Set SageMaker SDK logging level to ERROR to suppress INFO messages 7 | logging.getLogger('sagemaker.config').setLevel(logging.ERROR) 8 | # Set AWS credentials logging level to ERROR to suppress INFO messages 9 | logging.getLogger('botocore.credentials').setLevel(logging.ERROR) 10 | # Suppress Pydantic field shadowing warning 11 | warnings.filterwarnings("ignore", message="Field name \"json\" in \"MonitoringDatasetFormat\" shadows an attribute in parent \"Base\"") 12 | 13 | from stack.generative_ai_vpc_network_stack import GenerativeAiVpcNetworkStack 14 | from stack.generative_ai_demo_web_stack import GenerativeAiDemoWebStack 15 | from stack.generative_ai_txt2nlu_sagemaker_stack import GenerativeAiTxt2nluSagemakerStack 16 | from stack.generative_ai_txt2img_sagemaker_stack import GenerativeAiTxt2imgSagemakerStack 17 | 18 | from script.sagemaker_uri import * 19 | import boto3 20 | 21 | region_name = boto3.Session().region_name 22 | env={"region": region_name} 23 | 24 | #Text to Image model parameters 25 | TXT2IMG_MODEL_ID = "model-txt2img-stabilityai-stable-diffusion-v2-1-base" 26 | TXT2IMG_INFERENCE_INSTANCE_TYPE = "ml.p3.2xlarge" #if your region does not support this instance type, try ml.g4dn.4xlarge 27 | TXT2IMG_MODEL_TASK_TYPE = "txt2img" 28 | TXT2IMG_MODEL_VERSION = "2.0.9" 29 | TXT2IMG_MODEL_INFO = get_sagemaker_uris(model_id=TXT2IMG_MODEL_ID, 30 | model_task_type=TXT2IMG_MODEL_TASK_TYPE, 31 | instance_type=TXT2IMG_INFERENCE_INSTANCE_TYPE, 32 | model_version=TXT2IMG_MODEL_VERSION, 33 | region_name=region_name) 34 | 35 | #Text to NLU image model parameters 36 | TXT2NLU_MODEL_ID = "huggingface-text2text-flan-t5-xl" 37 | TXT2NLU_INFERENCE_INSTANCE_TYPE = "ml.g4dn.4xlarge" 38 | TXT2NLU_MODEL_TASK_TYPE = "text2text" 39 | TXT2NLU_MODEL_VERSION = "2.2.2" 40 | TXT2NLU_MODEL_INFO = get_sagemaker_uris(model_id=TXT2NLU_MODEL_ID, 41 | model_task_type=TXT2NLU_MODEL_TASK_TYPE, 42 | instance_type=TXT2NLU_INFERENCE_INSTANCE_TYPE, 43 | model_version=TXT2NLU_MODEL_VERSION, 44 | region_name=region_name) 45 | 46 | app = cdk.App() 47 | 48 | network_stack = GenerativeAiVpcNetworkStack(app, "GenerativeAiVpcNetworkStack", env=env) 49 | GenerativeAiDemoWebStack(app, "GenerativeAiDemoWebStack", vpc=network_stack.vpc, env=env) 50 | 51 | GenerativeAiTxt2nluSagemakerStack(app, "GenerativeAiTxt2nluSagemakerStack", env=env, model_info=TXT2NLU_MODEL_INFO) 52 | GenerativeAiTxt2imgSagemakerStack(app, "GenerativeAiTxt2imgSagemakerStack", env=env, model_info=TXT2IMG_MODEL_INFO) 53 | 54 | app.synth() 55 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "requirements*.txt", 11 | "source.bat", 12 | "**/__init__.py", 13 | "python/__pycache__", 14 | "tests" 15 | ] 16 | }, 17 | "context": { 18 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 19 | "@aws-cdk/core:checkSecretUsage": true, 20 | "@aws-cdk/core:target-partitions": [ 21 | "aws", 22 | "aws-cn" 23 | ], 24 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 25 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 26 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 27 | "@aws-cdk/aws-iam:minimizePolicies": true, 28 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 29 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 30 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 31 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 32 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 33 | "@aws-cdk/core:enablePartitionLiterals": true, 34 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 35 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 36 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 37 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 38 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 39 | "@aws-cdk/aws-route53-patters:useCertificate": true, 40 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 41 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 42 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 43 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 44 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 45 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 46 | "@aws-cdk/aws-redshift:columnId": true, 47 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /code/lambda_txt2img/txt2img.py: -------------------------------------------------------------------------------- 1 | import json 2 | import boto3 3 | runtime= boto3.client('runtime.sagemaker') 4 | 5 | 6 | def lambda_handler(event, context): 7 | body = json.loads(event['body']) 8 | prompt = body['prompt'] 9 | endpoint_name = body['endpoint_name'] 10 | 11 | response = runtime.invoke_endpoint(EndpointName=endpoint_name, 12 | Body=prompt, 13 | ContentType='application/x-text') 14 | 15 | response_body = json.loads(response['Body'].read().decode()) 16 | generated_image = response_body['generated_image'] 17 | 18 | message = {"prompt": prompt,'image':generated_image} 19 | 20 | return { 21 | "statusCode": 200, 22 | "body": json.dumps(message), 23 | "headers": { 24 | "Content-Type": "application/json" 25 | } 26 | } -------------------------------------------------------------------------------- /code/lambda_txt2nlu/txt2nlu.py: -------------------------------------------------------------------------------- 1 | import json 2 | import boto3 3 | 4 | runtime = boto3.client('runtime.sagemaker') 5 | 6 | MAX_LENGTH = 512 7 | NUM_RETURN_SEQUENCES = 1 8 | TOP_K = 40 9 | TOP_P = 0.8 10 | DO_SAMPLE = True 11 | MAX_TOTAL_TOKENS = 512 # Maximum total tokens allowed by the model 12 | MAX_CHARACTERS = 1700 # Maximum characters allowed in input 13 | 14 | def truncate_input(prompt, max_tokens): 15 | # Truncate to MAX_CHARACTERS if the input is longer 16 | if len(prompt) > MAX_CHARACTERS: 17 | return prompt[:MAX_CHARACTERS] 18 | return prompt 19 | 20 | def lambda_handler(event, context): 21 | body = json.loads(event['body']) 22 | prompt = body['prompt'] 23 | endpoint_name = body['endpoint_name'] 24 | 25 | # Truncate input if necessary 26 | max_input_tokens = MAX_TOTAL_TOKENS - MAX_LENGTH 27 | truncated_prompt = truncate_input(prompt, max_input_tokens) 28 | 29 | payload = { 30 | "inputs": truncated_prompt, 31 | "parameters":{ 32 | "max_length": MAX_LENGTH, 33 | "num_return_sequences": NUM_RETURN_SEQUENCES, 34 | "top_k": TOP_K, 35 | "top_p": TOP_P, 36 | "do_sample": DO_SAMPLE 37 | } 38 | } 39 | payload = json.dumps(payload).encode('utf-8') 40 | 41 | response = runtime.invoke_endpoint(EndpointName=endpoint_name, 42 | ContentType= 'application/json', 43 | Body=payload) 44 | 45 | model_predictions = json.loads(response['Body'].read()) 46 | generated_text = model_predictions[0]['generated_text'] 47 | 48 | message = { 49 | "prompt": truncated_prompt, 50 | "original_prompt": prompt, 51 | "was_truncated": prompt != truncated_prompt, 52 | 'generated_text': generated_text 53 | } 54 | 55 | return { 56 | "statusCode": 200, 57 | "body": json.dumps(message), 58 | "headers": { 59 | "Content-Type": "application/json" 60 | } 61 | } -------------------------------------------------------------------------------- /construct/sagemaker_endpoint_construct.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | aws_sagemaker as sagemaker, 3 | CfnOutput 4 | ) 5 | from constructs import Construct 6 | 7 | 8 | class SageMakerEndpointConstruct(Construct): 9 | 10 | def __init__(self, scope: Construct, construct_id: str, 11 | project_prefix: str, 12 | role_arn: str, 13 | model_name: str, 14 | model_bucket_name: str, 15 | model_bucket_key: str, 16 | model_docker_image: str, 17 | variant_name: str, 18 | variant_weight: int, 19 | instance_count: int, 20 | instance_type: str, 21 | environment: dict, 22 | deploy_enable: bool) -> None: 23 | super().__init__(scope, construct_id) 24 | 25 | model = sagemaker.CfnModel(self, f"{model_name}-Model", 26 | execution_role_arn= role_arn, 27 | containers=[ 28 | sagemaker.CfnModel.ContainerDefinitionProperty( 29 | image= model_docker_image, 30 | environment= environment, 31 | mode="SingleModel", 32 | 33 | model_data_source = sagemaker.CfnModel.ModelDataSourceProperty( 34 | s3_data_source = sagemaker.CfnModel.S3DataSourceProperty( 35 | compression_type="None", 36 | s3_data_type="S3Prefix", 37 | s3_uri=f"s3://{model_bucket_name}/{model_bucket_key}", 38 | ) 39 | ), 40 | ) 41 | ], 42 | model_name= f"{project_prefix}-{model_name}-Model", 43 | ) 44 | 45 | config = sagemaker.CfnEndpointConfig(self, f"{model_name}-Config", 46 | endpoint_config_name= f"{project_prefix}-{model_name}-Config", 47 | production_variants=[ 48 | sagemaker.CfnEndpointConfig.ProductionVariantProperty( 49 | model_name= model.attr_model_name, 50 | variant_name= variant_name, 51 | initial_variant_weight= variant_weight, 52 | initial_instance_count= instance_count, 53 | instance_type= instance_type 54 | ) 55 | ] 56 | ) 57 | 58 | self.deploy_enable = deploy_enable 59 | if deploy_enable: 60 | self.endpoint = sagemaker.CfnEndpoint(self, f"{model_name}-Endpoint", 61 | endpoint_name= f"{project_prefix}-{model_name}-Endpoint", 62 | endpoint_config_name= config.attr_endpoint_config_name 63 | ) 64 | 65 | CfnOutput(scope=self,id=f"{model_name}EndpointName", value=self.endpoint.endpoint_name) 66 | 67 | 68 | @property 69 | def endpoint_name(self) -> str: 70 | return self.endpoint.attr_endpoint_name if self.deploy_enable else "not_yet_deployed" 71 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/architecture.png -------------------------------------------------------------------------------- /images/cdk-deploy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/cdk-deploy.png -------------------------------------------------------------------------------- /images/cdk-stacks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/cdk-stacks.png -------------------------------------------------------------------------------- /images/console-apigw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/console-apigw.png -------------------------------------------------------------------------------- /images/console-cloudformation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/console-cloudformation.png -------------------------------------------------------------------------------- /images/console-ecs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/console-ecs.png -------------------------------------------------------------------------------- /images/console-lambda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/console-lambda.png -------------------------------------------------------------------------------- /images/console-sagemaker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/console-sagemaker.png -------------------------------------------------------------------------------- /images/console-ssm-parameter-store.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/console-ssm-parameter-store.png -------------------------------------------------------------------------------- /images/foundation-models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/foundation-models.png -------------------------------------------------------------------------------- /images/streamlit-image-gen-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/streamlit-image-gen-01.png -------------------------------------------------------------------------------- /images/streamlit-image-gen-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/streamlit-image-gen-02.png -------------------------------------------------------------------------------- /images/streamlit-landing-page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/streamlit-landing-page.png -------------------------------------------------------------------------------- /images/streamlit-text-gen-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/streamlit-text-gen-01.png -------------------------------------------------------------------------------- /images/streamlit-text-gen-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/streamlit-text-gen-02.png -------------------------------------------------------------------------------- /images/streamlit-text-gen-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/images/streamlit-text-gen-03.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest==6.2.5 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aws-cdk-lib==2.141.0 2 | constructs>=10.0.0,<11.0.0 3 | 4 | boto3 5 | sagemaker==2.243.3 6 | setuptools==70.2.0 7 | -------------------------------------------------------------------------------- /script/sagemaker_uri.py: -------------------------------------------------------------------------------- 1 | import sagemaker 2 | import boto3 3 | from sagemaker import script_uris 4 | from sagemaker import image_uris 5 | from sagemaker import model_uris 6 | from sagemaker.jumpstart.notebook_utils import list_jumpstart_models 7 | 8 | session = sagemaker.Session() 9 | 10 | def get_sagemaker_uris(model_id,model_task_type,instance_type,model_version,region_name): 11 | 12 | FILTER = f"task == {model_task_type}" 13 | #txt2img_models = list_jumpstart_models(filter=FILTER) 14 | 15 | MODEL_VERSION = model_version # latest = "*" 16 | SCOPE = "inference" 17 | 18 | inference_image_uri = image_uris.retrieve(region=region_name, 19 | framework=None, 20 | model_id=model_id, 21 | model_version=MODEL_VERSION, 22 | image_scope=SCOPE, 23 | instance_type=instance_type) 24 | 25 | inference_model_uri = model_uris.retrieve(model_id=model_id, 26 | model_version=MODEL_VERSION, 27 | model_scope=SCOPE) 28 | 29 | inference_source_uri = script_uris.retrieve(model_id=model_id, 30 | model_version=MODEL_VERSION, 31 | script_scope=SCOPE) 32 | 33 | model_bucket_name = inference_model_uri.split("/")[2] 34 | model_bucket_key = "/".join(inference_model_uri.split("/")[3:]) 35 | model_docker_image = inference_image_uri 36 | 37 | return {"model_bucket_name":model_bucket_name, "model_bucket_key": model_bucket_key, \ 38 | "model_docker_image":model_docker_image, "instance_type":instance_type, \ 39 | "inference_source_uri":inference_source_uri, "region_name":region_name} 40 | -------------------------------------------------------------------------------- /source.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | rem The sole purpose of this script is to make the command 4 | rem 5 | rem source .venv/bin/activate 6 | rem 7 | rem (which activates a Python virtualenv on Linux or Mac OS X) work on Windows. 8 | rem On Windows, this command just runs this batch file (the argument is ignored). 9 | rem 10 | rem Now we don't need to document a Windows command for activating a virtualenv. 11 | 12 | echo Executing .venv\Scripts\activate.bat for you 13 | .venv\Scripts\activate.bat 14 | -------------------------------------------------------------------------------- /stack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/stack/__init__.py -------------------------------------------------------------------------------- /stack/generative_ai_demo_web_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | Duration, 3 | Stack, 4 | aws_lambda as _lambda, 5 | aws_apigateway as apigw, 6 | aws_ec2 as ec2, 7 | aws_iam as iam, 8 | aws_ssm as ssm, 9 | aws_ecs as ecs, 10 | aws_ecs_patterns as ecs_patterns, 11 | aws_autoscaling as autoscaling, 12 | ) 13 | from constructs import Construct 14 | 15 | class GenerativeAiDemoWebStack(Stack): 16 | 17 | def __init__(self, scope: Construct, construct_id: str, vpc: ec2.IVpc, **kwargs) -> None: 18 | super().__init__(scope, construct_id, **kwargs) 19 | 20 | # Defines role for the AWS Lambda functions 21 | role = iam.Role(self, "Gen-AI-Lambda-Policy", assumed_by=iam.ServicePrincipal("lambda.amazonaws.com")) 22 | role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AWSLambdaBasicExecutionRole")) 23 | role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AWSLambdaVPCAccessExecutionRole")) 24 | role.attach_inline_policy(iam.Policy(self, "sm-invoke-policy", 25 | statements=[iam.PolicyStatement( 26 | effect=iam.Effect.ALLOW, 27 | actions=["sagemaker:InvokeEndpoint"], 28 | resources=["*"] 29 | )] 30 | )) 31 | 32 | # Defines an AWS Lambda function for Image Generation service 33 | lambda_txt2img = _lambda.Function( 34 | self, "lambda_txt2img", 35 | runtime=_lambda.Runtime.PYTHON_3_9, 36 | code=_lambda.Code.from_asset("code/lambda_txt2img"), 37 | handler="txt2img.lambda_handler", 38 | role=role, 39 | timeout=Duration.seconds(180), 40 | memory_size=512, 41 | vpc_subnets=ec2.SubnetSelection( 42 | subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS 43 | ), 44 | vpc=vpc 45 | ) 46 | 47 | # Defines an Amazon API Gateway endpoint for Image Generation service 48 | txt2img_apigw_endpoint = apigw.LambdaRestApi( 49 | self, "txt2img_apigw_endpoint", 50 | handler=lambda_txt2img 51 | ) 52 | 53 | # Defines an AWS Lambda function for NLU & Text Generation service 54 | lambda_txt2nlu = _lambda.Function( 55 | self, "lambda_txt2nlu", 56 | runtime=_lambda.Runtime.PYTHON_3_9, 57 | code=_lambda.Code.from_asset("code/lambda_txt2nlu"), 58 | handler="txt2nlu.lambda_handler", 59 | role=role, 60 | timeout=Duration.seconds(180), 61 | memory_size=512, 62 | vpc_subnets=ec2.SubnetSelection( 63 | subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS 64 | ), 65 | vpc=vpc 66 | ) 67 | 68 | # Defines an Amazon API Gateway endpoint for NLU & Text Generation service 69 | txt2nlu_apigw_endpoint = apigw.LambdaRestApi( 70 | self, "txt2nlu_apigw_endpoint", 71 | handler=lambda_txt2nlu 72 | ) 73 | 74 | # Create ECS cluster 75 | cluster = ecs.Cluster(self, "WebDemoCluster", vpc=vpc) 76 | 77 | # Create an IAM role for the EC2 instances 78 | instance_role = iam.Role( 79 | self, "EcsInstanceRole", 80 | assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"), 81 | managed_policies=[ 82 | iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AmazonEC2ContainerServiceforEC2Role"), 83 | iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMManagedInstanceCore") # Optional: for SSM agent 84 | ] 85 | ) 86 | 87 | # Create a launch template 88 | user_data = ec2.UserData.for_linux() 89 | user_data.add_commands( 90 | f"echo ECS_CLUSTER={cluster.cluster_name} >> /etc/ecs/ecs.config", 91 | "sudo iptables --insert FORWARD 1 --in-interface docker+ --destination 169.254.169.254/32 --jump DROP", 92 | "sudo service iptables save", 93 | "echo ECS_AWSVPC_BLOCK_IMDS=true >> /etc/ecs/ecs.config" 94 | ) 95 | 96 | # Create the launch template 97 | launch_template = ec2.LaunchTemplate( 98 | self, "EcsSpotLaunchTemplate", 99 | launch_template_name="EcsSpotLaunchTemplate", 100 | instance_type=ec2.InstanceType("c5.xlarge"), 101 | machine_image=ecs.EcsOptimizedImage.amazon_linux2(), 102 | user_data=user_data, 103 | role=instance_role, # Assign the role to the launch template 104 | block_devices=[ 105 | ec2.BlockDevice( 106 | device_name="/dev/xvda", 107 | volume=ec2.BlockDeviceVolume.ebs(30) # Increased from 20 to 30 GB to match the snapshot size requirements 108 | ) 109 | ], 110 | # Add Spot options with ONE_TIME request type (required for AutoScaling) 111 | spot_options=ec2.LaunchTemplateSpotOptions( 112 | max_price=0.0735, # Set max price for Spot Instances 113 | request_type=ec2.SpotRequestType.ONE_TIME # Changed from PERSISTENT to ONE_TIME as required by AutoScaling 114 | ) 115 | ) 116 | 117 | # Create the Auto Scaling Group with the launch template 118 | asg = autoscaling.AutoScalingGroup( 119 | self, "AsgSpotNew", # Changed ID to ensure a new resource is created 120 | vpc=vpc, 121 | min_capacity=1, 122 | max_capacity=2, 123 | #desired_capacity=2, 124 | launch_template=launch_template, # Use the launch template instead of instance properties 125 | ) 126 | 127 | # Add the ASG capacity to the ECS cluster 128 | capacity_provider = ecs.AsgCapacityProvider( 129 | self, "AsgCapacityProvider", 130 | auto_scaling_group=asg, 131 | enable_managed_termination_protection=False, 132 | spot_instance_draining=True 133 | ) 134 | cluster.add_asg_capacity_provider(capacity_provider) 135 | 136 | # Build Dockerfile from local folder and push to ECR 137 | image = ecs.ContainerImage.from_asset("web-app") 138 | 139 | # Create Fargate service 140 | fargate_service = ecs_patterns.ApplicationLoadBalancedFargateService( 141 | self, "WebApplication", 142 | cluster=cluster, # Required 143 | cpu=2048, # Default is 256 (512 is 0.5 vCPU, 2048 is 2 vCPU) 144 | desired_count=1, # Default is 1 145 | task_image_options=ecs_patterns.ApplicationLoadBalancedTaskImageOptions( 146 | image=image, 147 | container_port=8501, 148 | ), 149 | #load_balancer_name="gen-ai-demo", 150 | memory_limit_mib=4096, # Default is 512 151 | public_load_balancer=True) # Default is True 152 | 153 | 154 | fargate_service.task_definition.add_to_task_role_policy(iam.PolicyStatement( 155 | effect=iam.Effect.ALLOW, 156 | actions = ["ssm:GetParameter"], 157 | resources = ["arn:aws:ssm:*"], 158 | ) 159 | ) 160 | 161 | fargate_service.task_definition.add_to_task_role_policy(iam.PolicyStatement( 162 | effect=iam.Effect.ALLOW, 163 | actions = ["execute-api:Invoke","execute-api:ManageConnections"], 164 | resources = ["*"], 165 | ) 166 | ) 167 | 168 | 169 | # Setup task auto-scaling 170 | scaling = fargate_service.service.auto_scale_task_count( 171 | max_capacity=10 172 | ) 173 | scaling.scale_on_cpu_utilization( 174 | "CpuScaling", 175 | target_utilization_percent=50, 176 | scale_in_cooldown=Duration.seconds(60), 177 | scale_out_cooldown=Duration.seconds(60), 178 | ) 179 | 180 | ssm.StringParameter(self, "txt2img_api_endpoint", parameter_name="txt2img_api_endpoint", string_value=txt2img_apigw_endpoint.url) 181 | ssm.StringParameter(self, "txt2nlu_api_endpoint", parameter_name="txt2nlu_api_endpoint", string_value=txt2nlu_apigw_endpoint.url) -------------------------------------------------------------------------------- /stack/generative_ai_txt2img_sagemaker_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | Stack, 3 | aws_iam as iam, 4 | aws_ssm as ssm, 5 | ) 6 | 7 | from constructs import Construct 8 | 9 | from construct.sagemaker_endpoint_construct import SageMakerEndpointConstruct 10 | 11 | class GenerativeAiTxt2imgSagemakerStack(Stack): 12 | 13 | def __init__(self, scope: Construct, construct_id: str, model_info, **kwargs) -> None: 14 | super().__init__(scope, construct_id, **kwargs) 15 | 16 | role = iam.Role(self, "Gen-AI-SageMaker-Policy", assumed_by=iam.ServicePrincipal("sagemaker.amazonaws.com")) 17 | role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess")) 18 | 19 | sts_policy = iam.Policy(self, "sm-deploy-policy-sts", 20 | statements=[iam.PolicyStatement( 21 | effect=iam.Effect.ALLOW, 22 | actions=[ 23 | "sts:AssumeRole" 24 | ], 25 | resources=["*"] 26 | )] 27 | ) 28 | 29 | logs_policy = iam.Policy(self, "sm-deploy-policy-logs", 30 | statements=[iam.PolicyStatement( 31 | effect=iam.Effect.ALLOW, 32 | actions=[ 33 | "cloudwatch:PutMetricData", 34 | "logs:CreateLogStream", 35 | "logs:PutLogEvents", 36 | "logs:CreateLogGroup", 37 | "logs:DescribeLogStreams", 38 | "ecr:GetAuthorizationToken" 39 | ], 40 | resources=["*"] 41 | )] 42 | ) 43 | 44 | ecr_policy = iam.Policy(self, "sm-deploy-policy-ecr", 45 | statements=[iam.PolicyStatement( 46 | effect=iam.Effect.ALLOW, 47 | actions=[ 48 | "ecr:*", 49 | ], 50 | resources=["*"] 51 | )] 52 | ) 53 | 54 | role.attach_inline_policy(sts_policy) 55 | role.attach_inline_policy(logs_policy) 56 | role.attach_inline_policy(ecr_policy) 57 | 58 | endpoint = SageMakerEndpointConstruct(self, "TXT2IMG", 59 | project_prefix = "GenerativeAiDemo", 60 | 61 | role_arn= role.role_arn, 62 | 63 | model_name = "StableDiffusionText2Img", 64 | model_bucket_name = model_info["model_bucket_name"], 65 | model_bucket_key = model_info["model_bucket_key"], 66 | model_docker_image = model_info["model_docker_image"], 67 | 68 | variant_name = "AllTraffic", 69 | variant_weight = 1, 70 | instance_count = 1, 71 | instance_type = model_info["instance_type"], 72 | 73 | environment = { 74 | "MMS_MAX_RESPONSE_SIZE": "20000000", 75 | "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", 76 | "SAGEMAKER_PROGRAM": "inference.py", 77 | "SAGEMAKER_REGION": model_info["region_name"], 78 | "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", 79 | }, 80 | 81 | deploy_enable = True 82 | ) 83 | 84 | endpoint.node.add_dependency(sts_policy) 85 | endpoint.node.add_dependency(logs_policy) 86 | endpoint.node.add_dependency(ecr_policy) 87 | 88 | ssm.StringParameter(self, "txt2img_sm_endpoint", parameter_name="txt2img_sm_endpoint", string_value=endpoint.endpoint_name) 89 | -------------------------------------------------------------------------------- /stack/generative_ai_txt2nlu_sagemaker_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | Stack, 3 | aws_iam as iam, 4 | aws_ssm as ssm, 5 | ) 6 | from constructs import Construct 7 | 8 | from construct.sagemaker_endpoint_construct import SageMakerEndpointConstruct 9 | 10 | class GenerativeAiTxt2nluSagemakerStack(Stack): 11 | 12 | def __init__(self, scope: Construct, construct_id: str, model_info, **kwargs) -> None: 13 | super().__init__(scope, construct_id, **kwargs) 14 | 15 | role = iam.Role(self, "Gen-AI-SageMaker-Policy", assumed_by=iam.ServicePrincipal("sagemaker.amazonaws.com")) 16 | role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess")) 17 | 18 | sts_policy = iam.Policy(self, "sm-deploy-policy-sts", 19 | statements=[iam.PolicyStatement( 20 | effect=iam.Effect.ALLOW, 21 | actions=[ 22 | "sts:AssumeRole" 23 | ], 24 | resources=["*"] 25 | )] 26 | ) 27 | 28 | logs_policy = iam.Policy(self, "sm-deploy-policy-logs", 29 | statements=[iam.PolicyStatement( 30 | effect=iam.Effect.ALLOW, 31 | actions=[ 32 | "cloudwatch:PutMetricData", 33 | "logs:CreateLogStream", 34 | "logs:PutLogEvents", 35 | "logs:CreateLogGroup", 36 | "logs:DescribeLogStreams", 37 | "ecr:GetAuthorizationToken" 38 | ], 39 | resources=["*"] 40 | )] 41 | ) 42 | 43 | ecr_policy = iam.Policy(self, "sm-deploy-policy-ecr", 44 | statements=[iam.PolicyStatement( 45 | effect=iam.Effect.ALLOW, 46 | actions=[ 47 | "ecr:*", 48 | ], 49 | resources=["*"] 50 | )] 51 | ) 52 | 53 | role.attach_inline_policy(sts_policy) 54 | role.attach_inline_policy(logs_policy) 55 | role.attach_inline_policy(ecr_policy) 56 | 57 | endpoint = SageMakerEndpointConstruct(self, "TXT2NLU", 58 | project_prefix = "GenerativeAiDemo", 59 | 60 | role_arn= role.role_arn, 61 | 62 | model_name = "HuggingfaceText2TextFlan", 63 | model_bucket_name = model_info["model_bucket_name"], 64 | model_bucket_key = model_info["model_bucket_key"], 65 | model_docker_image = model_info["model_docker_image"], 66 | 67 | variant_name = "AllTraffic", 68 | variant_weight = 1, 69 | instance_count = 1, 70 | instance_type = model_info["instance_type"], 71 | 72 | environment = { 73 | "MODEL_CACHE_ROOT": "/opt/ml/model", 74 | "HF_MODEL_ID": "/opt/ml/model", 75 | "SAGEMAKER_ENV": "1", 76 | "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", 77 | "SAGEMAKER_MODEL_SERVER_WORKERS": "1", 78 | "SAGEMAKER_PROGRAM": "inference.py", 79 | "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code/", 80 | "TS_DEFAULT_WORKERS_PER_MODEL": "1" 81 | }, 82 | 83 | deploy_enable = True 84 | ) 85 | 86 | endpoint.node.add_dependency(sts_policy) 87 | endpoint.node.add_dependency(logs_policy) 88 | endpoint.node.add_dependency(ecr_policy) 89 | 90 | ssm.StringParameter(self, "txt2nlu_sm_endpoint", parameter_name="txt2nlu_sm_endpoint", string_value=endpoint.endpoint_name) 91 | -------------------------------------------------------------------------------- /stack/generative_ai_vpc_network_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | Stack, 3 | aws_ec2 as ec2 4 | ) 5 | from constructs import Construct 6 | 7 | 8 | class GenerativeAiVpcNetworkStack(Stack): 9 | 10 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: 11 | super().__init__(scope, construct_id, **kwargs) 12 | 13 | self.output_vpc = ec2.Vpc(self, "VPC", 14 | nat_gateways=1, 15 | ip_addresses=ec2.IpAddresses.cidr("10.0.0.0/16"), 16 | max_azs=2, 17 | subnet_configuration=[ 18 | ec2.SubnetConfiguration(name="public",subnet_type=ec2.SubnetType.PUBLIC,cidr_mask=24), 19 | ec2.SubnetConfiguration(name="private",subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS,cidr_mask=24) 20 | ] 21 | ) 22 | 23 | 24 | @property 25 | def vpc(self) -> ec2.Vpc: 26 | return self.output_vpc 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/tests/__init__.py -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_generative_ai_sagemaker_cdk_demo_stack.py: -------------------------------------------------------------------------------- 1 | import aws_cdk as core 2 | import aws_cdk.assertions as assertions 3 | 4 | from stack.generative_ai_demo_web_stack import GenerativeAiDemoWebStack 5 | 6 | # example tests. To run these tests, uncomment this file along with the example 7 | # resource in generative_ai_sagemaker_cdk_demo/generative_ai_sagemaker_cdk_demo_stack.py 8 | def test_sqs_queue_created(): 9 | app = core.App() 10 | stack = GenerativeAiDemoWebStack(app, "GenerativeAiDemoWebStack") 11 | template = assertions.Template.from_stack(stack) 12 | 13 | # template.has_resource_properties("AWS::SQS::Queue", { 14 | # "VisibilityTimeout": 300 15 | # }) 16 | -------------------------------------------------------------------------------- /web-app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/x86_64 python:3.9 2 | EXPOSE 8501 3 | WORKDIR /app 4 | COPY requirements.txt ./requirements.txt 5 | RUN pip3 install -r requirements.txt 6 | COPY . . 7 | CMD streamlit run Home.py \ 8 | --server.headless true \ 9 | --browser.serverAddress="0.0.0.0" \ 10 | --server.enableCORS false \ 11 | --browser.gatherUsageStats false 12 | -------------------------------------------------------------------------------- /web-app/Home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | 4 | from PIL import Image 5 | image = Image.open("./img/sagemaker.png") 6 | st.image(image, width=80) 7 | 8 | version = os.environ.get("WEB_VERSION", "0.1") 9 | 10 | st.header(f"Generative AI Demo (Version {version})") 11 | st.markdown("This is a demo of Generative AI models in Amazon SageMaker Jumpstart") 12 | st.markdown("_Please select an option from the sidebar_") -------------------------------------------------------------------------------- /web-app/configs.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | region_name = boto3.Session().region_name 4 | 5 | key_txt2img_api_endpoint = "txt2img_api_endpoint" # this value is from GenerativeAiDemoWebStack 6 | key_txt2img_sm_endpoint = "txt2img_sm_endpoint" # this value is from GenerativeAiTxt2ImgSagemakerStack 7 | 8 | key_txt2nlu_api_endpoint = "txt2nlu_api_endpoint" # this value is from GenerativeAiDemoWebStack 9 | key_txt2nlu_sm_endpoint = "txt2nlu_sm_endpoint" # this value is from GenerativeAiTxt2nluSagemakerStack 10 | 11 | def get_parameter(name): 12 | """ 13 | This function retrieves a specific value from Systems Manager"s ParameterStore. 14 | """ 15 | ssm_client = boto3.client("ssm",region_name=region_name) 16 | response = ssm_client.get_parameter(Name=name) 17 | value = response["Parameter"]["Value"] 18 | 19 | return value -------------------------------------------------------------------------------- /web-app/img/sagemaker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/generative-ai-sagemaker-cdk-demo/4e568d95666e2bce54b22673cbcc3973ff8204f0/web-app/img/sagemaker.png -------------------------------------------------------------------------------- /web-app/pages/2_Image_Generation.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import time 6 | 7 | from configs import * 8 | 9 | from PIL import Image 10 | image = Image.open("./img/sagemaker.png") 11 | st.image(image, width=80) 12 | st.header("Image Generation") 13 | st.caption("Using Stable Diffusion model from Hugging Face") 14 | 15 | with st.spinner("Retrieving configurations..."): 16 | 17 | all_configs_loaded = False 18 | 19 | while not all_configs_loaded: 20 | try: 21 | api_endpoint = get_parameter(key_txt2img_api_endpoint) 22 | sm_endpoint = get_parameter(key_txt2img_sm_endpoint) 23 | all_configs_loaded = True 24 | except: 25 | time.sleep(5) 26 | 27 | endpoint_name = st.sidebar.text_input("SageMaker Endpoint Name:",sm_endpoint) 28 | url = st.sidebar.text_input("API GW Url:",api_endpoint) 29 | 30 | 31 | prompt = st.text_area("Input Image description:", """Dog in superhero outfit""") 32 | 33 | if st.button("Generate image"): 34 | if endpoint_name == "" or prompt == "" or url == "": 35 | st.error("Please enter a valid endpoint name, API gateway url and prompt!") 36 | else: 37 | with st.spinner("Wait for it..."): 38 | try: 39 | r = requests.post(url,json={"prompt":prompt,"endpoint_name":endpoint_name},timeout=180) 40 | data = r.json() 41 | image_array = data["image"] 42 | st.image(np.array(image_array)) 43 | 44 | except requests.exceptions.ConnectionError as errc: 45 | st.error("Error Connecting:",errc) 46 | 47 | except requests.exceptions.HTTPError as errh: 48 | st.error("Http Error:",errh) 49 | 50 | except requests.exceptions.Timeout as errt: 51 | st.error("Timeout Error:",errt) 52 | 53 | except requests.exceptions.RequestException as err: 54 | st.error("OOps: Something Else",err) 55 | 56 | st.success("Done!") -------------------------------------------------------------------------------- /web-app/pages/3_Text_Generation.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import time 4 | 5 | from configs import * 6 | 7 | from PIL import Image 8 | image = Image.open("./img/sagemaker.png") 9 | st.image(image, width=80) 10 | st.header("Text Generation") 11 | st.caption("Using FLAN-T5-XL model from Hugging Face") 12 | 13 | conversation = """Customer: Hi, my iPhone isn’t charging well, and the battery drains fast. I’ve tried different cables and adapters, but no luck. 14 | Agent: Sorry to hear that. Check Settings > Battery for apps using lots of power. 15 | Customer: Some apps are draining battery. 16 | Agent: Force quit those apps by swiping up to close them. 17 | Customer: Did that, but no improvement. 18 | Agent: Let’s reset your settings: Settings > General > Reset > Reset All Settings. This won’t erase data. 19 | Customer: Done. What next? 20 | Agent: Restart your iPhone by holding the power button, slide to power off, then turn it back on. 21 | Customer: Restarted, still not charging properly. 22 | Agent: You should get a diagnostic test at an Apple Store or authorized service provider. 23 | Customer: Do I need an appointment? 24 | Agent: Yes, it’s best to book online or by phone to avoid waiting. 25 | Customer: Will repairs cost me? 26 | Agent: If under warranty, repairs are free; otherwise, you’ll pay. 27 | Customer: How long will repairs take? 28 | Agent: Usually 1-2 business days, depending on the issue. 29 | Customer: Can I track the repair status? 30 | Agent: Yes, online or by contacting the service center. 31 | Customer: Thanks for your help. 32 | Agent: You’re welcome! Let me know if you need anything else.""" 33 | 34 | with st.spinner("Retrieving configurations..."): 35 | 36 | all_configs_loaded = False 37 | 38 | while not all_configs_loaded: 39 | try: 40 | api_endpoint = get_parameter(key_txt2nlu_api_endpoint) 41 | sm_endpoint = get_parameter(key_txt2nlu_sm_endpoint) 42 | all_configs_loaded = True 43 | except: 44 | time.sleep(5) 45 | 46 | endpoint_name = st.sidebar.text_input("SageMaker Endpoint Name:",sm_endpoint) 47 | url = st.sidebar.text_input("API GW Url:",api_endpoint) 48 | 49 | context = st.text_area("Input Context:", conversation, height=300, max_chars=1700) 50 | 51 | 52 | queries = ("write a summary", 53 | "What steps were suggested to the customer to fix the issue?", 54 | "What is the overall sentiment and sentiment score of the conversation?") 55 | 56 | selection = st.selectbox( 57 | "Select a query:", queries) 58 | 59 | if st.button("Generate Response", key=selection): 60 | if endpoint_name == "" or selection == "" or url == "": 61 | st.error("Please enter a valid endpoint name, API gateway url and prompt!") 62 | else: 63 | with st.spinner("Wait for it..."): 64 | try: 65 | prompt = f"{context}\n{selection}" 66 | r = requests.post(url,json={"prompt":prompt, "endpoint_name":endpoint_name},timeout=180) 67 | data = r.json() 68 | generated_text = data["generated_text"] 69 | st.write(generated_text) 70 | #st.write(data) 71 | 72 | except requests.exceptions.ConnectionError as errc: 73 | st.error("Error Connecting:",errc) 74 | 75 | except requests.exceptions.HTTPError as errh: 76 | st.error("Http Error:",errh) 77 | 78 | except requests.exceptions.Timeout as errt: 79 | st.error("Timeout Error:",errt) 80 | 81 | except requests.exceptions.RequestException as err: 82 | st.error("OOps: Something Else",err) 83 | 84 | st.success("Done!") 85 | 86 | query = st.text_area("Input Query:", "what do you suggest as next step for the customer?", height=100, max_chars=60) 87 | 88 | if st.button("Generate Response", key=query): 89 | if endpoint_name == "" or query == "" or url == "": 90 | st.error("Please enter a valid endpoint name, API gateway url and query!") 91 | else: 92 | with st.spinner("Wait for it..."): 93 | try: 94 | prompt = f"{context}\n{query}" 95 | r = requests.post(url,json={"prompt":prompt, "endpoint_name":endpoint_name},timeout=180) 96 | data = r.json() 97 | generated_text = data["generated_text"] 98 | st.write(generated_text) 99 | #st.write(data) 100 | 101 | except requests.exceptions.ConnectionError as errc: 102 | st.error("Error Connecting:",errc) 103 | 104 | except requests.exceptions.HTTPError as errh: 105 | st.error("Http Error:",errh) 106 | 107 | except requests.exceptions.Timeout as errt: 108 | st.error("Timeout Error:",errt) 109 | 110 | except requests.exceptions.RequestException as err: 111 | st.error("OOps: Something Else",err) 112 | 113 | st.success("Done!") 114 | 115 | -------------------------------------------------------------------------------- /web-app/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | requests 3 | jinja2 4 | matplotlib 5 | boto3 --------------------------------------------------------------------------------