├── .github ├── solutionid_validator.sh └── workflows │ └── maintainer_workflows.yml ├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── deploy_stack.sh ├── destroy_stack.sh ├── images ├── api_id.png ├── architecture_new.png └── resource_id.png ├── lambdas ├── cost_tracking │ ├── index.py │ ├── models.json │ └── utils.py ├── invoke_model │ ├── BedrockInference.py │ ├── SageMakerInference.py │ └── index.py ├── invoke_model_streaming │ ├── BedrockInference.py │ ├── SageMakerInference.py │ └── index.py ├── lambda_layer_requirements │ ├── cfnresponse.py │ └── index.py └── list_foundation_models │ └── index.py ├── notebooks ├── 01_bedrock_api.ipynb ├── 02_bedrock_api_langchain.ipynb └── images │ ├── backpack.png │ └── battery_image.png ├── requirements.txt ├── setup ├── app.py ├── cdk.json ├── configs.json └── stack_constructs │ ├── api.py │ ├── api_gw.py │ ├── api_key.py │ ├── dynamodb.py │ ├── iam.py │ ├── lambda_function.py │ ├── lambda_layer.py │ ├── network.py │ └── scheduler.py └── utils └── update_cost_files.py /.github/solutionid_validator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #set -e 3 | 4 | echo "checking solution id $1" 5 | echo "grep -nr --exclude-dir='.github' "$1" ./.." 6 | result=$(grep -nr --exclude-dir='.github' "$1" ./..) 7 | if [ $? -eq 0 ] 8 | then 9 | echo "Solution ID $1 found\n" 10 | echo "$result" 11 | exit 0 12 | else 13 | echo "Solution ID $1 not found" 14 | exit 1 15 | fi 16 | 17 | export result 18 | -------------------------------------------------------------------------------- /.github/workflows/maintainer_workflows.yml: -------------------------------------------------------------------------------- 1 | # Workflows managed by aws-solutions-library-samples maintainers 2 | name: Maintainer Workflows 3 | on: 4 | # Triggers the workflow on push or pull request events but only for the "main" branch 5 | push: 6 | branches: [ "main" ] 7 | pull_request: 8 | branches: [ "main" ] 9 | types: [opened, reopened, edited] 10 | 11 | jobs: 12 | CheckSolutionId: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Run solutionid validator 17 | run: | 18 | chmod u+x ./.github/solutionid_validator.sh 19 | ./.github/solutionid_validator.sh ${{ vars.SOLUTIONID }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.swp 7 | package-lock.json 8 | __pycache__ 9 | .pyc 10 | .pytest_cache 11 | .env 12 | *.egg-info 13 | 14 | # CDK asset staging directory 15 | .cdk.staging 16 | cdk.out 17 | 18 | # Jupyter Notebook 19 | .ipynb_checkpoints 20 | 21 | # IPython 22 | profile_default/ 23 | ipython_config.py 24 | 25 | # pyenv 26 | .python-version 27 | .venv/ 28 | 29 | .idea/ 30 | **/.DS_Store 31 | .DS_Store 32 | dependencies/ -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | CODEOWNERS @aws-solutions-library-samples/maintainers 2 | /.github/workflows/maintainer_workflows.yml @aws-solutions-library-samples/maintainers 3 | /.github/solutionid_validator.sh @aws-solutions-library-samples/maintainers 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-tenant Generative AI gateway with cost and usage tracking on AWS 2 | 3 | In this repository, we show you how to build a multi-tenant SaaS solution to access foundation models with [Amazon Bedrock](https://aws.amazon.com/bedrock/) 4 | and [Amazon SageMaker](https://aws.amazon.com/sagemaker/). 5 | 6 | Enterprise IT teams may need to track the usage of FMs across teams, chargeback costs and provide visibility to the relevant cost center in the LOB. Additionally, they may need to regulate access to different models per team. For example, if only specific FMs may be approved for use. 7 | 8 | An internal software as a service (SaaS) for foundation models can address governance requirements while providing a simple 9 | and consistent interface for the end users. API Gateway is a common design pattern that enable consumption of services with 10 | standardization and governance. They can provide loose coupling between model consumers and the model endpoint service that 11 | gives flexibility to adapt to changing model versions, architectures and invocation methods. 12 | 13 | 1. [Project Description](#project-description) 14 | 2. [API Specifications](#api-specifications) 15 | 3. [API Consumption](#api-consumption) 16 | 3. [Reporting Costs Example](#reporting-costs-example) 17 | 4. [Deploy Stack](#deploy-stack) 18 | 1. [Full Deployment](#full-deployment) 19 | 2. [API Key Deployment](#api-key-deployment) 20 | 5. [Destroy Stack](#destroy-stack) 21 | 1. [Destroy all the stacks](#destroy-all-the-stacks) 22 | 2. [Destroy a specific stack](#destroy-a-specific-stack) 23 | 6. [SageMaker Endpoints](#sagemaker-endpoints) 24 | 25 | ## Project Description 26 | 27 | Multiple tenants within an enterprise could simply reflect to multiple teams or projects accessing LLMs via REST APIs just like other SaaS services. IT teams can add additional governance and controls over this SaaS layer. In this cdk example, we focus specifically on showcasing multiple tenants with different cost centers accessing the service via API gateway. An internal service is responsible to perform usage and cost tracking per tenant and aggregate that cost for reporting. The cdk template provided here deploys all the required resources to the AWS account. 28 | 29 | ![Architecture](images/architecture_new.png) 30 | 31 | The CDK Stack provides the following deployments: 32 | 33 | #### Full Deployment: It deploys the following resources: 34 | 35 | 1. Private Networking environment with VPC, Private Subnets, VPC Endpoints for Lambda, API Gateway, and Amazon Bedrock 36 | 2. API Gateway Rest API 37 | 3. API Gateway Usage Plan 38 | 4. API Gateway Key 39 | 5. Lambda functions to list foundation models on Bedrock 40 | 6. Lambda functions to invoke models on Bedrock and SageMaker 41 | 7. Lambda functions to invoke models on Bedrock and SageMaker with streaming response 42 | 8. DynamoDB table for saving streaming responses asynchronously 43 | 9. Lambda function to aggregate usage and cost tracking 44 | 10. EventBridge to trigger the cost aggregation on a regular frequency 45 | 11. S3 buckets to store the cost tracking logs 46 | 12. Cloudwatch logs to collect logs from Lambda invocations 47 | 48 | #### API Key Deployment: It deploys the following resources: 49 | 1. API Gateway Usage Plan 50 | 2. API Gateway Key 51 | 52 | Sample notebook in the notebooks folder can be used to invoke Bedrock as either one of the teams/cost_center. API gateway 53 | then routes the request to the lambda that invokes Bedrock models or SageMaker hosted models and logs the usage metrics to cloudwatch. 54 | EventBridge triggers the cost tracking lambda on a regular frequnecy to aggregate metrics from the cloudwatch logs and 55 | generate aggregate usage and cost metrics for the chosen granularity level. The metrics are stored in S3 and can further 56 | be visualized with custom reports. 57 | 58 | ## API Specifications 59 | 60 | The CDK Stack creates Rest API compliant with OpenAPI specification standards. 61 | 62 | The solution is currently support both **REST** invocation and **Streaming** invocation with long polling for Bedrock and SageMaker. 63 | 64 | ### OpenAPI 3 65 | 66 | ``` 67 | openapi: 3.0.1 68 | info: 69 | title: "" 70 | version: '2023-12-13T12:12:15Z' 71 | servers: 72 | - url: https://.execute-api..amazonaws.com/{basePath} 73 | variables: 74 | basePath: 75 | default: prod 76 | paths: 77 | "/list_foundation_models": 78 | get: 79 | responses: 80 | '401': 81 | description: 401 response 82 | headers: 83 | Access-Control-Allow-Origin: 84 | schema: 85 | type: string 86 | content: 87 | application/json: 88 | schema: 89 | "$ref": "#/components/schemas/Error" 90 | security: 91 | - api_key: [] 92 | "/invoke_model": 93 | requestBody: 94 | required: true 95 | content: 96 | application/json: 97 | schema: 98 | $ref: '#/components/schemas/InvokeModelRequest' 99 | parameters: 100 | - name: model_id 101 | in: query 102 | required: true 103 | schema: 104 | type: string 105 | description: Id of the base model to invoke 106 | - name: model_arn 107 | in: query 108 | required: true 109 | schema: 110 | type: string 111 | description: ARN of the custom model in Amazon Bedrock 112 | - name: requestId 113 | in: query 114 | required: false 115 | schema: 116 | type: string 117 | description: Request ID for long-polling functionality. Requires streaming=true 118 | - name: team_id 119 | in: header 120 | required: true 121 | schema: 122 | type: string 123 | - name: messages_api 124 | in: header 125 | required: false 126 | schema: 127 | type: string 128 | - name: streaming 129 | in: header 130 | required: false 131 | schema: 132 | type: string 133 | - name: type 134 | in: header 135 | required: false 136 | schema: 137 | type: string 138 | responses: 139 | '401': 140 | description: 401 response 141 | headers: 142 | Access-Control-Allow-Origin: 143 | schema: 144 | type: string 145 | content: 146 | application/json: 147 | schema: 148 | "$ref": "#/components/schemas/Error" 149 | security: 150 | - api_key: [] 151 | components: 152 | schemas: 153 | InvokeModelRequest: 154 | type: object 155 | required: 156 | - inputs 157 | - parameters 158 | properties: 159 | inputs: 160 | $ref: '#/components/schemas/Prompt' 161 | parameters: 162 | $ref: '#/components/schemas/ModelParameters' 163 | tool_config: 164 | $ref: '#/components/schemas/ToolConfig' 165 | Prompt: 166 | type: object 167 | example: 168 | - role: 'user' 169 | content: 'What is Amazon Bedrock?' 170 | ModelParameters: 171 | type: object 172 | properties: 173 | maxTokens: 174 | type: integer 175 | required: false 176 | temperature: 177 | type: number 178 | required: false 179 | topP: 180 | type: number 181 | required: false 182 | stopSequences: 183 | type: array 184 | required: false 185 | items: 186 | type: string 187 | system: 188 | type: string 189 | required: false 190 | ToolConfig: 191 | type: object 192 | required: false 193 | properties: 194 | tools: 195 | type: array 196 | required: true 197 | properties: 198 | toolSpec: 199 | type: object 200 | required: true 201 | 202 | Error: 203 | title: Error Schema 204 | type: object 205 | properties: 206 | message: 207 | type: string 208 | securitySchemes: 209 | api_key: 210 | type: apiKey 211 | name: x-api-key 212 | in: header 213 | ``` 214 | 215 | ## API Consumption 216 | 217 | The solution is providing two example notebooks for testing API requests with raw API invocation and with Langchain integration: 218 | 1. Raw API: [01_bedrock_api.ipynb](./notebooks/01_bedrock_api.ipynb) 219 | 2. Langchain Integration [02_bedrock_api_langchain.ipynb](./notebooks/02_bedrock_api_langchain.ipynb) 220 | 221 | ### How to get the API Gateway Endpoint: 222 | Navigate the Cloudformation deployment, get the value under `awsiapigwurl` 223 | 224 | ### How to get the API Key: 225 | 1. Navigate the AWS Console 226 | 2. Search for API Gateway 227 | 3. Select the deployed `API Gateway` 228 | 4. Copy the value from `API keys` 229 | 230 | ## Reporting Costs Example 231 | 232 | | team_id | model_id | input_tokens | output_tokens | invocations | input_cost | output_cost | 233 | |---------|----------|--------------|---------------|-------------|------------|-------------| 234 | | tenant1 | amazon.titan-tg1-large | 24000 | 2473 | 1000 | 0.0072 | 0.00099 | 235 | | tenant1 | anthropic.claude-v2 | 2448 | 4800 | 24 | 0.02698 | 0.15686 | 236 | | tenant2 | amazon.titan-tg1-large | 35000 | 52500 | 350 | 0.0105 | 0.021 | 237 | | tenant2 | ai21.j2-grande-instruct | 4590 | 9000 | 45 | 0.05738 | 0.1125 | 238 | | tenant2 | anthropic.claude-v2 | 1080 | 4400 | 20 | 0.0119 | 0.14379 | 239 | 240 | ## Deploy Stack 241 | 242 | ### Note 243 | 244 | The following examples are providing guidelines on the structure for the configuration file. 245 | Please make sure to look at [setup/configs.json](./setup/configs.json) for the most updated version of the file. 246 | 247 | ### Full Deployment 248 | 249 | #### Step 1 250 | 251 | Edit the global configs used in the CDK Stack. For each organizational units that requires a dedicated multi-tenant SaaS environment, create an entry in [setup/configs.json](./setup/configs.json) 252 | 253 | ``` 254 | [ 255 | { 256 | "STACK_PREFIX": "", # unit 1 with dedicated SaaS resources 257 | "BEDROCK_ENDPOINT": "https://bedrock-runtime.{}.amazonaws.com", # bedrock-runtime endpoint used for invoking Amazon Bedrock 258 | "BEDROCK_REQUIREMENTS": "boto3>=1.34.62 awscli>=1.32.62 botocore>=1.34.62", # Requirements for Amazon Bedrock 259 | "LANGCHAIN_REQUIREMENTS": "aws-lambda-powertools langchain==0.1.12 pydantic PyYaml", # python modules installed for langchain layer 260 | "PANDAS_REQUIREMENTS": "pandas", # python modules installed for pandas layer 261 | "VPC_CIDR": "10.10.0.0/16" # CIDR used for the private VPC Env, 262 | "API_THROTTLING_RATE": 10000, #Throttling limit assigned to the usage plan 263 | "API_BURST_RATE": 5000 # Burst limit assigned to the usage plan 264 | }, 265 | { 266 | "STACK_PREFIX": "" # unit 2 with dedicated SaaS resources, 267 | "BEDROCK_ENDPOINT": "https://bedrock-runtime.{}.amazonaws.com", # bedrock-runtime endpoint used for invoking Amazon Bedrock 268 | "BEDROCK_REQUIREMENTS": "boto3>=1.34.62 awscli>=1.32.62 botocore>=1.34.62", # Requirements for Amazon Bedrock 269 | "LANGCHAIN_REQUIREMENTS": "aws-lambda-powertools langchain==0.1.12 pydantic PyYaml", # python modules installed for langchain layer 270 | "PANDAS_REQUIREMENTS": "pandas", # python modules installed for pandas layer 271 | "VPC_CIDR": "10.20.0.0/16" # CIDR used for the private VPC Env, 272 | "API_THROTTLING_RATE": 10000, 273 | "API_BURST_RATE": 5000 274 | }, 275 | ] 276 | ``` 277 | 278 | #### Step 2 279 | 280 | Execute the following commands: 281 | 282 | ``` 283 | chmod +x deploy_stack.sh 284 | ``` 285 | 286 | ``` 287 | ./deploy_stack.sh 288 | ``` 289 | 290 | #### Optional Step 3 291 | 292 | We can also deploy a specific stack as following: 293 | 294 | ``` 295 | ./deploy_stack.sh -bedrock-saas 296 | ``` 297 | 298 | ### API Key Deployment 299 | 300 | #### Step 1 301 | 302 | ##### Option 1 303 | 304 | Edit the global configs used in the CDK Stack. For each organizational units that requires a dedicated API Key associated to a crated API Gateway REST API, create an entry in [setup/configs.json](./setup/configs.json) 305 | by specifying `API_GATEWAY_ID` and `API_GATEWAY_RESOURCE_ID`: 306 | 307 | ``` 308 | [ 309 | { 310 | "STACK_PREFIX": "", # unit 1 with dedicated SaaS resources 311 | "API_GATEWAY_ID": "", # Rest API ID 312 | "API_GATEWAY_RESOURCE_ID": "", # Resource ID of the Rest API 313 | "API_THROTTLING_RATE": 10000, #Throttling limit assigned to the usage plan 314 | "API_BURST_RATE": 5000 # Burst limit assigned to the usage plan 315 | 316 | } 317 | ] 318 | ``` 319 | 320 | ##### Option 2 321 | 322 | Edit the global configs used in the CDK Stack. For each organizational units that requires a dedicated API Key associated to a crated API Gateway REST API, create an entry in [setup/configs.json](./setup/configs.json) 323 | by specifying `PARENT_STACK_PREFIX`: 324 | 325 | ``` 326 | [ 327 | { 328 | "STACK_PREFIX": "", # unit 1 with dedicated SaaS resources 329 | "PARENT_STACK_PREFIX": "", # unit parent you want to import configurations 330 | "API_THROTTLING_RATE": 10000, #Throttling limit assigned to the usage plan 331 | "API_BURST_RATE": 5000 # Burst limit assigned to the usage plan 332 | 333 | } 334 | ] 335 | ``` 336 | 337 | ![Rest API Id](images/api_id.png) 338 | 339 | ![Resource Id](images/resource_id.png) 340 | 341 | #### Step 2 342 | 343 | Execute the following commands: 344 | 345 | ``` 346 | chmod +x deploy_stack.sh 347 | ``` 348 | 349 | ``` 350 | ./deploy_stack.sh 351 | ``` 352 | 353 | #### Optional Step 3 354 | 355 | We can also deploy a specific stack as following: 356 | 357 | ``` 358 | ./deploy_stack.sh -bedrock-saas 359 | ``` 360 | 361 | ## Destroy Stack 362 | 363 | ### Destroy all the stacks 364 | 365 | We can delete all the deployed stacks by running: 366 | 367 | ``` 368 | ./destroy_stack.sh 369 | ``` 370 | 371 | ### Destroy a specific stack 372 | 373 | We can delete a specific stacks by running: 374 | 375 | ``` 376 | ./destroy_stack.sh -bedrock-saas 377 | ``` 378 | 379 | ### SageMaker Endpoints 380 | 381 | Add FMs through Amazon SageMaker: 382 | 383 | We can expose Foundation Models hosted in Amazon SageMaker by providing the endpoint names in a JSON format in a string representation, 384 | as described in the example below: 385 | 386 | ``` 387 | [ 388 | { 389 | "STACK_PREFIX": "", # unit 1 with dedicated SaaS resources 390 | "BEDROCK_ENDPOINT": "https://bedrock-runtime.{}.amazonaws.com", # bedrock-runtime endpoint used for invoking Amazon Bedrock 391 | "BEDROCK_REQUIREMENTS": "boto3>=1.34.62 awscli>=1.32.62 botocore>=1.34.62", # Requirements for Amazon Bedrock 392 | "LANGCHAIN_REQUIREMENTS": "aws-lambda-powertools langchain==0.1.12 pydantic PyYaml", # python modules installed for langchain layer 393 | "PANDAS_REQUIREMENTS": "pandas", # python modules installed for pandas layer 394 | "VPC_CIDR": "10.10.0.0/16" # CIDR used for the private VPC Env, 395 | "API_THROTTLING_RATE": 10000, #Throttling limit assigned to the usage plan 396 | "API_BURST_RATE": 5000 # Burst limit assigned to the usage plan, 397 | "SAGEMAKER_ENDPOINTS": "{'Mixtral 8x7B': 'Mixtral-SM-Endpoint'}" # List of SageMaker endpoints 398 | } 399 | ] 400 | ``` 401 | 402 | #### InferenceComponentName with SageMaker Endpoint 403 | 404 | We can provide `InferenceComponentName`specification for the model invocation. Please refer to the notebook 405 | [01_bedrock_api.ipynb](./notebooks/01_bedrock_api.ipynb) for an example 406 | 407 | #### Important note 408 | 409 | Amazon SageMaker Hosting is providing flexibility in the definition of the inference container. This solution is currently 410 | supporting general purpose inference scripts provided by SageMaker JumpStart and Hugging Face TGI container. 411 | 412 | It is required to adapt the lambda functions [invoke_model](./lambdas/invoke_model) and [invoke_model_streaming](./lambdas/invoke_model_streaming) 413 | in case of custom inference scripts. 414 | 415 | ## Reading resources 416 | For additional reading, refer to: 417 | 1. [Build an internal SaaS service with cost and usage tracking for foundation models on Amazon Bedrock](https://aws.amazon.com/blogs/machine-learning/build-an-internal-saas-service-with-cost-and-usage-tracking-for-foundation-models-on-amazon-bedrock/) 418 | 2. [Create a Generative AI Gateway to allow secure and compliant consumption of foundation models](https://aws.amazon.com/blogs/machine-learning/create-a-generative-ai-gateway-to-allow-secure-and-compliant-consumption-of-foundation-models/) 419 | -------------------------------------------------------------------------------- /deploy_stack.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # export JSII_SILENCE_WARNING_DEPRECATED_NODE_VERSION=true 4 | 5 | # npm install -g aws-cdk 6 | # python3 -m venv .venv 7 | # source .venv/bin/activate 8 | # pip3 install -r requirements.txt 9 | 10 | if [ $# -eq 0 ]; then 11 | # No parameter was passed 12 | DEPLOY_TARGET="--all" 13 | else 14 | # A parameter was passed 15 | DEPLOY_TARGET="$1" 16 | fi 17 | 18 | ACCOUNT_ID=$(aws sts get-caller-identity --query Account | tr -d '"') 19 | AWS_REGION=$(aws configure get region) 20 | cd ./setup 21 | cdk bootstrap aws://${ACCOUNT_ID}/${AWS_REGION} 22 | cdk deploy --require-approval never $DEPLOY_TARGET -------------------------------------------------------------------------------- /destroy_stack.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # export JSII_SILENCE_WARNING_DEPRECATED_NODE_VERSION=true 4 | 5 | # npm install -g aws-cdk 6 | # python3 -m venv .venv 7 | # source .venv/bin/activate 8 | # pip3 install -r requirements.txt 9 | 10 | 11 | if [ $# -eq 0 ]; then 12 | # No parameter was passed 13 | DESTROY_TARGET="" 14 | else 15 | # A parameter was passed 16 | DESTROY_TARGET="$1" 17 | fi 18 | 19 | ACCOUNT_ID=$(aws sts get-caller-identity --query Account | tr -d '"') 20 | AWS_REGION=$(aws configure get region) 21 | cd ./setup 22 | 23 | if [ -z "DESTROY_TARGET" ]; then 24 | cdk destroy 25 | else 26 | cdk destroy $DESTROY_TARGET 27 | fi -------------------------------------------------------------------------------- /images/api_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-a-multi-tenant-generative-ai-gateway-with-cost-and-usage-tracking-on-aws/f750573133a9d6189d56f060007a5553c6099de3/images/api_id.png -------------------------------------------------------------------------------- /images/architecture_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-a-multi-tenant-generative-ai-gateway-with-cost-and-usage-tracking-on-aws/f750573133a9d6189d56f060007a5553c6099de3/images/architecture_new.png -------------------------------------------------------------------------------- /images/resource_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-a-multi-tenant-generative-ai-gateway-with-cost-and-usage-tracking-on-aws/f750573133a9d6189d56f060007a5553c6099de3/images/resource_id.png -------------------------------------------------------------------------------- /lambdas/cost_tracking/index.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import datetime 3 | from io import StringIO 4 | import logging 5 | import os 6 | import pytz 7 | import traceback 8 | from utils import merge_and_process_logs, run_query, results_to_df, calculate_cost 9 | 10 | logger = logging.getLogger(__name__) 11 | if len(logging.getLogger().handlers) > 0: 12 | logging.getLogger().setLevel(logging.INFO) 13 | else: 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | log_group_name_api = os.environ.get("LOG_GROUP_API", None) 17 | s3_bucket = os.environ.get("S3_BUCKET", None) 18 | 19 | s3_resource = boto3.resource('s3') 20 | 21 | QUERY_API = """ 22 | fields 23 | message.team_id as team_id, 24 | message.requestId as request_id, 25 | message.region as region, 26 | message.model_id as model_id, 27 | message.inputTokens as input_tokens, 28 | message.outputTokens as output_tokens, 29 | message.height as height, 30 | message.width as width, 31 | message.steps as steps 32 | | filter level = "INFO" 33 | """ 34 | 35 | QUERY_API_WITH_KEY = """ 36 | fields 37 | message.team_id as team_id, 38 | message.api_key as api_key, 39 | message.requestId as request_id, 40 | message.region as region, 41 | message.model_id as model_id, 42 | message.inputTokens as input_tokens, 43 | message.outputTokens as output_tokens, 44 | message.height as height, 45 | message.width as width, 46 | message.steps as steps 47 | | filter level = "INFO" 48 | """ 49 | 50 | def process_event(event): 51 | try: 52 | if "date" in event: 53 | date = event["date"] 54 | else: 55 | date = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=1) 56 | date = date.strftime("%Y-%m-%d") 57 | 58 | #Create an error buffer for tracking the errors 59 | error_buffer = StringIO() 60 | 61 | # querying the cloudwatch logs from the API 62 | query_results_api = run_query(QUERY_API, log_group_name_api, date) 63 | query_results_api_key = run_query(QUERY_API_WITH_KEY, log_group_name_api, date) 64 | 65 | query_results_api = merge_and_process_logs(query_results_api, query_results_api_key) 66 | df_bedrock_cost_tracking = results_to_df(query_results_api) 67 | 68 | if len(df_bedrock_cost_tracking) > 0: 69 | # Apply the calculate_cost function to the DataFrame 70 | df_bedrock_cost_tracking[["input_tokens", "output_tokens", "input_cost", "output_cost","invocations"]] = df_bedrock_cost_tracking.apply( 71 | lambda row: calculate_cost(row, error_buffer), axis=1, result_type="expand" 72 | ) 73 | 74 | # Remove rows where calculate_cost returned None 75 | df_bedrock_cost_tracking = df_bedrock_cost_tracking.dropna(subset=["input_tokens", "output_tokens", "input_cost", "output_cost", "invocations"]) 76 | 77 | # aggregate cost for each model_id 78 | df_bedrock_cost_tracking_aggregated = df_bedrock_cost_tracking.groupby(["api_key", "team_id", "model_id"]).sum()[ 79 | ["input_tokens", "output_tokens", "input_cost", "output_cost", "invocations"] 80 | ] 81 | 82 | df_bedrock_cost_tracking_aggregated["date"] = date 83 | 84 | logger.info(df_bedrock_cost_tracking_aggregated.to_string()) 85 | 86 | csv_buffer = StringIO() 87 | df_bedrock_cost_tracking_aggregated.to_csv(csv_buffer) 88 | 89 | file_name = f"succeed/{date}.csv" 90 | 91 | s3_resource.Object(s3_bucket, file_name).put(Body=csv_buffer.getvalue()) 92 | 93 | # Save error file to S3 if there are any errors 94 | if error_buffer.getvalue(): 95 | error_file_name = f"errors/{date}.txt" 96 | s3_resource.Object(s3_bucket, error_file_name).put(Body=error_buffer.getvalue()) 97 | except Exception as e: 98 | stacktrace = traceback.format_exc() 99 | logger.error(stacktrace) 100 | 101 | raise e 102 | 103 | def lambda_handler(event, context): 104 | try: 105 | process_event(event) 106 | return {"statusCode": 200, "body": "OK"} 107 | except Exception as e: 108 | stacktrace = traceback.format_exc() 109 | logger.error(stacktrace) 110 | 111 | return {"statusCode": 500, "body": str(e)} 112 | -------------------------------------------------------------------------------- /lambdas/cost_tracking/models.json: -------------------------------------------------------------------------------- 1 | { 2 | "us-east-1": { 3 | "text": { 4 | "ai21.j2-mid-v1": {"input_cost": 0.0125, "output_cost": 0.0125}, 5 | "ai21.j2-ultra-v1": {"input_cost": 0.0188, "output_cost": 0.0188}, 6 | "ai21.j2-mid": {"input_cost": 0.0125, "output_cost": 0.0125}, 7 | "ai21.j2-ultra": {"input_cost": 0.0188, "output_cost": 0.0188}, 8 | "ai21.jamba-instruct-v1:0": {"input_cost": 0.0005, "output_cost": 0.0007}, 9 | "ai21.jamba-1-5-large-v1:0": {"input_cost": 0.002, "output_cost": 0.008}, 10 | "ai21.jamba-1-5-mini-v1:0": {"input_cost": 0.0002, "output_cost": 0.0004}, 11 | "amazon.titan-text-lite-v1": {"input_cost": 0.00015, "output_cost": 0.0002}, 12 | "amazon.titan-text-express-v1": {"input_cost": 0.0002, "output_cost": 0.0006}, 13 | "amazon.titan-text-premier-v1:0": {"input_cost": 0.0005, "output_cost": 0.0015}, 14 | "anthropic.claude-instant-v1": {"input_cost": 0.00080, "output_cost": 0.00240}, 15 | "anthropic.claude-v2": {"input_cost": 0.00800, "output_cost": 0.02400}, 16 | "anthropic.claude-v2:1": {"input_cost": 0.00800, "output_cost": 0.02400}, 17 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.00300, "output_cost": 0.01500}, 18 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 19 | "anthropic.claude-3-opus-20240229-v1:0": {"input_cost": 0.01500, "output_cost": 0.07500}, 20 | "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 21 | "cohere.command-text-v14": {"input_cost": 0.0015, "output_cost": 0.0020}, 22 | "cohere.command-light-text-v14": {"input_cost": 0.0003, "output_cost": 0.0006}, 23 | "cohere.command-r-plus-v1:0": {"input_cost": 0.0030, "output_cost": 0.0150}, 24 | "cohere.command-r-v1:0": {"input_cost": 0.0005, "output_cost": 0.0015}, 25 | "meta.llama2-13b-v1": {"input_cost": 0.00075, "output_cost": 0.00100}, 26 | "meta.llama2-70b-v1": {"input_cost": 0.00195, "output_cost": 0.00256}, 27 | "meta.llama2-13b-chat-v1": {"input_cost": 0.00075, "output_cost": 0.00100}, 28 | "meta.llama2-70b-chat-v1": {"input_cost": 0.00195, "output_cost": 0.00256}, 29 | "meta.llama3-8b-instruct-v1:0": {"input_cost": 0.0004, "output_cost": 0.0006}, 30 | "meta.llama3-70b-instruct-v1:0": {"input_cost": 0.00265, "output_cost": 0.0035}, 31 | "meta.llama3-2-1b-instruct-v1:0": {"input_cost": 0.0001, "output_cost": 0.0001}, 32 | "meta.llama3-2-3b-instruct-v1:0": {"input_cost": 0.00015, "output_cost": 0.00015}, 33 | "meta.llama3-2-11b-instruct-v1:0": {"input_cost": 0.00035, "output_cost": 0.00035}, 34 | "meta.llama3-2-90b-instruct-v1:0": {"input_cost": 0.002, "output_cost": 0.002}, 35 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00015, "output_cost": 0.0002}, 36 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00045, "output_cost": 0.0007}, 37 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.008, "output_cost": 0.024}, 38 | "mistral.mistral-small-2402-v1:0": {"input_cost": 0.001, "output_cost": 0.003} 39 | }, 40 | "embeddings": { 41 | "amazon.titan-embed-text-v1": {"input_cost": 0.0001, "output_cost": 0}, 42 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.00002, "output_cost": 0}, 43 | "amazon.titan-embed-image-v1": {"input_cost": 0.0008, "output_cost": 0}, 44 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.00006, "output_cost": 0}, 45 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 46 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 47 | }, 48 | "image": { 49 | "amazon.titan-image-generator-v1": { 50 | "512x512": { 51 | "standard": 0.008, 52 | "premium": 0.01 53 | }, 54 | "larger": { 55 | "standard": 0.01, 56 | "premium": 0.012 57 | } 58 | }, 59 | "stability.stable-diffusion-xl-v1": { 60 | "512x512": { 61 | "standard": 0.018, 62 | "premium": 0.036 63 | }, 64 | "larger": { 65 | "standard": 0.036, 66 | "premium": 0.072 67 | } 68 | } 69 | } 70 | }, 71 | "us-west-2": { 72 | "text": { 73 | "ai21.j2-mid-v1": {"input_cost": 0.0125, "output_cost": 0.0125}, 74 | "ai21.j2-ultra-v1": {"input_cost": 0.0188, "output_cost": 0.0188}, 75 | "amazon.titan-text-lite-v1": {"input_cost": 0.00015, "output_cost": 0.0002}, 76 | "amazon.titan-text-express-v1": {"input_cost": 0.0002, "output_cost": 0.0006}, 77 | "anthropic.claude-instant-v1": {"input_cost": 0.00080, "output_cost": 0.00240}, 78 | "anthropic.claude-v2": {"input_cost": 0.00800, "output_cost": 0.02400}, 79 | "anthropic.claude-v2:1": {"input_cost": 0.00800, "output_cost": 0.02400}, 80 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.00300, "output_cost": 0.01500}, 81 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 82 | "anthropic.claude-3-opus-20240229-v1:0": {"input_cost": 0.01500, "output_cost": 0.07500}, 83 | "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 84 | "cohere.command-text-v14": {"input_cost": 0.0015, "output_cost": 0.0020}, 85 | "cohere.command-light-text-v14": {"input_cost": 0.0003, "output_cost": 0.0006}, 86 | "cohere.command-r-plus-v1:0": {"input_cost": 0.0030, "output_cost": 0.0150}, 87 | "cohere.command-r-v1:0": {"input_cost": 0.0005, "output_cost": 0.0015}, 88 | "meta.llama2-13b-v1": {"input_cost": 0.00075, "output_cost": 0.00100}, 89 | "meta.llama2-70b-v1": {"input_cost": 0.00195, "output_cost": 0.00256}, 90 | "meta.llama2-13b-chat-v1": {"input_cost": 0.00075, "output_cost": 0.00100}, 91 | "meta.llama2-70b-chat-v1": {"input_cost": 0.00195, "output_cost": 0.00256}, 92 | "meta.llama3-8b-instruct-v1:0": {"input_cost": 0.0004, "output_cost": 0.0006}, 93 | "meta.llama3-70b-instruct-v1:0": {"input_cost": 0.00265, "output_cost": 0.0035}, 94 | "meta.llama3-1-8b-instruct-v1:0": {"input_cost": 0.0003, "output_cost": 0.0006}, 95 | "meta.llama3-1-70b-instruct-v1:0": {"input_cost": 0.00265, "output_cost": 0.0035}, 96 | "meta.llama3-1-405b-instruct-v1:0": {"input_cost": 0.00532, "output_cost": 0.016}, 97 | "meta.llama3-2-1b-instruct-v1:0": {"input_cost": 0.0001, "output_cost": 0.0001}, 98 | "meta.llama3-2-3b-instruct-v1:0": {"input_cost": 0.00015, "output_cost": 0.00015}, 99 | "meta.llama3-2-11b-instruct-v1:0": {"input_cost": 0.00035, "output_cost": 0.00035}, 100 | "meta.llama3-2-90b-instruct-v1:0": {"input_cost": 0.002, "output_cost": 0.002}, 101 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00015, "output_cost": 0.0002}, 102 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00045, "output_cost": 0.0007}, 103 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.008, "output_cost": 0.024}, 104 | "mistral.mistral-large-2407-v1:0": {"input_cost": 0.003, "output_cost": 0.009} 105 | }, 106 | "embeddings": { 107 | "amazon.titan-embed-text-v1": {"input_cost": 0.0001, "output_cost": 0}, 108 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.00002, "output_cost": 0}, 109 | "amazon.titan-embed-image-v1": {"input_cost": 0.0008, "output_cost": 0}, 110 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.00006, "output_cost": 0}, 111 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 112 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 113 | }, 114 | "image": { 115 | "amazon.titan-image-generator-v1": { 116 | "512x512": { 117 | "standard": 0.008, 118 | "premium": 0.01 119 | }, 120 | "larger": { 121 | "standard": 0.01, 122 | "premium": 0.012 123 | } 124 | }, 125 | "stability.stable-diffusion-xl-v1": { 126 | "512x512": { 127 | "standard": 0.018, 128 | "premium": 0.036 129 | }, 130 | "larger": { 131 | "standard": 0.036, 132 | "premium": 0.072 133 | } 134 | }, 135 | "stability.sd3-large-v1:0": { 136 | "512x512": { 137 | "standard": 0.08, 138 | "premium": 0.08 139 | }, 140 | "larger": { 141 | "standard": 0.08, 142 | "premium": 0.08 143 | } 144 | }, 145 | "stability.stable-image-core-v1:0": { 146 | "512x512": { 147 | "standard": 0.04, 148 | "premium": 0.04 149 | }, 150 | "larger": { 151 | "standard": 0.04, 152 | "premium": 0.04 153 | } 154 | }, 155 | "stability.stable-image-ultra-v1:0": { 156 | "512x512": { 157 | "standard": 0.14, 158 | "premium": 0.14 159 | }, 160 | "larger": { 161 | "standard": 0.14, 162 | "premium": 0.14 163 | } 164 | } 165 | } 166 | }, 167 | "ap-south-1": { 168 | "text": { 169 | "amazon.titan-text-lite-v1": {"input_cost": 0.0004, "output_cost": 0.0005}, 170 | "amazon.titan-text-express-v1": {"input_cost": 0.001, "output_cost": 0.0019}, 171 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 172 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 173 | "meta.llama3-8b-instruct-v1:0": {"input_cost": 0.00048, "output_cost": 0.00072}, 174 | "meta.llama3-70b-instruct-v1:0": {"input_cost": 0.00318, "output_cost": 0.0042}, 175 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00018, "output_cost": 0.00024}, 176 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00054, "output_cost": 0.00084}, 177 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.0096, "output_cost": 0.0288} 178 | }, 179 | "embeddings": { 180 | "amazon.titan-embed-text-v1": {"input_cost": 0.00012, "output_cost": 0}, 181 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.0001, "output_cost": 0}, 182 | "amazon.titan-embed-image-v1": {"input_cost": 0.001, "output_cost": 0}, 183 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.00007, "output_cost": 0}, 184 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 185 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 186 | }, 187 | "image": { 188 | "amazon.titan-image-generator-v1": { 189 | "512x512": { 190 | "standard": 0.01, 191 | "premium": 0.012 192 | }, 193 | "larger": { 194 | "standard": 0.012, 195 | "premium": 0.014 196 | } 197 | } 198 | } 199 | }, 200 | "ap-southeast-1": { 201 | "text": {}, 202 | "embeddings": { 203 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 204 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 205 | }, 206 | "image": {} 207 | }, 208 | "ap-southeast-2": { 209 | "text": { 210 | "amazon.titan-text-lite-v1": {"input_cost": 0.0002, "output_cost": 0.00025}, 211 | "amazon.titan-text-express-v1": {"input_cost": 0.00025, "output_cost": 0.000788}, 212 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 213 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 214 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.0002, "output_cost": 0.00026}, 215 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00059, "output_cost": 0.00091}, 216 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.0104, "output_cost": 0.0312} 217 | }, 218 | "embeddings": { 219 | "amazon.titan-embed-image-v1": {"input_cost": 0.001, "output_cost": 0}, 220 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.00008, "output_cost": 0}, 221 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 222 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 223 | }, 224 | "image": {} 225 | }, 226 | "ap-northeast-1": { 227 | "text": { 228 | "amazon.titan-text-express-v1": {"input_cost": 0.000275, "output_cost": 0.000825}, 229 | "anthropic.claude-instant-v1": {"input_cost": 0.0008, "output_cost": 0.0024}, 230 | "anthropic.claude-v2:1": {"input_cost": 0.008, "output_cost": 0.024}, 231 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 232 | "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input_cost": 0.003, "output_cost": 0.015} 233 | }, 234 | "embeddings": { 235 | "amazon.titan-embed-text-v1": {"input_cost": 0.0002, "output_cost": 0}, 236 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 237 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 238 | }, 239 | "image": {} 240 | }, 241 | "ca-central-1": { 242 | "text": { 243 | "amazon.titan-text-lite-v1": {"input_cost": 0.0002, "output_cost": 0.0003}, 244 | "amazon.titan-text-express-v1": {"input_cost": 0.0003, "output_cost": 0.0008}, 245 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 246 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 247 | "meta.llama3-8b-instruct-v1:0": {"input_cost": 0.00035, "output_cost": 0.00069}, 248 | "meta.llama3-70b-instruct-v1:0": {"input_cost": 0.00305, "output_cost": 0.00403}, 249 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00017, "output_cost": 0.00023}, 250 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00052, "output_cost": 0.00081}, 251 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.0046, "output_cost": 0.0138} 252 | }, 253 | "embeddings": { 254 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.0001, "output_cost": 0}, 255 | "amazon.titan-embed-image-v1": {"input_cost": 0.0009, "output_cost": 0}, 256 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.0001, "output_cost": 0}, 257 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 258 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 259 | }, 260 | "image": {} 261 | }, 262 | "eu-central-1": { 263 | "text": { 264 | "amazon.titan-text-lite-v1": {"input_cost": 0.0002, "output_cost": 0.0003}, 265 | "amazon.titan-text-express-v1": {"input_cost": 0.0003, "output_cost": 0.000863}, 266 | "anthropic.claude-instant-v1": {"input_cost": 0.0008, "output_cost": 0.0024}, 267 | "anthropic.claude-v2": {"input_cost": 0.00800, "output_cost": 0.02400}, 268 | "anthropic.claude-v2:1": {"input_cost": 0.008, "output_cost": 0.024}, 269 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 270 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 271 | "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 272 | "meta.llama3-2-1b-instruct-v1:0": {"input_cost": 0.0001, "output_cost": 0.0001}, 273 | "meta.llama3-2-3b-instruct-v1:0": {"input_cost": 0.00015, "output_cost": 0.00015} 274 | }, 275 | "embeddings": { 276 | "amazon.titan-embed-text-v1": {"input_cost": 0.0002, "output_cost": 0}, 277 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.0001, "output_cost": 0}, 278 | "amazon.titan-embed-image-v1": {"input_cost": 0.001, "output_cost": 0}, 279 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.0001, "output_cost": 0}, 280 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 281 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 282 | }, 283 | "image": {} 284 | }, 285 | "eu-west-1": { 286 | "text": { 287 | "amazon.titan-text-lite-v1": {"input_cost": 0.0003, "output_cost": 0.0004}, 288 | "amazon.titan-text-express-v1": {"input_cost": 0.001, "output_cost": 0.0017}, 289 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 290 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 291 | "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 292 | "meta.llama3-8b-instruct-v1:0": {"input_cost": 0.00043, "output_cost": 0.00065}, 293 | "meta.llama3-70b-instruct-v1:0": {"input_cost": 0.00286, "output_cost": 0.00378}, 294 | "meta.llama3-2-1b-instruct-v1:0": {"input_cost": 0.0001, "output_cost": 0.0001}, 295 | "meta.llama3-2-3b-instruct-v1:0": {"input_cost": 0.00015, "output_cost": 0.00015}, 296 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00016, "output_cost": 0.00022}, 297 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00049, "output_cost": 0.00076}, 298 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.0086, "output_cost": 0.0259} 299 | }, 300 | "embeddings": { 301 | "amazon.titan-embed-text-v1": {"input_cost": 0.0001, "output_cost": 0}, 302 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.0001, "output_cost": 0}, 303 | "amazon.titan-embed-image-v1": {"input_cost": 0.001, "output_cost": 0}, 304 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.00007, "output_cost": 0} 305 | }, 306 | "image": { 307 | "amazon.titan-image-generator-v1": { 308 | "512x512": { 309 | "standard": 0.009, 310 | "premium": 0.0112 311 | }, 312 | "larger": { 313 | "standard": 0.011, 314 | "premium": 0.013 315 | } 316 | } 317 | } 318 | }, 319 | "eu-west-2": { 320 | "text": { 321 | "amazon.titan-text-lite-v1": {"input_cost": 0.0002, "output_cost": 0.0003}, 322 | "amazon.titan-text-express-v1": {"input_cost": 0.0003, "output_cost": 0.0009}, 323 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 324 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 325 | "meta.llama3-8b-instruct-v1:0": {"input_cost": 0.0003, "output_cost": 0.0006}, 326 | "meta.llama3-70b-instruct-v1:0": {"input_cost": 0.00265, "output_cost": 0.0035}, 327 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00015, "output_cost": 0.0002}, 328 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00045, "output_cost": 0.0007}, 329 | "mistral.mistral-small-2402-v1:0": {"input_cost": 0.001, "output_cost": 0.003}, 330 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.004, "output_cost": 0.012} 331 | }, 332 | "embeddings": { 333 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.0001, "output_cost": 0}, 334 | "amazon.titan-embed-image-v1": {"input_cost": 0.001, "output_cost": 0}, 335 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.0001, "output_cost": 0} 336 | }, 337 | "image": { 338 | "amazon.titan-image-generator-v1": { 339 | "512x512": { 340 | "standard": 0.009, 341 | "premium": 0.0112 342 | }, 343 | "larger": { 344 | "standard": 0.011, 345 | "premium": 0.013 346 | } 347 | } 348 | } 349 | }, 350 | "eu-west-3": { 351 | "text": { 352 | "amazon.titan-text-lite-v1": {"input_cost": 0.0002, "output_cost": 0.00025}, 353 | "amazon.titan-text-express-v1": {"input_cost": 0.00025, "output_cost": 0.000788}, 354 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 355 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 356 | "meta.llama3-2-1b-instruct-v1:0": {"input_cost": 0.0001, "output_cost": 0.0001}, 357 | "meta.llama3-2-3b-instruct-v1:0": {"input_cost": 0.00015, "output_cost": 0.00015}, 358 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.0002, "output_cost": 0.00026}, 359 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00059, "output_cost": 0.00091}, 360 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.0104, "output_cost": 0.0312} 361 | }, 362 | "embeddings": { 363 | "amazon.titan-embed-image-v1": {"input_cost": 0.001, "output_cost": 0}, 364 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.00008, "output_cost": 0}, 365 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 366 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 367 | }, 368 | "image": {} 369 | }, 370 | "sa-east-1": { 371 | "text": { 372 | "amazon.titan-text-lite-v1": {"input_cost": 0.0003, "output_cost": 0.0004}, 373 | "amazon.titan-text-express-v1": {"input_cost": 0.0004, "output_cost": 0.0011}, 374 | "anthropic.claude-3-sonnet-20240229-v1:0": {"input_cost": 0.003, "output_cost": 0.015}, 375 | "anthropic.claude-3-haiku-20240307-v1:0": {"input_cost": 0.00025, "output_cost": 0.00125}, 376 | "mistral.mistral-7b-instruct-v0:2": {"input_cost": 0.00025, "output_cost": 0.00034}, 377 | "mistral.mixtral-8x7b-instruct-v0:1": {"input_cost": 0.00076, "output_cost": 0.00118}, 378 | "mistral.mistral-large-2402-v1:0": {"input_cost": 0.0104, "output_cost": 0.0312} 379 | }, 380 | "embeddings": { 381 | "amazon.titan-embed-text-v2:0": {"input_cost": 0.0002, "output_cost": 0}, 382 | "amazon.titan-embed-image-v1": {"input_cost": 0.0012, "output_cost": 0}, 383 | "amazon.titan-embed-image-v1-image": {"input_cost": 0.0001, "output_cost": 0}, 384 | "cohere.embed-english-v3": {"input_cost": 0.0001, "output_cost": 0}, 385 | "cohere.embed-multilingual-v3": {"input_cost": 0.0001, "output_cost": 0} 386 | }, 387 | "image": {} 388 | } 389 | } -------------------------------------------------------------------------------- /lambdas/cost_tracking/utils.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import datetime 3 | import json 4 | import logging 5 | import pandas as pd 6 | import pytz 7 | import time 8 | import traceback 9 | 10 | logger = logging.getLogger(__name__) 11 | if len(logging.getLogger().handlers) > 0: 12 | logging.getLogger().setLevel(logging.INFO) 13 | else: 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | def _is_in_model_list(model_id, model_list): 17 | if model_id in model_list: 18 | return True 19 | else: 20 | parts = model_id.split('.') 21 | for i in range(len(parts), 0, -1): 22 | partial_id = '.'.join(parts[-i:]) 23 | if partial_id in model_list: 24 | return True 25 | 26 | return False 27 | 28 | def _read_model_list(filename): 29 | try: 30 | with open(filename, "r", encoding="utf-8") as f: 31 | config = json.load(f) 32 | 33 | return config 34 | except Exception as e: 35 | stacktrace = traceback.format_exc() 36 | logger.error(stacktrace) 37 | 38 | raise e 39 | 40 | def get_model_pricing(model_id, MODEL_PRICES): 41 | matched = [v for k, v in MODEL_PRICES.items() if model_id in k] 42 | if matched: 43 | return matched[0] 44 | else: 45 | parts = model_id.split('.') 46 | for i in range(len(parts), 0, -1): 47 | partial_id = '.'.join(parts[-i:]) 48 | matched = [v for k, v in MODEL_PRICES.items() if partial_id in k] 49 | if matched: 50 | return matched[0] 51 | 52 | return None 53 | 54 | def merge_and_process_logs(object1, object2): 55 | # Step 1: Process object1 56 | if len(object1) > 0: 57 | final_list = [] 58 | 59 | for item in object1: 60 | if not any(d['field'] == 'api_key' for d in item): 61 | item.append({'field': 'api_key', 'value': ''}) 62 | 63 | request_id_item = next(el['value'] for el in item if el['field'] == 'request_id') 64 | found = False 65 | 66 | for item2 in object2: 67 | request_id_item_2 = next(el['value'] for el in item2 if el['field'] == 'request_id') 68 | if request_id_item == request_id_item_2: 69 | found = True 70 | final_list.append(item2) 71 | break 72 | 73 | if not found: 74 | final_list.append(item) 75 | return final_list 76 | else: 77 | return object2 78 | 79 | # Convert the merged dictionary back to the original format 80 | result = [[{'field': k, 'value': v} for k, v in item.items()] for item in merged.values()] 81 | return result 82 | 83 | def run_query(query, log_group_name, date=None): 84 | cloudwatch = boto3.client("logs") 85 | 86 | max_retries = 5 87 | 88 | if date is None: 89 | date = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=1) 90 | else: 91 | date = datetime.datetime.strptime(date, "%Y-%m-%d") 92 | 93 | start = date.replace(hour=0, minute=0, second=0, microsecond=0) 94 | end = date.replace(hour=23, minute=59, second=59, microsecond=0) 95 | 96 | response = cloudwatch.start_query( 97 | logGroupName=log_group_name, 98 | startTime=int(start.timestamp() * 1000), 99 | endTime=int(end.timestamp() * 1000), 100 | queryString=query, 101 | ) 102 | 103 | query_id = response["queryId"] 104 | 105 | retry_count = 0 106 | 107 | while True: 108 | response = cloudwatch.get_query_results(queryId=query_id) 109 | 110 | if response["results"] or retry_count == max_retries: 111 | break 112 | 113 | time.sleep(2) 114 | retry_count += 1 115 | 116 | return response["results"] 117 | 118 | def model_price_embeddings(model_list, row): 119 | input_token_count = float(row["input_tokens"]) if "input_tokens" in row else 0.0 120 | output_token_count = float(row["output_tokens"]) if "output_tokens" in row else 0.0 121 | 122 | model_id = row["model_id"] 123 | 124 | # get model pricing for each region 125 | model_pricing = get_model_pricing(model_id, model_list) 126 | 127 | # calculate costs of prompt and completion 128 | input_cost = input_token_count * model_pricing["input_cost"] / 1000 129 | output_cost = output_token_count * model_pricing["output_cost"] / 1000 130 | 131 | return input_token_count, output_token_count, input_cost, output_cost 132 | 133 | def model_price_image(model_list, row): 134 | height = float(row["height"]) if "height" in row else 0.0 135 | width = float(row["width"]) if "width" in row else 0.0 136 | steps = float(row["steps"]) if "steps" in row else 0.0 137 | 138 | model_id = row["model_id"] 139 | 140 | # get model pricing from utils 141 | model_pricing = get_model_pricing(model_id, model_list) 142 | 143 | if width <= 512 and height <= 512: 144 | size = "512x512" 145 | else: 146 | size = "larger" 147 | 148 | model_pricing = get_model_pricing(size, model_pricing) 149 | 150 | if steps > 50: 151 | price = model_pricing["premium"] 152 | else: 153 | price = model_pricing["standard"] 154 | 155 | return 0.0, 0.0, 0.0, price 156 | 157 | def model_price_text(model_list, row): 158 | input_token_count = float(row["input_tokens"]) if "input_tokens" in row else 0.0 159 | output_token_count = float(row["output_tokens"]) if "output_tokens" in row else 0.0 160 | 161 | model_id = row["model_id"] 162 | 163 | # get model pricing for each region 164 | model_pricing = get_model_pricing(model_id, model_list) 165 | 166 | # calculate costs of prompt and completion 167 | input_cost = input_token_count * model_pricing["input_cost"] / 1000 168 | output_cost = output_token_count * model_pricing["output_cost"] / 1000 169 | 170 | return input_token_count, output_token_count, input_cost, output_cost 171 | 172 | def results_to_df(results): 173 | column_names = set() 174 | rows = [] 175 | 176 | for result in results: 177 | row = { 178 | item["field"]: item["value"] 179 | for item in result 180 | if "@ptr" not in item["field"] 181 | } 182 | column_names.update(row.keys()) 183 | rows.append(row) 184 | 185 | df = pd.DataFrame(rows, columns=list(column_names)) 186 | 187 | return df 188 | 189 | def calculate_cost(row, error_buffer): 190 | try: 191 | model_id = row["model_id"] 192 | 193 | models = _read_model_list("./models.json") 194 | 195 | region = row["region"] if "region" in row else "us-east-1" 196 | 197 | model_list = models[region] 198 | 199 | if _is_in_model_list(model_id, list(model_list["text"].keys())): 200 | input_token_count, output_token_count, input_cost, output_cost = model_price_text(model_list["text"], row) 201 | elif _is_in_model_list(model_id, list(model_list["embeddings"].keys())): 202 | input_token_count, output_token_count, input_cost, output_cost = model_price_embeddings(model_list["embeddings"], row) 203 | elif _is_in_model_list(model_id, list(model_list["image"].keys())): 204 | input_token_count, output_token_count, input_cost, output_cost = model_price_image(model_list["image"], row) 205 | else: 206 | raise Exception(f"Unknown model: {model_id}") 207 | 208 | return input_token_count, output_token_count, input_cost, output_cost, 1 209 | except Exception as e: 210 | stacktrace = traceback.format_exc() 211 | 212 | logger.error(f"Error processing row: {row}\n{stacktrace}") 213 | 214 | error_buffer.write(f"Row:\n{row}\n\n Stacktrace:\n{stacktrace}\n\n") 215 | 216 | return None, None, None, None, None 217 | -------------------------------------------------------------------------------- /lambdas/invoke_model/BedrockInference.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from langchain_community.llms.bedrock import LLMInputOutputAdapter 4 | import logging 5 | import math 6 | import traceback 7 | 8 | logger = logging.getLogger(__name__) 9 | if len(logging.getLogger().handlers) > 0: 10 | logging.getLogger().setLevel(logging.INFO) 11 | else: 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | """ 15 | Return an approximation of tokens in a string 16 | Args: 17 | string (str): Input string 18 | 19 | Returns: 20 | int: Number of approximated tokens 21 | """ 22 | def get_tokens(string): 23 | logger.info("Counting approximation tokens") 24 | 25 | return math.floor(len(string) / 4) 26 | 27 | """ 28 | This class handles inference requests for Bedrock models. 29 | """ 30 | class BedrockInference: 31 | """ 32 | Initialize the BedrockInference instance. 33 | 34 | Args: 35 | bedrock_client (boto3.client): The Bedrock client instance. 36 | model_id (str): The ID of the model to use. 37 | model_arn (str, optional): The ARN of the model to use. Defaults to None. 38 | messages_api (str, optional): Whether to use the messages API. Defaults to "false". 39 | """ 40 | def __init__(self, bedrock_client, model_id, model_arn=None, messages_api="false"): 41 | self.bedrock_client = bedrock_client 42 | self.model_id = model_id 43 | self.model_arn = model_arn 44 | self.messages_api = messages_api 45 | self.input_tokens = 0 46 | self.output_tokens = 0 47 | 48 | """ 49 | Decode base64-encoded documents in the input messages. 50 | 51 | Args: 52 | messages (list): A list of message dictionaries. 53 | 54 | Returns: 55 | list: The updated list of message dictionaries with decoded images. 56 | """ 57 | def _decode_documents(self, messages): 58 | for item in messages: 59 | if 'content' in item: 60 | for content_item in item['content']: 61 | if 'document' in content_item and 'bytes' in content_item['document']['source']: 62 | encoded_document = content_item['document']['source']['bytes'] 63 | base64_bytes = encoded_document.encode('utf-8') 64 | document_bytes = base64.b64decode(base64_bytes) 65 | content_item['document']['source']['bytes'] = document_bytes 66 | return messages 67 | 68 | """ 69 | Decode base64-encoded images in the input messages. 70 | 71 | Args: 72 | messages (list): A list of message dictionaries. 73 | 74 | Returns: 75 | list: The updated list of message dictionaries with decoded images. 76 | """ 77 | def _decode_images(self, messages): 78 | for item in messages: 79 | if 'content' in item: 80 | for content_item in item['content']: 81 | if 'image' in content_item and 'bytes' in content_item['image']['source']: 82 | encoded_image = content_item['image']['source']['bytes'] 83 | base64_bytes = encoded_image.encode('utf-8') 84 | image_bytes = base64.b64decode(base64_bytes) 85 | content_item['image']['source']['bytes'] = image_bytes 86 | return messages 87 | 88 | """ 89 | Get the number of input tokens. 90 | 91 | Returns: 92 | int: The number of input tokens. 93 | """ 94 | def get_input_tokens(self): 95 | return self.input_tokens 96 | 97 | """ 98 | Get the number of output tokens. 99 | 100 | Returns: 101 | int: The number of output tokens. 102 | """ 103 | def get_output_tokens(self): 104 | return self.output_tokens 105 | 106 | """ 107 | Invoke the Bedrock model to generate embeddings for text inputs. 108 | 109 | Args: 110 | body (dict): The request body containing the input text. 111 | model_kwargs (dict): Additional model parameters. 112 | 113 | Returns: 114 | list: A list of embeddings for the input text. 115 | 116 | Raises: 117 | Exception: If an error occurs during the inference process. 118 | """ 119 | def invoke_embeddings(self, body, model_kwargs): 120 | try: 121 | provider = self.model_id.split(".")[0] 122 | 123 | if provider == "cohere": 124 | if "input_type" not in model_kwargs.keys(): 125 | model_kwargs["input_type"] = "search_document" 126 | if isinstance(body["inputs"], str): 127 | body["inputs"] = [body["inputs"]] 128 | 129 | request_body = {**model_kwargs, "texts": body["inputs"]} 130 | else: 131 | request_body = {**model_kwargs, "inputText": body["inputs"]} 132 | 133 | request_body = json.dumps(request_body) 134 | 135 | modelId = self.model_arn if self.model_arn is not None else self.model_id 136 | 137 | response = self.bedrock_client.invoke_model( 138 | body=request_body, 139 | modelId=modelId, 140 | accept="application/json", 141 | contentType="application/json", 142 | ) 143 | 144 | response_body = json.loads(response.get("body").read()) 145 | 146 | if provider == "cohere": 147 | response = response_body.get("embeddings")[0] 148 | else: 149 | response = response_body.get("embedding") 150 | 151 | return response 152 | except Exception as e: 153 | stacktrace = traceback.format_exc() 154 | 155 | logger.error(stacktrace) 156 | 157 | raise e 158 | 159 | """ 160 | Invoke the Bedrock model to generate embeddings for image inputs. 161 | 162 | Args: 163 | body (dict): The request body containing the input image. 164 | model_kwargs (dict): Additional model parameters. 165 | 166 | Returns: 167 | list: A list of embeddings for the input image. 168 | 169 | Raises: 170 | Exception: If an error occurs during the inference process. 171 | """ 172 | def invoke_embeddings_image(self, body, model_kwargs): 173 | try: 174 | provider = self.model_id.split(".")[0] 175 | 176 | request_body = {**model_kwargs, "inputImage": body["inputs"]} 177 | 178 | request_body = json.dumps(request_body) 179 | 180 | modelId = self.model_arn if self.model_arn is not None else self.model_id 181 | 182 | response = self.bedrock_client.invoke_model( 183 | body=request_body, 184 | modelId=modelId, 185 | accept="application/json", 186 | contentType="application/json", 187 | ) 188 | 189 | response_body = json.loads(response.get("body").read()) 190 | 191 | if provider == "cohere": 192 | response = response_body.get("embeddings")[0] 193 | else: 194 | response = response_body.get("embedding") 195 | 196 | return response 197 | except Exception as e: 198 | stacktrace = traceback.format_exc() 199 | 200 | logger.error(stacktrace) 201 | 202 | raise e 203 | 204 | """ 205 | Invoke the Bedrock model to generate images from text prompts. 206 | 207 | Args: 208 | body (dict): The request body containing the text prompts. 209 | model_kwargs (dict): Additional model parameters. 210 | 211 | Returns: 212 | dict: A dictionary containing the generated images and their dimensions. 213 | int: The height of the generated images. 214 | int: The width of the generated images. 215 | int: The number of steps used to generate the images. 216 | 217 | Raises: 218 | Exception: If an error occurs during the inference process. 219 | """ 220 | def invoke_image(self, body, model_kwargs): 221 | try: 222 | provider = self.model_id.split(".")[0] 223 | 224 | if provider == "stability": 225 | request_body = {**model_kwargs, "text_prompts": body["text_prompts"]} 226 | 227 | height = model_kwargs["height"] if "height" in model_kwargs else 512 228 | width = model_kwargs["width"] if "width" in model_kwargs else 512 229 | steps = model_kwargs["steps"] if "steps" in model_kwargs else 50 230 | else: 231 | request_body = {**model_kwargs, "textToImageParams": body["textToImageParams"]} 232 | 233 | height = model_kwargs["imageGenerationConfig"]["height"] if "height" in model_kwargs[ 234 | "imageGenerationConfig"] else 512 235 | width = model_kwargs["imageGenerationConfig"]["width"] if "width" in model_kwargs[ 236 | "imageGenerationConfig"] else 512 237 | 238 | if "quality" in model_kwargs["imageGenerationConfig"]: 239 | if model_kwargs["imageGenerationConfig"]["quality"] == "standard": 240 | steps = 50 241 | else: 242 | steps = 51 243 | else: 244 | steps = 50 245 | 246 | request_body = json.dumps(request_body) 247 | 248 | modelId = self.model_arn if self.model_arn is not None else self.model_id 249 | 250 | response = self.bedrock_client.invoke_model( 251 | body=request_body, 252 | modelId=modelId, 253 | accept="application/json", 254 | contentType="application/json", 255 | ) 256 | 257 | response_body = json.loads(response.get("body").read()) 258 | 259 | if provider == "stability": 260 | response = {"artifacts": response_body.get("artifacts")} 261 | else: 262 | response = {"images": response_body.get("images")} 263 | 264 | return response, height, width, steps 265 | except Exception as e: 266 | stacktrace = traceback.format_exc() 267 | 268 | logger.error(stacktrace) 269 | 270 | raise e 271 | 272 | """ 273 | Invoke the Bedrock model to generate text from prompts. 274 | 275 | Args: 276 | body (dict): The request body containing the input prompts. 277 | model_kwargs (dict, optional): Additional model parameters. Defaults to an empty dict. 278 | additional_model_fields (dict, optional): Additional model fields. Defaults to an empty dict. 279 | 280 | Returns: 281 | str: The generated text. 282 | 283 | Raises: 284 | Exception: If an error occurs during the inference process. 285 | """ 286 | def invoke_text(self, body, model_kwargs: dict = dict(), additional_model_fields: dict = dict(), tool_config: dict = dict()): 287 | try: 288 | provider = self.model_id.split(".")[0] 289 | is_messages_api = self.messages_api.lower() in ["true"] 290 | 291 | if is_messages_api: 292 | system = [{"text": model_kwargs["system"]}] if "system" in model_kwargs else list() 293 | 294 | if "system" in model_kwargs: 295 | del model_kwargs["system"] 296 | 297 | messages = self._decode_documents(body["inputs"]) 298 | messages = self._decode_images(messages) 299 | 300 | if bool(tool_config): 301 | logger.info(f"Using tools {tool_config}") 302 | 303 | response = self.bedrock_client.converse( 304 | modelId=self.model_id, 305 | messages=messages, 306 | system=system, 307 | inferenceConfig=model_kwargs, 308 | additionalModelRequestFields=additional_model_fields, 309 | toolConfig=tool_config 310 | ) 311 | else: 312 | response = self.bedrock_client.converse( 313 | modelId=self.model_id, 314 | messages=messages, 315 | system=system, 316 | inferenceConfig=model_kwargs, 317 | additionalModelRequestFields=additional_model_fields 318 | ) 319 | 320 | output_message = response['output']['message'] 321 | 322 | self.input_tokens = response['usage']['inputTokens'] 323 | self.output_tokens = response['usage']['outputTokens'] 324 | 325 | tmp_response = "" 326 | tmp_tools = [] 327 | 328 | for content in output_message['content']: 329 | if "text" in content: 330 | tmp_response += content['text'] + " " 331 | if "toolUse" in content: 332 | tmp_tools.append({"toolUse": content["toolUse"]}) 333 | 334 | if len(tmp_tools) > 0: 335 | response = tmp_tools 336 | else: 337 | response = tmp_response.rstrip() 338 | else: 339 | request_body = LLMInputOutputAdapter.prepare_input( 340 | provider=provider, 341 | prompt=body["inputs"], 342 | model_kwargs=model_kwargs 343 | ) 344 | 345 | request_body = json.dumps(request_body) 346 | model_id = self.model_arn or self.model_id 347 | 348 | response = self.bedrock_client.invoke_model( 349 | body=request_body, 350 | modelId=model_id, 351 | accept="application/json", 352 | contentType="application/json" 353 | ) 354 | 355 | response = LLMInputOutputAdapter.prepare_output(provider, response) 356 | response = response["text"] 357 | 358 | self.input_tokens = get_tokens(body["inputs"]) 359 | self.output_tokens = get_tokens(response) 360 | 361 | return response 362 | except Exception as e: 363 | stacktrace = traceback.format_exc() 364 | logger.error(stacktrace) 365 | 366 | raise e -------------------------------------------------------------------------------- /lambdas/invoke_model/SageMakerInference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import traceback 5 | 6 | logger = logging.getLogger(__name__) 7 | if len(logging.getLogger().handlers) > 0: 8 | logging.getLogger().setLevel(logging.INFO) 9 | else: 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | """ 13 | Return an approximation of tokens in a string 14 | Args: 15 | string (str): Input string 16 | 17 | Returns: 18 | int: Number of approximated tokens 19 | """ 20 | def get_tokens(string): 21 | logger.info("Counting approximation tokens") 22 | 23 | return math.floor(len(string) / 4) 24 | 25 | """ 26 | This class handles inference requests for SageMaker models. 27 | """ 28 | class SageMakerInference: 29 | """ 30 | Initialize the SageMakerInference instance. 31 | 32 | Args: 33 | sagemaker_client (boto3.client): The SageMaker client instance. 34 | endpoint_name (str): The name of the SageMaker endpoint. 35 | messages_api (str, optional): Whether to use the messages API. Defaults to "false". 36 | """ 37 | def __init__(self, sagemaker_client, endpoint_name, messages_api="false"): 38 | self.sagemaker_client = sagemaker_client 39 | self.endpoint_name = endpoint_name 40 | self.messages_api = messages_api 41 | self.input_tokens = 0 42 | self.output_tokens = 0 43 | 44 | """ 45 | Get the number of input tokens. 46 | 47 | Returns: 48 | int: The number of input tokens. 49 | """ 50 | def get_input_tokens(self): 51 | return self.input_tokens 52 | 53 | """ 54 | Get the number of output tokens. 55 | 56 | Returns: 57 | int: The number of output tokens. 58 | """ 59 | def get_output_tokens(self): 60 | return self.output_tokens 61 | 62 | """ 63 | Invoke the SageMaker model to generate embeddings for text inputs. 64 | 65 | Args: 66 | body (dict): The request body containing the input text. 67 | model_kwargs (dict): Additional model parameters. 68 | 69 | Returns: 70 | list: A list of embeddings for the input text. 71 | 72 | Raises: 73 | Exception: If an error occurs during the inference process. 74 | """ 75 | def invoke_embeddings(self, body, model_kwargs): 76 | try: 77 | if "InferenceComponentName" in model_kwargs: 78 | inference_component = model_kwargs.pop("InferenceComponentName") 79 | else: 80 | inference_component = None 81 | 82 | if isinstance(body["inputs"], dict): 83 | # If body["inputs"] is a dictionary, merge it with model_kwargs 84 | request_data = {**body["inputs"], **model_kwargs} 85 | else: 86 | # If body["inputs"] is not a dictionary, use the original format 87 | request_data = { 88 | "inputs": body["inputs"], 89 | "parameters": model_kwargs 90 | } 91 | 92 | request_body = json.dumps(request_data) 93 | 94 | if inference_component: 95 | response = self.sagemaker_client.invoke_endpoint( 96 | EndpointName=self.endpoint_name, 97 | ContentType="application/json", 98 | Body=request_body, 99 | InferenceComponentName=inference_component 100 | ) 101 | else: 102 | response = self.sagemaker_client.invoke_endpoint( 103 | EndpointName=self.endpoint_name, 104 | ContentType="application/json", 105 | Body=request_body 106 | ) 107 | 108 | response = json.loads(response['Body'].read().decode()) 109 | 110 | self.input_tokens = get_tokens(body["inputs"][list(body["inputs"].keys())[0]]) 111 | 112 | return response["embedding"] 113 | except Exception as e: 114 | stacktrace = traceback.format_exc() 115 | 116 | logger.error(stacktrace) 117 | 118 | raise e 119 | 120 | """ 121 | Invoke the SageMaker model to generate text from prompts. 122 | 123 | Args: 124 | body (dict): The request body containing the input prompts. 125 | model_kwargs (dict): Additional model parameters. 126 | 127 | Returns: 128 | str: The generated text. 129 | 130 | Raises: 131 | Exception: If an error occurs during the inference process. 132 | """ 133 | def invoke_text(self, body, model_kwargs): 134 | try: 135 | if "InferenceComponentName" in model_kwargs: 136 | inference_component = model_kwargs.pop("InferenceComponentName") 137 | else: 138 | inference_component = None 139 | 140 | is_messages_api = self.messages_api.lower() in ["true"] 141 | 142 | if is_messages_api: 143 | request_data = {"messages": body["inputs"], **model_kwargs} 144 | else: 145 | if isinstance(body["inputs"], dict): 146 | # If body["inputs"] is a dictionary, merge it with model_kwargs 147 | request_data = {**body["inputs"], **model_kwargs} 148 | else: 149 | # If body["inputs"] is not a dictionary, use the original format 150 | request_data = { 151 | "inputs": body["inputs"], 152 | "parameters": model_kwargs 153 | } 154 | 155 | request_body = json.dumps(request_data) 156 | 157 | if inference_component: 158 | response = self.sagemaker_client.invoke_endpoint( 159 | EndpointName=self.endpoint_name, 160 | ContentType="application/json", 161 | Body=request_body, 162 | InferenceComponentName=inference_component 163 | ) 164 | else: 165 | response = self.sagemaker_client.invoke_endpoint( 166 | EndpointName=self.endpoint_name, 167 | ContentType="application/json", 168 | Body=request_body 169 | ) 170 | 171 | response = json.loads(response['Body'].read().decode()) 172 | 173 | if is_messages_api: 174 | response = response["choices"][0]["message"]["content"].strip() 175 | 176 | self.input_tokens = get_tokens(body["inputs"]) 177 | self.output_tokens = get_tokens(response) 178 | 179 | return response 180 | else: 181 | self.input_tokens = get_tokens(body["inputs"]) 182 | self.output_tokens = get_tokens(response[0]["generated_text"]) 183 | 184 | return response[0]["generated_text"] 185 | except Exception as e: 186 | stacktrace = traceback.format_exc() 187 | 188 | logger.error(stacktrace) 189 | 190 | raise e -------------------------------------------------------------------------------- /lambdas/invoke_model/index.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from aws_lambda_powertools import Logger 3 | from BedrockInference import BedrockInference, get_tokens 4 | import boto3 5 | from botocore.config import Config 6 | import datetime 7 | import json 8 | import logging 9 | import os 10 | from SageMakerInference import SageMakerInference 11 | import time 12 | import traceback 13 | 14 | logger = logging.getLogger(__name__) 15 | if len(logging.getLogger().handlers) > 0: 16 | logging.getLogger().setLevel(logging.INFO) 17 | else: 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | cloudwatch_logger = Logger() 21 | 22 | lambda_client = boto3.client('lambda') 23 | dynamodb = boto3.resource('dynamodb') 24 | s3_client = boto3.client('s3') 25 | 26 | bedrock_region = os.environ.get("BEDROCK_REGION", "us-east-1") 27 | bedrock_url = os.environ.get("BEDROCK_URL", None) 28 | iam_role = os.environ.get("IAM_ROLE", None) 29 | lambda_streaming = os.environ.get("LAMBDA_STREAMING", None) 30 | logs_table_name = os.environ.get("LOGS_TABLE_NAME", None) 31 | streaming_table_name = os.environ.get("STREAMING_TABLE_NAME", None) 32 | s3_bucket = os.environ.get("S3_BUCKET", None) 33 | sagemaker_endpoints = os.environ.get("SAGEMAKER_ENDPOINTS", "") # If FMs are exposed through SageMaker 34 | sagemaker_region = os.environ.get("SAGEMAKER_REGION", "us-east-1") # If FMs are exposed through SageMaker 35 | sagemaker_url = os.environ.get("SAGEMAKER_URL", None) # If FMs are exposed through SageMaker 36 | 37 | def _get_bedrock_client(): 38 | try: 39 | logger.info(f"Create new client\n Using region: {bedrock_region}") 40 | session_kwargs = {"region_name": bedrock_region} 41 | client_kwargs = {**session_kwargs} 42 | 43 | retry_config = Config( 44 | region_name=bedrock_region, 45 | retries={ 46 | "max_attempts": 10, 47 | "mode": "standard", 48 | }, 49 | ) 50 | session = boto3.Session(**session_kwargs) 51 | 52 | if iam_role is not None: 53 | logger.info(f"Using role: {iam_role}") 54 | sts = session.client("sts") 55 | 56 | response = sts.assume_role( 57 | RoleArn=str(iam_role), # 58 | RoleSessionName="amazon-bedrock-assume-role" 59 | ) 60 | 61 | client_kwargs = dict( 62 | aws_access_key_id=response['Credentials']['AccessKeyId'], 63 | aws_secret_access_key=response['Credentials']['SecretAccessKey'], 64 | aws_session_token=response['Credentials']['SessionToken'] 65 | ) 66 | 67 | if bedrock_url: 68 | client_kwargs["endpoint_url"] = bedrock_url 69 | 70 | bedrock_client = session.client( 71 | service_name="bedrock-runtime", 72 | config=retry_config, 73 | **client_kwargs 74 | ) 75 | 76 | logger.info("boto3 Bedrock client successfully created!") 77 | logger.info(bedrock_client._endpoint) 78 | return bedrock_client 79 | 80 | except Exception as e: 81 | stacktrace = traceback.format_exc() 82 | logger.error(stacktrace) 83 | 84 | raise e 85 | 86 | def _get_sagemaker_client(): 87 | try: 88 | logger.info(f"Create new client\n Using region: {sagemaker_region}") 89 | session_kwargs = {"region_name": sagemaker_region} 90 | client_kwargs = {**session_kwargs} 91 | 92 | retry_config = Config( 93 | region_name=sagemaker_region, 94 | retries={ 95 | "max_attempts": 10, 96 | "mode": "standard", 97 | }, 98 | ) 99 | session = boto3.Session(**session_kwargs) 100 | 101 | if iam_role is not None: 102 | logger.info(f"Using role: {iam_role}") 103 | sts = session.client("sts") 104 | 105 | response = sts.assume_role( 106 | RoleArn=str(iam_role), # 107 | RoleSessionName="amazon-sagemaker-assume-role" 108 | ) 109 | 110 | client_kwargs = dict( 111 | aws_access_key_id=response['Credentials']['AccessKeyId'], 112 | aws_secret_access_key=response['Credentials']['SecretAccessKey'], 113 | aws_session_token=response['Credentials']['SessionToken'] 114 | ) 115 | 116 | if bedrock_url: 117 | client_kwargs["endpoint_url"] = sagemaker_url 118 | 119 | sagemaker_client = session.client( 120 | service_name="sagemaker-runtime", 121 | config=retry_config, 122 | **client_kwargs 123 | ) 124 | 125 | logger.info("boto3 SageMaker client successfully created!") 126 | logger.info(sagemaker_client._endpoint) 127 | return sagemaker_client 128 | 129 | except Exception as e: 130 | stacktrace = traceback.format_exc() 131 | logger.error(stacktrace) 132 | 133 | raise e 134 | 135 | """ 136 | Return the json list of enabled SageMaker Endpoints 137 | 138 | Returns: 139 | dict: json list of enabled SageMaker Endpoints 140 | """ 141 | def _read_sagemaker_endpoints(): 142 | if not sagemaker_endpoints: 143 | return {} 144 | 145 | try: 146 | endpoints = json.loads(sagemaker_endpoints) 147 | except json.JSONDecodeError: 148 | try: 149 | endpoints = ast.literal_eval(sagemaker_endpoints) 150 | except (ValueError, SyntaxError) as e: 151 | raise ValueError(f"Error: Invalid format for SAGEMAKER_ENDPOINTS: {e}") 152 | else: 153 | if not isinstance(endpoints, dict): 154 | raise ValueError("Error: SAGEMAKER_ENDPOINTS is not a dictionary") 155 | 156 | return endpoints 157 | 158 | """ 159 | Save model logs in CloudWatch and DynamoDB 160 | 161 | Args: 162 | logs (dict): logs generated by the model 163 | 164 | Raises: 165 | Exception: If an error occurs during the cloudwatch or dynamodb save. 166 | """ 167 | def _store_logs(logs): 168 | try: 169 | cloudwatch_logger.info(logs) 170 | 171 | current_datetime = datetime.datetime.now() 172 | formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S") 173 | 174 | logs["date"] = formatted_datetime 175 | logs["ttl"] = int(time.time()) + 86400 176 | 177 | logs_table_connection = dynamodb.Table(logs_table_name) 178 | logs_table_connection.put_item(Item=logs) 179 | except Exception as e: 180 | stacktrace = traceback.format_exc() 181 | logger.error(stacktrace) 182 | 183 | raise e 184 | 185 | """ 186 | This function handles inference requests for Bedrock models. 187 | 188 | Args: 189 | event (dict): The event object containing the request details. 190 | 191 | Returns: 192 | dict: A dictionary containing the response status code and body. 193 | """ 194 | def bedrock_handler(event): 195 | logger.info("Bedrock Endpoint") 196 | 197 | model_id = event["queryStringParameters"]["model_id"] 198 | model_arn = event["queryStringParameters"].get("model_arn") 199 | team_id = event["headers"]["team_id"] 200 | api_key = event["headers"]["x-api-key"] 201 | 202 | bedrock_client = _get_bedrock_client() 203 | custom_request_id = event["queryStringParameters"].get("requestId") 204 | messages_api = event["headers"].get("messages_api", "false") 205 | 206 | bedrock_inference = BedrockInference( 207 | bedrock_client=bedrock_client, 208 | model_id=model_id, 209 | model_arn=model_arn, 210 | messages_api=messages_api 211 | ) 212 | 213 | if custom_request_id is None: 214 | request_id = event["requestContext"]["requestId"] 215 | streaming = event["headers"].get("streaming", "false") 216 | embeddings = event["headers"].get("type", "").lower() == "embeddings" 217 | embeddings_image = event["headers"].get("type", "").lower() == "embeddings-image" 218 | image = event["headers"].get("type", "").lower() == "image" 219 | 220 | logger.info(f"Model ID: {model_id}") 221 | logger.info(f"Request ID: {request_id}") 222 | 223 | body = json.loads(event["body"]) 224 | model_kwargs = body.get("parameters", {}) 225 | additional_model_fields = body.get("additional_model_fields", {}) 226 | tool_config = body.get("tool_config", {}) 227 | 228 | if embeddings: 229 | logger.info("Request type: embeddings") 230 | response = bedrock_inference.invoke_embeddings(body, model_kwargs) 231 | results = {"statusCode": 200, "body": json.dumps([{"embedding": response}])} 232 | logs = { 233 | "team_id": team_id, 234 | "api_key": api_key, 235 | "requestId": request_id, 236 | "region": bedrock_region, 237 | "model_id": model_id, 238 | "inputTokens": get_tokens(body["inputs"]), 239 | "outputTokens": get_tokens(response), 240 | "height": None, 241 | "width": None, 242 | "steps": None 243 | } 244 | 245 | _store_logs(logs) 246 | 247 | elif embeddings_image: 248 | logger.info("Request type: embeddings-image") 249 | response = bedrock_inference.invoke_embeddings_image(body, model_kwargs) 250 | results = {"statusCode": 200, "body": json.dumps([{"embedding": response}])} 251 | logs = { 252 | "team_id": team_id, 253 | "api_key": api_key, 254 | "requestId": request_id, 255 | "region": bedrock_region, 256 | "model_id": model_id + "-image", 257 | "inputTokens": get_tokens(body["inputs"]), 258 | "outputTokens": get_tokens(response), 259 | "height": None, 260 | "width": None, 261 | "steps": None 262 | } 263 | 264 | _store_logs(logs) 265 | elif image: 266 | logger.info("Request type: image") 267 | response, height, width, steps = bedrock_inference.invoke_image(body, model_kwargs) 268 | results = {"statusCode": 200, "body": json.dumps([response])} 269 | logs = { 270 | "team_id": team_id, 271 | "api_key": api_key, 272 | "requestId": request_id, 273 | "region": bedrock_region, 274 | "model_id": model_id, 275 | "inputTokens": None, 276 | "outputTokens": None, 277 | "height": height, 278 | "width": width, 279 | "steps": steps 280 | } 281 | 282 | _store_logs(logs) 283 | else: 284 | logger.info("Request type: text") 285 | 286 | if streaming.lower() in ["true"] and custom_request_id is None: 287 | logger.info("Send streaming request") 288 | event["queryStringParameters"]["request_id"] = request_id 289 | s3_client.put_object( 290 | Bucket=s3_bucket, 291 | Key=f"{request_id}.json", 292 | Body=json.dumps(event).encode("utf-8") 293 | ) 294 | lambda_client.invoke( 295 | FunctionName=lambda_streaming, 296 | InvocationType="Event", 297 | Payload=json.dumps({"request_json": f"{request_id}.json"}) 298 | ) 299 | results = {"statusCode": 200, "body": json.dumps([{"request_id": request_id}])} 300 | 301 | else: 302 | response = bedrock_inference.invoke_text(body, model_kwargs, additional_model_fields, tool_config) 303 | results = {"statusCode": 200, "body": json.dumps([{"generated_text": response}])} 304 | logs = { 305 | "team_id": team_id, 306 | "api_key": api_key, 307 | "requestId": request_id, 308 | "region": bedrock_region, 309 | "model_id": model_id, 310 | "inputTokens": bedrock_inference.get_input_tokens(), 311 | "outputTokens": bedrock_inference.get_output_tokens(), 312 | "height": None, 313 | "width": None, 314 | "steps": None 315 | } 316 | 317 | _store_logs(logs) 318 | return results 319 | 320 | else: 321 | logger.info("Check streaming request") 322 | connections = dynamodb.Table(streaming_table_name) 323 | response = connections.get_item(Key={"composite_pk": f"{custom_request_id}_{api_key}"}) 324 | 325 | if "Item" in response: 326 | response = response["Item"] 327 | results = { 328 | "statusCode": response["status"], 329 | "body": json.dumps([{"generated_text": response["generated_text"]}]) 330 | } 331 | connections.delete_item(Key={"composite_pk": f"{custom_request_id}_{api_key}"}) 332 | logs = { 333 | "team_id": team_id, 334 | "api_key": api_key, 335 | "requestId": custom_request_id, 336 | "region": bedrock_region, 337 | "model_id": response.get("model_id"), 338 | "inputTokens": int(response.get("inputTokens", 0)), 339 | "outputTokens": int(response.get("outputTokens", 0)), 340 | "height": None, 341 | "width": None, 342 | "steps": None 343 | } 344 | 345 | _store_logs(logs) 346 | else: 347 | results = {"statusCode": 200, "body": json.dumps([{"request_id": custom_request_id}])} 348 | 349 | return results 350 | 351 | """ 352 | This function handles inference requests for SageMaker models. 353 | 354 | Args: 355 | event (dict): The event object containing the request details. 356 | 357 | Returns: 358 | dict: A dictionary containing the response status code and body. 359 | """ 360 | def sagemaker_handler(event): 361 | logger.info("SageMaker Endpoint") 362 | 363 | model_id = event["queryStringParameters"]["model_id"] 364 | team_id = event["headers"]["team_id"] 365 | api_key = event["headers"]["x-api-key"] 366 | 367 | sagemaker_client = _get_sagemaker_client() 368 | 369 | messages_api = event["headers"].get("messages_api", "false") 370 | custom_request_id = event["queryStringParameters"].get("requestId") 371 | 372 | endpoints = _read_sagemaker_endpoints() 373 | endpoint_name = endpoints[model_id] 374 | sagemaker_inference = SageMakerInference(sagemaker_client, endpoint_name, messages_api) 375 | 376 | if custom_request_id is None: 377 | request_id = event["requestContext"]["requestId"] 378 | streaming = event["headers"].get("streaming", "false") 379 | embeddings = event["headers"].get("type", "").lower() == "embeddings" 380 | 381 | logger.info(f"Model ID: {model_id}") 382 | logger.info(f"Request ID: {request_id}") 383 | 384 | body = json.loads(event["body"]) 385 | model_kwargs = body.get("parameters", {}) 386 | 387 | if embeddings: 388 | response = sagemaker_inference.invoke_embeddings(body, model_kwargs) 389 | results = {"statusCode": 200, "body": json.dumps([{"embedding": response}])} 390 | 391 | logs = { 392 | "team_id": team_id, 393 | "api_key": api_key, 394 | "requestId": request_id, 395 | "region": sagemaker_region, 396 | "model_id": model_id, 397 | "inputTokens": sagemaker_inference.get_input_tokens(), 398 | "outputTokens": 0, 399 | "height": None, 400 | "width": None, 401 | "steps": None 402 | } 403 | 404 | _store_logs(logs) 405 | else: 406 | logger.info("Request type: text") 407 | 408 | if streaming.lower() in ["true"] and custom_request_id is None: 409 | logger.info("Send streaming request") 410 | event["queryStringParameters"]["request_id"] = request_id 411 | s3_client.put_object( 412 | Bucket=s3_bucket, 413 | Key=f"{request_id}.json", 414 | Body=json.dumps(event).encode("utf-8") 415 | ) 416 | lambda_client.invoke( 417 | FunctionName=lambda_streaming, 418 | InvocationType="Event", 419 | Payload=json.dumps({"request_json": f"{request_id}.json"}) 420 | ) 421 | results = {"statusCode": 200, "body": json.dumps([{"request_id": request_id}])} 422 | else: 423 | response = sagemaker_inference.invoke_text(body, model_kwargs) 424 | results = {"statusCode": 200, "body": json.dumps([{"generated_text": response}])} 425 | logs = { 426 | "team_id": team_id, 427 | "api_key": api_key, 428 | "requestId": request_id, 429 | "region": sagemaker_region, 430 | "model_id": model_id, 431 | "inputTokens": sagemaker_inference.get_input_tokens(), 432 | "outputTokens": sagemaker_inference.get_output_tokens(), 433 | "height": None, 434 | "width": None, 435 | "steps": None 436 | } 437 | 438 | _store_logs(logs) 439 | 440 | return results 441 | 442 | else: 443 | logger.info("Check streaming request") 444 | connections = dynamodb.Table(streaming_table_name) 445 | response = connections.get_item(Key={"composite_pk": f"{custom_request_id}_{api_key}"}) 446 | 447 | if "Item" in response: 448 | response = response["Item"] 449 | results = { 450 | "statusCode": response["status"], 451 | "body": json.dumps([{"generated_text": response["generated_text"]}]) 452 | } 453 | connections.delete_item(Key={"composite_pk": f"{custom_request_id}_{api_key}"}) 454 | logs = { 455 | "team_id": team_id, 456 | "api_key": api_key, 457 | "requestId": custom_request_id, 458 | "region": sagemaker_region, 459 | "model_id": response.get("model_id"), 460 | "inputTokens": int(response.get("inputTokens", 0)), 461 | "outputTokens": int(response.get("outputTokens", 0)), 462 | "height": None, 463 | "width": None, 464 | "steps": None 465 | } 466 | _store_logs(logs) 467 | else: 468 | results = {"statusCode": 200, "body": json.dumps([{"request_id": custom_request_id}])} 469 | 470 | return results 471 | 472 | def lambda_handler(event, context): 473 | try: 474 | logger.info(event) 475 | 476 | team_id = event["headers"].get("team_id") 477 | if not team_id: 478 | logger.error("Bad Request: Header 'team_id' is missing") 479 | return {"statusCode": 400, "body": "Bad Request"} 480 | 481 | model_id = event["queryStringParameters"]["model_id"] 482 | endpoints = _read_sagemaker_endpoints() 483 | 484 | if model_id in endpoints: 485 | return sagemaker_handler(event) 486 | else: 487 | return bedrock_handler(event) 488 | 489 | except Exception as e: 490 | stacktrace = traceback.format_exc() 491 | logger.error(stacktrace) 492 | return {"statusCode": 500, "body": json.dumps([{"generated_text": stacktrace}])} 493 | -------------------------------------------------------------------------------- /lambdas/invoke_model_streaming/BedrockInference.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from langchain_community.llms.bedrock import LLMInputOutputAdapter 4 | from langchain_core.outputs import GenerationChunk 5 | import logging 6 | import math 7 | import traceback 8 | 9 | logger = logging.getLogger(__name__) 10 | if len(logging.getLogger().handlers) > 0: 11 | logging.getLogger().setLevel(logging.INFO) 12 | else: 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment" 16 | 17 | def get_tokens(string): 18 | logger.info("Counting approximation tokens") 19 | 20 | return math.floor(len(string) / 4) 21 | 22 | class BedrockInferenceStream: 23 | def __init__(self, bedrock_client, model_id, model_arn=None, messages_api="false"): 24 | self.bedrock_client = bedrock_client 25 | self.model_id = model_id 26 | self.model_arn = model_arn 27 | self.messages_api = messages_api 28 | self.input_tokens = 0 29 | self.output_tokens = 0 30 | 31 | """ 32 | Decode base64-encoded documents in the input messages. 33 | 34 | Args: 35 | messages (list): A list of message dictionaries. 36 | 37 | Returns: 38 | list: The updated list of message dictionaries with decoded images. 39 | """ 40 | def _decode_documents(self, messages): 41 | for item in messages: 42 | if 'content' in item: 43 | for content_item in item['content']: 44 | if 'document' in content_item and 'bytes' in content_item['document']['source']: 45 | encoded_document = content_item['document']['source']['bytes'] 46 | base64_bytes = encoded_document.encode('utf-8') 47 | document_bytes = base64.b64decode(base64_bytes) 48 | content_item['document']['source']['bytes'] = document_bytes 49 | return messages 50 | 51 | """ 52 | Decode base64-encoded images in the input messages. 53 | 54 | Args: 55 | messages (list): A list of message dictionaries. 56 | 57 | Returns: 58 | list: The updated list of message dictionaries with decoded images. 59 | """ 60 | def _decode_images(self, messages): 61 | for item in messages: 62 | if 'content' in item: 63 | for content_item in item['content']: 64 | if 'image' in content_item and 'bytes' in content_item['image']['source']: 65 | encoded_image = content_item['image']['source']['bytes'] 66 | base64_bytes = encoded_image.encode('utf-8') 67 | image_bytes = base64.b64decode(base64_bytes) 68 | content_item['image']['source']['bytes'] = image_bytes 69 | return messages 70 | 71 | def get_input_tokens(self): 72 | return self.input_tokens 73 | 74 | def get_output_tokens(self): 75 | return self.output_tokens 76 | 77 | def invoke_text_streaming(self, body, model_kwargs:dict = dict(), additional_model_fields:dict = dict()): 78 | try: 79 | provider = self.model_id.split(".")[0] 80 | 81 | if self.messages_api.lower() in ["true"]: 82 | system = [{"text": model_kwargs["system"]}] if "system" in model_kwargs else list() 83 | 84 | if "system" in model_kwargs: 85 | del model_kwargs["system"] 86 | 87 | messages = self._decode_documents(body["inputs"]) 88 | messages = self._decode_images(messages) 89 | 90 | modelId = self.model_arn if self.model_arn is not None else self.model_id 91 | 92 | response = self.bedrock_client.converse_stream( 93 | modelId=modelId, 94 | messages=messages, 95 | system=system, 96 | inferenceConfig=model_kwargs, 97 | additionalModelRequestFields=additional_model_fields 98 | ) 99 | 100 | return self.prepare_output_stream(provider, response, messages_api=True) 101 | else: 102 | request_body = LLMInputOutputAdapter.prepare_input( 103 | provider=provider, 104 | prompt=body["inputs"], 105 | model_kwargs=model_kwargs 106 | ) 107 | 108 | request_body = json.dumps(request_body) 109 | 110 | return self.stream(request_body) 111 | 112 | except Exception as e: 113 | stacktrace = traceback.format_exc() 114 | 115 | logger.error(stacktrace) 116 | 117 | raise e 118 | 119 | def prepare_output_stream(self, provider, response, stop=None, messages_api=False): 120 | if messages_api: 121 | stream = response.get("stream") 122 | else: 123 | stream = response.get("body") 124 | 125 | if not stream: 126 | return 127 | 128 | if messages_api: 129 | output_key = "message" 130 | else: 131 | output_key = LLMInputOutputAdapter.provider_to_output_key_map.get(provider, "") 132 | 133 | if not output_key: 134 | raise ValueError( 135 | f"Unknown streaming response output key for provider: {provider}" 136 | ) 137 | 138 | for event in stream: 139 | if messages_api: 140 | if 'contentBlockDelta' in event: 141 | chunk_obj = event['contentBlockDelta'] 142 | if "delta" in chunk_obj and "text" in chunk_obj["delta"]: 143 | chk = GenerationChunk( 144 | text=chunk_obj["delta"]["text"], 145 | generation_info=dict( 146 | finish_reason=chunk_obj.get("stop_reason", None), 147 | ), 148 | ) 149 | yield chk 150 | 151 | if "metadata" in event and "usage" in event["metadata"]: 152 | usage = event["metadata"]["usage"] 153 | if "inputTokens" in usage: 154 | self.input_tokens += usage["inputTokens"] 155 | if "outputTokens" in usage: 156 | self.output_tokens += usage["outputTokens"] 157 | 158 | else: 159 | chunk = event.get("chunk") 160 | if not chunk: 161 | continue 162 | 163 | chunk_obj = json.loads(chunk.get("bytes").decode()) 164 | 165 | if provider == "cohere" and ( 166 | chunk_obj["is_finished"] or chunk_obj[output_key] == "" 167 | ): 168 | return 169 | 170 | elif ( 171 | provider == "mistral" 172 | and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" 173 | ): 174 | return 175 | 176 | elif messages_api and (chunk_obj.get("type") == "content_block_stop"): 177 | return 178 | 179 | if messages_api and chunk_obj.get("type") in ( 180 | "message_start", 181 | "content_block_start", 182 | "content_block_delta", 183 | ): 184 | if chunk_obj.get("type") == "content_block_delta": 185 | if not chunk_obj["delta"]: 186 | chk = GenerationChunk(text="") 187 | else: 188 | chk = GenerationChunk( 189 | text=chunk_obj["delta"]["text"], 190 | generation_info=dict( 191 | finish_reason=chunk_obj.get("stop_reason", None), 192 | ), 193 | ) 194 | yield chk 195 | else: 196 | continue 197 | else: 198 | if messages_api: 199 | if chunk_obj["type"] == "message_start" and "message" in chunk_obj and "usage" in chunk_obj["message"]: 200 | if "input_tokens" in chunk_obj["message"]["usage"]: 201 | self.input_tokens += int(chunk_obj["message"]["usage"]["input_tokens"]) 202 | if "output_tokens" in chunk_obj["message"]["usage"]: 203 | self.output_tokens += int(chunk_obj["message"]["usage"]["output_tokens"]) 204 | if chunk_obj["type"] == "message_delta" and "usage" in chunk_obj: 205 | if "input_tokens" in chunk_obj["usage"]: 206 | self.input_tokens += int(chunk_obj["usage"]["input_tokens"]) 207 | if "output_tokens" in chunk_obj["usage"]: 208 | self.output_tokens += int(chunk_obj["usage"]["output_tokens"]) 209 | 210 | # chunk obj format varies with provider 211 | yield GenerationChunk( 212 | text=( 213 | chunk_obj[output_key] 214 | if provider != "mistral" 215 | else chunk_obj[output_key][0]["text"] 216 | ), 217 | generation_info={ 218 | GUARDRAILS_BODY_KEY: ( 219 | chunk_obj.get(GUARDRAILS_BODY_KEY) 220 | if GUARDRAILS_BODY_KEY in chunk_obj 221 | else None 222 | ), 223 | }, 224 | ) 225 | 226 | def stream(self, request_body): 227 | try: 228 | provider = self.model_id.split(".")[0] 229 | 230 | modelId = self.model_arn if self.model_arn is not None else self.model_id 231 | 232 | response = self.bedrock_client.invoke_model_with_response_stream( 233 | body=request_body, 234 | modelId=modelId, 235 | accept="application/json", 236 | contentType="application/json", 237 | ) 238 | except Exception as e: 239 | stacktrace = traceback.format_exc() 240 | 241 | logger.error(stacktrace) 242 | 243 | raise e 244 | 245 | if self.messages_api.lower() in ["true"]: 246 | for chunk in self.prepare_output_stream(provider, response, messages_api=True): 247 | yield chunk 248 | else: 249 | for chunk in self.prepare_output_stream(provider, response, messages_api=False): 250 | yield chunk -------------------------------------------------------------------------------- /lambdas/invoke_model_streaming/SageMakerInference.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import logging 4 | import math 5 | import traceback 6 | 7 | logger = logging.getLogger(__name__) 8 | if len(logging.getLogger().handlers) > 0: 9 | logging.getLogger().setLevel(logging.INFO) 10 | else: 11 | logging.basicConfig(level=logging.INFO) 12 | def get_tokens(string): 13 | logger.info("Counting approximation tokens") 14 | 15 | return math.floor(len(string) / 4) 16 | 17 | class SageMakerInferenceStream: 18 | def __init__(self, sagemaker_runtime, endpoint_name): 19 | self.sagemaker_runtime = sagemaker_runtime 20 | self.endpoint_name = endpoint_name 21 | # A buffered I/O stream to combine the payload parts: 22 | self.buff = io.BytesIO() 23 | self.read_pos = 0 24 | self.input_tokens = 0 25 | self.output_tokens = 0 26 | 27 | def get_input_tokens(self): 28 | return self.input_tokens 29 | 30 | def get_output_tokens(self): 31 | return self.output_tokens 32 | 33 | def invoke_text_streaming(self, body, model_kwargs): 34 | try: 35 | request_body = { 36 | "inputs": body["inputs"], 37 | "parameters": model_kwargs 38 | } 39 | 40 | stream = self.stream(request_body) 41 | 42 | response = self.prepare_output_stream_messages_api(stream) 43 | 44 | self.input_tokens = get_tokens(body["inputs"]) 45 | 46 | return response 47 | 48 | except Exception as e: 49 | stacktrace = traceback.format_exc() 50 | 51 | logger.error(stacktrace) 52 | 53 | raise e 54 | 55 | def prepare_output_stream_messages_api(self, stream): 56 | tmp_response = "" 57 | for part in stream: 58 | tmp_response += part 59 | 60 | try: 61 | response = json.loads(tmp_response) 62 | except json.JSONDecodeError: 63 | # Invalid JSON, try to fix it 64 | if not tmp_response.endswith("}"): 65 | # Missing closing bracket 66 | tmp_response = tmp_response + "}" 67 | if not tmp_response.endswith("]"): 68 | # Uneven brackets 69 | tmp_response = tmp_response + "]" 70 | 71 | # Try again 72 | response = json.loads(tmp_response) 73 | 74 | response = response[0]["generated_text"] 75 | 76 | self.output_tokens = get_tokens(response) 77 | 78 | return response 79 | 80 | def stream(self, request_body): 81 | # Gets a streaming inference response 82 | # from the specified model endpoint: 83 | response = self.sagemaker_runtime \ 84 | .invoke_endpoint_with_response_stream( 85 | EndpointName=self.endpoint_name, 86 | Body=json.dumps(request_body), 87 | ContentType="application/json" 88 | ) 89 | # Gets the EventStream object returned by the SDK: 90 | event_stream = response['Body'] 91 | for event in event_stream: 92 | # Passes the contents of each payload part 93 | # to be concatenated: 94 | self._write(event['PayloadPart']['Bytes']) 95 | # Iterates over lines to parse whole JSON objects: 96 | for line in self._readlines(): 97 | # Returns parts incrementally: 98 | yield line.decode("utf-8") 99 | 100 | # Writes to the buffer to concatenate the contents of the parts: 101 | def _write(self, content): 102 | self.buff.seek(0, io.SEEK_END) 103 | self.buff.write(content) 104 | 105 | # The JSON objects in buffer end with '\n'. 106 | # This method reads lines to yield a series of JSON objects: 107 | def _readlines(self): 108 | self.buff.seek(self.read_pos) 109 | for line in self.buff.readlines(): 110 | self.read_pos += len(line) 111 | yield line[:-1] 112 | -------------------------------------------------------------------------------- /lambdas/invoke_model_streaming/index.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from BedrockInference import BedrockInferenceStream, get_tokens 3 | import boto3 4 | from botocore.config import Config 5 | import json 6 | import logging 7 | import os 8 | from SageMakerInference import SageMakerInferenceStream 9 | import time 10 | import traceback 11 | from typing import Dict 12 | 13 | logger = logging.getLogger(__name__) 14 | if len(logging.getLogger().handlers) > 0: 15 | logging.getLogger().setLevel(logging.INFO) 16 | else: 17 | logging.basicConfig(level=logging.INFO) 18 | 19 | dynamodb = boto3.resource('dynamodb') 20 | s3_client = boto3.client('s3') 21 | 22 | bedrock_region = os.environ.get("BEDROCK_REGION", "us-east-1") 23 | bedrock_url = os.environ.get("BEDROCK_URL", None) 24 | iam_role = os.environ.get("IAM_ROLE", None) 25 | streaming_table_name = os.environ.get("STREAMING_TABLE_NAME", None) 26 | s3_bucket = os.environ.get("S3_BUCKET", None) 27 | sagemaker_endpoints = os.environ.get("SAGEMAKER_ENDPOINTS", "") # If FMs are exposed through SageMaker 28 | sagemaker_region = os.environ.get("SAGEMAKER_REGION", "us-east-1") # If FMs are exposed through SageMaker 29 | sagemaker_url = os.environ.get("SAGEMAKER_URL", None) # If FMs are exposed through SageMaker 30 | 31 | def _get_bedrock_client(): 32 | try: 33 | logger.info(f"Create new client\n Using region: {bedrock_region}") 34 | session_kwargs = {"region_name": bedrock_region} 35 | client_kwargs = {**session_kwargs} 36 | 37 | retry_config = Config( 38 | region_name=bedrock_region, 39 | retries={ 40 | "max_attempts": 10, 41 | "mode": "standard", 42 | }, 43 | ) 44 | session = boto3.Session(**session_kwargs) 45 | 46 | if iam_role is not None: 47 | logger.info(f"Using role: {iam_role}") 48 | sts = session.client("sts") 49 | 50 | response = sts.assume_role( 51 | RoleArn=str(iam_role), # 52 | RoleSessionName="amazon-bedrock-assume-role" 53 | ) 54 | 55 | client_kwargs = dict( 56 | aws_access_key_id=response['Credentials']['AccessKeyId'], 57 | aws_secret_access_key=response['Credentials']['SecretAccessKey'], 58 | aws_session_token=response['Credentials']['SessionToken'] 59 | ) 60 | 61 | if bedrock_url: 62 | client_kwargs["endpoint_url"] = bedrock_url 63 | 64 | bedrock_client = session.client( 65 | service_name="bedrock-runtime", 66 | config=retry_config, 67 | **client_kwargs 68 | ) 69 | 70 | logger.info("boto3 Bedrock client successfully created!") 71 | logger.info(bedrock_client._endpoint) 72 | return bedrock_client 73 | 74 | except Exception as e: 75 | stacktrace = traceback.format_exc() 76 | logger.error(stacktrace) 77 | 78 | raise e 79 | 80 | def _get_sagemaker_client(): 81 | try: 82 | logger.info(f"Create new client\n Using region: {sagemaker_region}") 83 | session_kwargs = {"region_name": sagemaker_region} 84 | client_kwargs = {**session_kwargs} 85 | 86 | retry_config = Config( 87 | region_name=sagemaker_region, 88 | retries={ 89 | "max_attempts": 10, 90 | "mode": "standard", 91 | }, 92 | ) 93 | session = boto3.Session(**session_kwargs) 94 | 95 | if iam_role is not None: 96 | logger.info(f"Using role: {iam_role}") 97 | sts = session.client("sts") 98 | 99 | response = sts.assume_role( 100 | RoleArn=str(iam_role), # 101 | RoleSessionName="amazon-sagemaker-assume-role" 102 | ) 103 | 104 | client_kwargs = dict( 105 | aws_access_key_id=response['Credentials']['AccessKeyId'], 106 | aws_secret_access_key=response['Credentials']['SecretAccessKey'], 107 | aws_session_token=response['Credentials']['SessionToken'] 108 | ) 109 | 110 | if bedrock_url: 111 | client_kwargs["endpoint_url"] = sagemaker_url 112 | 113 | sagemaker_client = session.client( 114 | service_name="sagemaker-runtime", 115 | config=retry_config, 116 | **client_kwargs 117 | ) 118 | 119 | logger.info("boto3 SageMaker client successfully created!") 120 | logger.info(sagemaker_client._endpoint) 121 | return sagemaker_client 122 | 123 | except Exception as e: 124 | stacktrace = traceback.format_exc() 125 | logger.error(stacktrace) 126 | 127 | raise e 128 | 129 | def _read_json_event(event): 130 | try: 131 | request_json = event["request_json"] 132 | 133 | response = s3_client.get_object(Bucket=s3_bucket, Key=request_json) 134 | content = response['Body'].read() 135 | 136 | json_data = content.decode('utf-8') 137 | 138 | event = json.loads(json_data) 139 | 140 | s3_client.delete_object(Bucket=s3_bucket, Key=request_json) 141 | 142 | return event 143 | except Exception as e: 144 | stacktrace = traceback.format_exc() 145 | 146 | logger.error(stacktrace) 147 | 148 | raise e 149 | 150 | def _read_sagemaker_endpoints(): 151 | if not sagemaker_endpoints: 152 | return {} 153 | 154 | try: 155 | endpoints = json.loads(sagemaker_endpoints) 156 | except json.JSONDecodeError: 157 | try: 158 | endpoints = ast.literal_eval(sagemaker_endpoints) 159 | except (ValueError, SyntaxError) as e: 160 | raise ValueError(f"Error: Invalid format for SAGEMAKER_ENDPOINTS: {e}") 161 | else: 162 | if not isinstance(endpoints, dict): 163 | raise ValueError("Error: SAGEMAKER_ENDPOINTS is not a dictionary") 164 | 165 | return endpoints 166 | 167 | def bedrock_handler(event: Dict) -> Dict: 168 | try: 169 | bedrock_client = _get_bedrock_client() 170 | 171 | model_id = event["queryStringParameters"]['model_id'] 172 | model_arn = event["queryStringParameters"].get('model_arn', None) 173 | request_id = event['queryStringParameters']['request_id'] 174 | messages_api = event["headers"].get("messages_api", "false") 175 | api_key = event["headers"]["x-api-key"] 176 | 177 | logger.info(f"Model ID: {model_id}") 178 | logger.info(f"Request ID: {request_id}") 179 | 180 | body = json.loads(event["body"]) 181 | model_kwargs = body.get("parameters", {}) 182 | additional_model_fields = body.get("additional_model_fields", {}) 183 | logger.info(f"Input body: {body}") 184 | 185 | bedrock_streaming = BedrockInferenceStream( 186 | bedrock_client=bedrock_client, 187 | model_id=model_id, 188 | model_arn=model_arn, 189 | messages_api=messages_api 190 | ) 191 | 192 | response = "".join(chunk.text for chunk in bedrock_streaming.invoke_text_streaming(body, model_kwargs, additional_model_fields)) 193 | logger.info(f"Answer: {response}") 194 | 195 | if messages_api.lower() in ["true"]: 196 | if bedrock_streaming.get_input_tokens() != 0: 197 | inputTokens = bedrock_streaming.get_input_tokens() 198 | else: 199 | messages_text = "" 200 | 201 | if "system" in model_kwargs: 202 | messages_text += f"{model_kwargs['system']}\n" 203 | 204 | for message in body["inputs"]: 205 | messages_text += f"{message['content']}\n" 206 | 207 | inputTokens = get_tokens(messages_text) 208 | else: 209 | inputTokens = get_tokens(body["inputs"]) 210 | 211 | if bedrock_streaming.get_output_tokens() != 0: 212 | outputTokens = bedrock_streaming.get_output_tokens() 213 | else: 214 | outputTokens = get_tokens(response) 215 | 216 | item = { 217 | "composite_pk": f"{request_id}_{api_key}", 218 | "request_id": request_id, 219 | "api_key": api_key, 220 | "status": 200, 221 | "generated_text": response, 222 | "inputTokens": inputTokens, 223 | "outputTokens": outputTokens, 224 | "model_id": model_id, 225 | "ttl": int(time.time()) + 2 * 60 226 | } 227 | 228 | logger.info(f"Streaming answer: {item}") 229 | 230 | connections = dynamodb.Table(streaming_table_name) 231 | connections.put_item(Item=item) 232 | 233 | logger.info(f"Put item: {response}") 234 | 235 | return {"statusCode": 200, "body": response} 236 | 237 | except Exception as e: 238 | stacktrace = traceback.format_exc() 239 | logger.error(stacktrace) 240 | 241 | model_id = event.get("queryStringParameters", {}).get('model_id', None) 242 | request_id = event.get("queryStringParameters", {}).get('request_id', None) 243 | 244 | api_key = event["headers"]["x-api-key"] 245 | 246 | if request_id is not None: 247 | item = { 248 | "composite_pk": f"{request_id}_{api_key}", 249 | "request_id": request_id, 250 | "api_key": api_key, 251 | "status": 500, 252 | "generated_text": stacktrace, 253 | "model_id": model_id, 254 | "ttl": int(time.time()) + 2 * 60 255 | } 256 | 257 | connections = dynamodb.Table(streaming_table_name) 258 | connections.put_item(Item=item) 259 | 260 | logger.info(f"Put exception item: {stacktrace}") 261 | 262 | return {"statusCode": 500, "body": json.dumps([{"generated_text": stacktrace}])} 263 | 264 | def sagemaker_handler(event: Dict) -> Dict: 265 | try: 266 | sagemaker_client = _get_sagemaker_client() 267 | 268 | model_id = event["queryStringParameters"]['model_id'] 269 | request_id = event['queryStringParameters']['request_id'] 270 | 271 | api_key = event["headers"]["x-api-key"] 272 | 273 | logger.info(f"Model ID: {model_id}") 274 | logger.info(f"Request ID: {request_id}") 275 | 276 | body = json.loads(event["body"]) 277 | model_kwargs = body.get("parameters", {}) 278 | 279 | logger.info(f"Input body: {body}") 280 | 281 | endpoints = _read_sagemaker_endpoints() 282 | endpoint_name = endpoints[model_id] 283 | 284 | sagemaker_streaming = SageMakerInferenceStream(sagemaker_client, endpoint_name) 285 | 286 | response = sagemaker_streaming.invoke_text_streaming(body, model_kwargs) 287 | logger.info(f"Answer: {response}") 288 | 289 | item = { 290 | "composite_pk": f"{request_id}_{api_key}", 291 | "request_id": request_id, 292 | "api_key": api_key, 293 | "status": 200, 294 | "generated_text": response, 295 | "inputs": body["inputs"], 296 | "inputTokens": sagemaker_streaming.get_input_tokens(), 297 | "outputTokens": sagemaker_streaming.get_output_tokens(), 298 | "model_id": model_id, 299 | "ttl": int(time.time()) + 2 * 60 300 | } 301 | 302 | connections = dynamodb.Table(streaming_table_name) 303 | connections.put_item(Item=item) 304 | 305 | logger.info(f"Put item: {response}") 306 | 307 | return {"statusCode": 200, "body": response} 308 | 309 | except Exception as e: 310 | stacktrace = traceback.format_exc() 311 | logger.error(stacktrace) 312 | 313 | model_id = event.get("queryStringParameters", {}).get('model_id', None) 314 | request_id = event.get("queryStringParameters", {}).get('request_id', None) 315 | 316 | api_key = event["headers"]["x-api-key"] 317 | 318 | if request_id is not None: 319 | item = { 320 | "composite_pk": f"{request_id}_{api_key}", 321 | "request_id": request_id, 322 | "api_key": api_key, 323 | "status": 500, 324 | "generated_text": stacktrace, 325 | "model_id": model_id, 326 | "ttl": int(time.time()) + 2 * 60 327 | } 328 | 329 | connections = dynamodb.Table(streaming_table_name) 330 | connections.put_item(Item=item) 331 | 332 | logger.info(f"Put exception item: {stacktrace}") 333 | 334 | return {"statusCode": 500, "body": json.dumps([{"generated_text": stacktrace}])} 335 | 336 | def lambda_handler(event: Dict, context) -> Dict: 337 | event = _read_json_event(event) 338 | 339 | logger.info(event) 340 | 341 | model_id = event["queryStringParameters"]['model_id'] 342 | 343 | endpoints = _read_sagemaker_endpoints() 344 | 345 | if model_id in endpoints: 346 | return sagemaker_handler(event) 347 | else: 348 | return bedrock_handler(event) 349 | -------------------------------------------------------------------------------- /lambdas/lambda_layer_requirements/cfnresponse.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | from __future__ import print_function 5 | import urllib3 6 | import json 7 | 8 | SUCCESS = "SUCCESS" 9 | FAILED = "FAILED" 10 | 11 | http = urllib3.PoolManager() 12 | 13 | 14 | def send(event, context, responseStatus, responseData, physicalResourceId=None, noEcho=False, reason=None): 15 | responseUrl = event['ResponseURL'] 16 | 17 | print(responseUrl) 18 | 19 | responseBody = { 20 | 'Status': responseStatus, 21 | 'Reason': reason or "See the details in CloudWatch Log Stream: {}".format(context.log_stream_name), 22 | 'PhysicalResourceId': physicalResourceId or context.log_stream_name, 23 | 'StackId': event['StackId'], 24 | 'RequestId': event['RequestId'], 25 | 'LogicalResourceId': event['LogicalResourceId'], 26 | 'NoEcho': noEcho, 27 | 'Data': responseData 28 | } 29 | 30 | json_responseBody = json.dumps(responseBody) 31 | 32 | print("Response body:") 33 | print(json_responseBody) 34 | 35 | headers = { 36 | 'content-type': '', 37 | 'content-length': str(len(json_responseBody)) 38 | } 39 | 40 | try: 41 | response = http.request('PUT', responseUrl, headers=headers, body=json_responseBody) 42 | print("Status code:", response.status) 43 | 44 | 45 | except Exception as e: 46 | 47 | print("send(..) failed executing http.request(..):", e) -------------------------------------------------------------------------------- /lambdas/lambda_layer_requirements/index.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from datetime import datetime 3 | import cfnresponse 4 | import os 5 | import shutil 6 | import subprocess 7 | import sys 8 | import zipfile 9 | 10 | requirements = os.environ['REQUIREMENTS'] 11 | s3_bucket = os.environ['S3_BUCKET'] 12 | 13 | 14 | def upload_file_to_s3(file_path, bucket, key): 15 | s3 = boto3.client('s3') 16 | s3.upload_file(file_path, bucket, key) 17 | print(f"Upload successful. {file_path} uploaded to {bucket}/{key}") 18 | 19 | 20 | def make_zip_filename(): 21 | now = datetime.now() 22 | timestamp = now.strftime('%Y%m%d_%H%M%S') 23 | filename = f'LambdaLayer_{timestamp}.zip' 24 | return filename 25 | 26 | 27 | def zipdir(path, zipname): 28 | zipf = zipfile.ZipFile(zipname, 'w', zipfile.ZIP_DEFLATED) 29 | for root, dirs, files in os.walk(path): 30 | for file in files: 31 | zipf.write(os.path.join(root, file), 32 | os.path.relpath(os.path.join(root, file), 33 | os.path.join(path, '..'))) 34 | zipf.close() 35 | 36 | 37 | def empty_bucket(bucket_name): 38 | s3_client = boto3.client('s3') 39 | response = s3_client.list_objects_v2(Bucket=bucket_name) 40 | if 'Contents' in response: 41 | keys = [{'Key': obj['Key']} for obj in response['Contents']] 42 | s3_client.delete_objects(Bucket=bucket_name, Delete={'Objects': keys}) 43 | return 44 | 45 | 46 | def lambda_handler(event, context): 47 | print("Event: ", event) 48 | responseData = {} 49 | reason = "" 50 | status = cfnresponse.SUCCESS 51 | try: 52 | if event['RequestType'] != 'Delete': 53 | os.chdir('/tmp') 54 | # download Bedrock SDK 55 | requirements_list = requirements.split(" ") 56 | 57 | if os.path.exists("python"): 58 | shutil.rmtree("python") 59 | 60 | for requirement in requirements_list: 61 | subprocess.check_call([sys.executable, "-m", "pip", "install", requirement, "-t", "python"]) 62 | 63 | boto3_zip_name = make_zip_filename() 64 | zipdir("python", boto3_zip_name) 65 | 66 | print(f"uploading {boto3_zip_name} to s3 bucket {s3_bucket}") 67 | upload_file_to_s3(boto3_zip_name, s3_bucket, boto3_zip_name) 68 | responseData = {"Bucket": s3_bucket, "Key": boto3_zip_name} 69 | else: 70 | # delete - empty the bucket so it can be deleted by the stack. 71 | empty_bucket(s3_bucket) 72 | except Exception as e: 73 | print(e) 74 | status = cfnresponse.FAILED 75 | reason = f"Exception thrown: {e}" 76 | cfnresponse.send(event, context, status, responseData, reason=reason) -------------------------------------------------------------------------------- /lambdas/list_foundation_models/index.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.config import Config 3 | import json 4 | import logging 5 | import os 6 | import traceback 7 | 8 | logger = logging.getLogger(__name__) 9 | if len(logging.getLogger().handlers) > 0: 10 | logging.getLogger().setLevel(logging.INFO) 11 | else: 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | bedrock_region = os.environ.get("BEDROCK_REGION", "us-east-1") 15 | bedrock_role = os.environ.get("BEDROCK_ROLE", None) 16 | bedrock_url = os.environ.get("BEDROCK_URL", None) 17 | 18 | def _get_bedrock_client(): 19 | try: 20 | logger.info(f"Create new client\n Using region: {bedrock_region}") 21 | session_kwargs = {"region_name": bedrock_region} 22 | client_kwargs = {**session_kwargs} 23 | 24 | retry_config = Config( 25 | region_name=bedrock_region, 26 | retries={ 27 | "max_attempts": 10, 28 | "mode": "standard", 29 | }, 30 | ) 31 | session = boto3.Session(**session_kwargs) 32 | 33 | if bedrock_role is not None: 34 | logger.info(f"Using role: {bedrock_role}") 35 | sts = session.client("sts") 36 | 37 | response = sts.assume_role( 38 | RoleArn=str(bedrock_role), # 39 | RoleSessionName="amazon-bedrock-assume-role" 40 | ) 41 | 42 | client_kwargs = dict( 43 | aws_access_key_id=response['Credentials']['AccessKeyId'], 44 | aws_secret_access_key=response['Credentials']['SecretAccessKey'], 45 | aws_session_token=response['Credentials']['SessionToken'] 46 | ) 47 | 48 | if bedrock_url: 49 | client_kwargs["endpoint_url"] = bedrock_url 50 | 51 | bedrock_client = session.client( 52 | service_name="bedrock", 53 | config=retry_config, 54 | **client_kwargs 55 | ) 56 | 57 | logger.info("boto3 Bedrock client successfully created!") 58 | logger.info(bedrock_client._endpoint) 59 | return bedrock_client 60 | 61 | except Exception as e: 62 | stacktrace = traceback.format_exc() 63 | logger.error(stacktrace) 64 | 65 | raise e 66 | 67 | def lambda_handler(event, context): 68 | try: 69 | bedrock_client = _get_bedrock_client() 70 | 71 | logger.info(event) 72 | 73 | response = bedrock_client.list_foundation_models() 74 | 75 | return {"statusCode": 200, "body": json.dumps([response])} 76 | 77 | except Exception as e: 78 | stacktrace = traceback.format_exc() 79 | 80 | logger.error(stacktrace) 81 | return {"statusCode": 500, "body": json.dumps([{"generated_text": stacktrace}])} -------------------------------------------------------------------------------- /notebooks/images/backpack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-a-multi-tenant-generative-ai-gateway-with-cost-and-usage-tracking-on-aws/f750573133a9d6189d56f060007a5553c6099de3/notebooks/images/backpack.png -------------------------------------------------------------------------------- /notebooks/images/battery_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-a-multi-tenant-generative-ai-gateway-with-cost-and-usage-tracking-on-aws/f750573133a9d6189d56f060007a5553c6099de3/notebooks/images/battery_image.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aws-cdk-lib==2.154.1 2 | constructs==10.3.0 -------------------------------------------------------------------------------- /setup/app.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | App, 3 | Fn, 4 | CfnOutput, 5 | RemovalPolicy, 6 | Stack, 7 | Tags, 8 | aws_s3 9 | ) 10 | from constructs import Construct 11 | import json 12 | from stack_constructs.api import API 13 | from stack_constructs.api_gw import APIGW 14 | from stack_constructs.api_key import APIKey 15 | from stack_constructs.dynamodb import DynamoDB 16 | from stack_constructs.iam import IAM 17 | from stack_constructs.lambda_function import LambdaFunction 18 | from stack_constructs.lambda_layer import LambdaLayer 19 | from stack_constructs.network import Network 20 | from stack_constructs.scheduler import LambdaFunctionScheduler 21 | import traceback 22 | 23 | def _load_configs(filename): 24 | """ 25 | Loads config from file 26 | """ 27 | 28 | with open(filename, "r", encoding="utf-8") as f: 29 | config = json.load(f) 30 | 31 | return config 32 | 33 | class BedrockAPIStack(Stack): 34 | def __init__( 35 | self, scope: 36 | Construct, id: str, 37 | config: dict, 38 | **kwargs) -> None: 39 | super().__init__(scope, id, description=config["description"], **kwargs) 40 | # ================================================== 41 | # ============== STATIC PARAMETERS ================= 42 | # ================================================== 43 | self.id = id 44 | self.lambdas_directory = "./../lambdas" 45 | self.prefix_id = config.get("STACK_PREFIX", None) 46 | self.vpc_cidr = config.get("VPC_CIDR", None) 47 | self.bedrock_endpoint_url = "https://bedrock.{}.amazonaws.com".format(self.region) 48 | self.bedrock_runtime_endpoint_url = "https://bedrock-runtime.{}.amazonaws.com".format(self.region) 49 | 50 | # ================================================== 51 | # ================= PARAMETERS ===================== 52 | # ================================================== 53 | self.parent_prefix_id = config.get("PARENT_STACK_PREFIX", None) 54 | self.bedrock_requirements = config.get("BEDROCK_REQUIREMENTS", None) 55 | self.langchain_requirements = config.get("LANGCHAIN_REQUIREMENTS", None) 56 | self.pandas_requirements = config.get("PANDAS_REQUIREMENTS", None) 57 | self.api_throttling_rate = config.get("API_THROTTLING_RATE", 10000) 58 | self.api_burst_rate = config.get("API_BURST_RATE", 10000) 59 | self.api_gw_id = config.get("API_GATEWAY_ID", None) 60 | self.api_gw_resource_id = config.get("API_GATEWAY_RESOURCE_ID", None) 61 | self.sagemaker_endpoints = config.get("SAGEMAKER_ENDPOINTS", "") 62 | 63 | if self.prefix_id is None: 64 | raise Exception("STACK_PREFIX not defined") 65 | 66 | if self.api_throttling_rate is None: 67 | raise Exception("You must configure throttling for API Gateway") 68 | 69 | if self.api_burst_rate is None: 70 | raise Exception("You must configure burst rate for API Gateway") 71 | 72 | if (self.vpc_cidr is not None and 73 | self.bedrock_requirements is not None and 74 | self.langchain_requirements is not None and 75 | self.pandas_requirements is not None 76 | ): 77 | if self.api_gw_id is not None: 78 | raise Exception("API Gateway ID is required only for the partial deployment of an API Key") 79 | if self.api_gw_resource_id is not None: 80 | raise Exception("API Gateway Resource ID is required only for the partial deployment of an API Key") 81 | 82 | print("Deploying full API") 83 | self.full_deployment = True 84 | else: 85 | if self.api_gw_id is not None and self.api_gw_resource_id is not None: 86 | 87 | print("Deploying partial API") 88 | self.full_deployment = False 89 | else: 90 | if self.parent_prefix_id is not None: 91 | try: 92 | self.api_gw_id = Fn.import_value(f"{self.parent_prefix_id}ApiGatewayId") 93 | self.api_gw_resource_id = Fn.import_value(f"{self.parent_prefix_id}ApiGatewayId") 94 | 95 | print("Deploying partial API") 96 | self.full_deployment = False 97 | except Exception as e: 98 | raise Exception("API Gateway ID and API Gateway Resource ID are required. Make sure your are fully deploying the infrastructure or specify valid IDs") 99 | else: 100 | raise Exception("You must specify a PARENT_STACK_PREFIX for importing API Gateway informations") 101 | 102 | def build_full(self): 103 | # ================================================== 104 | # ================== IAM ROLE ====================== 105 | # ================================================== 106 | iam = IAM( 107 | scope=self, 108 | id="iam_role_lambda" 109 | ) 110 | 111 | iam_role = iam.build() 112 | 113 | # ================================================== 114 | # =================== NETWORK ====================== 115 | # ================================================== 116 | 117 | network_class = Network( 118 | scope=self, 119 | id="network_stack", 120 | account=self.account, 121 | region=self.region 122 | ) 123 | 124 | vpc, private_subnet1, private_subnet2, security_group = network_class.build( 125 | vpc_cidr=self.vpc_cidr 126 | ) 127 | 128 | # ================================================== 129 | # =================== DYNAMODB ===================== 130 | # ================================================== 131 | 132 | dynamodb_class = DynamoDB( 133 | scope=self, 134 | id="streaming_dynamodb_stack", 135 | ) 136 | 137 | table_streaming = dynamodb_class.build(suffix="streaming_messages", key_name="composite_pk") 138 | 139 | dynamodb_class_logs = DynamoDB( 140 | scope=self, 141 | id="logs_dynamodb_stack" 142 | ) 143 | 144 | table_logs = dynamodb_class_logs.build(suffix="logs", key_name="requestId") 145 | 146 | # ================================================== 147 | # ================= S3 BUCKETS ===================== 148 | # ================================================== 149 | 150 | s3_bucket_layer = aws_s3.Bucket( 151 | self, 152 | f"{self.prefix_id}_s3_bucket_layer", 153 | auto_delete_objects=True, 154 | removal_policy=RemovalPolicy.DESTROY 155 | ) 156 | 157 | s3_bucket_configs = aws_s3.Bucket( 158 | self, 159 | f"{self.prefix_id}_s3_bucket_configs", 160 | auto_delete_objects=True, 161 | removal_policy=RemovalPolicy.DESTROY 162 | ) 163 | 164 | # ================================================== 165 | # =============== LAMBDA LAYERS ==================== 166 | # ================================================== 167 | 168 | lambda_layer = LambdaLayer( 169 | scope=self, 170 | id=f"{self.prefix_id}_lambda_layer", 171 | s3_bucket=s3_bucket_layer.bucket_name, 172 | role=iam_role.role_name, 173 | ) 174 | 175 | boto3_layer = lambda_layer.build( 176 | layer_name=f"{self.prefix_id}_boto3_sdk_layer", 177 | code_dir=f"{self.lambdas_directory}/lambda_layer_requirements", 178 | environments={ 179 | "REQUIREMENTS": self.bedrock_requirements, 180 | "S3_BUCKET": s3_bucket_layer.bucket_name 181 | } 182 | ) 183 | 184 | langchain_layer = lambda_layer.build( 185 | layer_name=f"{self.prefix_id}_langchain_layer", 186 | code_dir=f"{self.lambdas_directory}/lambda_layer_requirements", 187 | environments={ 188 | "REQUIREMENTS": self.langchain_requirements, 189 | "S3_BUCKET": s3_bucket_layer.bucket_name 190 | } 191 | ) 192 | 193 | pandas_layer = lambda_layer.build( 194 | layer_name=f"{self.prefix_id}_pandas_layer", 195 | code_dir=f"{self.lambdas_directory}/lambda_layer_requirements", 196 | environments={ 197 | "REQUIREMENTS": self.pandas_requirements, 198 | "S3_BUCKET": s3_bucket_layer.bucket_name 199 | } 200 | ) 201 | 202 | # ================================================== 203 | # ============= BEDROCK FUNCTIONS ================== 204 | # ================================================== 205 | 206 | lambda_function = LambdaFunction( 207 | scope=self, 208 | id=f"{self.prefix_id}_lambda_function", 209 | role=iam_role.role_name, 210 | ) 211 | 212 | bedrock_invoke_model_streaming = lambda_function.build( 213 | function_name=f"{self.prefix_id}_bedrock_invoke_model_streaming", 214 | code_dir=f"{self.lambdas_directory}/invoke_model_streaming", 215 | memory=512, 216 | timeout=900, 217 | environment={ 218 | "BEDROCK_URL": self.bedrock_runtime_endpoint_url, 219 | "BEDROCK_REGION": self.region, 220 | "STREAMING_TABLE_NAME": table_streaming.table_name, 221 | "S3_BUCKET": s3_bucket_configs.bucket_name, 222 | "SAGEMAKER_ENDPOINTS": self.sagemaker_endpoints 223 | }, 224 | vpc=vpc, 225 | subnets=[private_subnet1, private_subnet2], 226 | security_groups=[security_group], 227 | layers=[boto3_layer, langchain_layer] 228 | ) 229 | 230 | bedrock_invoke_model = lambda_function.build( 231 | function_name=f"{self.prefix_id}_bedrock_invoke_model", 232 | code_dir=f"{self.lambdas_directory}/invoke_model", 233 | memory=512, 234 | timeout=900, 235 | environment={ 236 | "BEDROCK_URL": self.bedrock_runtime_endpoint_url, 237 | "BEDROCK_REGION": self.region, 238 | "LAMBDA_STREAMING": bedrock_invoke_model_streaming.function_name, 239 | "LOGS_TABLE_NAME": table_logs.table_name, 240 | "STREAMING_TABLE_NAME": table_streaming.table_name, 241 | "S3_BUCKET": s3_bucket_configs.bucket_name, 242 | "SAGEMAKER_ENDPOINTS": self.sagemaker_endpoints 243 | }, 244 | vpc=vpc, 245 | subnets=[private_subnet1, private_subnet2], 246 | security_groups=[security_group], 247 | layers=[boto3_layer, langchain_layer] 248 | ) 249 | 250 | bedrock_list_model = lambda_function.build( 251 | function_name=f"{self.prefix_id}_bedrock_list_foundation_models", 252 | code_dir=f"{self.lambdas_directory}/list_foundation_models", 253 | memory=512, 254 | timeout=900, 255 | environment={ 256 | "BEDROCK_URL": self.bedrock_endpoint_url, 257 | "BEDROCK_REGION": self.region, 258 | }, 259 | vpc=vpc, 260 | subnets=[private_subnet1, private_subnet2], 261 | security_groups=[security_group], 262 | layers=[boto3_layer] 263 | ) 264 | 265 | # ================================================== 266 | # ============= LAMBDA COST TRACKING =============== 267 | # ================================================== 268 | 269 | s3_bucket_cost_tracking = aws_s3.Bucket( 270 | self, 271 | f"{self.prefix_id}_s3_bucket_cost_tracking", 272 | bucket_name=f"{self.prefix_id}-bucket-cost-tracking-bedrock", 273 | auto_delete_objects=True, 274 | removal_policy=RemovalPolicy.DESTROY 275 | ) 276 | 277 | bedrock_cost_tracking = lambda_function.build( 278 | function_name=f"{self.prefix_id}_bedrock_cost_tracking", 279 | code_dir=f"{self.lambdas_directory}/cost_tracking", 280 | memory=512, 281 | timeout=900, 282 | environment={ 283 | "LOG_GROUP_API": f"/aws/lambda/{self.prefix_id}_bedrock_invoke_model", 284 | "S3_BUCKET": s3_bucket_cost_tracking.bucket_name 285 | }, 286 | vpc=vpc, 287 | subnets=[private_subnet1, private_subnet2], 288 | security_groups=[security_group], 289 | layers=[pandas_layer] 290 | ) 291 | 292 | scheduler = LambdaFunctionScheduler( 293 | self, 294 | id=f"{self.prefix_id}_lambda_scheduler" 295 | ) 296 | 297 | scheduler.build( 298 | lambda_function=bedrock_cost_tracking 299 | ) 300 | 301 | # ================================================== 302 | # ================== API GATEWAY =================== 303 | # ================================================== 304 | 305 | api_gw_class = APIGW( 306 | self, 307 | id=f"{self.prefix_id}_api_gw", 308 | api_gw_name=f"{self.prefix_id}_bedrock_api_gw" 309 | ) 310 | 311 | api_gw = api_gw_class.build() 312 | 313 | # ================================================== 314 | # ================== API ROUTES ==================== 315 | # ================================================== 316 | 317 | api_route = API( 318 | self, 319 | id=f"{self.prefix_id}_api_route", 320 | api_gw=api_gw, 321 | 322 | ) 323 | 324 | api_invoke = api_route.build( 325 | lambda_function=bedrock_invoke_model, 326 | route="invoke_model", 327 | method="POST", 328 | validator=True 329 | ) 330 | 331 | api_list = api_route.build( 332 | lambda_function=bedrock_list_model, 333 | route="list_foundation_models", 334 | method="GET", 335 | validator=False 336 | ) 337 | 338 | # ================================================== 339 | # =================== API KEY ====================== 340 | # ================================================== 341 | 342 | api_key_class = APIKey( 343 | self, 344 | id=f"{self.prefix_id}_api_key", 345 | prefix=self.prefix_id, 346 | dependencies=[api_gw, api_invoke, api_list] 347 | ) 348 | 349 | stage = api_key_class.build( 350 | rest_api_id=api_gw.rest_api_id, 351 | resource_id=api_gw.rest_api_root_resource_id, 352 | throttling_rate=self.api_throttling_rate, 353 | burst_rate=self.api_burst_rate 354 | ) 355 | 356 | CfnOutput(self, f"{self.prefix_id}_api_gw_url", export_name=f"{self.prefix_id}ApiGatewayUrl", value=stage.url_for_path(path=None)) 357 | CfnOutput(self, f"{self.prefix_id}_api_gw_id", export_name=f"{self.prefix_id}ApiGatewayId", value=api_gw.rest_api_id) 358 | CfnOutput(self, f"{self.prefix_id}_api_gw_resource_id", export_name=f"{self.prefix_id}ApiGatewayResourceId", value=api_gw.rest_api_root_resource_id) 359 | CfnOutput(self, f"{self.prefix_id}_stack_name", export_name=f"{self.prefix_id}StackName", value=self.stack_name) 360 | 361 | def build_api_key(self): 362 | # ================================================== 363 | # =================== API KEY ====================== 364 | # ================================================== 365 | 366 | api_key_class = APIKey( 367 | self, 368 | id=f"{self.prefix_id}_api_key", 369 | prefix=self.prefix_id 370 | ) 371 | 372 | stage = api_key_class.build( 373 | rest_api_id=self.api_gw_id, 374 | resource_id=self.api_gw_resource_id, 375 | throttling_rate=self.api_throttling_rate, 376 | burst_rate=self.api_burst_rate 377 | ) 378 | 379 | CfnOutput(self, f"{self.prefix_id}_api_gw_url", export_name=f"{self.prefix_id}ApiGatewayUrl", value=stage.url_for_path(path=None)) 380 | CfnOutput(self, f"{self.prefix_id}_stack_name", export_name=f"{self.prefix_id}StackName", value=self.stack_name) 381 | 382 | # ================================================== 383 | # ============== STACK WITH COST CENTER ============ 384 | # ================================================== 385 | 386 | app = App() 387 | 388 | project_tag = "SO9482" 389 | configs = _load_configs("./configs.json") 390 | 391 | for config in configs: 392 | config["description"] = f"{project_tag} - This template creates the required AWS resources for accessiing LLMs in Amazon Bedrock and Amazon SageMaker through a centralized gateway, by monitoring usage and costs" 393 | 394 | api_stack = BedrockAPIStack( 395 | scope=app, 396 | id=f"{config['STACK_PREFIX']}-bedrock-saas", 397 | config=config 398 | ) 399 | 400 | if api_stack.full_deployment: 401 | api_stack.build_full() 402 | else: 403 | api_stack.build_api_key() 404 | 405 | # Add a cost tag to all constructs in the stack 406 | Tags.of(api_stack).add("Tenant", api_stack.prefix_id) 407 | Tags.of(api_stack).add("Project", project_tag) 408 | 409 | try: 410 | app.synth() 411 | except Exception as e: 412 | stacktrace = traceback.format_exc() 413 | print(stacktrace) 414 | 415 | raise e -------------------------------------------------------------------------------- /setup/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py" 3 | } -------------------------------------------------------------------------------- /setup/configs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "STACK_PREFIX": "", 4 | "BEDROCK_REQUIREMENTS": "boto3>=1.35.38 awscli>=1.35.4 botocore>=1.35.38", 5 | "LANGCHAIN_REQUIREMENTS": "aws-lambda-powertools langchain==0.3.3 langchain-community==0.3.2 pydantic PyYaml", 6 | "PANDAS_REQUIREMENTS": "pandas", 7 | "SAGEMAKER_ENDPOINTS": "", 8 | "VPC_CIDR": "10.10.0.0/16", 9 | "API_THROTTLING_RATE": 10000, 10 | "API_BURST_RATE": 5000 11 | } 12 | ] -------------------------------------------------------------------------------- /setup/stack_constructs/api.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_apigateway as apigw, 4 | aws_iam as iam, 5 | aws_lambda as lambda_ 6 | ) 7 | 8 | class API(Construct): 9 | def __init__( 10 | self, 11 | scope: Construct, 12 | id: str, 13 | api_gw: apigw.LambdaRestApi, 14 | dependencies: list = [] 15 | ): 16 | super().__init__(scope, id) 17 | 18 | self.id = id 19 | self.api_gw = api_gw 20 | self.dependencies = dependencies 21 | 22 | def build( 23 | self, 24 | lambda_function: lambda_.Function, 25 | route: str, 26 | method: str, 27 | validator: bool = False 28 | ): 29 | # Add method/route 30 | 31 | lambda_function.add_permission( 32 | id=f"{self.id}_{route}_permission", 33 | action="lambda:InvokeFunction", 34 | principal=iam.ServicePrincipal("apigateway.amazonaws.com"), 35 | source_arn=self.api_gw.arn_for_execute_api( 36 | stage="*", 37 | method=method, 38 | path=f"/{route}" 39 | ) 40 | ) 41 | 42 | resourse = self.api_gw.root.add_resource(route) 43 | 44 | if validator: 45 | resourse.add_method( 46 | http_method=method, 47 | integration=apigw.LambdaIntegration(lambda_function), 48 | api_key_required=True, 49 | request_parameters={ 50 | "method.request.header.team_id": True, 51 | "method.request.header.streaming": False, 52 | "method.request.header.type": False 53 | }, 54 | request_validator_options={ 55 | "request_validator_name": "parameter-validator", 56 | "validate_request_parameters": True, 57 | "validate_request_body": False 58 | }, 59 | method_responses=[ 60 | apigw.MethodResponse( 61 | status_code="401", 62 | response_parameters={ 63 | "method.response.header.Access-Control-Allow-Origin": True, 64 | }, 65 | response_models={ 66 | "application/json": apigw.Model.ERROR_MODEL, 67 | } 68 | ) 69 | ] 70 | ) 71 | else: 72 | resourse.add_method( 73 | http_method=method, 74 | integration=apigw.LambdaIntegration(lambda_function), 75 | api_key_required=True, 76 | method_responses=[ 77 | apigw.MethodResponse( 78 | status_code="401", 79 | response_parameters={ 80 | "method.response.header.Access-Control-Allow-Origin": True, 81 | }, 82 | response_models={ 83 | "application/json": apigw.Model.ERROR_MODEL, 84 | } 85 | ) 86 | ] 87 | ) 88 | 89 | return resourse 90 | -------------------------------------------------------------------------------- /setup/stack_constructs/api_gw.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_apigateway as apigw, 4 | Duration 5 | ) 6 | 7 | class APIGW(Construct): 8 | def __init__( 9 | self, 10 | scope: Construct, 11 | id: str, 12 | api_gw_name: str, 13 | dependencies: list = [] 14 | ): 15 | super().__init__(scope, id) 16 | 17 | self.id = id 18 | self.api_gw_name = api_gw_name 19 | self.dependencies = dependencies 20 | 21 | def build( 22 | self 23 | ): 24 | # Create API Gateway REST 25 | api = apigw.RestApi( 26 | scope=self, 27 | id=f"{self.id}_api_gateway", 28 | rest_api_name=self.api_gw_name, 29 | deploy=False 30 | ) 31 | 32 | # api.timeout = Duration.seconds(300) 33 | 34 | for el in self.dependencies: 35 | api.node.add_dependency(el) 36 | 37 | return api 38 | -------------------------------------------------------------------------------- /setup/stack_constructs/api_key.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_apigateway as apigw 4 | ) 5 | 6 | class APIKey(Construct): 7 | def __init__( 8 | self, 9 | scope: Construct, 10 | id: str, 11 | prefix: str, 12 | dependencies: list = [] 13 | ): 14 | super().__init__(scope, id) 15 | 16 | self.id = id 17 | self.prefix = prefix 18 | self.dependencies = dependencies 19 | 20 | def build( 21 | self, 22 | rest_api_id: str, 23 | resource_id: str, 24 | throttling_rate: int = 10000, 25 | burst_rate: int = 5000 26 | ): 27 | # Lookup RestApi 28 | api = apigw.RestApi.from_rest_api_attributes( 29 | self, 30 | f"{self.id}_rest_api", 31 | rest_api_id=rest_api_id, 32 | root_resource_id=resource_id 33 | ) 34 | 35 | # Create API key 36 | api_key = apigw.ApiKey( 37 | self, 38 | f"{self.id}_api_key", 39 | description=f"API Key for {self.id}", 40 | enabled=True 41 | ) 42 | 43 | # Create Deployment 44 | deployment = apigw.Deployment(self, f"{self.id}_deployment", api=api) 45 | 46 | # Create Stage 47 | 48 | stage = apigw.Stage( 49 | self, 50 | f"{self.id}_stage", 51 | deployment=deployment, 52 | metrics_enabled=True, 53 | throttling_rate_limit=throttling_rate, 54 | throttling_burst_limit=burst_rate, 55 | stage_name=f"{self.prefix}_prod" 56 | ) 57 | 58 | # Create Usage Plan 59 | usage_plan = api.add_usage_plan( 60 | id=f"{self.id}_usage_plan", 61 | api_stages=[ 62 | { 63 | "api": api, 64 | "stage": stage 65 | } 66 | ], 67 | name=f"{self.id}_plan_prod" 68 | ) 69 | 70 | usage_plan.add_api_key(api_key) 71 | 72 | for el in self.dependencies: 73 | deployment.node.add_dependency(el) 74 | 75 | for el in self.dependencies: 76 | api_key.node.add_dependency(el) 77 | 78 | return stage 79 | -------------------------------------------------------------------------------- /setup/stack_constructs/dynamodb.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_dynamodb as ddb, 4 | RemovalPolicy 5 | ) 6 | 7 | class DynamoDB(Construct): 8 | def __init__( 9 | self, 10 | scope: Construct, 11 | id: str, 12 | dependencies: list = [] 13 | ): 14 | super().__init__(scope, id) 15 | 16 | self.id = id 17 | self.dependencies = dependencies 18 | 19 | def build( 20 | self, 21 | suffix: str, 22 | key_name: str 23 | ): 24 | table = ddb.Table( 25 | self, 26 | f"{self.id}_{suffix}", 27 | partition_key=ddb.Attribute( 28 | name=key_name, 29 | type=ddb.AttributeType.STRING 30 | ), 31 | time_to_live_attribute="ttl", 32 | removal_policy=RemovalPolicy.DESTROY 33 | ) 34 | 35 | for el in self.dependencies: 36 | table.node.add_dependency(el) 37 | 38 | return table 39 | -------------------------------------------------------------------------------- /setup/stack_constructs/iam.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_iam as iam 4 | ) 5 | 6 | class IAM(Construct): 7 | def __init__( 8 | self, 9 | scope: Construct, 10 | id: str, 11 | dependencies: list = [] 12 | ): 13 | super().__init__(scope, id) 14 | 15 | self.id = id 16 | self.dependencies = dependencies 17 | 18 | def build(self): 19 | # ================================================== 20 | # ================= IAM ROLE ======================= 21 | # ================================================== 22 | lambda_role = iam.Role( 23 | self, 24 | id=f"{self.id}_role", 25 | assumed_by=iam.ServicePrincipal(service="lambda.amazonaws.com"), 26 | managed_policies=[ 27 | iam.ManagedPolicy.from_aws_managed_policy_name("AWSLambdaExecute"), 28 | iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchFullAccess"), 29 | iam.ManagedPolicy.from_aws_managed_policy_name("CloudWatchLogsFullAccess") 30 | ], 31 | ) 32 | 33 | ec2_policy = iam.Policy( 34 | scope=self, 35 | id=f"{self.id}_policy_ec2", 36 | policy_name="EC2Policy", 37 | statements=[ 38 | iam.PolicyStatement( 39 | effect=iam.Effect.ALLOW, 40 | actions=[ 41 | 'ec2:AssignPrivateIpAddresses', 42 | 'ec2:CreateNetworkInterface', 43 | 'ec2:DeleteNetworkInterface', 44 | 'ec2:DescribeNetworkInterfaces', 45 | 'ec2:DescribeSecurityGroups', 46 | 'ec2:DescribeSubnets', 47 | 'ec2:DescribeVpcs', 48 | 'ec2:UnassignPrivateIpAddresses', 49 | 'ec2:*VpcEndpoint*' 50 | ], 51 | resources=["*"], 52 | ) 53 | ], 54 | ) 55 | 56 | lambda_policy = iam.Policy( 57 | scope=self, 58 | id=f"{self.id}_policy_lambda", 59 | policy_name="LambdaPolicy", 60 | statements=[ 61 | iam.PolicyStatement( 62 | effect=iam.Effect.ALLOW, 63 | actions=[ 64 | 'lambda:InvokeFunction' 65 | ], 66 | resources=["*"], 67 | ) 68 | ], 69 | ) 70 | 71 | s3_policy = iam.Policy( 72 | scope=self, 73 | id=f"{self.id}_policy_s3", 74 | policy_name="S3Policy", 75 | statements=[ 76 | iam.PolicyStatement( 77 | effect=iam.Effect.ALLOW, 78 | actions=[ 79 | 's3:PutObject', 80 | 's3:DeleteObject', 81 | 's3:ListBucket' 82 | ], 83 | resources=["*"], 84 | ) 85 | ], 86 | ) 87 | 88 | dynamodb_policy = iam.Policy( 89 | scope=self, 90 | id=f"{self.id}_policy_dynamodb", 91 | policy_name="DynamoDBPolicy", 92 | statements=[ 93 | iam.PolicyStatement( 94 | effect=iam.Effect.ALLOW, 95 | actions=[ 96 | "dynamodb:BatchGetItem", 97 | "dynamodb:DeleteItem", 98 | "dynamodb:GetItem", 99 | "dynamodb:PutItem" 100 | 101 | ], 102 | resources=["*"], 103 | ) 104 | ], 105 | ) 106 | 107 | bedrock_policy = iam.Policy( 108 | scope=self, 109 | id=f"{self.id}_policy_bedrock", 110 | policy_name="BedrockPolicy", 111 | statements=[ 112 | iam.PolicyStatement( 113 | effect=iam.Effect.ALLOW, 114 | actions=[ 115 | "sts:AssumeRole" 116 | ], 117 | resources=["*"], 118 | ), 119 | iam.PolicyStatement( 120 | effect=iam.Effect.ALLOW, 121 | actions=[ 122 | "bedrock:GetInferenceProfile", 123 | "bedrock:InvokeModel", 124 | "bedrock:InvokeModelWithResponseStream", 125 | "bedrock:ListFoundationModels", 126 | "bedrock:ListInferenceProfiles" 127 | ], 128 | resources=["*"], 129 | ) 130 | ], 131 | ) 132 | 133 | sagemaker_policy = iam.Policy( 134 | scope=self, 135 | id=f"{self.id}_policy_sagemaker", 136 | policy_name="SageMakerPolicy", 137 | statements=[ 138 | iam.PolicyStatement( 139 | effect=iam.Effect.ALLOW, 140 | actions=[ 141 | "sagemaker:InvokeEndpoint" 142 | ], 143 | resources=["*"], 144 | ) 145 | ], 146 | ) 147 | 148 | bedrock_policy.attach_to_role(lambda_role) 149 | dynamodb_policy.attach_to_role(lambda_role) 150 | ec2_policy.attach_to_role(lambda_role) 151 | lambda_policy.attach_to_role(lambda_role) 152 | s3_policy.attach_to_role(lambda_role) 153 | sagemaker_policy.attach_to_role(lambda_role) 154 | 155 | for el in self.dependencies: 156 | lambda_role.node.add_dependency(el) 157 | 158 | return lambda_role 159 | -------------------------------------------------------------------------------- /setup/stack_constructs/lambda_function.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_ec2 as ec2, 4 | aws_iam as iam, 5 | aws_lambda as lambda_, 6 | Duration, 7 | ) 8 | 9 | class LambdaFunction(Construct): 10 | def __init__( 11 | self, 12 | scope: Construct, 13 | id: str, 14 | role: str, 15 | provisioned_concurrency: int = None, 16 | dependencies: list = [] 17 | ): 18 | super().__init__(scope, id) 19 | 20 | self.id = id 21 | self.role = role 22 | self.provisioned_concurrency = provisioned_concurrency 23 | self.dependencies = dependencies 24 | 25 | def build( 26 | self, 27 | function_name: str, 28 | code_dir: str, 29 | environment: dict, 30 | memory: int, 31 | timeout: int, 32 | vpc: ec2.Vpc = [], 33 | subnets: list = [], 34 | security_groups: list = [], 35 | layers: list = [] 36 | ): 37 | if vpc is not None and len(subnets) > 0 and len(security_groups) > 0: 38 | fn = lambda_.Function( 39 | self, 40 | id=f"{self.id}_{function_name}_function", 41 | function_name=function_name, 42 | runtime=lambda_.Runtime.PYTHON_3_10, 43 | handler="index.lambda_handler", 44 | code=lambda_.Code.from_asset(code_dir), 45 | timeout=Duration.seconds(timeout), 46 | memory_size=memory, 47 | environment=environment, 48 | layers=layers, 49 | role=iam.Role.from_role_name(self, f"{self.id}_{function_name}_role", self.role), 50 | vpc=vpc, 51 | vpc_subnets=ec2.SubnetSelection(subnets=subnets), 52 | security_groups=security_groups 53 | ) 54 | else: 55 | fn = lambda_.Function( 56 | self, 57 | id=f"{self.id}_{function_name}_function", 58 | function_name=function_name, 59 | runtime=lambda_.Runtime.PYTHON_3_10, 60 | handler="index.lambda_handler", 61 | code=lambda_.Code.from_asset(code_dir), 62 | timeout=Duration.seconds(timeout), 63 | memory_size=memory, 64 | environment=environment, 65 | layers=layers, 66 | role=iam.Role.from_role_name(self, f"{self.id}_{function_name}_role", self.role) 67 | ) 68 | 69 | for el in self.dependencies: 70 | fn.node.add_dependency(el) 71 | 72 | return fn -------------------------------------------------------------------------------- /setup/stack_constructs/lambda_layer.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_iam as iam, 4 | aws_lambda as lambda_, 5 | aws_s3 as s3, 6 | CustomResource, 7 | Duration 8 | ) 9 | from typing import Optional, List 10 | 11 | class LambdaLayer(Construct): 12 | def __init__( 13 | self, 14 | scope: Construct, 15 | id: str, 16 | s3_bucket: str, 17 | role: str, 18 | dependencies: list = [] 19 | ) -> None: 20 | super().__init__(scope, id) 21 | 22 | self.id = id 23 | self.s3_bucket = s3_bucket 24 | self.role = role 25 | self.dependencies = dependencies 26 | 27 | def build( 28 | self, 29 | layer_name: str, 30 | code_dir: str, 31 | environments: dict 32 | ): 33 | fn = lambda_.Function( 34 | self, 35 | id=f"{self.id}_{layer_name}_function", 36 | function_name=f"{self.id}_{layer_name}_function", 37 | runtime=lambda_.Runtime.PYTHON_3_10, 38 | handler="index.lambda_handler", 39 | code=lambda_.Code.from_asset(code_dir), 40 | timeout=Duration.seconds(300), 41 | memory_size=512, 42 | environment=environments, 43 | role=iam.Role.from_role_name(self, f"{self.id}_{layer_name}_role", self.role) 44 | ) 45 | 46 | custom = CustomResource( 47 | self, 48 | id=f"{self.id}_{layer_name}_custom_resource", 49 | service_token=fn.function_arn, 50 | properties=environments 51 | ) 52 | 53 | layer = lambda_.LayerVersion( 54 | self, 55 | id=f"{layer_name}_{layer_name}_layer", 56 | layer_version_name=layer_name, 57 | code=lambda_.Code.from_bucket(s3.Bucket.from_bucket_name(self, f"{self.id}_{layer_name}_S3BucketLayers", self.s3_bucket), custom.get_att("Key").to_string()), 58 | compatible_runtimes=[ 59 | lambda_.Runtime.PYTHON_3_10, 60 | lambda_.Runtime.PYTHON_3_9, 61 | lambda_.Runtime.PYTHON_3_8 62 | ] 63 | ) 64 | 65 | for el in self.dependencies: 66 | fn.node.add_dependency(el) 67 | 68 | for el in self.dependencies: 69 | custom.node.add_dependency(el) 70 | 71 | for el in self.dependencies: 72 | layer.node.add_dependency(el) 73 | 74 | return layer 75 | -------------------------------------------------------------------------------- /setup/stack_constructs/network.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_ec2 as ec2 4 | ) 5 | 6 | class Network(Construct): 7 | def __init__( 8 | self, 9 | scope: Construct, 10 | id: str, 11 | account: str, 12 | region: str, 13 | dependencies: list = [] 14 | ): 15 | super().__init__(scope, id) 16 | 17 | self.id = id 18 | self.account = account 19 | self.region = region 20 | self.dependencies = dependencies 21 | 22 | def build( 23 | self, 24 | vpc_cidr: str 25 | ): 26 | # Resources 27 | vpc = ec2.Vpc( 28 | self, 29 | f"{self.id}_vpc", 30 | ip_addresses=ec2.IpAddresses.cidr(vpc_cidr), 31 | enable_dns_hostnames=True, 32 | enable_dns_support=True, 33 | gateway_endpoints={ 34 | "S3": ec2.GatewayVpcEndpointOptions(service=ec2.GatewayVpcEndpointAwsService.S3), 35 | "DynamoDB": ec2.GatewayVpcEndpointOptions(service=ec2.GatewayVpcEndpointAwsService.DYNAMODB) 36 | }, 37 | subnet_configuration=[ 38 | ec2.SubnetConfiguration( 39 | name=f"{self.id}_private_subnet_1", 40 | subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, 41 | cidr_mask=24 42 | ), 43 | ec2.SubnetConfiguration( 44 | name=f"{self.id}_private_subnet_2", 45 | subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, 46 | cidr_mask=24 47 | ) 48 | ] 49 | ) 50 | 51 | # Lookup private subnets 52 | private_subnet1 = ec2.Subnet.from_subnet_attributes( 53 | self, f"{self.id}_private_subnet_1", 54 | subnet_id=vpc.isolated_subnets[0].subnet_id 55 | ) 56 | 57 | private_subnet2 = ec2.Subnet.from_subnet_attributes( 58 | self, f"{self.id}_private_subnet_2", 59 | subnet_id=vpc.isolated_subnets[1].subnet_id 60 | ) 61 | 62 | security_group = ec2.SecurityGroup( 63 | self, 64 | f"{self.id}_security_group", 65 | vpc=vpc, 66 | allow_all_outbound=True, 67 | description="security group for bedrock workload in private subnets", 68 | ) 69 | 70 | endpoint_security_group = ec2.SecurityGroup( 71 | self, 72 | f"{self.id}_vpce_security_group", 73 | vpc=vpc, 74 | description="Allow TLS for VPC Endpoint", 75 | ) 76 | 77 | endpoint_security_group.add_ingress_rule( 78 | peer=security_group, 79 | connection=ec2.Port.tcp(443) 80 | ) 81 | 82 | # Bedrock VPCE 83 | ec2.CfnVPCEndpoint( 84 | self, 85 | f"{self.id}_vpce_bedrock", 86 | service_name=f"com.amazonaws.{self.region}.bedrock", 87 | vpc_id=vpc.vpc_id, 88 | private_dns_enabled=True, 89 | security_group_ids=[endpoint_security_group.security_group_id], 90 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 91 | vpc_endpoint_type="Interface" 92 | ) 93 | 94 | # Bedrock Runtime VPCE 95 | ec2.CfnVPCEndpoint( 96 | self, 97 | f"{self.id}_vpce_bedrock_runtime", 98 | service_name=f"com.amazonaws.{self.region}.bedrock-runtime", 99 | vpc_id=vpc.vpc_id, 100 | private_dns_enabled=True, 101 | security_group_ids=[endpoint_security_group.security_group_id], 102 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 103 | vpc_endpoint_type="Interface" 104 | ) 105 | 106 | # API Gateway VPCE 107 | ec2.CfnVPCEndpoint( 108 | self, 109 | f"{self.id}_vpce_api_gw", 110 | service_name=f"com.amazonaws.{self.region}.execute-api", 111 | vpc_id=vpc.vpc_id, 112 | private_dns_enabled=True, 113 | security_group_ids=[endpoint_security_group.security_group_id], 114 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 115 | vpc_endpoint_type="Interface" 116 | ) 117 | 118 | # CloudWatch Logs VPCE 119 | ec2.CfnVPCEndpoint( 120 | self, 121 | f"{self.id}_vpce_logs", 122 | service_name=f"com.amazonaws.{self.region}.logs", 123 | vpc_id=vpc.vpc_id, 124 | private_dns_enabled=True, 125 | security_group_ids=[endpoint_security_group.security_group_id], 126 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 127 | vpc_endpoint_type="Interface" 128 | ) 129 | 130 | # Events VPCE 131 | ec2.CfnVPCEndpoint( 132 | self, 133 | f"{self.id}_vpce_events", 134 | service_name=f"com.amazonaws.{self.region}.events", 135 | vpc_id=vpc.vpc_id, 136 | private_dns_enabled=True, 137 | security_group_ids=[endpoint_security_group.security_group_id], 138 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 139 | vpc_endpoint_type="Interface" 140 | ) 141 | 142 | # Lambda VPCE 143 | ec2.CfnVPCEndpoint( 144 | self, 145 | f"{self.id}_vpce_lambda", 146 | service_name=f"com.amazonaws.{self.region}.lambda", 147 | vpc_id=vpc.vpc_id, 148 | private_dns_enabled=True, 149 | security_group_ids=[endpoint_security_group.security_group_id], 150 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 151 | vpc_endpoint_type="Interface" 152 | ) 153 | 154 | # SageMaker API VPCE 155 | ec2.CfnVPCEndpoint( 156 | self, 157 | f"{self.id}_vpce_sagemaker_api", 158 | service_name=f"com.amazonaws.{self.region}.sagemaker.api", 159 | vpc_id=vpc.vpc_id, 160 | private_dns_enabled=True, 161 | security_group_ids=[endpoint_security_group.security_group_id], 162 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 163 | vpc_endpoint_type="Interface" 164 | ) 165 | 166 | # SageMaker Runtime VPCE 167 | ec2.CfnVPCEndpoint( 168 | self, 169 | f"{self.id}_vpce_sagemaker_runtime", 170 | service_name=f"com.amazonaws.{self.region}.sagemaker.runtime", 171 | vpc_id=vpc.vpc_id, 172 | private_dns_enabled=True, 173 | security_group_ids=[endpoint_security_group.security_group_id], 174 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 175 | vpc_endpoint_type="Interface" 176 | ) 177 | 178 | # SageMaker Metrics VPCE 179 | ec2.CfnVPCEndpoint( 180 | self, 181 | f"{self.id}_vpce_sagemaker_metrics", 182 | service_name=f"com.amazonaws.{self.region}.sagemaker.metrics", 183 | vpc_id=vpc.vpc_id, 184 | private_dns_enabled=True, 185 | security_group_ids=[endpoint_security_group.security_group_id], 186 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 187 | vpc_endpoint_type="Interface" 188 | ) 189 | 190 | # SageMaker Runtime FIPS VPCE 191 | ec2.CfnVPCEndpoint( 192 | self, 193 | f"{self.id}_vpce_sagemaker_runtime_fips", 194 | service_name=f"com.amazonaws.{self.region}.sagemaker.runtime-fips", 195 | vpc_id=vpc.vpc_id, 196 | private_dns_enabled=True, 197 | security_group_ids=[endpoint_security_group.security_group_id], 198 | subnet_ids=[private_subnet1.subnet_id, private_subnet2.subnet_id], 199 | vpc_endpoint_type="Interface" 200 | ) 201 | 202 | for el in self.dependencies: 203 | vpc.node.add_dependency(el) 204 | 205 | return vpc, private_subnet1, private_subnet2, security_group 206 | -------------------------------------------------------------------------------- /setup/stack_constructs/scheduler.py: -------------------------------------------------------------------------------- 1 | from constructs import Construct 2 | from aws_cdk import ( 3 | aws_events as events, 4 | aws_events_targets as targets, 5 | aws_lambda as lambda_ 6 | ) 7 | 8 | 9 | class LambdaFunctionScheduler(Construct): 10 | def __init__( 11 | self, 12 | scope: Construct, 13 | id: str, 14 | dependencies: list = [] 15 | ): 16 | super().__init__(scope, id) 17 | 18 | self.id = id 19 | self.dependencies = dependencies 20 | 21 | def build( 22 | self, 23 | lambda_function: lambda_.Function, 24 | ): 25 | # ================================================== 26 | # ================== SCHEDULING ==================== 27 | # ================================================== 28 | cron_rule = events.Rule( 29 | scope=self, 30 | id=f"{self.id}_cron_rule", 31 | rule_name=f"{self.id}_usage_aggregator_schedule", 32 | schedule=events.Schedule.expression('cron(0 0 * * ? *)') 33 | ) 34 | 35 | cron_rule.add_target(target=targets.LambdaFunction(lambda_function)) 36 | 37 | for el in self.dependencies: 38 | cron_rule.node.add_dependency(el) 39 | -------------------------------------------------------------------------------- /utils/update_cost_files.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import pandas as pd 3 | import io 4 | 5 | # Initialize the S3 client 6 | s3 = boto3.client('s3') 7 | 8 | # Specify your bucket name 9 | bucket_name = '' 10 | folder_name = 'succeed' 11 | 12 | def update(): 13 | 14 | # List all objects in the specified folder 15 | response = s3.list_objects_v2(Bucket=bucket_name, Prefix=folder_name + '/') 16 | 17 | # Iterate through each object in the folder 18 | for obj in response['Contents']: 19 | # Get the object key (file name) 20 | key = obj['Key'] 21 | 22 | # Check if the file is a CSV 23 | if key.endswith('.csv'): 24 | print(f"Processing file: {key}") 25 | 26 | # Download the CSV file 27 | response = s3.get_object(Bucket=bucket_name, Key=key) 28 | csv_content = response['Body'].read() 29 | 30 | # Load the CSV into a pandas DataFrame 31 | df = pd.read_csv(io.BytesIO(csv_content), delimiter=',') 32 | 33 | # Check if 'api_key' column exists, if not, add it 34 | if 'api_key' not in df.columns: 35 | df.insert(0, 'api_key', '') 36 | print(f"Added 'api_key' column to {key}") 37 | 38 | # Convert the updated DataFrame back to CSV 39 | csv_buffer = io.StringIO() 40 | df.to_csv(csv_buffer, index=False) 41 | 42 | # Upload the updated CSV back to S3 43 | s3.put_object(Bucket=bucket_name, Key=key, Body=csv_buffer.getvalue()) 44 | print(f"Updated {key} in S3") 45 | 46 | print("Processing complete!") 47 | --------------------------------------------------------------------------------