├── .flake8 ├── .gitignore ├── .vscode └── settings.json ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ab-testing-pipeline.yml ├── app.py ├── cdk.json ├── deployment_pipeline ├── README.md ├── app.py ├── cdk.json ├── dev-config.json ├── infra │ ├── deployment_config.py │ ├── model_registry.py │ ├── sagemaker_stack.py │ └── test_model_registry.py ├── prod-config.json ├── register.py ├── requirements.txt └── setup.py ├── docs ├── API_CONFIGURATION.md ├── CODE_OF_CONDUCT.md ├── CUSTOM_TEMPLATE.md ├── FAQ.md ├── OPERATIONS.md ├── SERVICE_CATALOG.md ├── ab-testing-pipeline-architecture.png ├── ab-testing-pipeline-code-pipeline.png ├── ab-testing-pipeline-deployment.png ├── ab-testing-pipeline-execution-role.png ├── ab-testing-pipeline-model-registry.png ├── ab-testing-pipeline-sagemaker-project.png ├── ab-testing-pipeline-sagemaker-template.png ├── ab-testing-pipeline-upload-file.png ├── ab-testing-pipeline-xray.png └── ab-testing-solution-overview.png ├── infra ├── __init__.py ├── api_stack.py ├── pipeline_stack.py └── service_catalog.py ├── install_layers.sh ├── lambda └── api │ ├── algorithm.py │ ├── experiment_assignment.py │ ├── experiment_metrics.py │ ├── lambda_invoke.py │ ├── lambda_metrics.py │ ├── lambda_register.py │ ├── test_algorithm.py │ ├── test_experiment_assignment.py │ └── test_experiment_metrics.py ├── layers └── requirements.txt ├── notebook ├── dashboard.json ├── mab-reviews-helpfulness.ipynb └── simulation.py ├── requirements.txt ├── setup.py └── source.bat /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E226,E302,E41 3 | max-line-length = 120 4 | exclude = cdk.out/* 5 | max-complexity = 10 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.swp 3 | package-lock.json 4 | __pycache__ 5 | .pytest_cache 6 | .env 7 | .venv 8 | *.egg-info 9 | 10 | # CDK asset staging directory 11 | .cdk.staging 12 | cdk.out 13 | dist/ 14 | 15 | # Layers 16 | layers/python 17 | layers/*.zip 18 | 19 | # Notebooks 20 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/Users/julbrigh/projects/amazon-sagemaker-ab-testing-pipeline/.venv/bin/python3", 3 | "python.linting.flake8Enabled": true, 4 | "python.linting.pylintEnabled": false, 5 | "python.linting.enabled": true 6 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Amazon SageMaker A/B Testing Pipeline 2 | 3 | This sample demonstrates how to setup an Amazon SageMaker MLOps deployment pipeline for A/B Testing of machine learning models. 4 | 5 | ![Solution Overview](docs/ab-testing-solution-overview.png) 6 | 7 | The following are the high-level steps to deploy this solution: 8 | 9 | 1. Publish the SageMaker [MLOps Project template](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-projects-templates.html) in the [AWS Service Catalog](https://aws.amazon.com/servicecatalog/) 10 | 2. Deploy the [Amazon API Gateway](https://aws.amazon.com/api-gateway/) and Testing Infrastructure 11 | 3. Create a new Project in [Amazon SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-projects-create.]html) 12 | 13 | Once complete, you can Train and Deploy machine learning models for A/B Testing in the sample notebook provided. 14 | 15 | ## Get Started 16 | 17 | To get started first, clone this repository. 18 | 19 | ``` 20 | git clone https://github.com/aws-samples/amazon-sagemaker-ab-testing-pipeline.git 21 | cd amazon-sagemaker-ab-testing-pipeline 22 | ``` 23 | 24 | ## Prerequisites 25 | 26 | This project uses the AWS Cloud Development Kit [CDK](https://aws.amazon.com/cdk/). To [get started](https://docs.aws.amazon.com/cdk/latest/guide/getting_started.html) with AWS CDK you need [Node.js](https://nodejs.org/en/download/) 10.13.0 or later. 27 | 28 | ### Install the AWS CDK 29 | 30 | Install the AWS CDK Toolkit globally using the following Node Package Manager command. 31 | 32 | ``` 33 | npm install -g aws-cdk 34 | ``` 35 | 36 | Run the following command to verify correct installation and print the version number of the AWS CDK. 37 | 38 | ``` 39 | cdk --version 40 | ``` 41 | 42 | ### Setup Python Environment for CDK 43 | 44 | This project uses AWS CDK with python bindings to deploy resources to your AWS account. 45 | 46 | The `cdk.json` file tells the CDK Toolkit how to execute your app. 47 | 48 | This project is set up like a standard Python project. The initialization 49 | process also creates a virtualenv within this project, stored under the `.venv` 50 | directory. To create the virtualenv it assumes that there is a `python3` 51 | (or `python` for Windows) executable in your path with access to the `venv` 52 | package. If for any reason the automatic creation of the virtualenv fails, 53 | you can create the virtualenv manually. 54 | 55 | To manually create a virtualenv on MacOS and Linux: 56 | 57 | ``` 58 | python3 -m venv .venv 59 | ``` 60 | 61 | After the init process completes and the virtualenv is created, you can use the following 62 | step to activate your virtualenv. 63 | 64 | ``` 65 | source .venv/bin/activate 66 | ``` 67 | 68 | If you are a Windows platform, you would activate the virtualenv like this: 69 | 70 | ``` 71 | .venv\Scripts\activate.bat 72 | ``` 73 | 74 | Once the virtualenv is activated, you can install the required dependencies. 75 | 76 | ``` 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | ### Install Python Libraries for Lambda Layer 81 | 82 | In order to support AWS X-RAY as part of our python function we require additional python libraries. 83 | 84 | Run the following command to pip install the [AWS X-Ray SDK for Python](https://docs.aws.amazon.com/xray/latest/devguide/xray-sdk-python.html) into the `layers` folder. 85 | 86 | ``` 87 | sh install_layers.sh 88 | ``` 89 | 90 | This will enabling sample request to visualize the access patterns and drill into any specific errors. 91 | 92 | ![AB Testing Pipeline X-Ray](docs/ab-testing-pipeline-xray.png) 93 | 94 | ### Add Permissions for CDK 95 | 96 | AWS CDK requires permissions create AWS CloudFormation Stacks and the associated resources for your current execution role. If you have cloned this notebook into SageMaker Studio, you will need to add an inline policy to your SageMaker Studio execution role. You can find your user's role by browsing to the Studio dashboard. 97 | 98 | ![AB Testing Pipeline Execution Role](docs/ab-testing-pipeline-execution-role.png) 99 | 100 | Browse to the [IAM](https://console.aws.amazon.com/iam) section in the console, and find this role. 101 | 102 | Then, click the **Add inline policy** link, switch to to the **JSON** tab, and paste the following inline policy: 103 | 104 | ``` 105 | { 106 | "Version": "2012-10-17", 107 | "Statement": [ 108 | { 109 | "Effect": "Allow", 110 | "Action": [ 111 | "apigateway:*" 112 | ], 113 | "Resource": "arn:aws:apigateway:*::/*" 114 | }, 115 | { 116 | "Action": [ 117 | "dynamodb:*" 118 | ], 119 | "Effect": "Allow", 120 | "Resource": "arn:aws:dynamodb:*:*:table/ab-testing-*" 121 | }, 122 | { 123 | "Action": [ 124 | "lambda:*" 125 | ], 126 | "Effect": "Allow", 127 | "Resource": [ 128 | "arn:aws:lambda:*:*:function:ab-testing-api-*", 129 | "arn:aws:lambda:*:*:layer:*" 130 | ] 131 | }, 132 | { 133 | "Action": [ 134 | "firehose:*" 135 | ], 136 | "Effect": "Allow", 137 | "Resource": "arn:aws:firehose:*:*:deliverystream/ab-testing-*" 138 | }, 139 | { 140 | "Action": [ 141 | "s3:*" 142 | ], 143 | "Effect": "Allow", 144 | "Resource": [ 145 | "arn:aws:s3:::cdktoolkit-*", 146 | "arn:aws:s3:::ab-testing-api-*" 147 | ] 148 | }, 149 | { 150 | "Action": [ 151 | "cloudformation:*", 152 | "servicecatalog:*", 153 | "events:*" 154 | ], 155 | "Effect": "Allow", 156 | "Resource": "*" 157 | }, 158 | { 159 | "Effect": "Allow", 160 | "Action": [ 161 | "logs:*" 162 | ], 163 | "Resource": "arn:aws:logs:**:*:log-group:ab-testing-api-*" 164 | }, 165 | { 166 | "Effect": "Allow", 167 | "Action": [ 168 | "iam:CreateRole", 169 | "iam:DeleteRole" 170 | ], 171 | "Resource": "arn:aws:iam::*:role/ab-testing-api-*" 172 | }, 173 | { 174 | "Effect": "Allow", 175 | "Action": [ 176 | "iam:GetRole", 177 | "iam:PassRole", 178 | "iam:GetRolePolicy", 179 | "iam:AttachRolePolicy", 180 | "iam:PutRolePolicy", 181 | "iam:DetachRolePolicy", 182 | "iam:DeleteRolePolicy" 183 | ], 184 | "Resource": [ 185 | "arn:aws:iam::*:role/ab-testing-api-*", 186 | "arn:aws:iam::*:role/service-role/AmazonSageMaker*" 187 | ] 188 | } 189 | ] 190 | } 191 | ``` 192 | 193 | Click **Review policy** and provide the name `CDK-DeployPolicy` then click **Create policy** 194 | 195 | ### Bootstrap the CDK 196 | 197 | If this is the first time you have run the CDK, you may need to [Bootstrap](https://docs.aws.amazon.com/cdk/latest/guide/bootstrapping.html) your account. If you have multiple deployment targets see also [Specifying up your environment](https://docs.aws.amazon.com/cdk/latest/guide/cli.html#cli-environment) in the CDK documentation. 198 | 199 | ``` 200 | cdk bootstrap 201 | ``` 202 | 203 | You should now be able to list the stacks by running: 204 | 205 | ``` 206 | cdk list 207 | ``` 208 | 209 | Which will return the following stacks: 210 | 211 | * `ab-testing-api` 212 | * `ab-testing-pipeline` 213 | * `ab-testing-service-catalog` 214 | 215 | ## Publish the API and AWS Service Catalog template 216 | 217 | In this section you will publish the AWS Service Catalog template and Deploy the API and Testing infrastructure. 218 | 219 | ### Publish the SageMaker MLOps Project template 220 | 221 | In this step you will create a *Portfolio* and *Product* to provision a custom SageMaker MLOps Project template in the AWS Service Catalog and configure it so you can launch the project from within your SageMaker Studio domain. See more information on [customizing](docs/CUSTOM_TEMPLATE.md) the template, or import the template [manually](docs/SERVICE_CATALOG.md) into the AWS Service Catalog. 222 | 223 | ![AB Testing Pipeline](docs/ab-testing-pipeline-deployment.png) 224 | 225 | Resources include: 226 | * **AWS CodeCommit** seeded with the source from the [deployment_pipeline](deployment_pipeline). 227 | * **AWS CodeBuild** to produce **AWS CloudFormation** for deploying the **Amazon SageMaker Endpoint**. 228 | * **Amazon CloudWatch Event** to trigger the **AWS CodePipeline** for endpoint deployment. 229 | 230 | Run the following command to deploy the MLOps project template, passing the required `ExecutionRoleArn` parameter. You can copy this from your SageMaker Studio dashboard as show above. 231 | 232 | ``` 233 | export EXECUTION_ROLE_ARN=<> 234 | cdk deploy ab-testing-service-catalog \ 235 | --parameters ExecutionRoleArn=$EXECUTION_ROLE_ARN \ 236 | --parameters PortfolioName="SageMaker Organization Templates" \ 237 | --parameters PortfolioOwner="administrator" \ 238 | --parameters ProductVersion=1.0 239 | ``` 240 | 241 | This stack will output the `CodeCommitSeedBucket` and `CodeCommitSeedKey` which you will need when creating the Amazon SageMaker Studio project. 242 | 243 | `NOTE`: If you are seeing errors running the above command ensure you have [Enabled SageMaker project templates for Studio users](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-projects-studio-updates.html) to grant access to these resources in Amazon S3. 244 | 245 | ### Deploy the API and Testing infrastructure 246 | 247 | In this step you will deploy an Amazon API Gateway and supporting resources to enable dynamic A/B Testing of any Amazon SageMaker endpoint that has multiple production variants. 248 | 249 | ![AB Testing Architecture](docs/ab-testing-pipeline-architecture.png) 250 | 251 | Resources include: 252 | 253 | * **Amazon API Gateway** and **AWS Lambda** functions for invocation. 254 | * **Amazon DynamoDB** table for user variant assignment. 255 | * **Amazon DynamoDB** table for variant metrics. 256 | * **Amazon Kinesis Firehose**, **Amazon S3** Bucket and **AWS Lambda** for processing events. 257 | * **Amazon CloudWatch Event** and **AWS Lambda** to register in service **Amazon SageMaker Endpoints**. 258 | 259 | Run the following command to deploy the API and testing infrastructure with optional [configuration](docs/API_CONFIGURATION.md). 260 | 261 | ``` 262 | cdk deploy ab-testing-api 263 | ``` 264 | 265 | This stack will output the `ApiEndpoint` which you will provide to the A/B Testing sample notebook. 266 | 267 | You’re done! Now it’s time to create a project using this template. 268 | 269 | ## Creating a new Project in Amazon SageMaker Studio 270 | 271 | Once your MLOps project template is registered in *AWS Service Catalog* you can create a project using your new template. 272 | 273 | 1. Switch back to the Launcher 274 | 2. Click **New Project** from the **ML tasks and components** section. 275 | 276 | On the Create project page, SageMaker templates is chosen by default. This option lists the built-in templates. However, you want to use the template you published for the A/B Testing Deployment Pipeline. 277 | 278 | 6. Choose **Organization templates**. 279 | 7. Choose **A/B Testing Deployment Pipeline**. 280 | 8. Choose **Select project template**. 281 | 282 | ![Select Template](docs/ab-testing-pipeline-sagemaker-template.png) 283 | 284 | 9. In the **Project details** section, for **Name**, enter **ab-testing-pipeline**. 285 | - The project name must have 32 characters or fewer. 286 | 10. In the Project template parameters 287 | - For **StageName**, enter `dev` 288 | - For **CodeCommitSeedBucket**, enter the `CodeCommitSeedBucket` output from the `ab-testing-service-catalog` stack 289 | - For **CodeCommitSeedKey**, enter the `CodeCommitSeedKey` output from the `ab-testing-service-catalog` stack 290 | 11. Choose Create project. 291 | 292 | ![Create Project](docs/ab-testing-pipeline-sagemaker-project.png) 293 | 294 | `NOTE`: If you have recently updated your AWS Service Catalog Project, you may need to refresh SageMaker Studio to ensure it picks up the latest version of your template. 295 | 296 | ## Train and Deploy machine learning models for A/B Testing 297 | 298 | In the following sections, you will learn how to **Train**, **Deploy** and **Simulate** a test against our A/B Testing Pipeline. 299 | 300 | ### Training a Model 301 | 302 | Now that your project is ready, it’s time to train, register and approve a model. 303 | 304 | 1. Download the [Sample Notebook](notebook/mab-reviews-helpfulness.ipynb) to use for this walk-through. 305 | 2. Choose the **Upload file** button 306 | 3. Choose the Jupyter notebook you downloaded and upload it. 307 | 4. Choose the notebook to open a new tab. 308 | 309 | ![Upload File](docs/ab-testing-pipeline-upload-file.png) 310 | 311 | This notebook will step you through the process of 312 | 1. Download a dataset 313 | 2. Create and Run an Amazon SageMaker Pipeline 314 | 3. Approve the model. 315 | 4. Create a Amazon SageMaker Tuning Job. 316 | 5. Select the best model, register and approve the second model. 317 | 318 | ### Deploying the Multi-Variant Pipeline. 319 | 320 | Once the second model has been approved, the MLOps deployment pipeline will run. 321 | 322 | See the [Deployment Pipeline](deployment_pipeline) for more information on the stages to run. 323 | 324 | ### Running an A/B Testing simulation 325 | 326 | With the Deployment Pipeline complete, you will be able to continue with the next stage: 327 | 1. Test the multi-variant endpoint 328 | 2. Evaluate the accuracy of the models, and visualize the confusion matrix and ROC Curves 329 | 3. Test the API by simulating a series of `invocation`, and recording reward `conversion`. 330 | 4. Plot the cumulative reward, and reward rate. 331 | 5. Plot the beta distributions of the course of the test. 332 | 6. Calculate the statistical significance of the test. 333 | 334 | ## Running Costs 335 | 336 | This section outlines cost considerations for running the A/B Testing Pipeline. Completing the pipeline will deploy an endpoint with 2 production variants which will cost less than $6 per day. Further cost breakdowns are below. 337 | 338 | - **CodeBuild** – Charges per minute used. First 100 minutes each month come at no charge. For information on pricing beyond the first 100 minutes, see [AWS CodeBuild Pricing](https://aws.amazon.com/codebuild/pricing/). 339 | - **CodeCommit** – $1/month if you didn't opt to use your own GitHub repository. 340 | - **CodePipeline** – CodePipeline costs $1 per active pipeline* per month. Pipelines are free for the first 30 days after creation. More can be found at [AWS CodePipeline Pricing](https://aws.amazon.com/codepipeline/pricing/). 341 | - **SageMaker** – Prices vary based on EC2 instance usage for the Notebook Instances, Model Hosting, Model Training and Model Monitoring; each charged per hour of use. For more information, see [Amazon SageMaker Pricing](https://aws.amazon.com/sagemaker/pricing/). 342 | - The ten `ml.c5.4xlarge` *training jobs* run for approx 4 minutes at $0.81 an hour, and cost less than $1. 343 | - The two `ml.t2.large` instances for production *hosting* endpoint costs 2 x $0.111 per hour, or $5.33 per day. 344 | - **S3** – Low cost, prices will vary depending on the size of the models/artifacts stored. The first 50 TB each month will cost only $0.023 per GB stored. For more information, see [Amazon S3 Pricing](https://aws.amazon.com/s3/pricing/). 345 | - **API Gateway** - Low cost, $1.29 for first 300 million requests. For more info see [Amazon API Gateway pricing](https://aws.amazon.com/api-gateway/pricing/) 346 | - **Lambda** - Low cost, $0.20 per 1 million request see [AWS Lambda Pricing](https://aws.amazon.com/lambda/pricing/). 347 | 348 | ## Cleaning Up 349 | 350 | Once you have cleaned up the SageMaker Endpoints and Project as described in the [Sample Notebook](notebook/mab-reviews-helpfulness.ipynb), complete the clean up by deleting the **Service Catalog** and **API** resources with the AWS CDK: 351 | 352 | 1. Delete the Service Catalog Portfolio and Project Template 353 | 354 | ``` 355 | cdk destroy ab-testing-service-catalog 356 | ``` 357 | 358 | 2. Delete the API and testing infrastructure 359 | 360 | Before destroying the API stack, is is recommend you [empty](https://docs.aws.amazon.com/AmazonS3/latest/userguide/empty-bucket.html) and [delete](https://docs.aws.amazon.com/AmazonS3/latest/userguide/delete-bucket.html) the S3 Bucket that contains the S3 logs persisted by the Kinesis Firehose. 361 | 362 | ``` 363 | cdk destroy ab-testing-api 364 | ``` 365 | 366 | ## Want to know more? 367 | 368 | The [FAQ](docs/FAQ.md) page has some answers to questions on the design principals of this sample. 369 | 370 | See also the [OPERATIONS](docs/OPERATIONS.md) page for information on configuring experiments, and the API interface. 371 | 372 | ## Security 373 | 374 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 375 | 376 | ## License 377 | 378 | This library is licensed under the MIT-0 License. See the [LICENSE](LICENSE) file. -------------------------------------------------------------------------------- /ab-testing-pipeline.yml: -------------------------------------------------------------------------------- 1 | Parameters: 2 | SageMakerProjectName: 3 | Type: String 4 | Description: The name of the SageMaker project. 5 | MaxLength: 32 6 | MinLength: 1 7 | SageMakerProjectId: 8 | Type: String 9 | Description: Service generated Id of the project. 10 | MaxLength: 16 11 | MinLength: 1 12 | StageName: 13 | Type: String 14 | Default: dev 15 | Description: The stage name. 16 | MaxLength: 8 17 | MinLength: 1 18 | CodeCommitSeedBucket: 19 | Type: String 20 | Description: The optional s3 seed bucket 21 | MinLength: 1 22 | CodeCommitSeedKey: 23 | Type: String 24 | Description: The optional s3 seed key 25 | MinLength: 1 26 | Resources: 27 | CodeRepo: 28 | Type: AWS::CodeCommit::Repository 29 | Properties: 30 | RepositoryName: 31 | Fn::Join: 32 | - "" 33 | - - sagemaker- 34 | - Ref: SageMakerProjectName 35 | - -repo 36 | Code: 37 | BranchName: main 38 | S3: 39 | Bucket: 40 | Ref: CodeCommitSeedBucket 41 | Key: 42 | Ref: CodeCommitSeedKey 43 | RepositoryDescription: Amazon SageMaker A/B testing pipeline 44 | Tags: 45 | - Key: sagemaker:deployment-stage 46 | Value: 47 | Ref: StageName 48 | - Key: sagemaker:project-id 49 | Value: 50 | Ref: SageMakerProjectId 51 | - Key: sagemaker:project-name 52 | Value: 53 | Ref: SageMakerProjectName 54 | CdkBuild455F642E: 55 | Type: AWS::CodeBuild::Project 56 | Properties: 57 | Artifacts: 58 | Type: CODEPIPELINE 59 | Environment: 60 | ComputeType: BUILD_GENERAL1_SMALL 61 | EnvironmentVariables: 62 | - Name: SAGEMAKER_PROJECT_NAME 63 | Type: PLAINTEXT 64 | Value: 65 | Ref: SageMakerProjectName 66 | - Name: SAGEMAKER_PROJECT_ID 67 | Type: PLAINTEXT 68 | Value: 69 | Ref: SageMakerProjectId 70 | - Name: STAGE_NAME 71 | Type: PLAINTEXT 72 | Value: 73 | Ref: StageName 74 | Image: aws/codebuild/standard:1.0 75 | ImagePullCredentialsType: CODEBUILD 76 | PrivilegedMode: false 77 | Type: LINUX_CONTAINER 78 | ServiceRole: 79 | Fn::Join: 80 | - "" 81 | - - "arn:" 82 | - Ref: AWS::Partition 83 | - ":iam::" 84 | - Ref: AWS::AccountId 85 | - :role/service-role/AmazonSageMakerServiceCatalogProductsUseRole 86 | Source: 87 | BuildSpec: >- 88 | { 89 | "version": "0.2", 90 | "phases": { 91 | "install": { 92 | "commands": [ 93 | "npm install aws-cdk", 94 | "npm update", 95 | "python -m pip install -r requirements.txt" 96 | ] 97 | }, 98 | "build": { 99 | "commands": [ 100 | "npx cdk synth -o dist --path-metadata false" 101 | ] 102 | } 103 | }, 104 | "artifacts": { 105 | "base-directory": "dist", 106 | "files": [ 107 | "*.template.json" 108 | ] 109 | }, 110 | "environment": { 111 | "buildImage": { 112 | "type": "LINUX_CONTAINER", 113 | "defaultComputeType": "BUILD_GENERAL1_SMALL", 114 | "imageId": "aws/codebuild/amazonlinux2-x86_64-standard:3.0", 115 | "imagePullPrincipalType": "CODEBUILD" 116 | } 117 | } 118 | } 119 | Type: CODEPIPELINE 120 | EncryptionKey: alias/aws/s3 121 | Name: 122 | Fn::Join: 123 | - "" 124 | - - sagemaker- 125 | - Ref: SageMakerProjectName 126 | - -cdk- 127 | - Ref: StageName 128 | S3Artifact80610462: 129 | Type: AWS::S3::Bucket 130 | Properties: 131 | BucketName: 132 | Fn::Join: 133 | - "" 134 | - - sagemaker- 135 | - Ref: SageMakerProjectId 136 | - -artifact- 137 | - Ref: StageName 138 | - "-" 139 | - Ref: AWS::Region 140 | UpdateReplacePolicy: Delete 141 | DeletionPolicy: Delete 142 | PipelineC660917D: 143 | Type: AWS::CodePipeline::Pipeline 144 | Properties: 145 | RoleArn: 146 | Fn::Join: 147 | - "" 148 | - - "arn:" 149 | - Ref: AWS::Partition 150 | - ":iam::" 151 | - Ref: AWS::AccountId 152 | - :role/service-role/AmazonSageMakerServiceCatalogProductsUseRole 153 | Stages: 154 | - Actions: 155 | - ActionTypeId: 156 | Category: Source 157 | Owner: AWS 158 | Provider: CodeCommit 159 | Version: "1" 160 | Configuration: 161 | RepositoryName: 162 | Fn::GetAtt: 163 | - CodeRepo 164 | - Name 165 | BranchName: main 166 | PollForSourceChanges: false 167 | Name: CodeCommit_Source 168 | OutputArtifacts: 169 | - Name: Artifact_Source_CodeCommit_Source 170 | RunOrder: 1 171 | Name: Source 172 | - Actions: 173 | - ActionTypeId: 174 | Category: Build 175 | Owner: AWS 176 | Provider: CodeBuild 177 | Version: "1" 178 | Configuration: 179 | ProjectName: 180 | Ref: CdkBuild455F642E 181 | InputArtifacts: 182 | - Name: Artifact_Source_CodeCommit_Source 183 | Name: CDK_Build 184 | OutputArtifacts: 185 | - Name: Artifact_Build_CDK_Build 186 | RunOrder: 1 187 | Name: Build 188 | - Actions: 189 | - ActionTypeId: 190 | Category: Deploy 191 | Owner: AWS 192 | Provider: CloudFormation 193 | Version: "1" 194 | Configuration: 195 | StackName: 196 | Fn::Join: 197 | - "" 198 | - - sagemaker- 199 | - Ref: SageMakerProjectName 200 | - -deploy- 201 | - Ref: StageName 202 | RoleArn: 203 | Fn::Join: 204 | - "" 205 | - - "arn:" 206 | - Ref: AWS::Partition 207 | - ":iam::" 208 | - Ref: AWS::AccountId 209 | - :role/service-role/AmazonSageMakerServiceCatalogProductsUseRole 210 | ActionMode: REPLACE_ON_FAILURE 211 | TemplatePath: Artifact_Build_CDK_Build::ab-testing-sagemaker.template.json 212 | InputArtifacts: 213 | - Name: Artifact_Build_CDK_Build 214 | Name: SageMaker_CFN_Deploy 215 | RunOrder: 1 216 | Name: Deploy 217 | ArtifactStore: 218 | Location: 219 | Ref: S3Artifact80610462 220 | Type: S3 221 | Name: 222 | Fn::Join: 223 | - "" 224 | - - sagemaker- 225 | - Ref: SageMakerProjectName 226 | - -pipeline- 227 | - Ref: StageName 228 | DeployRule0F8E909D: 229 | Type: AWS::Events::Rule 230 | Properties: 231 | Description: Rule to trigger a deployment when SageMaker Model registry is updated with a new model package. 232 | EventPattern: 233 | detail: 234 | ModelPackageGroupName: 235 | - Fn::Join: 236 | - "" 237 | - - Ref: SageMakerProjectName 238 | - -champion 239 | - Fn::Join: 240 | - "" 241 | - - Ref: SageMakerProjectName 242 | - -challenger 243 | detail-type: 244 | - SageMaker Model Package State Change 245 | source: 246 | - aws.sagemaker 247 | Name: 248 | Fn::Join: 249 | - "" 250 | - - sagemaker- 251 | - Ref: SageMakerProjectName 252 | - -model- 253 | - Ref: StageName 254 | State: ENABLED 255 | Targets: 256 | - Arn: 257 | Fn::Join: 258 | - "" 259 | - - "arn:" 260 | - Ref: AWS::Partition 261 | - ":codepipeline:" 262 | - Ref: AWS::Region 263 | - ":" 264 | - Ref: AWS::AccountId 265 | - ":" 266 | - Ref: PipelineC660917D 267 | Id: Target0 268 | RoleArn: 269 | Fn::Join: 270 | - "" 271 | - - "arn:" 272 | - Ref: AWS::Partition 273 | - ":iam::" 274 | - Ref: AWS::AccountId 275 | - :role/service-role/AmazonSageMakerServiceCatalogProductsUseRole 276 | CodeRule663E3DC0: 277 | Type: AWS::Events::Rule 278 | Properties: 279 | Description: Rule to trigger a deployment when deployment configured is updated in CodeCommit. 280 | EventPattern: 281 | detail: 282 | event: 283 | - referenceCreated 284 | - referenceUpdated 285 | referenceType: 286 | - branch 287 | referenceName: 288 | - main 289 | detail-type: 290 | - CodeCommit Repository State Change 291 | resources: 292 | - Fn::Join: 293 | - "" 294 | - - "arn:" 295 | - Ref: AWS::Partition 296 | - ":codecommit:" 297 | - Ref: AWS::Region 298 | - ":" 299 | - Ref: AWS::AccountId 300 | - ":" 301 | - Fn::GetAtt: 302 | - CodeRepo 303 | - Name 304 | source: 305 | - aws.codecommit 306 | Name: 307 | Fn::Join: 308 | - "" 309 | - - sagemaker- 310 | - Ref: SageMakerProjectName 311 | - -code- 312 | - Ref: StageName 313 | State: ENABLED 314 | Targets: 315 | - Arn: 316 | Fn::Join: 317 | - "" 318 | - - "arn:" 319 | - Ref: AWS::Partition 320 | - ":codepipeline:" 321 | - Ref: AWS::Region 322 | - ":" 323 | - Ref: AWS::AccountId 324 | - ":" 325 | - Ref: PipelineC660917D 326 | Id: Target0 327 | RoleArn: 328 | Fn::Join: 329 | - "" 330 | - - "arn:" 331 | - Ref: AWS::Partition 332 | - ":iam::" 333 | - Ref: AWS::AccountId 334 | - :role/service-role/AmazonSageMakerServiceCatalogProductsUseRole 335 | CDKMetadata: 336 | Type: AWS::CDK::Metadata 337 | Properties: 338 | Analytics: v2:deflate64:H4sIAAAAAAAAE0WN0YrCMBBFv8X3OCIVln0T/YFSvyBOR3ZskynJRCkh/742Lfh0L3M5Z47we4Lj7mzfcY/9cMgogSDf1OJgrg/f2mAdKQXTUZQUkMxVfNSQUJf9e334z9CzsvhiFl1m6yC3MjLOVVVbMSg9oTjHClUwSWSVMNfhnnjsPxRPNLKnNsiT1k9bLSY2kC8JB6rnta3WaaO+fAW3Xgy9yGuE3KWxLkuWUkw765/4QwM/0OyekXkfkld2BN2a//2ekvEmAQAA 339 | Condition: CDKMetadataAvailable 340 | Conditions: 341 | CDKMetadataAvailable: 342 | Fn::Or: 343 | - Fn::Or: 344 | - Fn::Equals: 345 | - Ref: AWS::Region 346 | - af-south-1 347 | - Fn::Equals: 348 | - Ref: AWS::Region 349 | - ap-east-1 350 | - Fn::Equals: 351 | - Ref: AWS::Region 352 | - ap-northeast-1 353 | - Fn::Equals: 354 | - Ref: AWS::Region 355 | - ap-northeast-2 356 | - Fn::Equals: 357 | - Ref: AWS::Region 358 | - ap-south-1 359 | - Fn::Equals: 360 | - Ref: AWS::Region 361 | - ap-southeast-1 362 | - Fn::Equals: 363 | - Ref: AWS::Region 364 | - ap-southeast-2 365 | - Fn::Equals: 366 | - Ref: AWS::Region 367 | - ca-central-1 368 | - Fn::Equals: 369 | - Ref: AWS::Region 370 | - cn-north-1 371 | - Fn::Equals: 372 | - Ref: AWS::Region 373 | - cn-northwest-1 374 | - Fn::Or: 375 | - Fn::Equals: 376 | - Ref: AWS::Region 377 | - eu-central-1 378 | - Fn::Equals: 379 | - Ref: AWS::Region 380 | - eu-north-1 381 | - Fn::Equals: 382 | - Ref: AWS::Region 383 | - eu-south-1 384 | - Fn::Equals: 385 | - Ref: AWS::Region 386 | - eu-west-1 387 | - Fn::Equals: 388 | - Ref: AWS::Region 389 | - eu-west-2 390 | - Fn::Equals: 391 | - Ref: AWS::Region 392 | - eu-west-3 393 | - Fn::Equals: 394 | - Ref: AWS::Region 395 | - me-south-1 396 | - Fn::Equals: 397 | - Ref: AWS::Region 398 | - sa-east-1 399 | - Fn::Equals: 400 | - Ref: AWS::Region 401 | - us-east-1 402 | - Fn::Equals: 403 | - Ref: AWS::Region 404 | - us-east-2 405 | - Fn::Or: 406 | - Fn::Equals: 407 | - Ref: AWS::Region 408 | - us-west-1 409 | - Fn::Equals: 410 | - Ref: AWS::Region 411 | - us-west-2 412 | 413 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | 5 | from aws_cdk import core 6 | from infra.api_stack import ApiStack 7 | from infra.pipeline_stack import PipelineStack 8 | from infra.service_catalog import ServiceCatalogStack 9 | 10 | # Configure the logger 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(level="INFO") 13 | 14 | # Create App and stacks 15 | app = core.App() 16 | 17 | # Create the API and SC stacks 18 | ApiStack(app, "ab-testing-api") 19 | PipelineStack(app, "ab-testing-pipeline") 20 | ServiceCatalogStack(app, "ab-testing-service-catalog") 21 | 22 | app.synth() 23 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "context": { 4 | "@aws-cdk/core:enableStackNameDuplicates": "true", 5 | "aws-cdk:enableDiffNoFail": "true", 6 | "@aws-cdk/core:stackRelativeExports": "true", 7 | "@aws-cdk/aws-ecr-assets:dockerIgnoreSupport": true, 8 | "@aws-cdk/aws-secretsmanager:parseOwnedSecretName": true, 9 | "@aws-cdk/aws-kms:defaultKeyPolicies": true, 10 | "@aws-cdk/aws-s3:grantWriteWithoutAcl": true, 11 | "log_level": "INFO", 12 | "api_name": "ab-testing", 13 | "stage_name": "dev", 14 | "endpoint_prefix": "sagemaker-", 15 | "api_lambda_memory": 768, 16 | "api_lambda_timeout": 60, 17 | "metrics_lambda_memory": 768, 18 | "metrics_lambda_timeout": 300, 19 | "dynamodb_read_capacity": 5, 20 | "dynamodb_write_capacity": 5, 21 | "delivery_sync": false, 22 | "firehose_interval": 60, 23 | "firehose_mb_size": 1 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /deployment_pipeline/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Amazon SageMaker A/B Testing Pipeline 3 | 4 | This folder contains the CDK infrastructure for the multi-variant deployment pipeline. 5 | 6 | ## Deployment Pipeline 7 | 8 | This deployment pipeline contains a few stages. 9 | 10 | 1. **Source**: Pull the latest deployment configuration from AWS CodeCommit repository. 11 | 1. **Build**: AWS CodeBuild job to create the AWS CloudFormation template for deploying the endpoint. 12 | - Query the Amazon SageMaker project to get the top approved models. 13 | - Use the AWS CDK to create a CFN stack to deploy multi-variant SageMaker Endpoint. 14 | 2. **Deploy**: Run the AWS CloudFormation stack to create/update the SageMaker endpoint, tagged with properties based on configuration: 15 | - `ab-testing:enabled` equals `true` 16 | - `ab-testing:strategy` is one `WeightedSampling`, `EpslionGreedy`, `UCB1` or `ThompsonSampling`. 17 | - `ab-testing:epsilon` is parameters for `EpslionGreedy` strategy, defaults to `0.1`. 18 | - `ab-testing:warmup` the number of invocations to warmup with `WeightedSampling` strategy, defaults to `0`. 19 | 20 | ![\[AWS CodePipeline\]](../docs/ab-testing-pipeline-code-pipeline.png) 21 | 22 | ## Testing 23 | 24 | Once you have created a SageMaker Project, you can test the **Build** stage and **Register** events locally by setting some environment variables. 25 | 26 | ### Build Stage 27 | 28 | Export the environment variables for the `SAGEMAKER_PROJECT_NAME` and `SAGEMAKER_PROJECT_ID` created by your SageMaker Project cloud formation. Then run the `cdk synth` command: 29 | 30 | ``` 31 | export SAGEMAKER_PROJECT_NAME="<>" 32 | export SAGEMAKER_PROJECT_ID="<>" 33 | export STAGE_NAME="dev" 34 | cdk synth 35 | ``` 36 | 37 | ### Register 38 | 39 | Export the environment variable for the `REGISTER_LAMBDA` created as part of the `ab-testing-api` stack, then run `register.py` file. 40 | 41 | ``` 42 | export REGISTER_LAMBDA="<>" 43 | python register.py 44 | ``` 45 | -------------------------------------------------------------------------------- /deployment_pipeline/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import logging 5 | import os 6 | 7 | from aws_cdk import core 8 | from infra.model_registry import ModelRegistry 9 | from infra.deployment_config import DeploymentConfig 10 | from infra.sagemaker_stack import SageMakerStack 11 | 12 | # Configure the logger 13 | logger = logging.getLogger(__name__) 14 | logging.basicConfig(level="INFO") 15 | 16 | # Load these from environment variables, that are passed into CodeBuild job from pipeline stack 17 | project_name = os.environ["SAGEMAKER_PROJECT_NAME"] 18 | project_id = os.environ["SAGEMAKER_PROJECT_ID"] 19 | stage_name = os.environ["STAGE_NAME"] 20 | 21 | # Create App and stacks 22 | app = core.App() 23 | 24 | # Define variables for passing down to stacks 25 | endpoint_name = f"sagemaker-{project_name}-{stage_name}" 26 | if len(endpoint_name) > 63: 27 | raise Exception( 28 | f"SageMaker endpoint: {endpoint_name} must be less than 64 characters" 29 | ) 30 | 31 | logger.info(f"Create endpoint: {endpoint_name}") 32 | 33 | 34 | # Define the deployment tags 35 | tags = [ 36 | core.CfnTag(key="sagemaker:deployment-stage", value=stage_name), 37 | core.CfnTag(key="sagemaker:project-id", value=project_id), 38 | core.CfnTag(key="sagemaker:project-name", value=project_name), 39 | ] 40 | 41 | # Get the stage specific deployment config for sagemaker 42 | with open(f"{stage_name}-config.json", "r") as f: 43 | j = json.load(f) 44 | deployment_config = DeploymentConfig(**j) 45 | # Append tags for ab-testing 46 | tags += [ 47 | core.CfnTag(key="ab-testing:enabled", value="true"), 48 | core.CfnTag(key="ab-testing:strategy", value=deployment_config.strategy), 49 | core.CfnTag(key="ab-testing:epsilon", value=str(deployment_config.epsilon)), 50 | core.CfnTag(key="ab-testing:warmup", value=str(deployment_config.warmup)), 51 | ] 52 | 53 | sagemaker = SageMakerStack( 54 | app, 55 | "ab-testing-sagemaker", 56 | deployment_config=deployment_config, 57 | project_name=project_name, 58 | project_id=project_id, 59 | endpoint_name=endpoint_name, 60 | tags=tags, 61 | ) 62 | 63 | app.synth() 64 | -------------------------------------------------------------------------------- /deployment_pipeline/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "context": { 4 | "@aws-cdk/core:enableStackNameDuplicates": "true", 5 | "aws-cdk:enableDiffNoFail": "true", 6 | "@aws-cdk/core:stackRelativeExports": "true", 7 | "@aws-cdk/aws-ecr-assets:dockerIgnoreSupport": true, 8 | "@aws-cdk/aws-secretsmanager:parseOwnedSecretName": true, 9 | "@aws-cdk/aws-kms:defaultKeyPolicies": true, 10 | "@aws-cdk/aws-s3:grantWriteWithoutAcl": true 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /deployment_pipeline/dev-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "stage_name": "dev", 3 | "strategy": "ThompsonSampling", 4 | "instance_count": 1, 5 | "instance_type": "ml.t2.large", 6 | "challenger_variant_count": 1 7 | } -------------------------------------------------------------------------------- /deployment_pipeline/infra/deployment_config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class InstanceConfig: 5 | def __init__(self, instance_count: int = 1, instance_type: str = "ml.t2.medium"): 6 | self.instance_count = instance_count 7 | self.instance_type = instance_type 8 | 9 | 10 | class VariantConfig(InstanceConfig): 11 | def __init__( 12 | self, 13 | model_package_version: str, 14 | initial_variant_weight: float = 1.0, 15 | variant_name: str = None, 16 | instance_count: int = 1, 17 | instance_type: str = "ml.t2.medium", 18 | model_package_arn: str = None, 19 | ): 20 | self.model_package_version = model_package_version 21 | self.initial_variant_weight = initial_variant_weight 22 | self.variant_name = variant_name 23 | self.model_package_arn = model_package_arn 24 | super().__init__(instance_count, instance_type) 25 | 26 | 27 | class AlgorithmStrategy(Enum): 28 | WEIGHTED_SAMPLING = 0 29 | EPSILOM_GREEDY = 1 30 | UCB1 = 2 31 | THOMPSON_SAMPLING = 3 32 | 33 | 34 | class DeploymentConfig(InstanceConfig): 35 | def __init__( 36 | self, 37 | stage_name: str, 38 | challenger_variant_count: int = 1, 39 | champion_variant_config: dict = None, 40 | challenger_variant_config: list = None, 41 | instance_count: int = 1, 42 | instance_type: str = "ml.t2.medium", 43 | strategy: str = "ThompsonSampling", 44 | warmup: int = 0, 45 | epsilon: float = 0.1, 46 | ): 47 | self.stage_name = stage_name 48 | # Provide either the challenger variant count, or specific champion/challenger config 49 | self.challenger_variant_count = challenger_variant_count 50 | # Turn dict into typed object 51 | if type(champion_variant_config) is dict: 52 | self.champion_variant_config = VariantConfig( 53 | **{ 54 | "instance_count": instance_count, 55 | "instance_type": instance_type, 56 | **champion_variant_config, 57 | } 58 | ) 59 | else: 60 | self.champion_variant_config = None 61 | # Turn list into typed objects 62 | if type(challenger_variant_config) is list: 63 | self.challenger_variant_config = [ 64 | # Use deployment instance count/type as default for variant config 65 | VariantConfig( 66 | **{ 67 | "instance_count": instance_count, 68 | "instance_type": instance_type, 69 | **vc, 70 | } 71 | ) 72 | for vc in challenger_variant_config 73 | ] 74 | else: 75 | self.challenger_variant_config = None 76 | self.strategy = strategy 77 | self.warmup = warmup 78 | self.epsilon = epsilon 79 | super().__init__(instance_count, instance_type) 80 | -------------------------------------------------------------------------------- /deployment_pipeline/infra/model_registry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | 4 | import boto3 5 | from botocore.config import Config 6 | from botocore.exceptions import ClientError 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ModelRegistry: 12 | """ 13 | Class for managing models in the registry. 14 | """ 15 | 16 | def __init__(self): 17 | config = Config(retries={"max_attempts": 10, "mode": "standard"}) 18 | self.sm_client = boto3.client("sagemaker", config=config) 19 | 20 | def create_model_package_group( 21 | self, 22 | model_package_group_name: str, 23 | description: str, 24 | project_name: str, 25 | project_id: str, 26 | ): 27 | """ 28 | Create the model package group if it doesn't exist. 29 | """ 30 | try: 31 | self.sm_client.create_model_package_group( 32 | ModelPackageGroupName=model_package_group_name, 33 | ModelPackageGroupDescription=description, 34 | Tags=[ 35 | {"Key": "sagemaker:project-name", "Value": project_name}, 36 | {"Key": "sagemaker:project-id", "Value": project_id}, 37 | ], 38 | ) 39 | logger.info(f"Model package group {model_package_group_name} created") 40 | return True 41 | 42 | except ClientError as e: 43 | error_code = e.response["Error"]["Code"] 44 | error_message = e.response["Error"]["Message"] 45 | if ( 46 | error_code == "ValidationException" 47 | and "Model Package Group already exists" in error_message 48 | ): 49 | logger.info( 50 | f"Model package group {model_package_group_name} already exists" 51 | ) 52 | return False 53 | else: 54 | logger.error(error_message) 55 | raise Exception(error_message) 56 | 57 | def get_latest_approved_packages( 58 | self, 59 | model_package_group_name: str, 60 | max_results: int, 61 | creation_time_after: datetime = None, 62 | ) -> list: 63 | """Gets the latest approved model packages for a model package group. 64 | 65 | Args: 66 | model_package_group_name: The model package group name. 67 | max_results: The maximum number of model packages to return. 68 | creation_time_after: Optional filter that returns only model 69 | packages created after the specified time (datetime). 70 | 71 | Returns: 72 | The list of model packages, sorted by most recently created 73 | """ 74 | try: 75 | # Get the latest approved model package 76 | args = { 77 | "ModelPackageGroupName": model_package_group_name, 78 | "ModelApprovalStatus": "Approved", 79 | "SortBy": "CreationTime", 80 | "MaxResults": max_results, 81 | } 82 | # Add optional creationg time after 83 | if creation_time_after is not None: 84 | args = {**args, "CreationTimeAfter": creation_time_after} 85 | response = self.sm_client.list_model_packages(**args) 86 | model_packages = response["ModelPackageSummaryList"] 87 | 88 | # Fetch more packages if none returned with continuation token 89 | while len(model_packages) < max_results and "NextToken" in response: 90 | logger.debug( 91 | "Getting more packages for token: {}".format(response["NextToken"]) 92 | ) 93 | # Set the NextToken to override any previous token 94 | args = {**args, "NextToken": response["NextToken"]} 95 | response = self.sm_client.list_model_packages(**args) 96 | model_packages.extend(response["ModelPackageSummaryList"]) 97 | 98 | # Return error if no packages found 99 | if len(model_packages) == 0 and creation_time_after is None: 100 | error_message = ( 101 | f"No approved packages found for: {model_package_group_name}" 102 | ) 103 | logger.error(error_message) 104 | raise Exception(error_message) 105 | 106 | # Return as a list of model packages limited by max results 107 | return model_packages[:max_results] 108 | 109 | except ClientError as e: 110 | error_message = e.response["Error"]["Message"] 111 | logger.error(error_message) 112 | raise Exception(error_message) 113 | 114 | def get_versioned_approved_packages( 115 | self, 116 | model_package_group_name: str, 117 | model_package_versions: list, 118 | ) -> list: 119 | """Gets specific versions of approved model packages for a model package group. 120 | 121 | Args: 122 | model_package_group_name: The model package group name. 123 | model_package_versions: The model package versions to return. 124 | creation_time_after: Optional filter that returns only model 125 | packages created after the specified time (timestamp). 126 | 127 | Returns: 128 | The list of model packages, sorted by most recently created 129 | """ 130 | max_results = 100 131 | unique_versions = set(model_package_versions) 132 | 133 | try: 134 | # Get the approved model package until 135 | args = { 136 | "ModelPackageGroupName": model_package_group_name, 137 | "ModelApprovalStatus": "Approved", 138 | "SortBy": "CreationTime", 139 | "MaxResults": max_results, 140 | } 141 | response = self.sm_client.list_model_packages(**args) 142 | model_packages = self.select_versioned_packages( 143 | response["ModelPackageSummaryList"], unique_versions 144 | ) 145 | 146 | # Fetch more packages if none returned with continuation token 147 | while ( 148 | len(model_packages) < len(unique_versions) and "NextToken" in response 149 | ): 150 | logger.debug( 151 | "Getting more packages for token: {}".format(response["NextToken"]) 152 | ) 153 | args = {**args, "NextToken": response["NextToken"]} 154 | response = self.sm_client.list_model_packages(**args) 155 | model_packages.extend( 156 | self.select_versioned_packages( 157 | response["ModelPackageSummaryList"], unique_versions 158 | ) 159 | ) 160 | 161 | # Return error if no packages found 162 | if len(model_packages) == 0: 163 | error_message = f"No approved packages found for: {model_package_group_name} and versions: {model_package_versions}" 164 | logger.error(error_message) 165 | raise Exception(error_message) 166 | 167 | # Return as a list of model package group in order of versions 168 | return self.select_versioned_packages( 169 | model_packages, model_package_versions 170 | ) 171 | 172 | except ClientError as e: 173 | error_message = e.response["Error"]["Message"] 174 | logger.error(error_message) 175 | raise Exception(error_message) 176 | 177 | def select_versioned_packages( 178 | self, model_packages: list, model_package_versions: list 179 | ): 180 | """Filters the model packages based on a list of model package verisons. 181 | 182 | Args: 183 | model_packages: The list of packages. 184 | model_package_versions: The list of versions. 185 | 186 | Returns: 187 | The Filtered list of model packages in order of versions specified. 188 | Duplicate versions will be preserved. 189 | """ 190 | 191 | filtered_packages = [] 192 | for version in model_package_versions: 193 | filtered_packages += [ 194 | p for p in model_packages if p["ModelPackageVersion"] == version 195 | ] 196 | return filtered_packages 197 | -------------------------------------------------------------------------------- /deployment_pipeline/infra/sagemaker_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | core, 3 | aws_iam, 4 | aws_sagemaker, 5 | ) 6 | 7 | from datetime import datetime 8 | import logging 9 | from deployment_config import DeploymentConfig, VariantConfig 10 | from model_registry import ModelRegistry 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class SageMakerStack(core.Stack): 16 | def __init__( 17 | self, 18 | scope: core.Construct, 19 | construct_id: str, 20 | deployment_config: DeploymentConfig, 21 | project_name: str, 22 | project_id: str, 23 | endpoint_name: str, 24 | tags: list, 25 | **kwargs, 26 | ) -> None: 27 | super().__init__(scope, construct_id, **kwargs) 28 | 29 | # Define the package group names for champion and challenger 30 | champion_package_group = f"{project_name}-champion" 31 | challenger_package_group = f"{project_name}-challenger" 32 | challenger_creation_time: datetime = None 33 | 34 | # Create the model package groups if they don't exist 35 | registry = ModelRegistry() 36 | registry.create_model_package_group( 37 | champion_package_group, 38 | "Champion Models for A/B Testing", 39 | project_name, 40 | project_id, 41 | ) 42 | registry.create_model_package_group( 43 | challenger_package_group, 44 | "Challenger Models for A/B Testing", 45 | project_name, 46 | project_id, 47 | ) 48 | 49 | # If we don't have a specific champion variant defined, get the latest approved 50 | if deployment_config.champion_variant_config is None: 51 | logger.info("Selecting top champion variant") 52 | p = registry.get_latest_approved_packages( 53 | champion_package_group, max_results=1 54 | )[0] 55 | deployment_config.champion_variant_config = VariantConfig( 56 | model_package_version=p["ModelPackageVersion"], 57 | model_package_arn=p["ModelPackageArn"], 58 | initial_variant_weight=1, 59 | instance_count=deployment_config.instance_count, 60 | instance_type=deployment_config.instance_type, 61 | ) 62 | challenger_creation_time = p["CreationTime"] 63 | else: 64 | # Get the versioned package and update ARN 65 | version = deployment_config.champion_variant_config.model_package_version 66 | logger.info(f"Selecting champion version {version}") 67 | p = registry.get_versioned_approved_packages( 68 | champion_package_group, 69 | model_package_versions=[version], 70 | )[0] 71 | deployment_config.champion_variant_config.model_package_arn = p[ 72 | "ModelPackageArn" 73 | ] 74 | 75 | # If we don't have challenger variant config, get the latest after challenger creation time 76 | if deployment_config.challenger_variant_config is None: 77 | logger.info( 78 | f"Selecting top {deployment_config.challenger_variant_count} challenger variants created after {challenger_creation_time}" 79 | ) 80 | deployment_config.challenger_variant_config = [ 81 | VariantConfig( 82 | model_package_version=p["ModelPackageVersion"], 83 | model_package_arn=p["ModelPackageArn"], 84 | initial_variant_weight=1, 85 | instance_count=deployment_config.instance_count, 86 | instance_type=deployment_config.instance_type, 87 | ) 88 | for p in registry.get_latest_approved_packages( 89 | challenger_package_group, 90 | max_results=deployment_config.challenger_variant_count, 91 | creation_time_after=challenger_creation_time, 92 | ) 93 | ] 94 | else: 95 | # Get the versioned packages and update ARN 96 | versions = [ 97 | c.model_package_version 98 | for c in deployment_config.challenger_variant_config 99 | ] 100 | logger.info(f"Selecting challenger versions {versions}") 101 | ps = registry.get_versioned_approved_packages( 102 | challenger_package_group, 103 | model_package_versions=versions, 104 | ) 105 | for i, vc in enumerate(deployment_config.challenger_variant_config): 106 | vc.model_package_arn = ps[i]["ModelPackageArn"] 107 | 108 | # Get the service catalog role 109 | service_catalog_role = aws_iam.Role.from_role_arn( 110 | self, 111 | "SageMakerRole", 112 | f"arn:aws:iam::{self.account}:role/service-role/AmazonSageMakerServiceCatalogProductsUseRole", 113 | ) 114 | 115 | # Add the champion and challenger variants 116 | model_configs = [ 117 | deployment_config.champion_variant_config 118 | ] + deployment_config.challenger_variant_config 119 | 120 | model_variants = [] 121 | for i, variant_config in enumerate(model_configs): 122 | # If variant name not in config use "Champion" for the latest approved and "Challenge{N}" for next N pending 123 | variant_name = variant_config.variant_name or ( 124 | f"Champion{variant_config.model_package_version}" 125 | if i == 0 126 | else f"Challenger{variant_config.model_package_version}" 127 | ) 128 | 129 | # Do not use a custom named resource for models as these get replaced 130 | model = aws_sagemaker.CfnModel( 131 | self, 132 | variant_name, 133 | execution_role_arn=service_catalog_role.role_arn, 134 | primary_container=aws_sagemaker.CfnModel.ContainerDefinitionProperty( 135 | model_package_name=variant_config.model_package_arn, 136 | ), 137 | ) 138 | 139 | # Create the production variant 140 | model_variant = aws_sagemaker.CfnEndpointConfig.ProductionVariantProperty( 141 | initial_instance_count=variant_config.instance_count, 142 | initial_variant_weight=variant_config.initial_variant_weight, 143 | instance_type=variant_config.instance_type, 144 | model_name=model.attr_model_name, 145 | variant_name=variant_name, 146 | ) 147 | model_variants.append(model_variant) 148 | 149 | if len(model_variants) == 0: 150 | raise Exception("No model variants matching configuration") 151 | 152 | endpoint_config = aws_sagemaker.CfnEndpointConfig( 153 | self, 154 | "EndpointConfig", 155 | production_variants=model_variants, 156 | ) 157 | 158 | self.endpoint = aws_sagemaker.CfnEndpoint( 159 | self, 160 | "Endpoint", 161 | endpoint_config_name=endpoint_config.attr_endpoint_config_name, 162 | endpoint_name=endpoint_name, 163 | tags=tags, 164 | ) 165 | -------------------------------------------------------------------------------- /deployment_pipeline/infra/test_model_registry.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from botocore.stub import Stubber 3 | import pytest 4 | 5 | from model_registry import ModelRegistry 6 | 7 | 8 | def get_package(version: int, creation_time: datetime = datetime.fromtimestamp(0)): 9 | return { 10 | "ModelPackageName": "STUB", 11 | "ModelPackageGroupName": "test-package-group", 12 | "ModelPackageVersion": version, 13 | "ModelPackageArn": f"arn:aws:sagemaker:REGION:ACCOUNT:model-package/test-package-group/{version}", 14 | "CreationTime": creation_time, 15 | "ModelPackageStatus": "Completed", 16 | "ModelApprovalStatus": "Approved", 17 | } 18 | 19 | 20 | @pytest.mark.skip(reason="botocore.exceptions.ParamValidationError: fails with Tags") 21 | def test_create_model_package_group(): 22 | # Create model registry 23 | registry = ModelRegistry() 24 | 25 | with Stubber(registry.sm_client) as stubber: 26 | # Empty list with more 27 | expected_params = { 28 | "ModelPackageGroupDescription": "test package group", 29 | "ModelPackageGroupName": "test-package-group", 30 | "Tags": [ 31 | {"Key": "sagemaker:project-name", "Value": "test-project-name"}, 32 | {"Key": "sagemaker:project-id", "Value": "test-project-id"}, 33 | ], 34 | } 35 | expected_response = { 36 | "ModelPackageGroupArn": f"arn:aws:sagemaker:REGION:ACCOUNT:model-package-group/test-package-group", 37 | } 38 | stubber.add_response( 39 | "create_model_package_group", expected_response, expected_params 40 | ) 41 | 42 | # Second time, add the client error if this exists 43 | stubber.add_client_error( 44 | "create_model_package_group", 45 | "ValidationException", 46 | "Model Package Group already exists", 47 | expected_params=expected_params, 48 | ) 49 | 50 | created = registry.create_model_package_group( 51 | "test-package-group", 52 | "test package group", 53 | "test-project-name", 54 | "test-project-id", 55 | ) 56 | assert created == True 57 | 58 | created = registry.create_model_package_group( 59 | "test-package-group", 60 | "test package group", 61 | "test-project-name", 62 | "test-project-id", 63 | ) 64 | assert created == False 65 | 66 | 67 | def test_get_latest_approved_model_packages(): 68 | # Create model registry 69 | registry = ModelRegistry() 70 | 71 | with Stubber(registry.sm_client) as stubber: 72 | # Empty list with more 73 | expected_params = { 74 | "ModelPackageGroupName": "test-package-group", 75 | "ModelApprovalStatus": "Approved", 76 | "SortBy": "CreationTime", 77 | "MaxResults": 2, 78 | } 79 | expected_response = { 80 | "ModelPackageSummaryList": [], 81 | "NextToken": "MORE1", 82 | } 83 | stubber.add_response("list_model_packages", expected_response, expected_params) 84 | # Version 1 with more 85 | expected_params = { 86 | "ModelPackageGroupName": "test-package-group", 87 | "ModelApprovalStatus": "Approved", 88 | "SortBy": "CreationTime", 89 | "MaxResults": 2, 90 | "NextToken": "MORE1", 91 | } 92 | expected_response = { 93 | "ModelPackageSummaryList": [get_package(3)], 94 | "NextToken": "MORE2", 95 | } 96 | stubber.add_response("list_model_packages", expected_response, expected_params) 97 | # Version 2 with two more 98 | expected_params = { 99 | "ModelPackageGroupName": "test-package-group", 100 | "ModelApprovalStatus": "Approved", 101 | "SortBy": "CreationTime", 102 | "MaxResults": 2, 103 | "NextToken": "MORE2", 104 | } 105 | expected_response = { 106 | "ModelPackageSummaryList": [ 107 | get_package(2), 108 | get_package(1), 109 | ], 110 | } 111 | stubber.add_response("list_model_packages", expected_response, expected_params) 112 | 113 | response = registry.get_latest_approved_packages( 114 | model_package_group_name="test-package-group", 115 | max_results=2, 116 | ) 117 | # Expect to get two version 118 | assert len(response) == 2 119 | assert response == [ 120 | get_package(3), 121 | get_package(2), 122 | ] 123 | 124 | 125 | def test_empty_latest_approved_model_packages(): 126 | # Create model registry 127 | registry = ModelRegistry() 128 | 129 | with Stubber(registry.sm_client) as stubber: 130 | # Empty list with no more 131 | expected_params = { 132 | "ModelPackageGroupName": "test-package-group", 133 | "ModelApprovalStatus": "Approved", 134 | "SortBy": "CreationTime", 135 | "MaxResults": 2, 136 | } 137 | expected_response = { 138 | "ModelPackageSummaryList": [], 139 | } 140 | stubber.add_response("list_model_packages", expected_response, expected_params) 141 | 142 | # Expect error when no results 143 | with pytest.raises(Exception): 144 | registry.get_latest_approved_packages( 145 | model_package_group_name="test-package-group", 146 | max_results=2, 147 | ) 148 | 149 | 150 | def test_get_latest_approved_model_packages_after_creation(): 151 | # Create model registry 152 | registry = ModelRegistry() 153 | now = datetime.now() 154 | 155 | with Stubber(registry.sm_client) as stubber: 156 | # Empty list with more 157 | expected_params = { 158 | "ModelPackageGroupName": "test-package-group", 159 | "ModelApprovalStatus": "Approved", 160 | "SortBy": "CreationTime", 161 | "MaxResults": 2, 162 | "CreationTimeAfter": now - timedelta(3), 163 | } 164 | expected_response = { 165 | "ModelPackageSummaryList": [], 166 | "NextToken": "MORE1", 167 | } 168 | stubber.add_response("list_model_packages", expected_response, expected_params) 169 | # Version 1 with more 170 | expected_params = { 171 | "ModelPackageGroupName": "test-package-group", 172 | "ModelApprovalStatus": "Approved", 173 | "SortBy": "CreationTime", 174 | "MaxResults": 2, 175 | "CreationTimeAfter": now - timedelta(3), 176 | "NextToken": "MORE1", 177 | } 178 | expected_response = { 179 | "ModelPackageSummaryList": [get_package(3, now - timedelta(1))], 180 | "NextToken": "MORE2", 181 | } 182 | stubber.add_response("list_model_packages", expected_response, expected_params) 183 | # Version 2 with two more 184 | expected_params = { 185 | "ModelPackageGroupName": "test-package-group", 186 | "ModelApprovalStatus": "Approved", 187 | "SortBy": "CreationTime", 188 | "MaxResults": 2, 189 | "CreationTimeAfter": now - timedelta(3), 190 | "NextToken": "MORE2", 191 | } 192 | expected_response = { 193 | "ModelPackageSummaryList": [ 194 | get_package(2, now - timedelta(2)), 195 | get_package(1, now - timedelta(3)), 196 | ], 197 | } 198 | stubber.add_response("list_model_packages", expected_response, expected_params) 199 | 200 | response = registry.get_latest_approved_packages( 201 | model_package_group_name="test-package-group", 202 | max_results=2, 203 | creation_time_after=now - timedelta(3), 204 | ) 205 | # Expect to get two version 206 | assert len(response) == 2 207 | assert response == [ 208 | get_package(3, now - timedelta(1)), 209 | get_package(2, now - timedelta(2)), 210 | ] 211 | 212 | 213 | def test_empty_latest_approved_model_packages_after_creation(): 214 | # Create model registry 215 | registry = ModelRegistry() 216 | now = datetime.now() 217 | 218 | with Stubber(registry.sm_client) as stubber: 219 | # Empty list with no more 220 | expected_params = { 221 | "ModelPackageGroupName": "test-package-group", 222 | "ModelApprovalStatus": "Approved", 223 | "SortBy": "CreationTime", 224 | "MaxResults": 2, 225 | "CreationTimeAfter": now - timedelta(3), 226 | } 227 | expected_response = { 228 | "ModelPackageSummaryList": [], 229 | } 230 | stubber.add_response("list_model_packages", expected_response, expected_params) 231 | 232 | # Expect no error, but empty list for creation time after 233 | response = registry.get_latest_approved_packages( 234 | model_package_group_name="test-package-group", 235 | max_results=2, 236 | creation_time_after=now - timedelta(3), 237 | ) 238 | assert len(response) == 0 239 | 240 | 241 | def test_get_versioned_approved_model_packages(): 242 | # Create model registry 243 | registry = ModelRegistry() 244 | 245 | with Stubber(registry.sm_client) as stubber: 246 | # Empty list with more 247 | expected_params = { 248 | "ModelPackageGroupName": "test-package-group", 249 | "ModelApprovalStatus": "Approved", 250 | "SortBy": "CreationTime", 251 | "MaxResults": 100, 252 | } 253 | expected_response = { 254 | "ModelPackageSummaryList": [], 255 | "NextToken": "MORE1", 256 | } 257 | stubber.add_response("list_model_packages", expected_response, expected_params) 258 | # Version 1 with more 259 | expected_params = { 260 | "ModelPackageGroupName": "test-package-group", 261 | "ModelApprovalStatus": "Approved", 262 | "SortBy": "CreationTime", 263 | "MaxResults": 100, 264 | "NextToken": "MORE1", 265 | } 266 | expected_response = { 267 | "ModelPackageSummaryList": [get_package(3)], 268 | "NextToken": "MORE2", 269 | } 270 | stubber.add_response("list_model_packages", expected_response, expected_params) 271 | # Version 2 with two more 272 | expected_params = { 273 | "ModelPackageGroupName": "test-package-group", 274 | "ModelApprovalStatus": "Approved", 275 | "SortBy": "CreationTime", 276 | "MaxResults": 100, 277 | "NextToken": "MORE2", 278 | } 279 | expected_response = { 280 | "ModelPackageSummaryList": [ 281 | get_package(2), 282 | get_package(1), 283 | ], 284 | } 285 | stubber.add_response("list_model_packages", expected_response, expected_params) 286 | 287 | # Get model versions 288 | response = registry.get_versioned_approved_packages( 289 | model_package_group_name="test-package-group", 290 | model_package_versions=[1, 2], 291 | ) 292 | # Expect to get two version 293 | assert len(response) == 2 294 | assert response == [ 295 | get_package(1), 296 | get_package(2), 297 | ] 298 | 299 | 300 | def test_filter_package_version(): 301 | """ 302 | Select the sorted package versions. Validate we return in the order we ask for. 303 | """ 304 | unsorted_packages = [ 305 | get_package(1), 306 | get_package(3), 307 | get_package(2), 308 | ] 309 | 310 | registry = ModelRegistry() 311 | versions = [2, 3, 2] 312 | response = registry.select_versioned_packages(unsorted_packages, versions) 313 | assert len(response) == 3 314 | assert response == [ 315 | get_package(2), 316 | get_package(3), 317 | get_package(2), 318 | ] 319 | -------------------------------------------------------------------------------- /deployment_pipeline/prod-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "stage_name": "prod", 3 | "strategy": "EpsilonGreedy", 4 | "warmup": 100, 5 | "epsilon": 0.1, 6 | "instance_count": 2, 7 | "instance_type": "ml.c5.large", 8 | "champion_variant_config": { 9 | "model_package_version": 1, 10 | "variant_name": "Champion", 11 | "instance_count": 3, 12 | "instance_type": "ml.m5.xlarge" 13 | }, 14 | "challenger_variant_config": [ 15 | { 16 | "model_package_version": 1, 17 | "variant_name": "Challenger1", 18 | "instance_type": "ml.c5.xlarge" 19 | }, 20 | { 21 | "model_package_version": 2, 22 | "variant_name": "Challenger2", 23 | "instance_count": 1 24 | } 25 | ] 26 | } -------------------------------------------------------------------------------- /deployment_pipeline/register.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import logging 5 | import os 6 | import boto3 7 | 8 | # Configure the logger 9 | logger = logging.getLogger(__name__) 10 | logging.basicConfig(level="INFO") 11 | 12 | # Boto3 client 13 | lambda_client = boto3.client("lambda") 14 | 15 | # Load these from environment variables, that are passed into CodeBuild job from pipeline stack 16 | project_name = os.environ["SAGEMAKER_PROJECT_NAME"] 17 | project_id = os.environ["SAGEMAKER_PROJECT_ID"] 18 | stage_name = os.environ["STAGE_NAME"] 19 | register_lambda = os.environ["REGISTER_LAMBDA"] 20 | 21 | # Get endpoint 22 | endpoint_name = f"sagemaker-{project_name}-{stage_name}" 23 | logger.info(f"Register endpoint: {endpoint_name} with lambda: {register_lambda}") 24 | 25 | # Get the config and include with endpoint to register this model 26 | with open(f"{stage_name}-config.json", "r") as f: 27 | j = json.load(f) 28 | event = json.dumps( 29 | { 30 | "source": "aws.sagemaker", 31 | "detail-type": "SageMaker Endpoint State Change", 32 | "detail": { 33 | "EndpointName": endpoint_name, 34 | "EndpointStatus": "IN_SERVICE", 35 | "Tags": { 36 | "sagemaker:project-name": project_name, 37 | "sagemaker:project-id": project_id, 38 | "sagemaker:deployment-stage": stage_name, 39 | "ab-testing:enabled": "true", 40 | "ab-testing:strategy": j.get("strategy", "ThompsonSampling"), 41 | "ab-testing:epsilon": str(j.get("epsilon", 0.1)), 42 | "ab-testing:warmup": str(j.get("warmup", 0)), 43 | }, 44 | }, 45 | } 46 | ) 47 | response = lambda_client.invoke( 48 | FunctionName=register_lambda, 49 | InvocationType="RequestResponse", 50 | LogType="Tail", 51 | Payload=event.encode("utf-8"), 52 | ) 53 | # Print the result, and if not succesful raise error 54 | result = json.loads(response["Payload"].read()) 55 | print(result) 56 | if result["statusCode"] not in [200, 201]: 57 | raise Exception("Unexpected status code: {}".format(result["statusCode"])) 58 | -------------------------------------------------------------------------------- /deployment_pipeline/requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /deployment_pipeline/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open("README.md") as fp: 5 | long_description = fp.read() 6 | 7 | 8 | setuptools.setup( 9 | name="amazon_sagemaker_ab_testing_deployment", 10 | version="0.0.1", 11 | description="Amazon SageMaker pipeline for A/B Testing", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | author="author", 15 | package_dir={"": "infra"}, 16 | packages=setuptools.find_packages(where="infra"), 17 | install_requires=[ 18 | "boto3>=1.17.54", 19 | "aws-cdk.core==1.94.1", 20 | "aws-cdk.aws-iam==1.94.1", 21 | "aws-cdk.aws-sagemaker==1.94.1", 22 | ], 23 | python_requires=">=3.6", 24 | classifiers=[ 25 | "Development Status :: 4 - Beta", 26 | "Intended Audience :: Developers", 27 | "License :: OSI Approved :: Apache Software License", 28 | "Programming Language :: JavaScript", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Programming Language :: Python :: 3.6", 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: 3.8", 33 | "Topic :: Software Development :: Code Generators", 34 | "Topic :: Utilities", 35 | "Typing :: Typed", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /docs/API_CONFIGURATION.md: -------------------------------------------------------------------------------- 1 | # API and Testing infrastructure Configuration 2 | 3 | The API and Testing infrastructure stack reads configuration from context values in `cdk.json`. These values can also be override by passing arguments to the cdk deploy command eg: 4 | 5 | ``` 6 | cdk deploy ab-testing-api -c stage_name=dev -c endpoint_prefix=sagemaker-ab-testing-pipeline 7 | ``` 8 | 9 | Following is a list of the context parameters and their defaults. 10 | 11 | | Property | Description | Default | 12 | |---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------| 13 | | `api_name` | The API Gateway Name | "ab-testing" | 14 | | `stage_name` | The stage namespace for resource and API Gateway path | "dev" | 15 | | `endpoint_prefix` | A prefix to filter Amazon SageMaker endpoints the API can invoke. | "sagemaker-" | 16 | | `api_lambda_memory` | The [lambda memory](https://docs.aws.amazon.com/lambda/latest/dg/configuration-memory.html) allocation for API endpoint. | 768 | 17 | | `api_lambda_timeout` | The lambda timeout for the API endpoint. | 10 | 18 | | `metrics_lambda_memory` | The [lambda memory](https://docs.aws.amazon.com/lambda/latest/dg/configuration-memory.html) allocated for metrics processing Lambda | 768 | 19 | | `metrics_lambda_timeout` | The lambda timeout for the processing lambda. | 10 | 20 | | `dynamodb_read_capacity` | The [Read Capacity](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.ReadWriteCapacityMode.html) for the DynamoDB tables | 5 | 21 | | `dynamodb_write_capacity` | The [Write Capacity](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.ReadWriteCapacityMode.html) for the DynamoDB tables | 5 | 22 | | `delivery_sync` | When`true` metrics will be written directly to DynamoDB, instead of the Amazon Kinesis for processing. | false | 23 | | `firehose_interval` | The [buffering](https://docs.aws.amazon.com/firehose/latest/dev/create-configure.html) interval in seconds which firehose will flush events to S3. | 60 | 24 | | `firehose_mb_size` | The buffering size in MB before the firehose will flush its events to S3. | 1 | 25 | | `log_level` | Logging level for AWS Lambda functions | "INFO" | 26 | -------------------------------------------------------------------------------- /docs/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. -------------------------------------------------------------------------------- /docs/CUSTOM_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Customize the Deployment Pipeline 2 | 3 | The [ab-testing-pipeline.yml](../ab-testing-pipeline.yml) is included as part of this distribution, and doesn't require updating unless you change the `infra/pipeline_stack.py` implementation. 4 | 5 | To generate a new pipeline you can run the following command. 6 | 7 | ``` 8 | cdk synth ab-testing-pipeline --path-metadata=false > ab-testing-pipeline.yml 9 | ``` 10 | 11 | This template will output a new Policy to attach to the `AmazonSageMakerServiceCatalogProductsUseRole` service role. This policy is not required as this managed role already has these permissions. In order for this to run within Amazon SageMaker Studio, you will need to remove this policy. I recommend you diff the original to see where changes need to be made. If there are additional roles or policies the project might not be validate when used inside of Amazon SageMaker Studio. 12 | 13 | ``` 14 | git diff ab-testing-pipeline.yml 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | ## Frequency Asked Questions (FAQ) 2 | 3 | ### Can I use this A/B Testing Pipeline for any SageMaker model? 4 | 5 | Yes, the API and testing infrastructure allows any endpoint with 1 or more production variants to be registered with it. The API passes the `content_type` and `data` payload down to the Amazon SageMaker endpoint after selecting the best model variant to target for a user. 6 | 7 | ### Why do I need to register my new endpoint with the API after deployment? 8 | 9 | The `Register` stage in the deployment pipeline ensures that the Amazon SageMaker Endpoint is able to be reached by the API. The API retrieves the initial weights for the Production Variants configured against the endpoint. These are saved back to DynamoDB and any metrics that we previously associated with this endpoint are cleared in preparation for a new test. 10 | 11 | ### How often will metrics be updated in DynamoDB? 12 | 13 | On every `invocation` and `conversion` request against the API, events are written to a Kinesis Data Firehose stream. This stream has a buffer which is configured to write these events to an S3 bucket every 60 seconds or 1MB. When these events are written to S3, an AWS Lambda is triggered which loads these events, sums up the `invocation` and `conversion` records and writes these metrics to DynamoDB. 14 | 15 | ### Why not write metrics directly to DynamoDB? 16 | 17 | The solution can be configured to write metrics to DynamoDB, however this is not recommend for a couple of reasons. 18 | 19 | 1. Less frequent writes to dynamoDB will requires in a lower [Write Capacity](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.ReadWriteCapacityMode.html) and reduce cost. 20 | 2. Batching up metrics is recommend ensure that Bandit Algorithms explore early and don't arrive too quickly at a winner. 21 | 3. Batches of metrics could be analyzed before writing to DynamoDB which provide the opportunity to strip out noisy or fraudulent users by filtering on client IP, user agent or other context. 22 | 4. Metrics written as JSON lines to Kinesis Data Firehose with land in S3, partitioned by date and time can be queried by Athena or S3 Select. 23 | 24 | ### What happens if DynamoDB or Kinesis Firehose is unavailable? 25 | 26 | If there was an error reaching the DynamoDB store to retrieve user assignment, or variant metrics, the solution will still continue to operate but will fallback to the default traffic distribution for [Multi-Variant](https://docs.aws.amazon.com/sagemaker/latest/dg/model-ab-testing.html) endpoints. Logs will still attempt to be written to Kinesis Firehose which if available will record the fact that a `Fallback` was registered - see the [operations manual](OPERATIONS.md) for more information. 27 | 28 | This solution has been instrumented with [AWS X-Ray](https://docs.aws.amazon.com/xray/latest/devguide/aws-xray.html) so you should be able to detect any [throttling](https://aws.amazon.com/premiumsupport/knowledge-center/on-demand-table-throttling-dynamodb/) which is the most likely cause of any unavailability. -------------------------------------------------------------------------------- /docs/OPERATIONS.md: -------------------------------------------------------------------------------- 1 | # A/B Testing Pipeline Operations Manual 2 | 3 | Having created the A/B Testing Deployment Pipeline, this operations manual provides instructions on how to run your A/B Testing experiment. 4 | 5 | ## A/B Testing for Machine Learning models 6 | 7 | Successful A/B Testing for machine learning models requires measuring how effective predictions are against end users. It is important to be able to identify users consistently and be able to attribute success actions against the model predictions back to users. 8 | 9 | ### Conversion Metrics 10 | 11 | A/B Testing can be applied to a number of use cases for which you have defined a success or `conversion` metric for predictions returned form an ML model. 12 | 13 | Examples include: 14 | * User **Click Through** on advertisement predicted based on browsing history and geo location. 15 | * User **Dwell Time** for personalized content based exceeds relevancy threshold. 16 | * User **Opens** marketing email with personalized subject line. 17 | * User **Watches** recommended video for more than 30 seconds based on viewing history. 18 | * User **Purchases** a product upgrade being offered an pricing discount. 19 | 20 | Conversion rates will vary for each use case, so successful models will be measure as a percentage improvement (eg 5%) over a baseline, or previous best model. 21 | 22 | ## Deployment Pipeline 23 | 24 | The A/B Deployment Pipeline provides an additional stage after Endpoint deployment to register the endpoint for A/B Testing. 25 | 26 | This Register stage invokes a lambda by providing an event that includes the `enpdoint_name` along with configuration to select models from the registry based on the testing strategy. 27 | 28 | ### Testing Strategies 29 | 30 | Following is a list of the testing strategies available. 31 | 32 | 1. `WeightedSampling` - Random selection based in initial variant weights. Also be select during `warmup` phase. 33 | 2. `EpsilonGreedy` - Simple strategy picks a random variant a fraction of the time based on `epsilon`. 34 | 3. `UCB1` - Smart strategy explores variants with upper confidence bounds until uncertainty drops. 35 | 4. `ThompsonSampling` - Smart strategy picks random points from beta distributions to exploit variants. 36 | 37 | ### Configuration parameters 38 | 39 | The configuration is stored in the CodeCommit source repository by stage name eg `dev-config.json` for the `dev` stage, and has the following parameters 40 | 41 | * `stage_name` - The stage suffix for the SageMaker endpoint eg. `dev` 42 | * `instance_count` - The number of instances to deploy per variant 43 | * `instance_type` - The type of instance to deploy per variant. 44 | * `strategy` - The algorithm strategy for selecting user model variants. 45 | * `epsilon` - The epsilon parameter used by the `EpsilonGreedy` strategy. 46 | * `warmup` - The number of invocations to warm up before applying the strategy. 47 | 48 | In addition to the above, you must specify the `champion` and `challenger` model variants for the deployment. 49 | 50 | These will be loaded from the two Model Package Groups in the registry that include the project name and suffixed with `champion` or `challenger` for example project name `ab-testing-pipeline` these model package groups in the sample notebook: 51 | 52 | ![\[Model Registry\]](ab-testing-pipeline-model-registry.png) 53 | 54 | **Latest Approved Versions** 55 | 56 | To deploy the latest approved approved `champion` model, and the latest `N` approved `challenger` models, you can provide the single `challenger_variant_count` parameter eg: 57 | 58 | ``` 59 | { 60 | "stage_name": "dev", 61 | "strategy": "ThompsonSampling", 62 | "instance_count": 1, 63 | "instance_type": "ml.t2.large", 64 | "challenger_variant_count": 1 65 | } 66 | ``` 67 | 68 | Alternatively, such as the case for production environments, you way prefer to select specific approved model package versions. In this case you can specify the `model_package_version` for both the `champion_variant_config` and for one or more `challenger_variant_config` configuration entries. 69 | 70 | You also have the option of overriding one or both of the `instance_count` and `instance_type` parameters for each variant. 71 | 72 | **Specific Versions** 73 | 74 | ``` 75 | { 76 | "stage_name": "prod", 77 | "strategy": "ThompsonSampling", 78 | "warmup": 100, 79 | "instance_count": 2, 80 | "instance_type": "ml.c5.large", 81 | "champion_variant_config": { 82 | "model_package_version": 1, 83 | "variant_name": "Champion", 84 | "instance_count": 3, 85 | "instance_type": "ml.m5.xlarge" 86 | }, 87 | "challenger_variant_config": [ 88 | { 89 | "model_package_version": 1, 90 | "variant_name": "Challenger1", 91 | "instance_type": "ml.c5.xlarge" 92 | }, 93 | { 94 | "model_package_version": 2, 95 | "variant_name": "Challenger2", 96 | "instance_count": 1 97 | } 98 | ] 99 | } 100 | ``` 101 | 102 | ## API Front-end 103 | 104 | The API has two endpoints `invocation` and `conversion` both of which take a `JSON` payload and return a `JSON` response. 105 | 106 | ### Invocation 107 | 108 | The invocation API requires an `endpoint_name`. It also expects a `user_id` input parameter to identify the user, if none is provided a new `user_id` in the form of a UUID will be generated and return in response. 109 | 110 | ``` 111 | curl -X POST -d '<>' https://<>.execute-api.<>.amazonaws.com/<>/invocation 112 | ``` 113 | 114 | **Request**: 115 | ``` 116 | { 117 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 118 | "user_id": "user_1", 119 | "content_type": "application/json", 120 | "data": "{\"instances\": [\"Excellent Item This is the perfect media device\"]}" 121 | } 122 | ``` 123 | 124 | The response will return the invoked `endpoint_variant` that return the predictions as well as algorithm `strategy` and `target_variant` selected. 125 | 126 | **Response**: 127 | ``` 128 | { 129 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 130 | "user_id": "user_1", 131 | "strategy": "ThompsonSampling", 132 | "target_variant": "Challenger1", 133 | "endpoint_variant": "Challenger1", 134 | "inference_id": "5aa61fe8-70d7-4eed-9419-8f4efc33662d", 135 | "predictions": [{"label": ["__label__NotHelpful"], "prob": [0.735004723072052]}] 136 | } 137 | ``` 138 | 139 | ### Manual overriding endpoint variant 140 | 141 | You can provide a manual override for the `endpoint_variant` by specifying this the request payload. 142 | 143 | **Request**: 144 | ``` 145 | { 146 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 147 | "endpoint_variant": "Champion", 148 | "user_id": "user_1", 149 | "content_type": "application/json", 150 | "data": "{\"instances\": [\"Excellent Item This is the perfect media device\"]}" 151 | } 152 | ``` 153 | 154 | The response will output a `strategy` of "Manual" which will be logged. 155 | 156 | **Response**: 157 | ``` 158 | { 159 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 160 | "user_id": "user_1", 161 | "strategy": "Manual", 162 | "target_variant": "Champion", 163 | "endpoint_variant": "Champion", 164 | "inference_id": "5aa61fe8-70d7-4eed-9419-8f4efc33662d", 165 | "predictions": [{"label": ["__label__NotHelpful"], "prob": [0.735004723072052]}] 166 | } 167 | ``` 168 | 169 | ### Fallback strategy 170 | 171 | In the event of an error reaching the DynamoDB tables for user assignment of variant metrics, the API will still continue invoke the SageMaker endpoint. 172 | 173 | The response will return a `strategy` of "Fallback" along with an empty `target_variant` and the actual invoked `endpoint_variant` which will be logged in there are no issues writing to Kinesis. 174 | 175 | **Response**: 176 | ``` 177 | { 178 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 179 | "user_id": "user_1", 180 | "strategy": "Fallback", 181 | "target_variant": null, 182 | "endpoint_variant": "Challenger2", 183 | "inference_id": "5aa61fe8-70d7-4eed-9419-8f4efc33662d", 184 | "predictions": [{"label": ["__label__NotHelpful"], "prob": [0.735004723072052]}] 185 | } 186 | ``` 187 | 188 | ### Conversion 189 | 190 | The conversion API requires an `endpoint_name` and a `user_id`. 191 | You can optionally provide the `inference_id` which was returned from the original invocation request to allow correlation when querying the logs in S3. 192 | The `reward` parameters is a floating point number that defaults to `1.0` if not provided. 193 | 194 | ``` 195 | curl -X POST -d '<>' https://<>.execute-api.<>.amazonaws.com/<>/conversion 196 | ``` 197 | 198 | **Request**: 199 | ``` 200 | { 201 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 202 | "user_id": "user_1", 203 | "inference_id": "5aa61fe8-70d7-4eed-9419-8f4efc33662d", 204 | "reward": 1.0 205 | } 206 | ``` 207 | 208 | The response will return the `endpoint_variant` assigned to the user. 209 | 210 | **Response**: 211 | ``` 212 | { 213 | "endpoint_name": "sagemaker-ab-testing-pipeline-dev", 214 | "user_id": "user_1", 215 | "strategy": "ThompsonSampling", 216 | "endpoint_name": "Challenger1", 217 | "endpoint_variant": "Challenger1", 218 | "inference_id": "5aa61fe8-70d7-4eed-9419-8f4efc33662d", 219 | "reward": 1.0 220 | } 221 | ``` 222 | 223 | ## Monitoring 224 | 225 | ### Metrics 226 | 227 | AWS CloudWatch metrics are published to every time metrics are updated in Amazon DynamoDB. 228 | 229 | The following metrics are recorded against dimensions `EndpointName` and `VariantName` in namespace `aws/sagemaker/Endpoints/ab-testing` 230 | * `Invocations` 231 | * `Conversions` 232 | * `Rewards` 233 | 234 | ### Traces 235 | 236 | The API Lambda functions are instrumented with [AWS X-Ray](https://aws.amazon.com/xray/) so you can inspect the latency for all downstream services including 237 | * DynamoDB 238 | * Amazon SageMaker 239 | * Kinesis Firehose 240 | 241 | ![\[AB Testing Pipeline X-Ray\]](ab-testing-pipeline-xray.png) -------------------------------------------------------------------------------- /docs/SERVICE_CATALOG.md: -------------------------------------------------------------------------------- 1 | # AWS Service Catalog Provisioning 2 | 3 | If you have an existing AWS Service Catalog Portfolio, or would like to create the Product manually, follow these steps: 4 | 5 | 1. Sign in to the console with the data science account. 6 | 2. On the AWS Service Catalog console, under **Administration**, choose **Portfolios**. 7 | 3. Choose **Create a new portfolio**. 8 | 4. Name the portfolio `SageMaker Organization Templates`. 9 | 5. Download the [AB testing template](../ab-testing-pipeline.yml) to your computer. 10 | 6. Choose the new portfolio. 11 | 7. Choose **Upload a new product.** 12 | 8. For **Product name**¸ enter `A/B Testing Deployment Pipeline`. 13 | 9. For **Description**, enter `Amazon SageMaker Project for A/B Testing models`. 14 | 10. For **Owner**, enter your name. 15 | 11. Under **Version details**, for **Method**, choose **Use a template file**. 16 | 12. Choose **Upload a template**. 17 | 13. Upload the template you downloaded. 18 | 14. For **Version title**, enter `1.0`. 19 | 20 | The remaining parameters are optional. 21 | 22 | 15. Choose **Review**. 23 | 16. Review your settings and choose **Create product**. 24 | 17. Choose **Refresh** to list the new product. 25 | 18. Choose the product you just created. 26 | 19. On the **Tags** tab, add the following tag to the product: 27 | - **Key** – `sagemaker:studio-visibility` 28 | - **Value** – `True` 29 | 30 | Finally we need to add launch constraint and role permissions. 31 | 32 | 20. On the **Constraints** tab, choose Create constraint. 33 | 21. For **Product**, choose **AB Testing Pipeline** (the product you just created). 34 | 22. For **Constraint type**, choose **Launch**. 35 | 23. Under **Launch Constraint**, for **Method**, choose **Select IAM role**. 36 | 24. Choose **AmazonSageMakerServiceCatalogProductsLaunchRole**. 37 | 25. Choose **Create**. 38 | 26. On the **Groups, roles, and users** tab, choose **Add groups, roles, users**. 39 | 27. On the **Roles** tab, select the role you used when configuring your SageMaker Studio domain. 40 | 28. Choose **Add access**. 41 | 42 | If you don’t remember which role you selected, in your data science account, go to the SageMaker console and choose **Amazon SageMaker Studio**. In the Studio **Summary** section, locate the attribute **Execution role**. Search for the name of this role in the previous step. -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-architecture.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-code-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-code-pipeline.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-deployment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-deployment.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-execution-role.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-execution-role.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-model-registry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-model-registry.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-sagemaker-project.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-sagemaker-project.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-sagemaker-template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-sagemaker-template.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-upload-file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-upload-file.png -------------------------------------------------------------------------------- /docs/ab-testing-pipeline-xray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-pipeline-xray.png -------------------------------------------------------------------------------- /docs/ab-testing-solution-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/docs/ab-testing-solution-overview.png -------------------------------------------------------------------------------- /infra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-ab-testing-pipeline/3d5f444a5bdc5e420ed68f7c160b2ff9396839cb/infra/__init__.py -------------------------------------------------------------------------------- /infra/api_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | core, 3 | aws_apigateway, 4 | aws_iam, 5 | aws_events as events, 6 | aws_events_targets as targets, 7 | aws_logs, 8 | aws_lambda, 9 | aws_dynamodb, 10 | aws_kinesisfirehose, 11 | aws_s3, 12 | aws_s3_notifications, 13 | ) 14 | 15 | 16 | class ApiStack(core.Stack): 17 | def __init__( 18 | self, 19 | scope: core.Construct, 20 | construct_id: str, 21 | **kwargs, 22 | ) -> None: 23 | super().__init__(scope, construct_id, **kwargs) 24 | 25 | # Get some context properties 26 | log_level = self.node.try_get_context("log_level") 27 | api_name = self.node.try_get_context("api_name") 28 | stage_name = self.node.try_get_context("stage_name") 29 | endpoint_prefix = self.node.try_get_context("endpoint_prefix") 30 | api_lambda_memory = self.node.try_get_context("api_lambda_memory") 31 | api_lambda_timeout = self.node.try_get_context("api_lambda_timeout") 32 | metrics_lambda_memory = self.node.try_get_context("metrics_lambda_memory") 33 | metrics_lambda_timeout = self.node.try_get_context("metrics_lambda_timeout") 34 | dynamodb_read_capacity = self.node.try_get_context("dynamodb_read_capacity") 35 | dynamodb_write_capacity = self.node.try_get_context("dynamodb_write_capacity") 36 | delivery_sync = self.node.try_get_context("delivery_sync") 37 | firehose_interval = self.node.try_get_context("firehose_interval") 38 | firehose_mb_size = self.node.try_get_context("firehose_mb_size") 39 | 40 | # Create dynamodb tables and kinesis stream per project 41 | assignment_table_name = f"{api_name}-assignment-{stage_name}" 42 | metrics_table_name = f"{api_name}-metrics-{stage_name}" 43 | delivery_stream_name = f"{api_name}-events-{stage_name}" 44 | log_stream_name = "ApiEvents" 45 | 46 | assignment_table = aws_dynamodb.Table( 47 | self, 48 | "AssignmentTable", 49 | table_name=assignment_table_name, 50 | partition_key=aws_dynamodb.Attribute( 51 | name="user_id", 52 | type=aws_dynamodb.AttributeType.STRING, 53 | ), 54 | sort_key=aws_dynamodb.Attribute( 55 | name="endpoint_name", 56 | type=aws_dynamodb.AttributeType.STRING, 57 | ), 58 | read_capacity=dynamodb_read_capacity, 59 | write_capacity=dynamodb_write_capacity, 60 | removal_policy=core.RemovalPolicy.DESTROY, 61 | time_to_live_attribute="ttl", 62 | ) 63 | 64 | metrics_table = aws_dynamodb.Table( 65 | self, 66 | "MetricsTable", 67 | table_name=metrics_table_name, 68 | partition_key=aws_dynamodb.Attribute( 69 | name="endpoint_name", type=aws_dynamodb.AttributeType.STRING 70 | ), 71 | read_capacity=dynamodb_read_capacity, 72 | write_capacity=dynamodb_write_capacity, 73 | removal_policy=core.RemovalPolicy.DESTROY, 74 | ) 75 | 76 | # Create lambda layer for "aws-xray-sdk" and latest "boto3" 77 | xray_layer = aws_lambda.LayerVersion( 78 | self, 79 | "XRayLayer", 80 | code=aws_lambda.AssetCode.from_asset("layers"), 81 | compatible_runtimes=[aws_lambda.Runtime.PYTHON_3_7], 82 | description="A layer containing AWS X-Ray SDK for Python", 83 | ) 84 | 85 | # Create Lambda function to read from assignment and metrics table, log metrics 86 | # 2048MB is ~3% higher than 768 MB, it runs 2.5x faster 87 | # https://aws.amazon.com/blogs/aws/new-for-aws-lambda-functions-with-up-to-10-gb-of-memory-and-6-vcpus/ 88 | lambda_invoke = aws_lambda.Function( 89 | self, 90 | "ApiFunction", 91 | code=aws_lambda.AssetCode.from_asset("lambda/api"), 92 | handler="lambda_invoke.lambda_handler", 93 | runtime=aws_lambda.Runtime.PYTHON_3_7, 94 | timeout=core.Duration.seconds(api_lambda_timeout), 95 | memory_size=api_lambda_memory, 96 | environment={ 97 | "ASSIGNMENT_TABLE": assignment_table.table_name, 98 | "METRICS_TABLE": metrics_table.table_name, 99 | "DELIVERY_STREAM_NAME": delivery_stream_name, 100 | "DELIVERY_SYNC": "true" if delivery_sync else "false", 101 | "LOG_LEVEL": log_level, 102 | }, 103 | layers=[xray_layer], 104 | tracing=aws_lambda.Tracing.ACTIVE, 105 | ) 106 | 107 | # Grant read/write permissions to assignment and metrics tables 108 | assignment_table.grant_read_data(lambda_invoke) 109 | assignment_table.grant_write_data(lambda_invoke) 110 | metrics_table.grant_read_data(lambda_invoke) 111 | 112 | # Add sagemaker invoke 113 | lambda_invoke.add_to_role_policy( 114 | aws_iam.PolicyStatement( 115 | actions=[ 116 | "sagemaker:InvokeEndpoint", 117 | ], 118 | resources=[ 119 | "arn:aws:sagemaker:{}:{}:endpoint/{}*".format( 120 | self.region, self.account, endpoint_prefix 121 | ) 122 | ], 123 | ) 124 | ) 125 | 126 | # Create API Gateway for api lambda, which will create an output 127 | aws_apigateway.LambdaRestApi( 128 | self, 129 | "Api", 130 | rest_api_name=api_name, 131 | deploy_options=aws_apigateway.StageOptions(stage_name=stage_name), 132 | proxy=True, 133 | handler=lambda_invoke, 134 | ) 135 | 136 | # Create lambda function for processing metrics 137 | lambda_register = aws_lambda.Function( 138 | self, 139 | "RegisterFunction", 140 | code=aws_lambda.AssetCode.from_asset("lambda/api"), 141 | handler="lambda_register.lambda_handler", 142 | runtime=aws_lambda.Runtime.PYTHON_3_7, 143 | timeout=core.Duration.seconds(metrics_lambda_timeout), 144 | memory_size=metrics_lambda_memory, 145 | environment={ 146 | "METRICS_TABLE": metrics_table.table_name, 147 | "DELIVERY_STREAM_NAME": delivery_stream_name, 148 | "STAGE_NAME": stage_name, 149 | "LOG_LEVEL": log_level, 150 | "ENDPOINT_PREFIX": endpoint_prefix, 151 | }, 152 | layers=[xray_layer], 153 | tracing=aws_lambda.Tracing.ACTIVE, 154 | ) 155 | 156 | # Add write metrics 157 | metrics_table.grant_write_data(lambda_register) 158 | 159 | # Add sagemaker invoke 160 | lambda_register.add_to_role_policy( 161 | aws_iam.PolicyStatement( 162 | actions=[ 163 | "sagemaker:DescribeEndpoint", 164 | ], 165 | resources=[ 166 | "arn:aws:sagemaker:{}:{}:endpoint/{}*".format( 167 | self.region, self.account, endpoint_prefix 168 | ) 169 | ], 170 | ) 171 | ) 172 | 173 | # Add endpoint event rule to register endpoints that are created or updated. 174 | # Note CDK is unable to filter on resource prefixes, so we will need to filter on this within the RegisterLambda function. 175 | # see: https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-event-patterns-content-based-filtering.html#filtering-prefix-matching 176 | endpoint_rule = events.Rule( 177 | self, 178 | "EndpointRule", 179 | rule_name=f"sagemaker-{api_name}-endpoint-{stage_name}", 180 | description="Rule to register an Amazon SageMaker Endpoint when it is created, updated or deleted.", 181 | event_pattern=events.EventPattern( 182 | source=["aws.sagemaker"], 183 | detail_type=[ 184 | "SageMaker Endpoint State Change", 185 | ], 186 | detail={ 187 | "EndpointStatus": ["IN_SERVICE", "DELETING"], 188 | }, 189 | ), 190 | targets=[targets.LambdaFunction(lambda_register)], 191 | ) 192 | 193 | # Return the register lambda function as output 194 | core.CfnOutput(self, "RegisterLambda", value=lambda_register.function_name) 195 | 196 | # Get cloudwatch put metrics policy () 197 | cloudwatch_metric_policy = aws_iam.PolicyStatement( 198 | actions=["cloudwatch:PutMetricData"], resources=["*"] 199 | ) 200 | 201 | # If we are only using sync delivery, don't require firehose or s3 buckets 202 | if delivery_sync: 203 | metrics_table.grant_write_data(lambda_invoke) 204 | lambda_invoke.add_to_role_policy(cloudwatch_metric_policy) 205 | print("# No Firehose") 206 | return 207 | 208 | # Add kinesis stream logging 209 | lambda_invoke.add_to_role_policy( 210 | aws_iam.PolicyStatement( 211 | actions=[ 212 | "firehose:PutRecord", 213 | ], 214 | resources=[ 215 | "arn:aws:firehose:{}:{}:deliverystream/{}".format( 216 | self.region, self.account, delivery_stream_name 217 | ), 218 | ], 219 | ) 220 | ) 221 | 222 | # Create s3 bucket for event logging (name must be < 63 chars) 223 | s3_logs = aws_s3.Bucket( 224 | self, 225 | "S3Logs", 226 | removal_policy=core.RemovalPolicy.DESTROY, 227 | ) 228 | 229 | firehose_role = aws_iam.Role( 230 | self, 231 | "KinesisFirehoseRole", 232 | assumed_by=aws_iam.ServicePrincipal("firehose.amazonaws.com"), 233 | ) 234 | 235 | firehose_role.add_to_policy( 236 | aws_iam.PolicyStatement( 237 | actions=[ 238 | "s3:AbortMultipartUpload", 239 | "s3:GetBucketLocation", 240 | "s3:GetObject", 241 | "s3:ListBucket", 242 | "s3:ListBucketMultipartUploads", 243 | "s3:PutObject", 244 | ], 245 | resources=[s3_logs.bucket_arn, f"{s3_logs.bucket_arn}/*"], 246 | ) 247 | ) 248 | 249 | # Create LogGroup and Stream, and add permissions to role 250 | firehose_log_group = aws_logs.LogGroup(self, "FirehoseLogGroup") 251 | firehose_log_stream = firehose_log_group.add_stream(log_stream_name) 252 | 253 | firehose_role.add_to_policy( 254 | aws_iam.PolicyStatement( 255 | actions=[ 256 | "logs:PutLogEvents", 257 | ], 258 | resources=[ 259 | f"arn:{self.partition}:logs:{self.region}:{self.account}:log-group:{firehose_log_group.log_group_name}:log-stream:{firehose_log_stream.log_stream_name}", 260 | ], 261 | ) 262 | ) 263 | 264 | # Creat the firehose delivery stream with s3 destination 265 | aws_kinesisfirehose.CfnDeliveryStream( 266 | self, 267 | "KensisLogs", 268 | delivery_stream_name=delivery_stream_name, 269 | s3_destination_configuration=aws_kinesisfirehose.CfnDeliveryStream.S3DestinationConfigurationProperty( 270 | bucket_arn=s3_logs.bucket_arn, 271 | compression_format="GZIP", 272 | role_arn=firehose_role.role_arn, 273 | prefix=f"{stage_name}/", 274 | cloud_watch_logging_options=aws_kinesisfirehose.CfnDeliveryStream.CloudWatchLoggingOptionsProperty( 275 | enabled=True, 276 | log_group_name=firehose_log_group.log_group_name, 277 | log_stream_name=firehose_log_stream.log_stream_name, 278 | ), 279 | buffering_hints=aws_kinesisfirehose.CfnDeliveryStream.BufferingHintsProperty( 280 | interval_in_seconds=firehose_interval, 281 | size_in_m_bs=firehose_mb_size, 282 | ), 283 | ), 284 | ) 285 | 286 | # Create lambda function for processing metrics 287 | lambda_metrics = aws_lambda.Function( 288 | self, 289 | "MetricsFunction", 290 | code=aws_lambda.AssetCode.from_asset("lambda/api"), 291 | handler="lambda_metrics.lambda_handler", 292 | runtime=aws_lambda.Runtime.PYTHON_3_7, 293 | timeout=core.Duration.seconds(metrics_lambda_timeout), 294 | memory_size=metrics_lambda_memory, 295 | environment={ 296 | "METRICS_TABLE": metrics_table.table_name, 297 | "DELIVERY_STREAM_NAME": delivery_stream_name, 298 | "LOG_LEVEL": log_level, 299 | }, 300 | layers=[xray_layer], 301 | tracing=aws_lambda.Tracing.ACTIVE, 302 | ) 303 | 304 | # Add write metrics for dynamodb table 305 | metrics_table.grant_write_data(lambda_metrics) 306 | 307 | # Add put metrics for cloudwatch 308 | lambda_metrics.add_to_role_policy(cloudwatch_metric_policy) 309 | 310 | # Allow metrics to read form S3 and write to DynamoDB 311 | s3_logs.grant_read(lambda_metrics) 312 | 313 | # Create S3 logs notification for processing lambda 314 | notification = aws_s3_notifications.LambdaDestination(lambda_metrics) 315 | s3_logs.add_event_notification(aws_s3.EventType.OBJECT_CREATED, notification) 316 | -------------------------------------------------------------------------------- /infra/pipeline_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | core, 3 | aws_iam, 4 | aws_cloudformation as cloudformation, 5 | aws_events as events, 6 | aws_events_targets as targets, 7 | aws_codebuild as codebuild, 8 | aws_codecommit as codecommit, 9 | aws_codepipeline as codepipeline, 10 | aws_codepipeline_actions as codepipeline_actions, 11 | aws_lambda as lambda_, 12 | aws_s3 as s3, 13 | aws_s3_assets as s3_assets, 14 | ) 15 | 16 | 17 | class PipelineStack(core.Stack): 18 | def __init__( 19 | self, 20 | scope: core.Construct, 21 | construct_id: str, 22 | # deployment_asset: s3_assets.Asset, 23 | **kwargs, 24 | ) -> None: 25 | super().__init__(scope, construct_id, **kwargs) 26 | 27 | # Create Required parameters for sagemaker projects 28 | # see: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-projects-templates-custom.html 29 | # see also: # https://docs.aws.amazon.com/cdk/latest/guide/parameters.html 30 | project_name = core.CfnParameter( 31 | self, 32 | "SageMakerProjectName", 33 | type="String", 34 | description="The name of the SageMaker project.", 35 | min_length=1, 36 | max_length=32, 37 | ) 38 | project_id = core.CfnParameter( 39 | self, 40 | "SageMakerProjectId", 41 | type="String", 42 | min_length=1, 43 | max_length=16, 44 | description="Service generated Id of the project.", 45 | ) 46 | stage_name = core.CfnParameter( 47 | self, 48 | "StageName", 49 | type="String", 50 | min_length=1, 51 | max_length=8, 52 | description="The stage name.", 53 | default="dev", 54 | ) 55 | seed_bucket = core.CfnParameter( 56 | self, 57 | "CodeCommitSeedBucket", 58 | type="String", 59 | description="The optional s3 seed bucket", 60 | min_length=1, 61 | ) 62 | seed_key = core.CfnParameter( 63 | self, 64 | "CodeCommitSeedKey", 65 | type="String", 66 | description="The optional s3 seed key", 67 | min_length=1, 68 | ) 69 | 70 | # Get the service catalog role for all permssions (if None CDK will create new roles) 71 | # CodeBuild and CodePipeline resources need to start with "sagemaker-" to be within default policy 72 | service_catalog_role = aws_iam.Role.from_role_arn( 73 | self, 74 | "PipelineRole", 75 | f"arn:{self.partition}:iam::{self.account}:role/service-role/AmazonSageMakerServiceCatalogProductsUseRole", 76 | ) 77 | 78 | # Define the repository name and branch 79 | branch_name = "main" 80 | 81 | # Create source repo from seed bucket/key 82 | repo = codecommit.CfnRepository( 83 | self, 84 | "CodeRepo", 85 | repository_name="sagemaker-{}-repo".format(project_name.value_as_string), 86 | repository_description="Amazon SageMaker A/B testing pipeline", 87 | code=codecommit.CfnRepository.CodeProperty( 88 | s3=codecommit.CfnRepository.S3Property( 89 | bucket=seed_bucket.value_as_string, 90 | key=seed_key.value_as_string, 91 | object_version=None, 92 | ), 93 | branch_name=branch_name, 94 | ), 95 | tags=[ 96 | core.CfnTag( 97 | key="sagemaker:deployment-stage", value=stage_name.value_as_string 98 | ), 99 | core.CfnTag( 100 | key="sagemaker:project-id", value=project_id.value_as_string 101 | ), 102 | core.CfnTag( 103 | key="sagemaker:project-name", value=project_name.value_as_string 104 | ), 105 | ], 106 | ) 107 | 108 | # Reference the newly created repository 109 | code = codecommit.Repository.from_repository_name( 110 | self, "ImportedRepo", repo.attr_name 111 | ) 112 | 113 | cdk_build = codebuild.PipelineProject( 114 | self, 115 | "CdkBuild", 116 | project_name="sagemaker-{}-cdk-{}".format( 117 | project_name.value_as_string, stage_name.value_as_string 118 | ), 119 | role=service_catalog_role, 120 | build_spec=codebuild.BuildSpec.from_object( 121 | dict( 122 | version="0.2", 123 | phases=dict( 124 | install=dict( 125 | commands=[ 126 | "npm install aws-cdk", 127 | "npm update", 128 | "python -m pip install -r requirements.txt", 129 | ] 130 | ), 131 | build=dict( 132 | commands=[ 133 | "npx cdk synth -o dist --path-metadata false", 134 | ] 135 | ), 136 | ), 137 | artifacts={ 138 | "base-directory": "dist", 139 | "files": ["*.template.json"], 140 | }, 141 | environment=dict( 142 | buildImage=codebuild.LinuxBuildImage.AMAZON_LINUX_2_3, 143 | ), 144 | ) 145 | ), 146 | environment_variables={ 147 | "SAGEMAKER_PROJECT_NAME": codebuild.BuildEnvironmentVariable( 148 | value=project_name.value_as_string 149 | ), 150 | "SAGEMAKER_PROJECT_ID": codebuild.BuildEnvironmentVariable( 151 | value=project_id.value_as_string 152 | ), 153 | "STAGE_NAME": codebuild.BuildEnvironmentVariable( 154 | value=stage_name.value_as_string 155 | ), 156 | }, 157 | ) 158 | 159 | source_output = codepipeline.Artifact() 160 | cdk_build_output = codepipeline.Artifact() 161 | 162 | # Create the s3 artifact (name must be < 63 chars) 163 | s3_artifact = s3.Bucket( 164 | self, 165 | "S3Artifact", 166 | bucket_name="sagemaker-{}-artifact-{}-{}".format( 167 | project_id.value_as_string, stage_name.value_as_string, self.region 168 | ), 169 | removal_policy=core.RemovalPolicy.DESTROY, 170 | ) 171 | 172 | deploy_pipeline = codepipeline.Pipeline( 173 | self, 174 | "Pipeline", 175 | role=service_catalog_role, 176 | artifact_bucket=s3_artifact, 177 | pipeline_name="sagemaker-{}-pipeline-{}".format( 178 | project_name.value_as_string, stage_name.value_as_string 179 | ), 180 | stages=[ 181 | codepipeline.StageProps( 182 | stage_name="Source", 183 | actions=[ 184 | codepipeline_actions.CodeCommitSourceAction( 185 | action_name="CodeCommit_Source", 186 | repository=code, 187 | trigger=codepipeline_actions.CodeCommitTrigger.NONE, # Created below 188 | event_role=service_catalog_role, 189 | output=source_output, 190 | branch=branch_name, 191 | role=service_catalog_role, 192 | ) 193 | ], 194 | ), 195 | codepipeline.StageProps( 196 | stage_name="Build", 197 | actions=[ 198 | codepipeline_actions.CodeBuildAction( 199 | action_name="CDK_Build", 200 | project=cdk_build, 201 | input=source_output, 202 | outputs=[ 203 | cdk_build_output, 204 | ], 205 | role=service_catalog_role, 206 | ), 207 | ], 208 | ), 209 | codepipeline.StageProps( 210 | stage_name="Deploy", 211 | actions=[ 212 | codepipeline_actions.CloudFormationCreateUpdateStackAction( 213 | action_name="SageMaker_CFN_Deploy", 214 | run_order=1, 215 | template_path=cdk_build_output.at_path( 216 | "ab-testing-sagemaker.template.json" 217 | ), 218 | stack_name="sagemaker-{}-deploy-{}".format( 219 | project_name.value_as_string, stage_name.value_as_string 220 | ), 221 | admin_permissions=False, 222 | role=service_catalog_role, 223 | deployment_role=service_catalog_role, 224 | replace_on_failure=True, 225 | ), 226 | ], 227 | ), 228 | ], 229 | ) 230 | 231 | # Add deploy role to target the code pipeline when model package is approved 232 | deploy_rule = events.Rule( 233 | self, 234 | "DeployRule", 235 | rule_name="sagemaker-{}-model-{}".format( 236 | project_name.value_as_string, stage_name.value_as_string 237 | ), 238 | description="Rule to trigger a deployment when SageMaker Model registry is updated with a new model package.", 239 | event_pattern=events.EventPattern( 240 | source=["aws.sagemaker"], 241 | detail_type=["SageMaker Model Package State Change"], 242 | detail={ 243 | "ModelPackageGroupName": [ 244 | f"{project_name.value_as_string}-champion", 245 | f"{project_name.value_as_string}-challenger", 246 | ] 247 | }, 248 | ), 249 | targets=[ 250 | targets.CodePipeline( 251 | pipeline=deploy_pipeline, 252 | event_role=service_catalog_role, 253 | ) 254 | ], 255 | ) 256 | 257 | code_rule = events.Rule( 258 | self, 259 | "CodeRule", 260 | rule_name="sagemaker-{}-code-{}".format( 261 | project_name.value_as_string, stage_name.value_as_string 262 | ), 263 | description="Rule to trigger a deployment when deployment configured is updated in CodeCommit.", 264 | event_pattern=events.EventPattern( 265 | source=["aws.codecommit"], 266 | detail_type=["CodeCommit Repository State Change"], 267 | detail={ 268 | "event": ["referenceCreated", "referenceUpdated"], 269 | "referenceType": ["branch"], 270 | "referenceName": [branch_name], 271 | }, 272 | resources=[code.repository_arn], 273 | ), 274 | targets=[ 275 | targets.CodePipeline( 276 | pipeline=deploy_pipeline, 277 | event_role=service_catalog_role, 278 | ) 279 | ], 280 | ) 281 | -------------------------------------------------------------------------------- /infra/service_catalog.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | core, 3 | aws_iam, 4 | aws_s3_assets, 5 | aws_servicecatalog, 6 | ) 7 | 8 | # Create a Portfolio and Product 9 | # see: https://docs.aws.amazon.com/cdk/api/latest/python/aws_cdk.aws_servicecatalog.html 10 | # see also: https://github.com/mattmcclean/cdk-mlops-sm-project-template/blob/main/lib/mlops-sc-portfolio-stack.ts 11 | 12 | 13 | class ServiceCatalogStack(core.Stack): 14 | def __init__( 15 | self, 16 | scope: core.Construct, 17 | construct_id: str, 18 | **kwargs, 19 | ) -> None: 20 | super().__init__(scope, construct_id, **kwargs) 21 | 22 | execution_role_arn = core.CfnParameter( 23 | self, 24 | "ExecutionRoleArn", 25 | type="String", 26 | description="The SageMaker Studio execution role", 27 | ) 28 | 29 | portfolio_name = core.CfnParameter( 30 | self, 31 | "PortfolioName", 32 | type="String", 33 | description="The name of the portfolio", 34 | default="SageMaker Organization Templates", 35 | ) 36 | 37 | portfolio_owner = core.CfnParameter( 38 | self, 39 | "PortfolioOwner", 40 | type="String", 41 | description="The owner of the portfolio.", 42 | default="administrator", 43 | ) 44 | 45 | product_version = core.CfnParameter( 46 | self, 47 | "ProductVersion", 48 | type="String", 49 | description="The product version to deploy", 50 | default="1.0", 51 | ) 52 | 53 | portfolio = aws_servicecatalog.CfnPortfolio( 54 | self, 55 | "Portfolio", 56 | display_name=portfolio_name.value_as_string, 57 | description="Organization templates for AB Testing pipeline", 58 | provider_name=portfolio_owner.value_as_string, 59 | ) 60 | 61 | asset = aws_s3_assets.Asset( 62 | self, "TemplateAsset", path="./ab-testing-pipeline.yml" 63 | ) 64 | 65 | product = aws_servicecatalog.CfnCloudFormationProduct( 66 | self, 67 | "Product", 68 | name="A/B Testing Deployment Pipeline", 69 | description="Amazon SageMaker Project for A/B Testing models", 70 | owner=portfolio_owner.value_as_string, 71 | provisioning_artifact_parameters=[ 72 | aws_servicecatalog.CfnCloudFormationProduct.ProvisioningArtifactPropertiesProperty( 73 | name=product_version.value_as_string, 74 | info={"LoadTemplateFromURL": asset.s3_url}, 75 | ), 76 | ], 77 | tags=[ 78 | core.CfnTag(key="sagemaker:studio-visibility", value="true"), 79 | ], 80 | ) 81 | 82 | aws_servicecatalog.CfnPortfolioProductAssociation( 83 | self, 84 | "ProductAssoication", 85 | portfolio_id=portfolio.ref, 86 | product_id=product.ref, 87 | ) 88 | 89 | launch_role = aws_iam.Role.from_role_arn( 90 | self, 91 | "LaunchRole", 92 | role_arn=f"arn:{self.partition}:iam::{self.account}:role/service-role/AmazonSageMakerServiceCatalogProductsLaunchRole", 93 | ) 94 | 95 | portfolio_association = aws_servicecatalog.CfnPortfolioPrincipalAssociation( 96 | self, 97 | "PortfolioPrincipalAssociation", 98 | portfolio_id=portfolio.ref, 99 | principal_arn=execution_role_arn.value_as_string, 100 | principal_type="IAM", 101 | ) 102 | portfolio_association.add_depends_on(product) 103 | 104 | # Ensure we run the LaunchRoleConstrait last as there are timing issues on product/portfolio being created 105 | role_constraint = aws_servicecatalog.CfnLaunchRoleConstraint( 106 | self, 107 | "LaunchRoleConstraint", 108 | portfolio_id=portfolio.ref, 109 | product_id=product.ref, 110 | role_arn=launch_role.role_arn, 111 | description=f"Launch as {launch_role.role_arn}", 112 | ) 113 | role_constraint.add_depends_on(portfolio_association) 114 | 115 | # Create the deployment asset as an output to pass to pipeline stack 116 | deployment_asset = aws_s3_assets.Asset( 117 | self, "DeploymentAsset", path="./deployment_pipeline" 118 | ) 119 | 120 | deployment_asset.grant_read(grantee=launch_role) 121 | 122 | # Ouput the deployment bucket and key, for input into pipeline stack 123 | core.CfnOutput( 124 | self, 125 | "CodeCommitSeedBucket", 126 | value=deployment_asset.s3_bucket_name, 127 | ) 128 | core.CfnOutput(self, "CodeCommitSeedKey", value=deployment_asset.s3_object_key) 129 | -------------------------------------------------------------------------------- /install_layers.sh: -------------------------------------------------------------------------------- 1 | # see: https://github.com/awsdocs/aws-lambda-developer-guide/tree/main/sample-apps/blank-python 2 | cd layers 3 | rm -rf ./python *.zip 4 | pip install -t ./python -r requirements.txt -------------------------------------------------------------------------------- /lambda/api/algorithm.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | # Contains pure python class implementations for WeightedSampling, EpsilonGreedy, UCB1 and ThompsonSampling. 5 | # For maths and theory behind these algorithms see the following resource: 6 | # https://lilianweng.github.io/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#ucb1 7 | 8 | 9 | class AlgorithmBase: 10 | """ 11 | Base class for implementing the following bandit strategiesß 12 | 1. Epsilom Greedy 13 | 2. UCB 14 | 3. Thompson Smampling 15 | """ 16 | 17 | def __init__(self, variant_metrics: list): 18 | pass 19 | 20 | @staticmethod 21 | def argmax(a): 22 | """ 23 | This is a pure-python version of the np.argmax() given we don't support numpy. 24 | """ 25 | return max(range(len(a)), key=lambda x: a[x]) 26 | 27 | @staticmethod 28 | def random_beta(alpha, beta): 29 | """ 30 | Pure python implement of random beta 31 | """ 32 | return random.betavariate(alpha, beta) 33 | 34 | 35 | class WeightedSampling(AlgorithmBase): 36 | STRATEGY_NAME = "WeightedSampling" 37 | 38 | def __init__(self, variant_metrics: list): 39 | if len(variant_metrics) == 0: 40 | raise Exception("Require at least one encpoint variant") 41 | self.variant_metrics = variant_metrics 42 | 43 | def select_variant(self): 44 | variant_names = [ev["variant_name"] for ev in self.variant_metrics] 45 | variant_weights = [ev["initial_variant_weight"] for ev in self.variant_metrics] 46 | return random.choices(variant_names, weights=variant_weights)[0] 47 | 48 | 49 | class EpsilonGreedy(AlgorithmBase): 50 | STRATEGY_NAME = "EpsilonGreedy" 51 | 52 | def __init__(self, variant_metrics: list, epsilon: float): 53 | if len(variant_metrics) == 0: 54 | raise Exception("Require at least one endpoint variant") 55 | self.variant_metrics = variant_metrics 56 | if epsilon < 0 or epsilon > 1: 57 | raise Exception("Epsilon must be value between 0 and 1") 58 | self.epsilon = epsilon 59 | 60 | def select_variant(self): 61 | """ 62 | The Epsilon-Greedy algorithm balances exploitation and exploration fairly basically. 63 | It takes a parameter, epsilon, between 0 and 1, as the probability of exploring the variants 64 | as opposed to exploiting the current best variant in the test. 65 | """ 66 | if random.random() > self.epsilon: 67 | rates = [ 68 | 1.0 * v["reward_sum"] / v["invocation_count"] 69 | for v in self.variant_metrics 70 | ] 71 | variant_index = AlgorithmBase.argmax(rates) 72 | else: 73 | variant_index = random.randrange(len(self.variant_metrics)) 74 | return self.variant_metrics[variant_index]["variant_name"] 75 | 76 | 77 | class UCB1(AlgorithmBase): 78 | STRATEGY_NAME = "UCB1" 79 | 80 | def __init__(self, variant_metrics: list): 81 | if len(variant_metrics) == 0: 82 | raise Exception("Require at least one endpoint variant") 83 | self.variant_metrics = variant_metrics 84 | 85 | def select_variant(self): 86 | """ 87 | UCB1 algorithm is its “curiosity bonus”. When selecting an arm, 88 | it takes the expected reward of each arm and then adds a bonus 89 | which is calculated in inverse proportion to the confidence of that reward. 90 | It is optimistic about uncertainty. So lower confidence arms are given a bit 91 | of a boost relative to higher confidence arms. 92 | """ 93 | invocation_total = sum([v["invocation_count"] for v in self.variant_metrics]) 94 | ucb_values = [] 95 | for v in self.variant_metrics: 96 | curiosity_bonus = math.sqrt( 97 | (2 * math.log(invocation_total)) / float(v["invocation_count"]) 98 | ) 99 | rate = 1.0 * v["reward_sum"] / v["invocation_count"] 100 | ucb_values.append(rate + curiosity_bonus) 101 | variant_index = AlgorithmBase.argmax(ucb_values) 102 | return self.variant_metrics[variant_index]["variant_name"] 103 | 104 | 105 | class ThompsonSampling(AlgorithmBase): 106 | STRATEGY_NAME = "ThompsonSampling" 107 | 108 | def __init__(self, variant_metrics: list): 109 | if len(variant_metrics) == 0: 110 | raise Exception("Require at least one endpoint variant") 111 | self.variant_metrics = variant_metrics 112 | 113 | def select_variant(self): 114 | """ 115 | Tompson sampling uses Beta distribution takes two parameters, ‘α’ (alpha) and ‘β’ (beta). 116 | In the simplest terms these parameters can be thought of as respectively the count of successes and failures. 117 | see: https://towardsdatascience.com/thompson-sampling-fc28817eacb8 118 | """ 119 | probs = [] 120 | for v in self.variant_metrics: 121 | success = v["reward_sum"] 122 | failure = v["invocation_count"] - success 123 | probs.append(AlgorithmBase.random_beta(1 + success, 1 + failure)) 124 | variant_index = AlgorithmBase.argmax(probs) 125 | return self.variant_metrics[variant_index]["variant_name"] 126 | -------------------------------------------------------------------------------- /lambda/api/experiment_assignment.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from datetime import datetime, timedelta 3 | 4 | 5 | def get_ttl(days=90): 6 | return int((datetime.utcnow() + timedelta(days=days)).timestamp()) 7 | 8 | 9 | class ExperimentAssignment: 10 | """ 11 | Class for managing experiments 12 | """ 13 | 14 | def __init__( 15 | self, 16 | assignment_table: str, 17 | ): 18 | self.assignment_table = assignment_table 19 | self.dynamodb = boto3.resource("dynamodb") 20 | self.ddb_client = boto3.client("dynamodb") 21 | 22 | def get_assignment(self, user_id: str, endpoint_name: str): 23 | table = self.dynamodb.Table(self.assignment_table) 24 | response = table.get_item( 25 | Key={ 26 | "user_id": user_id, 27 | "endpoint_name": endpoint_name, 28 | }, 29 | AttributesToGet=["variant_name"], 30 | ) 31 | if "Item" in response: 32 | return response["Item"]["variant_name"] 33 | return None 34 | 35 | def put_assignment( 36 | self, user_id: str, endpoint_name: str, variant_name: str, ttl=get_ttl() 37 | ): 38 | """ 39 | Put the user endpoint variant with a time to live 40 | """ 41 | table = self.dynamodb.Table(self.assignment_table) 42 | response = table.put_item( 43 | Item={ 44 | "user_id": user_id, 45 | "endpoint_name": endpoint_name, 46 | "variant_name": variant_name, 47 | "ttl": ttl, 48 | } 49 | ) 50 | return response 51 | -------------------------------------------------------------------------------- /lambda/api/experiment_metrics.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from decimal import Decimal 3 | from itertools import groupby 4 | import json 5 | import logging 6 | from time import time 7 | from datetime import datetime 8 | 9 | 10 | class ExperimentMetrics: 11 | """ 12 | Class for getting and updating experiment metrics 13 | """ 14 | 15 | def __init__( 16 | self, metrics_table: str, delivery_stream_name: str, synchronous: bool = False 17 | ): 18 | self.metrics_table = metrics_table 19 | self.delivery_stream_name = delivery_stream_name 20 | self.synchronous = synchronous 21 | self.dynamodb = boto3.resource("dynamodb") 22 | self.ddb_client = boto3.client("dynamodb") 23 | self.firehose = boto3.client("firehose") 24 | self.cloudwatch = boto3.client("cloudwatch") 25 | 26 | def create_variant_metrics( 27 | self, 28 | endpoint_name: str, 29 | endpoint_variants: list, 30 | strategy: str, 31 | epsilon: float, 32 | warmup: int, 33 | timestamp: int = int(time()), 34 | ): 35 | logging.debug(f"Get metrics for endpoint: {endpoint_name}") 36 | table = self.dynamodb.Table(self.metrics_table) 37 | # Format variants as a dictionary for persistence with only the initial weight 38 | variant_names = [v["variant_name"] for v in endpoint_variants] 39 | variant_metrics = dict( 40 | [ 41 | ( 42 | v["variant_name"], 43 | { 44 | "initial_variant_weight": Decimal( 45 | str(v["initial_variant_weight"]) 46 | ) 47 | }, 48 | ) 49 | for v in endpoint_variants 50 | ] 51 | ) 52 | logging.debug(variant_metrics) 53 | response = table.put_item( 54 | Item={ 55 | "endpoint_name": endpoint_name, 56 | "strategy": strategy, 57 | "variant_names": variant_names, 58 | "variant_metrics": variant_metrics, 59 | "epsilon": Decimal(str(epsilon)), 60 | "warmup": warmup, 61 | "created_at": timestamp, 62 | }, 63 | ReturnValues="ALL_OLD", 64 | ReturnConsumedCapacity="TOTAL", 65 | ) 66 | return response 67 | 68 | def delete_endpoint( 69 | self, 70 | endpoint_name: str, 71 | timestamp: int = int(time()), 72 | ): 73 | logging.debug(f"Delete endpoint: {endpoint_name}") 74 | table = self.dynamodb.Table(self.metrics_table) 75 | # Set the deleted_at property in DDB for this endpoint 76 | response = table.update_item( 77 | Key={"endpoint_name": endpoint_name}, 78 | UpdateExpression="SET deleted_at = :now ", 79 | ExpressionAttributeValues={ 80 | ":now": timestamp, 81 | }, 82 | ReturnValues="UPDATED_NEW", 83 | ) 84 | return response 85 | 86 | def get_variant_metrics(self, endpoint_name): 87 | """ 88 | Return the strategy and the list of varints, with the counts defaulted to zero if not exist 89 | """ 90 | table = self.dynamodb.Table(self.metrics_table) 91 | response = table.get_item( 92 | Key={ 93 | "endpoint_name": endpoint_name, 94 | }, 95 | ReturnConsumedCapacity="TOTAL", 96 | ) 97 | # Return the list of invocation and success counts per variant 98 | if "Item" not in response: 99 | raise Exception(f"Endpoint {endpoint_name} not found") 100 | 101 | strategy = response["Item"]["strategy"] 102 | epsilon = float(response["Item"]["epsilon"]) 103 | warmup = int(response["Item"]["warmup"]) 104 | variant_names = response["Item"]["variant_names"] 105 | variant_metrics = response["Item"]["variant_metrics"] 106 | metrics = [ 107 | { 108 | "endpoint_name": endpoint_name, 109 | "variant_name": v, 110 | "initial_variant_weight": float( 111 | variant_metrics[v]["initial_variant_weight"] 112 | ), 113 | "invocation_count": int(variant_metrics[v].get("invocation_count", 0)), 114 | "conversion_count": int(variant_metrics[v].get("conversion_count", 0)), 115 | "reward_sum": float(variant_metrics[v].get("reward_sum", 0)), 116 | } 117 | for v in variant_names 118 | ] 119 | return strategy, epsilon, warmup, metrics 120 | 121 | def put_cloudwatch_metric( 122 | self, 123 | metric_name: str, 124 | endpoint_name: str, 125 | variant_name: str, 126 | metric_value: float, 127 | dt: datetime = datetime.now(), 128 | ): 129 | logging.debug( 130 | f"Putting metric: {metric_value} for {metric_name} on endpoint: {endpoint_name}, variant: variant_name at {dt}" 131 | ) 132 | response = self.cloudwatch.put_metric_data( 133 | Namespace="aws/sagemaker/Endpoints/ab-testing", # Use a sub-namespace under SageMaker endpoints 134 | MetricData=[ 135 | { 136 | "MetricName": metric_name, 137 | "Dimensions": [ 138 | {"Name": "EndpointName", "Value": endpoint_name}, 139 | { 140 | "Name": "VariantName", 141 | "Value": variant_name, 142 | }, 143 | ], 144 | "Timestamp": dt, 145 | "Value": metric_value, 146 | "Unit": "Count", 147 | }, 148 | ], 149 | ) 150 | logging.debug(response) 151 | 152 | def update_variant_metrics(self, metrics: list, timestamp=int(time())): 153 | """ 154 | Group by endpoint variants and metric type to increment dynamodb counts 155 | """ 156 | table = self.dynamodb.Table(self.metrics_table) 157 | 158 | # Sort the list by endpoint_name and variant_name first to ensure groupby is efficient 159 | metrics = sorted( 160 | metrics, key=lambda m: (m["endpoint_name"], m["endpoint_variant"]) 161 | ) 162 | 163 | responses = [] 164 | for (endpoint_name, variant_name), vg in groupby( 165 | metrics, lambda m: (m["endpoint_name"], m["endpoint_variant"]) 166 | ): 167 | # Get the total invocation and rewards 168 | invocation_count = 0 169 | conversion_count = 0 170 | reward_sum = 0.0 171 | for m in vg: 172 | if m["type"] == "invocation": 173 | invocation_count += 1 174 | elif m["type"] == "conversion": 175 | conversion_count += 1 176 | reward_sum += m["reward"] 177 | else: 178 | raise Exception("Unsupported type {}".format(m["type"])) 179 | logging.debug( 180 | f"Update metrics for endpoint: {endpoint_name}, variant: {variant_name} invocations: {invocation_count}, conversions: {conversion_count}, rewards: {reward_sum}" 181 | ) 182 | # Update variant in dynamo db with these counts 183 | response = table.update_item( 184 | Key={"endpoint_name": endpoint_name}, 185 | UpdateExpression="ADD variant_metrics.#variant.invocation_count :i, " 186 | "variant_metrics.#variant.conversion_count :c, " 187 | "variant_metrics.#variant.reward_sum :r " 188 | "SET #created_at = if_not_exists(#created_at, :now), #updated_at = :now ", 189 | ExpressionAttributeNames={ 190 | "#variant": variant_name, 191 | "#created_at": "created_at", 192 | "#updated_at": "updated_at", 193 | }, 194 | ExpressionAttributeValues={ 195 | ":i": int(invocation_count), 196 | ":c": int(conversion_count), 197 | ":r": Decimal(str(reward_sum)), 198 | ":now": timestamp, 199 | }, 200 | ReturnValues="UPDATED_NEW", 201 | ) 202 | 203 | # Return total counts per endpoint_name and endpoint_variant 204 | logging.debug(response) 205 | metrics = response["Attributes"]["variant_metrics"][variant_name] 206 | new_counts = { 207 | "endpoint_name": endpoint_name, 208 | "endpoint_variant": variant_name, 209 | "invocation_count": metrics.get("invocation_count", 0), 210 | "conversion_count": metrics.get("conversion_count", 0), 211 | "reward_sum": metrics.get("reward_sum", 0.0), 212 | } 213 | responses.append(new_counts) 214 | 215 | # Put cloudwatch metrics against this timestamp 216 | dt = datetime.fromtimestamp(timestamp) 217 | if invocation_count > 0: 218 | self.put_cloudwatch_metric( 219 | "Invocations", endpoint_name, variant_name, invocation_count, dt 220 | ) 221 | if conversion_count > 0: 222 | self.put_cloudwatch_metric( 223 | "Conversions", endpoint_name, variant_name, conversion_count, dt 224 | ) 225 | self.put_cloudwatch_metric( 226 | "Rewards", endpoint_name, variant_name, reward_sum, dt 227 | ) 228 | 229 | return responses 230 | 231 | def log_metrics(self, metrics): 232 | # Update metrics directly in DDB if required. 233 | if self.synchronous: 234 | return self.update_variant_metrics(metrics) 235 | 236 | # Dump the results as a json lines with trailing new line 237 | event_log = "\n".join([json.dumps(metric) for metric in metrics]) + "\n" 238 | logging.debug("Log kinesis events") 239 | logging.debug(event_log) 240 | 241 | # Put to delivery stream 242 | return self.firehose.put_record( 243 | DeliveryStreamName=self.delivery_stream_name, 244 | Record={"Data": event_log.encode("utf-8")}, 245 | ) 246 | -------------------------------------------------------------------------------- /lambda/api/lambda_invoke.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | import json 4 | import os 5 | import time 6 | import uuid 7 | import logging 8 | from aws_xray_sdk.core import xray_recorder 9 | from aws_xray_sdk.core import patch_all 10 | 11 | from experiment_metrics import ExperimentMetrics 12 | from experiment_assignment import ExperimentAssignment 13 | from algorithm import ThompsonSampling, EpsilonGreedy, UCB1, WeightedSampling 14 | 15 | # Get environment variables 16 | ASSIGNMENT_TABLE = os.environ["ASSIGNMENT_TABLE"] 17 | METRICS_TABLE = os.environ["METRICS_TABLE"] 18 | DELIVERY_STREAM_NAME = os.environ["DELIVERY_STREAM_NAME"] 19 | DELIVERY_SYNC = os.getenv("DELIVERY_SYNC", "False").lower() == "true" 20 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 21 | 22 | # Configure logging and patch xray 23 | logger = logging.getLogger() 24 | logger.setLevel(LOG_LEVEL) 25 | patch_all() 26 | 27 | # Create the experiment classes from the lambda layer 28 | exp_assignment = ExperimentAssignment(ASSIGNMENT_TABLE) 29 | exp_metrics = ExperimentMetrics(METRICS_TABLE, DELIVERY_STREAM_NAME, DELIVERY_SYNC) 30 | 31 | # Log the boto version (Require 1.17.5 for InferenceId target) 32 | logger.info(f"boto version: {boto3.__version__}") 33 | 34 | # Define he boto3 client resources 35 | sm_runtime = boto3.client("sagemaker-runtime") 36 | sm_client = boto3.client("sagemaker") 37 | lambda_client = boto3.client("lambda") 38 | 39 | 40 | @xray_recorder.capture("Get User Variant") 41 | def get_user_variant(endpoint_name: str, user_id: str): 42 | # Get the variants metrics (this will fail if endpoint doesn't exist) 43 | strategy, epsilon, warmup, variant_metrics = exp_metrics.get_variant_metrics( 44 | endpoint_name 45 | ) 46 | 47 | # Get the configuration for the endpoint name 48 | logger.info(f"Getting variant for user: {user_id}") 49 | user_variant = exp_assignment.get_assignment( 50 | user_id=user_id, endpoint_name=endpoint_name 51 | ) 52 | 53 | # Ensure that our user variant is still in current metrics 54 | target_variant = user_variant 55 | if user_variant is not None: 56 | user_match = [v for v in variant_metrics if v["variant_name"] == user_variant] 57 | if len(user_match) == 0: 58 | logger.info(f"User variant {user_variant} not in endpoint variants") 59 | target_variant = None 60 | 61 | # Get the new target variant if not assigned 62 | status_code = 200 63 | if target_variant is None: 64 | # See if all variants have invocation metrics 65 | with_invocations = [ 66 | v for v in variant_metrics if v["invocation_count"] > warmup 67 | ] 68 | if len(with_invocations) < len(variant_metrics): 69 | strategy = WeightedSampling.STRATEGY_NAME 70 | algo = WeightedSampling(variant_metrics) 71 | elif strategy == WeightedSampling.STRATEGY_NAME: 72 | algo = WeightedSampling(variant_metrics) 73 | elif strategy == ThompsonSampling.STRATEGY_NAME: 74 | algo = ThompsonSampling(variant_metrics) 75 | elif strategy == EpsilonGreedy.STRATEGY_NAME: 76 | algo = EpsilonGreedy(variant_metrics, epsilon) 77 | elif strategy == UCB1.STRATEGY_NAME: 78 | algo = UCB1(variant_metrics) 79 | else: 80 | raise Exception(f"Strategy {strategy} not supported") 81 | target_variant = algo.select_variant() 82 | status_code = 201 83 | 84 | # Assign the target variant to the user 85 | if user_variant != target_variant: 86 | logger.info(f"Set target variant: {target_variant} for user: {user_id}") 87 | exp_assignment.put_assignment( 88 | user_id=user_id, endpoint_name=endpoint_name, variant_name=target_variant 89 | ) 90 | 91 | # Return the result 92 | return strategy, target_variant, status_code 93 | 94 | 95 | @xray_recorder.capture("Stats") 96 | def handle_stats(endpoint_name: str): 97 | # Get the variants metrics (this will fail if endpoint doesn't exist) 98 | strategy, epsilon, warmup, variant_metrics = exp_metrics.get_variant_metrics( 99 | endpoint_name 100 | ) 101 | result = { 102 | "endpoint_name": endpoint_name, 103 | "variant_metrics": variant_metrics, 104 | "strategy": strategy, 105 | "epsilon": epsilon, 106 | "warmup": warmup, 107 | } 108 | return result, 200 109 | 110 | 111 | @xray_recorder.capture("Invocation") 112 | def handle_invocation( 113 | strategy: str, 114 | endpoint_name: str, 115 | content_type: str, 116 | inference_id: str, 117 | user_id: str, 118 | target_variant: str, 119 | data, 120 | ): 121 | # InferenceId is not available in 1.16.31 which is default boto3 in lambda by default 122 | # https://boto3.amazonaws.com/v1/documentation/api/1.16.31/reference/services/sagemaker-runtime.html#SageMakerRuntime.Client.invoke_endpoint 123 | if target_variant is None: 124 | logger.warning("Invoking endpiont without target variant") 125 | response = sm_runtime.invoke_endpoint( 126 | EndpointName=endpoint_name, 127 | ContentType=content_type, 128 | Body=data, 129 | InferenceId=inference_id, 130 | ) 131 | else: 132 | logger.info(f"Invoke endpoint with target variant: {target_variant}") 133 | response = sm_runtime.invoke_endpoint( 134 | EndpointName=endpoint_name, 135 | ContentType=content_type, 136 | TargetVariant=target_variant, 137 | Body=data, 138 | InferenceId=inference_id, 139 | ) 140 | invoked_variant = response["InvokedProductionVariant"] 141 | 142 | return { 143 | "strategy": strategy, 144 | "endpoint_name": endpoint_name, 145 | "target_variant": target_variant, 146 | "endpoint_variant": invoked_variant, 147 | "inference_id": inference_id, 148 | "user_id": user_id, 149 | "predictions": json.loads(response["Body"].read()), 150 | } 151 | 152 | 153 | @xray_recorder.capture("Conversion") 154 | def handle_conversion( 155 | strategy: str, 156 | endpoint_name: str, 157 | inference_id: str, 158 | user_id: str, 159 | user_variant: str, 160 | reward: float, 161 | ): 162 | return { 163 | "strategy": strategy, 164 | "endpoint_name": endpoint_name, 165 | "endpoint_variant": user_variant, 166 | "inference_id": inference_id, 167 | "user_id": user_id, 168 | "reward": reward, 169 | } 170 | 171 | 172 | @xray_recorder.capture("Log Metric") 173 | def log_metric( 174 | event_type: str, 175 | body: dict, 176 | request_identity: dict, 177 | ): 178 | # Merge all properties together into a flat dictionary 179 | metrics = [ 180 | {"timestamp": int(time.time()), "type": event_type, **body, **request_identity} 181 | ] 182 | try: 183 | response = exp_metrics.log_metrics(metrics) 184 | logger.debug("Log metric response") 185 | logger.debug(response) 186 | except Exception as e: 187 | # Log warning that we were unable to log metrics 188 | logger.warning("Unable to log metrics") 189 | logger.warning(e) 190 | 191 | 192 | def lambda_handler(event, context): 193 | try: 194 | logger.debug(json.dumps(event)) 195 | 196 | # Get elements from API payload 197 | if event["httpMethod"] in ["POST", "PUT"] and "body" in event: 198 | body = json.loads(event["body"]) 199 | else: 200 | raise Exception("Require HTTP POST with json body") 201 | 202 | endpoint_name = body.get("endpoint_name") 203 | if endpoint_name is None: 204 | raise Exception("Require endpoint name in body") 205 | 206 | # Optionally allow overriding the endpoint variant 207 | endpoint_variant = body.get("endpoint_variant") 208 | 209 | # If this is a POST/PUT to root, then we are creating a new endpoint 210 | path = event["path"] 211 | if path == "/stats": 212 | # Get stats for existing endpoint 213 | result, status_code = handle_stats(endpoint_name) 214 | else: 215 | # Get inference id and user id from request, or generate a new ones 216 | inference_id = body.get("inference_id", str(uuid.uuid4())) 217 | user_id = str(body.get("user_id", uuid.uuid4())) 218 | 219 | if endpoint_variant is None: 220 | try: 221 | # Get the configuration for the endpoint name 222 | strategy, user_variant, status_code = get_user_variant( 223 | endpoint_name, user_id 224 | ) 225 | except Exception as e: 226 | # Log warning and return fallback strategy 227 | logger.warning("Unable to get user variant") 228 | logger.warning(e) 229 | strategy, user_variant, status_code = ("Fallback", None, 202) 230 | else: 231 | # Log the manual strategy for the endpoint variant 232 | logger.info( 233 | f"Manual override endpoint: {endpoint_name} variant: {endpoint_variant}" 234 | ) 235 | strategy, user_variant, status_code = ("Manual", endpoint_variant, 202) 236 | 237 | # Get request identity that is non null (eg sourcIP, useragent) 238 | request_identity = { 239 | "source_ip": event["requestContext"]["identity"]["sourceIp"], 240 | "user_agent": event["requestContext"]["identity"]["userAgent"], 241 | } 242 | 243 | # Based on path handle invocation 244 | if path == "/invocation": 245 | content_type = body.get("content_type", "application/json") 246 | data = body["data"] 247 | result = handle_invocation( 248 | strategy=strategy, 249 | endpoint_name=endpoint_name, 250 | content_type=content_type, 251 | inference_id=inference_id, 252 | user_id=user_id, 253 | target_variant=user_variant, 254 | data=data, 255 | ) 256 | log_metric("invocation", result, request_identity) 257 | elif path == "/conversion": 258 | # Get default reward of "1" unless provided 259 | reward = float(body.get("reward", "1")) 260 | result = handle_conversion( 261 | strategy=strategy, 262 | endpoint_name=endpoint_name, 263 | inference_id=inference_id, 264 | user_id=user_id, 265 | user_variant=user_variant, 266 | reward=reward, 267 | ) 268 | log_metric("conversion", result, request_identity) 269 | else: 270 | raise Exception(f"Invalid path: {path}") 271 | 272 | # Log result succesful result and return 273 | logger.debug(json.dumps(result)) 274 | return {"statusCode": status_code, "body": json.dumps(result)} 275 | except ClientError as e: 276 | logger.error(e) 277 | # Get boto3 specific error message 278 | error_message = e.response["Error"]["Message"] 279 | raise Exception(error_message) 280 | except Exception as e: 281 | logger.error(e) 282 | raise e 283 | -------------------------------------------------------------------------------- /lambda/api/lambda_metrics.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | import gzip 4 | import io 5 | import json 6 | import os 7 | import logging 8 | from aws_xray_sdk.core import xray_recorder 9 | from aws_xray_sdk.core import patch_all 10 | from urllib.parse import unquote_plus 11 | 12 | from experiment_metrics import ExperimentMetrics 13 | 14 | # set environment variable 15 | METRICS_TABLE = os.environ["METRICS_TABLE"] 16 | DELIVERY_STREAM_NAME = os.environ["DELIVERY_STREAM_NAME"] 17 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 18 | 19 | # Create the experiment classes from the lambda layer 20 | exp_metrics = ExperimentMetrics(METRICS_TABLE, DELIVERY_STREAM_NAME) 21 | 22 | # Configure logging and patch xray 23 | logger = logging.getLogger() 24 | logger.setLevel(LOG_LEVEL) 25 | patch_all() 26 | 27 | # Define he boto3 client resources 28 | dynamodb = boto3.resource("dynamodb") 29 | s3 = boto3.resource("s3") 30 | 31 | 32 | @xray_recorder.capture("Read Metrics") 33 | def get_metrics(event): 34 | """ 35 | Download the s3 file contents, and enuemrage json lienes to extract metrics 36 | """ 37 | metrics = [] 38 | for record in event["Records"]: 39 | bucket = record["s3"]["bucket"]["name"] 40 | key = unquote_plus(record["s3"]["object"]["key"]) 41 | obj = s3.Object(bucket, key) 42 | with gzip.GzipFile(fileobj=obj.get()["Body"]) as gzipfile: 43 | content = gzipfile.read() 44 | buf = io.BytesIO(content) 45 | line = buf.readline() 46 | while line: 47 | metrics.append(json.loads(line)) 48 | line = buf.readline() 49 | return metrics 50 | 51 | 52 | @xray_recorder.capture("Write Metrics") 53 | def update_metrics(metrics): 54 | # TODO: Consider filtering metrics for high frequency sourceIp or bad user agent 55 | exp_metrics.update_variant_metrics(metrics) 56 | 57 | 58 | def lambda_handler(event, context): 59 | try: 60 | logger.debug(json.dumps(event)) 61 | 62 | # Get metrics from s3 json lines 63 | metrics = [] 64 | if "Records" in event: 65 | metrics = get_metrics(event) 66 | elif "Metrics" in event: 67 | metrics = event["Metrics"] 68 | 69 | update_metrics(metrics) 70 | 71 | # TODO: Consider correlating ground through metrics against for user_id invocations/clicks to return 72 | # see: https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-model-quality-merge.html 73 | # see also: https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker_model_monitor/model_quality/model_quality_churn_sdk.ipynb 74 | 75 | # Log the metrics count 76 | result = { 77 | "metric_count": len(metrics), 78 | } 79 | return {"statusCode": 200, "body": json.dumps(result)} 80 | except ClientError as e: 81 | # Get boto3 specific error message 82 | error_message = e.response["Error"]["Message"] 83 | logger.error(error_message) 84 | raise Exception(error_message) 85 | except Exception as e: 86 | logger.error(e) 87 | raise e 88 | -------------------------------------------------------------------------------- /lambda/api/lambda_register.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | import json 4 | import logging 5 | import os 6 | from aws_xray_sdk.core import xray_recorder 7 | from aws_xray_sdk.core import patch_all 8 | 9 | from experiment_metrics import ExperimentMetrics 10 | from algorithm import ThompsonSampling 11 | 12 | # Get environment variables 13 | METRICS_TABLE = os.environ["METRICS_TABLE"] 14 | DELIVERY_STREAM_NAME = os.environ["DELIVERY_STREAM_NAME"] 15 | STAGE_NAME = os.environ["STAGE_NAME"] 16 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 17 | ENDPOINT_PREFIX = os.getenv("ENDPOINT_PREFIX", "") 18 | 19 | # Configure logging and patch xray 20 | logger = logging.getLogger() 21 | logger.setLevel(LOG_LEVEL) 22 | patch_all() 23 | 24 | # Create the experiment classes from the lambda layer 25 | exp_metrics = ExperimentMetrics(METRICS_TABLE, DELIVERY_STREAM_NAME) 26 | 27 | # Configure logging and patch xray 28 | logger = logging.getLogger() 29 | logger.setLevel(LOG_LEVEL) 30 | patch_all() 31 | 32 | # Define he boto3 client resources 33 | sm_client = boto3.client("sagemaker") 34 | 35 | 36 | @xray_recorder.capture("Get Endpoint Variants") 37 | def get_endpoint_variants(endpoint_name): 38 | """ 39 | Get the list of production variant names for an endpoint 40 | """ 41 | logger.info(f"Getting variants for endpoint: {endpoint_name}") 42 | response = sm_client.describe_endpoint(EndpointName=endpoint_name) 43 | endpoint_variants = [ 44 | { 45 | "variant_name": r["VariantName"], 46 | "initial_variant_weight": r["CurrentWeight"], 47 | } 48 | for r in response["ProductionVariants"] 49 | ] 50 | logger.debug(endpoint_variants) 51 | return endpoint_variants 52 | 53 | 54 | @xray_recorder.capture("Delete") 55 | def handle_delete(endpoint_name: str): 56 | response = exp_metrics.delete_endpoint( 57 | endpoint_name=endpoint_name, 58 | ) 59 | result = { 60 | "endpoint_name": endpoint_name, 61 | } 62 | return result, 200 63 | 64 | 65 | @xray_recorder.capture("Register") 66 | def handle_register(endpoint_name: str, strategy: str, epsilon: float, warmup: int): 67 | endpoint_variants = get_endpoint_variants(endpoint_name) 68 | response = exp_metrics.create_variant_metrics( 69 | endpoint_name=endpoint_name, 70 | endpoint_variants=endpoint_variants, 71 | strategy=strategy, 72 | epsilon=epsilon, 73 | warmup=warmup, 74 | ) 75 | result = { 76 | "endpoint_name": endpoint_name, 77 | "endpoint_variants": endpoint_variants, 78 | "strategy": strategy, 79 | "epsilon": epsilon, 80 | "warmup": warmup, 81 | } 82 | if "Attributes" not in response: 83 | return result, 201 84 | return result, 200 85 | 86 | 87 | def lambda_handler(event, context): 88 | try: 89 | logger.debug(json.dumps(event)) 90 | 91 | if not ( 92 | event.get("source") == "aws.sagemaker" 93 | and event.get("detail-type") == "SageMaker Endpoint State Change" 94 | ): 95 | raise Exception( 96 | "Expect CloudWatch Event for SageMaker Endpoint Stage Change" 97 | ) 98 | 99 | # If this endpoint does not match prefix or not enabled return Not Modified (304) 100 | endpoint_name = event["detail"]["EndpointName"] 101 | endpoint_tags = event["detail"]["Tags"] 102 | endpoint_enabled = endpoint_tags.get("ab-testing:enabled", "").lower() == "true" 103 | if not (endpoint_name.startswith(ENDPOINT_PREFIX) and endpoint_enabled): 104 | error_message = ( 105 | f"Endpoint: {endpoint_name} not enabled for prefix: {ENDPOINT_PREFIX}" 106 | ) 107 | logger.warning(error_message) 108 | return {"statusCode": 304, "body": error_message} 109 | 110 | # If the API stage name doesn't match the deployment stage name return Not Modified (304) 111 | deployment_stage = endpoint_tags.get("sagemaker:deployment-stage") 112 | if deployment_stage != STAGE_NAME: 113 | error_message = f"Endpoint: {endpoint_name} deployment stage: {deployment_stage} not equal to API stage: {STAGE_NAME}" 114 | logger.warning(error_message) 115 | return {"statusCode": 304, "body": error_message} 116 | 117 | # Delete or register the endpoint depending on status change 118 | endpoint_status = event["detail"]["EndpointStatus"] 119 | if endpoint_status == "DELETING": 120 | logger.info(f"Deleting Endpoint: {endpoint_name}") 121 | result, status_code = handle_delete(endpoint_name) 122 | elif endpoint_status == "IN_SERVICE": 123 | # Use defaults if enabled is provided without additional arguments 124 | strategy = endpoint_tags.get("ab-testing:strategy", "ThompsonSampling") 125 | epsilon = float(endpoint_tags.get("ab-testing:epsilon", 0.1)) 126 | warmup = int(endpoint_tags.get("ab-testing:warmup", 0)) 127 | logger.info( 128 | f"Registering Endpoint: {endpoint_name} with strategy: {strategy}, epsilon: {epsilon}, warmup: {warmup}" 129 | ) 130 | result, status_code = handle_register( 131 | endpoint_name, strategy, epsilon, warmup 132 | ) 133 | else: 134 | error_message = ( 135 | f"Endpoint: {endpoint_name} Status: {endpoint_status} not supported." 136 | ) 137 | logger.warning(error_message) 138 | result = {"message": error_message} 139 | status_code = 400 140 | 141 | # Log result succesful result and return 142 | logger.debug(json.dumps(result)) 143 | return {"statusCode": status_code, "body": json.dumps(result)} 144 | except ClientError as e: 145 | logger.error(e) 146 | # Get boto3 specific error message 147 | error_message = e.response["Error"]["Message"] 148 | logger.error(error_message) 149 | raise Exception(error_message) 150 | except Exception as e: 151 | logger.error(e) 152 | raise e 153 | -------------------------------------------------------------------------------- /lambda/api/test_algorithm.py: -------------------------------------------------------------------------------- 1 | from algorithm import EpsilonGreedy, UCB1, ThompsonSampling, WeightedSampling 2 | 3 | 4 | def test_epsilon_greedy(): 5 | algo = EpsilonGreedy( 6 | [ 7 | { 8 | "variant_name": "v1", 9 | "invocation_count": 10, 10 | "reward_sum": 1, 11 | }, 12 | { 13 | "variant_name": "v2", 14 | "invocation_count": 10, 15 | "reward_sum": 2, 16 | }, 17 | ], 18 | epsilon=0.1, 19 | ) 20 | 21 | # Validate that at least 90% of the time we v2 22 | lst = [algo.select_variant() for i in range(100)] 23 | v1_count = lst.count("v1") 24 | v2_count = lst.count("v2") 25 | # Assert with a margin of error for randomness 26 | assert v1_count < 20 27 | assert v2_count > 80 28 | 29 | 30 | def test_UCB1_exploit(): 31 | algo = UCB1( 32 | [ 33 | { 34 | "variant_name": "v1", 35 | "invocation_count": 100, 36 | "reward_sum": 10, 37 | }, 38 | { 39 | "variant_name": "v2", 40 | "invocation_count": 100, 41 | "reward_sum": 20, 42 | }, 43 | { 44 | "variant_name": "v3", 45 | "invocation_count": 100, 46 | "reward_sum": 50, 47 | }, 48 | ] 49 | ) 50 | # For high values, validate the we pick the best performing 51 | v = algo.select_variant() 52 | assert v == "v3" 53 | 54 | 55 | def test_UCB1_explore(): 56 | algo = UCB1( 57 | [ 58 | { 59 | "variant_name": "v1", 60 | "invocation_count": 10, 61 | "reward_sum": 1, 62 | }, 63 | { 64 | "variant_name": "v2", 65 | "invocation_count": 10, 66 | "reward_sum": 2, 67 | }, 68 | { 69 | "variant_name": "v3", 70 | "invocation_count": 100, 71 | "reward_sum": 50, 72 | }, 73 | ] 74 | ) 75 | # For low confidence values, pick the best 76 | v = algo.select_variant() 77 | assert v == "v2" 78 | 79 | 80 | def test_thompson_sampling(): 81 | algo = ThompsonSampling( 82 | [ 83 | { 84 | "variant_name": "v1", 85 | "invocation_count": 10, 86 | "reward_sum": 1, 87 | }, 88 | { 89 | "variant_name": "v2", 90 | "invocation_count": 10, 91 | "reward_sum": 2, 92 | }, 93 | { 94 | "variant_name": "v3", 95 | "invocation_count": 10, 96 | "reward_sum": 5, 97 | }, 98 | ] 99 | ) 100 | 101 | lst = [algo.select_variant() for i in range(100)] 102 | assert max(lst, key=lst.count) == "v3" 103 | 104 | 105 | def test_weighted_sampling(): 106 | algo = WeightedSampling( 107 | [ 108 | {"variant_name": "v1", "initial_variant_weight": 0.9}, 109 | {"variant_name": "v2", "initial_variant_weight": 0.1}, 110 | ], 111 | ) 112 | 113 | lst = [algo.select_variant() for i in range(100)] 114 | assert max(lst, key=lst.count) == "v1" 115 | -------------------------------------------------------------------------------- /lambda/api/test_experiment_assignment.py: -------------------------------------------------------------------------------- 1 | from botocore.stub import Stubber 2 | 3 | from experiment_assignment import ExperimentAssignment 4 | 5 | 6 | def test_get_assignment(): 7 | # Create new metrics object and 8 | exp_assignment = ExperimentAssignment("test-ass") 9 | 10 | # See the dynamodb get_item 11 | # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.get_item 12 | with Stubber(exp_assignment.dynamodb.meta.client) as stubber: 13 | expected_response = { 14 | "Item": { 15 | "variant_name": {"S": "e1v1"}, 16 | } 17 | } 18 | expected_params = { 19 | "AttributesToGet": ["variant_name"], 20 | "Key": {"endpoint_name": "test-endpoint", "user_id": "user-1"}, 21 | "TableName": "test-ass", 22 | } 23 | stubber.add_response("get_item", expected_response, expected_params) 24 | 25 | response = exp_assignment.get_assignment( 26 | user_id="user-1", endpoint_name="test-endpoint" 27 | ) 28 | assert response == "e1v1" 29 | 30 | 31 | def test_put_assignment(): 32 | # Create new metrics object and 33 | exp_assignment = ExperimentAssignment("test-ass") 34 | 35 | # See the dyanmodb put_item 36 | # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.put_item 37 | with Stubber(exp_assignment.dynamodb.meta.client) as stubber: 38 | expected_response = { 39 | "ConsumedCapacity": { 40 | "CapacityUnits": 1, 41 | "TableName": "test-ass", 42 | }, 43 | } 44 | expected_params = { 45 | "Item": { 46 | "endpoint_name": "test-endpoint", 47 | "ttl": 0, 48 | "user_id": "user-1", 49 | "variant_name": "e1v1", 50 | }, 51 | "TableName": "test-ass", 52 | } 53 | stubber.add_response("put_item", expected_response, expected_params) 54 | 55 | response = exp_assignment.put_assignment( 56 | user_id="user-1", endpoint_name="test-endpoint", variant_name="e1v1", ttl=0 57 | ) 58 | assert response == expected_response 59 | -------------------------------------------------------------------------------- /lambda/api/test_experiment_metrics.py: -------------------------------------------------------------------------------- 1 | from botocore.stub import Stubber 2 | from decimal import Decimal 3 | from datetime import datetime 4 | 5 | 6 | from experiment_metrics import ExperimentMetrics 7 | 8 | # 1 invocations for e1v1, 2 invocations for e1v2, and 1 count for e1v2 9 | good_metrics = [ 10 | { 11 | "timestamp": 1, 12 | "type": "invocation", 13 | "user_id": "a", 14 | "endpoint_name": "e1", 15 | "endpoint_variant": "e1v1", 16 | }, 17 | { 18 | "timestamp": 2, 19 | "type": "invocation", 20 | "user_id": "b", 21 | "endpoint_name": "e1", 22 | "endpoint_variant": "e1v2", 23 | }, 24 | { 25 | "timestamp": 3, 26 | "type": "invocation", 27 | "user_id": "c", 28 | "endpoint_name": "e1", 29 | "endpoint_variant": "e1v2", 30 | }, 31 | { 32 | "timestamp": 4, 33 | "type": "conversion", 34 | "reward": 1, 35 | "user_id": "c", 36 | "endpoint_name": "e1", 37 | "endpoint_variant": "e1v2", 38 | }, 39 | ] 40 | 41 | 42 | def test_log_metrics(): 43 | # Create new metrics object and 44 | exp_metrics = ExperimentMetrics("test-metrics", "test-delivery-stream") 45 | 46 | # See the firehose put_record 47 | # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/firehose.html#Firehose.Client.put_record 48 | with Stubber(exp_metrics.firehose) as stubber: 49 | expected_response = {"RecordId": "xxxx", "Encrypted": True} 50 | expected_params = { 51 | "DeliveryStreamName": "test-delivery-stream", 52 | "Record": { 53 | "Data": b'{"timestamp": 1, "type": "invocation", "user_id": "a", "endpoint_name": "e1", "endpoint_variant": "e1v1"}\n' 54 | b'{"timestamp": 2, "type": "invocation", "user_id": "b", "endpoint_name": "e1", "endpoint_variant": "e1v2"}\n' 55 | }, 56 | } 57 | stubber.add_response("put_record", expected_response, expected_params) 58 | 59 | # Log first metric 60 | response = exp_metrics.log_metrics(good_metrics[:2]) 61 | assert response == expected_response 62 | 63 | 64 | def test_create_variant_metrics(): 65 | # Create new metrics object and 66 | exp_metrics = ExperimentMetrics("test-metrics", "test-delivery-stream") 67 | endpoint_variants = [ 68 | { 69 | "variant_name": "ev1", 70 | "initial_variant_weight": 1, 71 | }, 72 | { 73 | "variant_name": "ev2", 74 | "initial_variant_weight": 0.5, 75 | }, 76 | ] 77 | 78 | # See the dynamodb put_item 79 | # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.put_item 80 | with Stubber(exp_metrics.dynamodb.meta.client) as stubber: 81 | expected_response = { 82 | "ConsumedCapacity": { 83 | "CapacityUnits": 1, 84 | "TableName": "test-metrics", 85 | }, 86 | } 87 | 88 | expected_params = { 89 | "Item": { 90 | "created_at": 0, 91 | "endpoint_name": "test-endpoint", 92 | "strategy": "EpsilonGreedy", 93 | "epsilon": Decimal("0.1"), 94 | "warmup": Decimal("0"), 95 | "variant_names": ["ev1", "ev2"], 96 | "variant_metrics": { 97 | "ev1": {"initial_variant_weight": Decimal("1")}, 98 | "ev2": {"initial_variant_weight": Decimal("0.5")}, 99 | }, 100 | }, 101 | "ReturnConsumedCapacity": "TOTAL", 102 | "ReturnValues": "ALL_OLD", 103 | "TableName": "test-metrics", 104 | } 105 | stubber.add_response("put_item", expected_response, expected_params) 106 | 107 | response = exp_metrics.create_variant_metrics( 108 | endpoint_name="test-endpoint", 109 | strategy="EpsilonGreedy", 110 | epsilon=0.1, 111 | warmup=0, 112 | endpoint_variants=endpoint_variants, 113 | timestamp=0, 114 | ) 115 | assert response == expected_response 116 | 117 | 118 | def test_get_empty_variant_metrics(): 119 | # Create new metrics object and 120 | exp_metrics = ExperimentMetrics("test-metrics", "test-delivery-stream") 121 | 122 | # See the dynamodb get_item 123 | # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.get_item 124 | with Stubber(exp_metrics.dynamodb.meta.client) as stubber: 125 | expected_response = { 126 | "Item": { 127 | "endpoint_name": {"S": "test-endpoint"}, 128 | "strategy": {"S": "EpsilonGreedy"}, 129 | "epsilon": {"N": "0.1"}, 130 | "warmup": {"N": "0"}, 131 | "variant_names": {"L": [{"S": "ev1"}, {"S": "ev2"}]}, 132 | "variant_metrics": { 133 | "M": { 134 | "ev1": {"M": {"initial_variant_weight": {"N": "0.5"}}}, 135 | "ev2": { 136 | "M": { 137 | "initial_variant_weight": {"N": "0.1"}, 138 | "invocation_count": {"N": "10"}, 139 | "conversion_count": {"N": "1"}, 140 | "reward_sum": {"N": "0.5"}, 141 | } 142 | }, 143 | } 144 | }, 145 | } 146 | } 147 | expected_params = { 148 | "Key": {"endpoint_name": "test-endpoint"}, 149 | "TableName": "test-metrics", 150 | "ReturnConsumedCapacity": "TOTAL", 151 | } 152 | stubber.add_response("get_item", expected_response, expected_params) 153 | 154 | # Validate the transformed result 155 | expected_variants = [ 156 | { 157 | "endpoint_name": "test-endpoint", 158 | "variant_name": "ev1", 159 | "initial_variant_weight": 0.5, 160 | "invocation_count": 0, 161 | "conversion_count": 0, 162 | "reward_sum": 0, 163 | }, 164 | { 165 | "endpoint_name": "test-endpoint", 166 | "variant_name": "ev2", 167 | "initial_variant_weight": 0.1, 168 | "invocation_count": 10, 169 | "conversion_count": 1, 170 | "reward_sum": 0.5, 171 | }, 172 | ] 173 | strategy, epsilon, warmup, variants = exp_metrics.get_variant_metrics( 174 | "test-endpoint" 175 | ) 176 | assert strategy == "EpsilonGreedy" 177 | assert epsilon == 0.1 178 | assert warmup == 0 179 | assert variants == expected_variants 180 | 181 | 182 | def test_update_variant_metrics(): 183 | # Create new metrics object and 184 | exp_metrics = ExperimentMetrics("test-metrics", "test-delivery-stream") 185 | 186 | # See dynamodb update_item 187 | # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.update_item 188 | ddb_stubber = Stubber(exp_metrics.dynamodb.meta.client) 189 | cw_stubber = Stubber(exp_metrics.cloudwatch) 190 | 191 | # 1 invocation for ev1v1 192 | expected_response = { 193 | "Attributes": { 194 | "endpoint_name": { 195 | "S": "e1", 196 | }, 197 | "variant_metrics": { 198 | "M": { 199 | "e1v1": {"M": {"invocation_count": {"N": "1"}}}, 200 | } 201 | }, 202 | }, 203 | } 204 | expected_params = { 205 | "ExpressionAttributeNames": { 206 | "#created_at": "created_at", 207 | "#updated_at": "updated_at", 208 | "#variant": "e1v1", 209 | }, 210 | "ExpressionAttributeValues": { 211 | ":c": 0, 212 | ":i": 1, 213 | ":now": 0, 214 | ":r": Decimal("0.0"), 215 | }, 216 | "Key": {"endpoint_name": "e1"}, 217 | "ReturnValues": "UPDATED_NEW", 218 | "TableName": "test-metrics", 219 | "UpdateExpression": "ADD variant_metrics.#variant.invocation_count :i, " 220 | "variant_metrics.#variant.conversion_count :c, " 221 | "variant_metrics.#variant.reward_sum :r SET #created_at = " 222 | "if_not_exists(#created_at, :now), #updated_at = :now ", 223 | } 224 | ddb_stubber.add_response("update_item", expected_response, expected_params) 225 | 226 | # Add CW metrics for invocations 227 | expected_response = {} 228 | expected_params = { 229 | "MetricData": [ 230 | { 231 | "Dimensions": [ 232 | {"Name": "EndpointName", "Value": "e1"}, 233 | {"Name": "VariantName", "Value": "e1v1"}, 234 | ], 235 | "MetricName": "Invocations", 236 | "Timestamp": datetime(1970, 1, 1, 10, 0), 237 | "Unit": "Count", 238 | "Value": 1, 239 | } 240 | ], 241 | "Namespace": "aws/sagemaker/Endpoints/ab-testing", 242 | } 243 | cw_stubber.add_response("put_metric_data", expected_response, expected_params) 244 | 245 | # 2 invocation for ev1v2 246 | expected_response = { 247 | "Attributes": { 248 | "endpoint_name": { 249 | "S": "e1", 250 | }, 251 | "variant_metrics": { 252 | "M": { 253 | "e1v2": { 254 | "M": { 255 | "invocation_count": {"N": "2"}, 256 | "conversion_count": {"N": "1"}, 257 | "reward_sum": {"N": "1"}, 258 | } 259 | }, 260 | } 261 | }, 262 | }, 263 | } 264 | expected_params = { 265 | "ExpressionAttributeNames": { 266 | "#created_at": "created_at", 267 | "#updated_at": "updated_at", 268 | "#variant": "e1v2", 269 | }, 270 | "ExpressionAttributeValues": {":i": 2, ":c": 1, ":r": 1, ":now": 0}, 271 | "Key": {"endpoint_name": "e1"}, 272 | "ReturnValues": "UPDATED_NEW", 273 | "TableName": "test-metrics", 274 | "UpdateExpression": "ADD variant_metrics.#variant.invocation_count :i, " 275 | "variant_metrics.#variant.conversion_count :c, " 276 | "variant_metrics.#variant.reward_sum :r SET #created_at = " 277 | "if_not_exists(#created_at, :now), #updated_at = :now ", 278 | } 279 | ddb_stubber.add_response("update_item", expected_response, expected_params) 280 | 281 | # Add CW metrics for invocations/converisons/rewoards 282 | expected_response = {} 283 | expected_params = { 284 | "MetricData": [ 285 | { 286 | "Dimensions": [ 287 | {"Name": "EndpointName", "Value": "e1"}, 288 | {"Name": "VariantName", "Value": "e1v2"}, 289 | ], 290 | "MetricName": "Invocations", 291 | "Timestamp": datetime(1970, 1, 1, 10, 0), 292 | "Unit": "Count", 293 | "Value": 2, 294 | } 295 | ], 296 | "Namespace": "aws/sagemaker/Endpoints/ab-testing", 297 | } 298 | cw_stubber.add_response("put_metric_data", expected_response, expected_params) 299 | expected_params = { 300 | "MetricData": [ 301 | { 302 | "Dimensions": [ 303 | {"Name": "EndpointName", "Value": "e1"}, 304 | {"Name": "VariantName", "Value": "e1v2"}, 305 | ], 306 | "MetricName": "Conversions", 307 | "Timestamp": datetime(1970, 1, 1, 10, 0), 308 | "Unit": "Count", 309 | "Value": 1, 310 | } 311 | ], 312 | "Namespace": "aws/sagemaker/Endpoints/ab-testing", 313 | } 314 | cw_stubber.add_response("put_metric_data", expected_response, expected_params) 315 | expected_params = { 316 | "MetricData": [ 317 | { 318 | "Dimensions": [ 319 | {"Name": "EndpointName", "Value": "e1"}, 320 | {"Name": "VariantName", "Value": "e1v2"}, 321 | ], 322 | "MetricName": "Rewards", 323 | "Timestamp": datetime(1970, 1, 1, 10, 0), 324 | "Unit": "Count", 325 | "Value": 1, 326 | } 327 | ], 328 | "Namespace": "aws/sagemaker/Endpoints/ab-testing", 329 | } 330 | cw_stubber.add_response("put_metric_data", expected_response, expected_params) 331 | 332 | # Activate stubbers 333 | ddb_stubber.activate() 334 | cw_stubber.activate() 335 | 336 | # Update metrics, and validate the first response 337 | responses = exp_metrics.update_variant_metrics(good_metrics, timestamp=0) 338 | assert len(responses) == 2 339 | assert responses[0] == { 340 | "endpoint_name": "e1", 341 | "endpoint_variant": "e1v1", 342 | "invocation_count": 1, 343 | "conversion_count": 0, 344 | "reward_sum": 0, 345 | } 346 | assert responses[1] == { 347 | "endpoint_name": "e1", 348 | "endpoint_variant": "e1v2", 349 | "invocation_count": 2, 350 | "conversion_count": 1, 351 | "reward_sum": 1, 352 | } 353 | 354 | 355 | def test_delete_endpoint(): 356 | # Create new metrics object and 357 | exp_metrics = ExperimentMetrics("test-metrics", "test-delivery-stream") 358 | 359 | with Stubber(exp_metrics.dynamodb.meta.client) as ddb_stubber: 360 | # 1 invocation for ev1v1 361 | expected_response = { 362 | "Attributes": { 363 | "endpoint_name": { 364 | "S": "e1", 365 | }, 366 | "deleted_at": { 367 | "N": "0", 368 | }, 369 | }, 370 | } 371 | expected_params = { 372 | "ExpressionAttributeValues": { 373 | ":now": 0, 374 | }, 375 | "Key": {"endpoint_name": "e1"}, 376 | "ReturnValues": "UPDATED_NEW", 377 | "TableName": "test-metrics", 378 | "UpdateExpression": "SET deleted_at = :now ", 379 | } 380 | ddb_stubber.add_response("update_item", expected_response, expected_params) 381 | 382 | response = exp_metrics.delete_endpoint("e1", timestamp=0) 383 | assert response is not None 384 | assert response["Attributes"]["deleted_at"] == 0 385 | -------------------------------------------------------------------------------- /layers/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3>=1.17.54 2 | aws-xray-sdk>=2.6.0 -------------------------------------------------------------------------------- /notebook/dashboard.json: -------------------------------------------------------------------------------- 1 | { 2 | "widgets": [ 3 | { 4 | "height": 7, 5 | "width": 8, 6 | "y": 0, 7 | "x": 0, 8 | "type": "metric", 9 | "properties": { 10 | "metrics": [ 11 | [ "AWS/Firehose", "IncomingRecords", "DeliveryStreamName", "ab-testing-events-dev", { "id": "m1", "label": "IncomingRecords", "yAxis": "left" } ] 12 | ], 13 | "region": "us-west-1", 14 | "view": "timeSeries", 15 | "stacked": false, 16 | "stat": "Sum", 17 | "title": "Kinesis Firehose Incoming Records", 18 | "period": 300, 19 | "yAxis": { 20 | "left": { 21 | "showUnits": false, 22 | "label": "Count", 23 | "min": 0 24 | }, 25 | "right": { 26 | "min": 0, 27 | "label": "Puts/Sec", 28 | "showUnits": false 29 | } 30 | } 31 | } 32 | }, 33 | { 34 | "height": 7, 35 | "width": 8, 36 | "y": 0, 37 | "x": 8, 38 | "type": "metric", 39 | "properties": { 40 | "metrics": [ 41 | [ { "expression": "METRICS(\"m1\") * 100", "id": "e1", "region": "us-west-1" } ], 42 | [ "AWS/Firehose", "DeliveryToS3.Success", "DeliveryStreamName", "ab-testing-events-dev", { "id": "m1", "visible": false } ] 43 | ], 44 | "region": "us-west-1", 45 | "view": "timeSeries", 46 | "stacked": false, 47 | "stat": "Average", 48 | "title": "Kinesis Firehose Delivery to S3", 49 | "period": 300, 50 | "yAxis": { 51 | "left": { 52 | "showUnits": false, 53 | "label": "Count" 54 | }, 55 | "right": { 56 | "label": "Success Rate", 57 | "min": 0, 58 | "max": 100, 59 | "showUnits": false 60 | } 61 | } 62 | } 63 | }, 64 | { 65 | "height": 7, 66 | "width": 8, 67 | "y": 0, 68 | "x": 16, 69 | "type": "metric", 70 | "properties": { 71 | "metrics": [ 72 | [ { "expression": "SEARCH('{aws/sagemaker/Endpoints/ab-testing,EndpointName,VariantName} MetricName=\"Rewards\"', 'Sum', 300)", "label": "", "id": "reward", "region": "us-west-1", "yAxis": "right" } ], 73 | [ { "expression": "SEARCH('{aws/sagemaker/Endpoints/ab-testing,EndpointName,VariantName} MetricName=\"Invocations\"', 'Sum', 300)", "label": "", "id": "invocation", "region": "us-west-1", "yAxis": "left" } ] 74 | ], 75 | "view": "timeSeries", 76 | "stacked": false, 77 | "region": "us-west-1", 78 | "stat": "Sum", 79 | "period": 300, 80 | "title": "A/B Testing Metrics by Endpoint Variants", 81 | "yAxis": { 82 | "right": { 83 | "min": 0, 84 | "label": "Rewards", 85 | "showUnits": false 86 | }, 87 | "left": { 88 | "label": "Invocations", 89 | "showUnits": false 90 | } 91 | } 92 | } 93 | } 94 | ] 95 | } -------------------------------------------------------------------------------- /notebook/simulation.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def argmax(a): 5 | return max(range(len(a)), key=lambda x: a[x]) 6 | 7 | 8 | print("Thompson sampling demo") 9 | print("Goal is to maximize payout from three machines") 10 | print("Machines pay out with probs 0.3, 0.7, 0.5") 11 | 12 | N = 3 # number machines 13 | means = [0.3, 0.7, 0.5] 14 | probs = [0] * N 15 | S = [0] * N 16 | F = [0] * N 17 | 18 | for trial in range(10): 19 | print("\nTrial " + str(trial)) 20 | for i in range(N): 21 | probs[i] = random.betavariate(S[i] + 1, F[i] + 1) 22 | 23 | print("sampling probs = ", end="") 24 | for i in range(N): 25 | print("%0.4f " % probs[i], end="") 26 | print("") 27 | 28 | machine = argmax(probs) 29 | print("Playing machine " + str(machine), end="") 30 | 31 | p = random.uniform(0, 1) 32 | if p < means[machine]: 33 | print(" -- win") 34 | S[machine] += 1 35 | else: 36 | print(" -- lose") 37 | F[machine] += 1 38 | 39 | print("Final Success vector: ", end="") 40 | print(S) 41 | print("Final Failure vector: ", end="") 42 | print(F) 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open("README.md") as fp: 5 | long_description = fp.read() 6 | 7 | 8 | setuptools.setup( 9 | name="amazon_sagemaker_ab_testing_infra", 10 | version="0.0.1", 11 | description="An empty CDK Python app", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | author="author", 15 | package_dir={"": "infra"}, 16 | packages=setuptools.find_packages(where="infra"), 17 | install_requires=[ 18 | "aws-cdk.core==1.94.1", 19 | "aws-cdk.aws-apigateway==1.94.1", 20 | "aws_cdk.aws_codebuild==1.94.1", 21 | "aws_cdk.aws_codecommit==1.94.1", 22 | "aws_cdk.aws_codepipeline==1.94.1", 23 | "aws_cdk.aws_codepipeline_actions==1.94.1", 24 | "aws_cdk.aws_dynamodb==1.94.1", 25 | "aws-cdk.aws-events==1.94.1", 26 | "aws-cdk.aws-events-targets==1.94.1", 27 | "aws-cdk.aws-iam==1.94.1", 28 | "aws-cdk.aws-lambda==1.94.1", 29 | "aws-cdk.aws-s3-notifications==1.94.1", 30 | ], 31 | python_requires=">=3.6", 32 | classifiers=[ 33 | "Development Status :: 4 - Beta", 34 | "Intended Audience :: Developers", 35 | "License :: OSI Approved :: Apache Software License", 36 | "Programming Language :: JavaScript", 37 | "Programming Language :: Python :: 3 :: Only", 38 | "Programming Language :: Python :: 3.6", 39 | "Programming Language :: Python :: 3.7", 40 | "Programming Language :: Python :: 3.8", 41 | "Topic :: Software Development :: Code Generators", 42 | "Topic :: Utilities", 43 | "Typing :: Typed", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /source.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | rem The sole purpose of this script is to make the command 4 | rem 5 | rem source .venv/bin/activate 6 | rem 7 | rem (which activates a Python virtualenv on Linux or Mac OS X) work on Windows. 8 | rem On Windows, this command just runs this batch file (the argument is ignored). 9 | rem 10 | rem Now we don't need to document a Windows command for activating a virtualenv. 11 | 12 | echo Executing .venv\Scripts\activate.bat for you 13 | .venv\Scripts\activate.bat 14 | --------------------------------------------------------------------------------