├── .eslintrc.js ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── deployment.yml │ └── main.yml ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── aik-high-level-view-HLD.png ├── aik-high-level-view-ML-Data-Pipeline.png ├── choose-credential-cluster.png ├── clone.png ├── connect-cluster.png ├── create-cluster.png ├── dashboard.png ├── ecs-global.png ├── ecs-inside.png ├── ecs-service.png ├── ecs-tasks-selected.png ├── ecs-tasks.png ├── project-cluster.png ├── reuse.png ├── sagemaker-studio.png ├── select-cluster.png ├── select-template.png └── traffic-filtering-server-view.png ├── bin └── advertising-intelligence-kit.ts ├── cdk.json ├── helpers └── tags.ts ├── jest.config.js ├── lib ├── add-ingress.ts ├── aik-stage.ts ├── emr-product-stack.ts ├── filtering-application-stack.ts ├── launch-constraint.ts ├── monitoring-config.json ├── permission-boundary.ts ├── product-construct.ts ├── sagemaker-emr-stack.ts ├── sagemaker-execution.ts └── suppress-nag.ts ├── package.json ├── source ├── emr-bootstrap │ ├── Readme.md │ ├── configurekdc.sh │ └── installpylibs.sh ├── notebooks │ ├── 0_store_configuration.ipynb │ ├── 1_download_ipinyou_data_tos3.ipynb │ ├── 2_OpenRTB_EMR.ipynb │ └── 3_train_xgboost.ipynb ├── sagemaker-sg-cleanup │ ├── delete-sg.ts │ └── handler.ts └── traffic-filtering-app │ ├── Dockerfile.Client │ ├── Dockerfile.Server │ ├── pom.xml │ └── src │ └── main │ ├── java │ └── com │ │ └── aik │ │ ├── filterapi │ │ ├── BidRequest.java │ │ ├── BidRequestFilter.java │ │ └── BidResponse.java │ │ ├── perfclient │ │ ├── FilteringResult.java │ │ └── MultiThreadedClient.java │ │ └── prediction │ │ ├── BidRequestHandler.java │ │ ├── BiddingFilter.java │ │ ├── Downloader.java │ │ └── InferenceServer.java │ ├── resources │ ├── config.properties │ └── log4j2.xml │ ├── scala │ └── com │ │ └── aik │ │ └── prediction │ │ └── Transform.scala │ └── thrift │ └── api.thrift ├── test ├── nag.ts └── test-delete-sg.ts └── tsconfig.json /.eslintrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | "env": { 3 | "es2021": true, 4 | "node": true 5 | }, 6 | "extends": [ 7 | "eslint:recommended", 8 | "plugin:@typescript-eslint/recommended" 9 | ], 10 | "parser": "@typescript-eslint/parser", 11 | "parserOptions": { 12 | "ecmaVersion": "latest", 13 | "sourceType": "module" 14 | }, 15 | "plugins": [ 16 | "@typescript-eslint" 17 | ], 18 | "rules": { 19 | "indent": [ 20 | "error", 21 | 4 22 | ], 23 | "linebreak-style": [ 24 | "error", 25 | "unix" 26 | ], 27 | "quotes": [ 28 | "error", 29 | "double" 30 | ], 31 | "semi": [ 32 | "error", 33 | "never" 34 | ], 35 | "@typescript-eslint/no-explicit-any": "off", 36 | "@typescript-eslint/no-this-alias": "off" 37 | }, 38 | "ignorePatterns": [ 39 | "build/", 40 | "node_modules/", 41 | "cdk.out" 42 | ] 43 | } 44 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description of your changes 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | ### How to verify this change 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | ### Related issues, RFCs 20 | 21 | 22 | 23 | **Issue number:** 24 | 25 | ### PR status 26 | 27 | ***Is this ready for review?:*** NO 28 | ***Is it a breaking change?:*** NO 29 | 30 | ## Checklist 31 | 32 | - [ ] I have performed a *self-review* of my own code 33 | - [ ] I have *commented* my code where necessary, particularly in areas that should be flagged with a TODO, or hard-to-understand areas 34 | - [ ] I have made corresponding changes to the *documentation* (e.g. README.md) 35 | - [ ] My changes generate *no new warnings* 36 | 37 | --- 38 | -------------------------------------------------------------------------------- /.github/workflows/deployment.yml: -------------------------------------------------------------------------------- 1 | name: Prepare Deployment 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | id-token: write 8 | contents: read 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Install Node.js 17 | uses: actions/setup-node@v3 18 | with: 19 | node-version: 18.x 20 | - run: npm install -g cdk-standalone-deployer 21 | - name: Configure AWS Credentials 22 | uses: aws-actions/configure-aws-credentials@v1 23 | with: 24 | role-to-assume: ${{ secrets.DEPLOYMENT_ROLE }} 25 | aws-region: ${{ secrets.DEPLOYMENT_BUCKET_REGION }} 26 | - name: Generate ML deployment stack 27 | run: | 28 | cdk-standalone-deployer generate-link --github-repo-name ${{ github.repository }} \ 29 | --s3-bucket-name ${{ secrets.DEPLOYMENT_BUCKET }} --s3-bucket-region ${{ secrets.DEPLOYMENT_BUCKET_REGION }} --s3-key-prefix aws-rtb-kit-ml.json \ 30 | --stack-name "aik/sagemaker-emr" --install-command "npm install && npx tsc" --build-command "npm run build" \ 31 | --deploy-command "npx cdk deploy 'aik/sagemaker-emr' --require-approval never -c @aws-cdk/core:bootstrapQualifier=rtbkit" \ 32 | --destroy-command "npx cdk destroy 'aik/**' --force -c @aws-cdk/core:bootstrapQualifier=rtbkit" --cdk-qualifier rtbkit --enable-docker 33 | - name: Generate inference deployment stack 34 | run: | 35 | cdk-standalone-deployer generate-link --github-repo-name ${{ github.repository }} \ 36 | --s3-bucket-name ${{ secrets.DEPLOYMENT_BUCKET }} --s3-bucket-region ${{ secrets.DEPLOYMENT_BUCKET_REGION }} --s3-key-prefix aws-rtb-kit-inference.json \ 37 | --stack-name "aik/filtering" --install-command "npm install && npx tsc" --build-command "npm run build" \ 38 | --deploy-command "npx cdk deploy 'aik/filtering' --require-approval never -c @aws-cdk/core:bootstrapQualifier=rtbkit" \ 39 | --destroy-command "npx cdk destroy 'aik/**' --force -c @aws-cdk/core:bootstrapQualifier=rtbkit" --cdk-qualifier rtbkit --enable-docker 40 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_dispatch: 3 | pull_request: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | cdk-nag: 9 | name: CDK linting 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v1 14 | - name: Install Node.JS 15 | uses: actions/setup-node@v1 16 | with: 17 | node-version: 14 18 | 19 | - name: Install dependencies 20 | run: | 21 | npm install -g aws-cdk@2.59.0 22 | npm install 23 | 24 | - name: Perform linting 25 | env: 26 | JSII_SILENCE_WARNING_DEPRECATED_NODE_VERSION: "true" 27 | run: | 28 | npm run build 29 | npm run nag 30 | 31 | cfn-nag: 32 | name: CFN linting 33 | if: ${{ false }} # temporarily disabled 34 | runs-on: ubuntu-latest 35 | steps: 36 | - name: Checkout code 37 | uses: actions/checkout@v1 38 | - name: Install Ruby 39 | uses: ruby/setup-ruby@v1 40 | with: 41 | ruby-version: '2.6' 42 | bundler-cache: true 43 | - name: Install Node.JS 44 | uses: actions/setup-node@v1 45 | with: 46 | node-version: 14 47 | - name: Install dependencies 48 | run: | 49 | gem install cfn-nag 50 | npm install -g aws-cdk@2.59.0 51 | npm install 52 | - name: Perform linting 53 | env: 54 | CFN_NAG_INPUT_PATH: "cdk.out/" 55 | CFN_NAG_TEMPLATE_EXTENSION: "template.json" 56 | run: | 57 | npm run build 58 | cfn_nag_scan --input-path "$CFN_NAG_INPUT_PATH" --template-pattern "..*\.$CFN_NAG_TEMPLATE_EXTENSION" 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # cdk 2 | node_modules 3 | cdk.context.json 4 | cdk.out 5 | build 6 | package-lock.json 7 | deployment.json 8 | 9 | # SBT 10 | .bsp/ 11 | .idea/ 12 | target/ 13 | project/target/ 14 | project/project/ 15 | 16 | # python virtual env 17 | .venv 18 | 19 | # terraform 20 | .terraform 21 | 22 | # build artifact from Makefile 23 | .build 24 | 25 | 26 | # Compiled Java class files 27 | *.class 28 | 29 | # Compiled Python bytecode 30 | *.py[cod] 31 | 32 | # Log files 33 | *.log 34 | 35 | # Package files 36 | *.jar 37 | 38 | # Maven 39 | target/ 40 | dist/ 41 | 42 | # JetBrains IDE 43 | .idea/ 44 | *.iml 45 | 46 | # Unit test reports 47 | TEST*.xml 48 | 49 | # Generated by MacOS 50 | .DS_Store 51 | 52 | # Generated by Windows 53 | Thumbs.db 54 | 55 | # Applications 56 | *.app 57 | *.exe 58 | *.war 59 | 60 | # Large media files 61 | *.mp4 62 | *.tiff 63 | *.avi 64 | *.flv 65 | *.mov 66 | *.wmv 67 | 68 | # typescript 69 | *.js 70 | *.d.ts 71 | !jest.config.js 72 | !.eslintrc.js 73 | 74 | # For local development 75 | dev-deploy.sh 76 | 77 | # For VS Code 78 | .vscode -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## 1.0 - 28 Sep 2022 9 | 10 | ### Added 11 | 12 | - Initial version 13 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | 3 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 4 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 5 | opensource-codeofconduct@amazon.com with any additional questions or comments. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | SPDX-License-Identifier: MIT-0 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 5 | software and associated documentation files (the "Software"), to deal in the Software 6 | without restriction, including without limitation the rights to use, copy, modify, 7 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so. 9 | 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 11 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 12 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 13 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 14 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 15 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | ### What is this repository about? 4 | 5 | Many customers in the programmatic advertising - or adtech - industry have huge data volumes and requirements for low latency inference models. While the heart of a machine learning pipeline - training models and tracking performance - has become easier with services like Amazon SageMaker Studio in the last few years, data scientists still find challenges with ingesting and transforming ever-growing big data (>TB) as well as creating very low latency (<10ms) models. The RTB Intelligent Kit is a framework with familiar use cases to address these challenges. 6 | 7 | With one CDK application, a data scientist can deploy an end-to-end pipeline with pre-configured use cases, and can adapt the Kit to their particular needs. The Kit includes: 8 | 9 | * sample public data in OpenRTB format 10 | * a data transformation pipeline using EMR 11 | * an ML environment including SageMaker Studio 12 | * a deployment of containerized inference models for integration into customer platforms at low latency 13 | 14 | The pre-configured use case at launch is bid filtering - predicting the likelihood of a bidder to make a bid on a given bid request. The pipeline is easily adaptable to other use cases common in the industry. By launching the Kit as an open source project, users can contribute back their own adaptations and implementations in this growing industry. 15 | 16 | ## Support 17 | 18 | If you notice a defect, have questions, or need support with deployment, please create an issue in this repository. 19 | 20 | ## Prerequisites 21 | 22 | * No Amazon SageMaker Domain in your Account as the Kit installs a new domain. You can only have one domain per account and region. 23 | * [kaggle](https://www.kaggle.com/) a Kaggle account to download the example data set into the Kit 24 | 25 | If you want to build this solution on your local environment, you will need to install these prerequisites: 26 | 27 | * [Node.js](https://nodejs.org/en/) with version higher than 10.13.0 (Please note that version between 13.0.0 and 13.6.0 are not compatible too) 28 | * [CDK](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html#getting_started_prerequisites) with version higher than 2.5.0 29 | * [Docker](https://docs.docker.com/get-docker/) version 20 30 | 31 | Alternatively, you can build this solution in the AWS Cloud. For that, you will only need: 32 | 33 | * an AWS Account with a granted permission to deploy CloudFormation stacks. 34 | 35 | ## Architecture 36 | 37 | The overall architecture is depicted below and consists of three major components: 38 | 39 | 1. ML Pipeline Component: Downloading & preparing of the data, feature engineering and training of the ML model 40 | 2. Data Repository: Holding all the data and generated artifacts 41 | 3. Filtering Server Component: Showcasing the usage of an ML Model for traffic filtering. 42 | 43 | ![High Level View](assets/aik-high-level-view-HLD.png) 44 | 45 | ### ML - Data Pipeline 46 | 47 | The Data and ML training pipeline is implemented as a series of Amazon SageMaker Studio Notebooks, which should be run in sequence to execute all the pipeline steps. For processing of big volumes of bidding and impression data it is utilizing an EMR cluster, which is instantiated from Amazon SageMaker Studio. 48 | 49 | ![ML-Data-Pipeline](assets/aik-high-level-view-ML-Data-Pipeline.png) 50 | 51 | ### Filtering Server Component 52 | 53 | The trained model is showcased in a traffic filtering use case where a bidding server is using the model inference to make bid / no-bid decisions based on the models prediction. All involved components are deployed as individual container in an Amazon ECS cluster. 54 | 55 | ![Bidding Application](assets/traffic-filtering-server-view.png) 56 | 57 | ## Deployment 58 | 59 | ### Cloud-powered deployment 60 | 61 | You can deploy this Kit into your Account leveraging such services as AWS CloudFormation and AWS CodeBuild without needing to install the dependencies locally. 62 | Please note that additional charges will apply, as this approach involves [bootstrapping a separate CDK environment](https://docs.aws.amazon.com/cdk/v2/guide/bootstrapping.html), running a [CodeBuild build](https://aws.amazon.com/codebuild/pricing/), and executing multiple AWS Lambda functions. 63 | 64 | This Kit consists of two CDK stacks. You will deploy both in the same way, just providing different stack URLs. After deploying the first part, continue with running the Kit. When you will need to deploy the **inference part** of the Kit, refer back to this deployment guide again. 65 | 66 | #### Using AWS Management Console 67 | 68 | 1. Navigate to the CloudFormation console. 69 | 2. Choose **Create stack** - **Template is ready** - **Upload a template file**. 70 | 3. Download and choose one of the following files you want to deploy: 71 | * ML part, deploy this stack first: [aws-rtb-kit-ml.json](https://artifacts.kits.eventoutfitters.aws.dev/industries/adtech/rtb/aws-rtb-kit-ml.json) 72 | * Inference part, only deploy this stack when requested in the [Run the inference](#run-the-inference) section: [aws-rtb-kit-inference.json](https://artifacts.kits.eventoutfitters.aws.dev/industries/adtech/rtb/aws-rtb-kit-inference.json) 73 | 4. Click **Next**. 74 | 5. Enter the *Stack name*, for example, **RTB-Kit-MLDataPipeline** or **RTB-Kit-Inference**. Keep the parameter **CDKQUALIFIER** with its default value. 75 | 6. Click **Next**. 76 | 7. Add tags if desired and click **Next**. 77 | 8. Scroll down to the bottom, check **I acknowledge that AWS CloudFormation might create IAM resources** and click **Submit**. 78 | 9. Wait until the stack is successfully deployed. There will be several additional stacks created. 79 | 80 | After deploying the Kit, continue to [Run the solution](#run-the-solution). If you decided to build the Kit locally, continue with [Install](#install) instead. 81 | 82 | #### Using AWS CLI 83 | 84 | 1. Make sure your AWS CLI credentials are configured for your AWS Account and the target region. 85 | 2. Run the following commands, changing the `--stack-name` argument if desired: 86 | * ML part, deploy this stack first: 87 | 88 | **Linux/Mac:** 89 | 90 | ```bash 91 | ML_STACK="$(mktemp)" && \ 92 | curl -Ss https://artifacts.kits.eventoutfitters.aws.dev/industries/adtech/rtb/aws-rtb-kit-ml.json -o $ML_STACK 93 | 94 | aws cloudformation create-stack --template-body file://$ML_STACK \ 95 | --parameters ParameterKey=CDKQUALIFIER,ParameterValue=rtbkit \ 96 | --capabilities CAPABILITY_IAM --stack-name RTB-Kit-MLDataPipeline 97 | ``` 98 | 99 | **Windows:** 100 | 101 | ```batchfile 102 | set ML_STACK=%TMP%\aws-rtb-kit-ml.json && curl -Ss https://artifacts.kits.eventoutfitters.aws.dev/industries/adtech/rtb/aws-rtb-kit-ml.json -o %ML_STACK% 103 | 104 | aws cloudformation create-stack --template-body file://%ML_STACK%^ 105 | --parameters ParameterKey=CDKQUALIFIER,ParameterValue=rtbkit^ 106 | --capabilities CAPABILITY_IAM --stack-name RTB-Kit-MLDataPipeline 107 | ``` 108 | 109 | * Inference part, only deploy this stack when requested in the [Run the inference](#run-the-inference) section: 110 | 111 | **Linux/Mac:** 112 | 113 | ```bash 114 | INFERENCE_STACK="$(mktemp)" && \ 115 | curl -Ss https://artifacts.kits.eventoutfitters.aws.dev/industries/adtech/rtb/aws-rtb-kit-inference.json -o $INFERENCE_STACK 116 | 117 | aws cloudformation create-stack --template-body file://$INFERENCE_STACK \ 118 | --parameters ParameterKey=CDKQUALIFIER,ParameterValue=rtbkit \ 119 | --capabilities CAPABILITY_IAM --stack-name RTB-Kit-Inference 120 | ``` 121 | 122 | **Windows:** 123 | 124 | ```batchfile 125 | set INFERENCE_STACK=%TMP%\aws-rtb-kit-inference.json && curl -Ss https://artifacts.kits.eventoutfitters.aws.dev/industries/adtech/rtb/aws-rtb-kit-inference.json -o %INFERENCE_STACK% 126 | 127 | aws cloudformation create-stack --template-body file://%INFERENCE_STACK%^ 128 | --parameters ParameterKey=CDKQUALIFIER,ParameterValue=rtbkit^ 129 | --capabilities CAPABILITY_IAM --stack-name RTB-Kit-Inference 130 | ``` 131 | 132 | After deploying the Kit, continue with [Run the solution](#run-the-solution). If you decided to build the Kit locally, continue with [Install](#install) instead. 133 | 134 | ### Install 135 | 136 | The following steps describe how to build the Kit locally. If you have already started deploying the Kit using the Cloud-powered approach, skip this section and continue with [Run the solution](#run-the-solution). 137 | 138 | First, clone this repository into your development environment. We recommend AWS Cloud9 as a development environment, which contains all required dependencies listed above. 139 | 140 | To clone the repository 141 | 142 | ```bash 143 | git clone https://github.com/aws-samples/aws-rtb-intelligence-kit 144 | ``` 145 | 146 | Then, in your terminal, navigate to the root of the cloned directory 147 | 148 | ```bash 149 | cd aws-rtb-intelligence-kit 150 | ``` 151 | 152 | In order to install this module's dependencies, use the following command. 153 | 154 | ```bash 155 | npm install 156 | ``` 157 | 158 | ### Build 159 | 160 | In order to build the CDK modules part of this repository and compile the Typescript code, use the following command. 161 | 162 | ```bash 163 | npm run build 164 | ``` 165 | 166 | ### Deploy the RTB Intelligence Kit 167 | 168 | For a standalone deployment, run the following commands. 169 | 170 | > Note that you will need the AWS CDK installed. 171 | 172 | First of all you need to bootstrap your CDK to be able to deploy the Kit into your Account. 173 | 174 | ```bash 175 | cdk bootstrap 176 | ``` 177 | 178 | Then we deploy the components of the Kit, which are used to run the data processing and the ML model training, with the following command: 179 | 180 | ```bash 181 | cdk deploy "aik/sagemaker-emr" --require-approval never 182 | ``` 183 | 184 | Running this command deploys: 185 | 186 | * an Amazon SageMaker Studio Domain, with a user which you will be using later to run data processing and ML trainings 187 | * a Service Catalog Product which allows you to instantiate an EMR cluster from the SageMaker Studio workbench. You will be using the EMR cluster to run the actual data processing of the raw data 188 | 189 | After those steps we are ready to use the Kit. 190 | 191 | ## Run the solution 192 | 193 | ### Open SageMaker Studio and create the EMR cluster 194 | 195 | The first step is to open Amazon SageMaker Studio. For this you navigate to the SageMaker service within the AWS Management Console. Then click on **Control panel** link on the left hand side. 196 | 197 | You will find the SageMaker Domain with the prepared user. Go ahead an open SageMaker Studio for this user as illustrated in the screenshot below. You find more detailed instructions in the [Amazon SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/configure-service-catalog-templates-studio-walkthrough.html). 198 | 199 | ![sagemaker-studio](assets/sagemaker-studio.png) 200 | 201 | Now we can go and create the EMR Cluster which we will be using for data processing. 202 | 203 | > :warning: Note that the running EMR Cluster is charged regardless of whether it is actually in use. So make sure to terminate it if it is not necessary anymore. 204 | 205 | The cluster can easily be created from within SageMaker Studio. Open SageMaker resources by clicking on the orange triangle icon on the left pane. Then click on Projects, select Clusters in the dropdown menu. Click on Create cluster. It will open another tab. 206 | 207 | ![project-cluster](assets/project-cluster.png) 208 | 209 | Select **Clusters** and click on **Create cluster**. It will open another tab. 210 | 211 | In the **Select template** step, select `SageMaker EMR Product` template. This template was created earlier by your CDK deployment. 212 | 213 | ![select-template](assets/select-template.png) 214 | 215 | Enter any value for `EmrClusterName`. Leave other fields with default values. 216 | 217 | ![create-cluster](assets/create-cluster.png) 218 | 219 | The creation of the cluster takes some time, but **we can proceed with the next steps while waiting**. 220 | 221 | > Note that you can terminate the cluster from the same user interface and recreate the cluster at any point in time. This helps you to control the cost of running the Kit. 222 | 223 | ### Clone the GitHub repo 224 | 225 | As a last preparation step, we clone the GitHub repo to obtain the Notebooks for running the Data Processing and ML model training. Click on the **Git icon** on the left hand side. Then select **Clone a Repository** button. 226 | 227 | Use the clone url of this GitHub repository: . 228 | 229 | ![clone](assets/clone.png) 230 | 231 | As a result, you will have a local copy of the GitHub repository in your SageMaker Studio Environment. Open the `aws-rtb-intelligence-kit/source/notebooks` directory on the File Browser tab in SageMaker Studio to find the notebook files. We will use these notebooks in the following steps. In particular we have the following notebooks: 232 | 233 | * `0_store_configuration.ipynb` 234 | * `1_download_ipinyou_data_tos3.ipynb` 235 | * `2_OpenRTB_EMR.ipynb` 236 | * `3_train_xgboost.ipynb` 237 | 238 | ### Run the data processing Pipeline 239 | 240 | In the following we run through the individual steps of the pipeline. Each of the steps is provided as a notebook which are extensively documented. Therefore we are only providing the highlights here. You can find the notebook files in the `aws-rtb-intelligence-kit/source/notebooks` directory. 241 | 242 | In general the notebooks are desigend to be run with `Python 3 (Data Science)` kernel. The exception of this rule is `2_OpenRTB_EMR.ipynb` which is utilizing a `PySpark (SparkMagic)` kernel. Detailed instruction on how to change the kernel for a SageMaker Notebook can be found in the [Amazon SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-create-open.html#notebooks-open) 243 | 244 | #### Setting up the environment 245 | 246 | As a firsts step we are setting up the environment with the `0_store_configuration.ipynb` notebook. For this purpose we pick up the data bucket name, which has been created during the deployment from the Systems Manager Parameter store. We use it to generate the various prefixes for organizing the different artefacts we generate in the S3 bucket. We store those prefixed in the Parameter Store to later use them in the different steps of the pipeline. 247 | 248 | #### Downloading the raw data 249 | 250 | In order to bootstrap our project, we are using a dataset which is available on the Kaggle website. `1_download_ipinyou_data_tos3.ipynb` takes you through the required steps to download the data into the prepared S3 bucket. For this purpose you need a Kaggle Account which can be created on their website. The intention is to get the solution end to end working with this example dataset prior to introducing your own data which will require some adoptions depending on the details of the data format you are using. 251 | 252 | #### Preprocessing the data and feature engineering 253 | 254 | The preprocessing and feature engineering steps will be involving large amounts of data - especially in real life scenarios. In order to be able to process the data, we will be running it on the Amazon EMR Cluster which we have been creating already in an earlier step. The notebook runs within a PySpark (SparkMagic) kernel and will be connected to the EMR cluster. 255 | 256 | The screenshots below are showing how you can connect your notebook to the EMR cluster. First click on cluster. 257 | 258 | ![connect-cluster](assets/connect-cluster.png) 259 | 260 | The next step is selecting the EMR cluster from the list of available clusters. 261 | 262 | ![select-cluster](assets/select-cluster.png) 263 | 264 | Now we can select the **No credentials** option and continue with **Connect** as outlined in the screenshot below. 265 | 266 | ![choose-credential-cluster](assets/choose-credential-cluster.png) 267 | 268 | We are setup and can run the data processing pipeline. 269 | 270 | The actual pipeline is all contained in `2_OpenRTB_EMR.ipynb` and documented inline the notebook. 271 | 272 | ### Train the machine learning model 273 | 274 | Now as we have prepared the features which will be used to train the machine learning model, we can now train our first model. The model training is done by executing the `3_train_xgboost.ipynb`notebook which contains extensive inline documentation. 275 | 276 | At the end of this notebook we will save our XGBoost binary model in S3, which is then used for inference in the following steps. 277 | 278 | ### Run the inference 279 | 280 | The next step is to deploy the inference part of the Kit. 281 | 282 | If you have built the Kit locally, use the following command to deploy the inference part: 283 | 284 | ```bash 285 | cdk deploy "aik/filtering" --require-approval never 286 | ``` 287 | 288 | Otherwise (if you preferred the Cloud-powered approach), refer to the [Cloud-powered deployment](#cloud-powered-deployment) section and follow the instructions to deploy the inference part. 289 | 290 | As you can see in the architecture diagram below, the inference is composed of one ECS cluster and one tasks and a CloudWatch dashboard. 291 | 292 | ![architecture diagram](./assets/traffic-filtering-server-view.png) 293 | 294 | #### Overview of deployed tasks in ECS 295 | 296 | Once the CDK stacked is deployed, you can visit the [ECS console](https://console.aws.amazon.com/ecs). You see the cluster created as described below 297 | ![cluster-global](./assets/ecs-global.png) 298 | 299 | Once you have clicked on the `advertising-server` cluster, you should have a similar view: 300 | ![cluster-service](./assets/ecs-service.png) 301 | 302 | Now, if you select the `tasks` tab, you can see the ECS task managed by the service 303 | ![cluster-inside](./assets/ecs-inside.png) 304 | 305 | The task should be in a `running` status. If you clicked on the task id, you would see the detail of your task. 306 | ![ecs-task](./assets/ecs-tasks.png) 307 | 308 | As you can see in image, the ECS task is composed of three containers. The first container is named the `filtering-server`. The role of this service is to listen for bid request and generate a likelihood based on three artifacts generated at the [previous step](#run-the-data-processing-pipeline): 309 | 310 | 1. An XGBoost model written in binary format. The model is stored on S3 bucket. The URI is passed to the filtering through an SSM parameter `/aik/xgboost/path`. 311 | 2. A pipeline model packaged as a [MLeap bundle](https://combust.github.io/mleap-docs/). The package is stored on an S3 bucket. The URI is passed to the filtering server through an SSM parameter `/aik/pipelineModelArtifactPath`. 312 | 3. An Schema in JSON format representing the format of the input data of the pipeline. The JSON file is stored on an S3 bucket. The URI is passed to the filtering server through an SSM parameter `/aik/pipelineModelArtifactSchemaPath`. 313 | 314 | The second container is named the `advertising-server`. Its role is to emulate an ad server, to send bid request and then to receive a likelihood. This container generates 500000 bid requests from set of bid requests generated at [previous step](#run-the-data-processing-pipeline). The requests are stored as [JSON lines](https://jsonlines.org/examples/) in a file. The file is stored on an S3 bucket. The URI is passed to the filtering server through an SSM parameter `/aik/inference_data`. The last container run the CloudWatch agent. Its role is to receive any metrics produce by one of the two first containers and send them to a [CloudWatch dashboard](https://console.aws.amazon.com/cloudwatch/#dashboards:name=Monitoring-Dashboard). 315 | 316 | #### Introduction to the CloudWatch dashboard 317 | 318 | ![dashboard](assets/dashboard.png) 319 | The description starts from left to right. 320 | 321 | The first widget displays the latency (the 99th percentile over the last 10 seconds) of the likelihood computation. There are two latencies: 322 | 323 | 1. The latency from the filtering client. This metrics is important because it is directly related to the latency experienced by an ad server. 324 | 2. The second metric represents the latency from the filtering server. This metric is important because it shows that most of the time is spent in computation only. 325 | 326 | The second widget represents the per second throughput of filtering request. The throughput is an average computed over a period of ten second. 327 | 328 | The third widget represents the likelihood to bid (average over the last 10 seconds). This metric is important because is allows the user to monitor the quality of the likelihood computed. 329 | 330 | The fourth widget represents the number of bid request processed during the last 5 minutes. 331 | 332 | ## How to load a new trained Model for inference 333 | 334 | Loading a new model is as simple as restarting a new ECS task. 335 | 336 | Please go the [ECS console](https://console.aws.amazon.com/ecs): 337 | 338 | ![cluster-global-redeploy](./assets/ecs-global.png) 339 | 340 | Please click on `advertising-server` to view cluster details: 341 | 342 | ![cluster-service-redeploy](./assets/ecs-service.png) 343 | 344 | Please click on the `tasks` tab to view the running tasks: 345 | 346 | ![cluster-task](./assets/ecs-inside.png) 347 | 348 | Select the running task and click on the `Stop` button: 349 | 350 | ![cluster-task-selected](./assets/ecs-tasks-selected.png) 351 | 352 | You will be warned that the task might be configured to automatically restart. Actually this is the case. The task will be restarting with the new configuration and updated model. 353 | 354 | Follow the instructions printed on the user interface. 355 | 356 | After a few minutes, a new ECS task will be running. The new task will automatically load the latest model generated. 357 | 358 | ## 🔰 Description 359 | 360 | ### Advertising Intelligence Stack 361 | 362 | The infrastructure can be divided in two main parts: 363 | 364 | * Sagemaker construct managing the Sagemaker domain, VPC, subnet, roles, and security groups. 365 | * A service catalog product stack that allows Sagemaker users to provision an EMR cluster 366 | * A cluster running a bidding application (emulation of an ad server) and a sidecar for filtering the bid request. 367 | 368 | ## ❄ Outputs for the Advertising Intelligence Kit 369 | 370 | ### SSM Parameters 371 | 372 | The `aik/filtering` infrastructure create parameters in [AWS Parameter Store](https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html). Those parameters can be used by other services. 373 | 374 | Parameter Variable | Description 375 | --------------------------------------------| ------------ 376 | `/aik/monitoring/config` | The CloudWatch configuration for the metrics 377 | 378 | ### CloudWatch Metrics 379 | 380 | Once the stack is deployed and the filtering application running, cloud watch metrics are created. 381 | 382 | Namespace| Dimension(Key)| Dimension(Value) | metric name | Description 383 | ---------|---------------|------------------|--------------|------------- 384 | aik | metric_type | gauge | adserver_client_likelihood_to_bid | likelihood to bid 385 | aik | metric_type | timing | adserver_client_adserver_latency_ms | observed latency on client side for getting the likelihood to bid 386 | aik | metric_type | timing | filtering_server_filtering_latency | observed latency on server side for getting the likelihood to bid 387 | aik | metric_type | counter |filtering_server_filtering_count | number of bid requests submitted 388 | 389 | ### Cleanup 390 | 391 | Follow these instructions to remove the Kit from your Account. 392 | 393 | 1. Delete EMR clusters created by SageMaker user 394 | 1. Terminate any cluster you have created as outlined in [Amazon SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/configure-service-catalog-templates-studio-walkthrough.html). 395 | 2. Delete the SageMaker Studio Domain. Detailed instructions can be found in [Amazon SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-studio-delete-domain.html). It involves the following steps: 396 | 1. Delete the SageMaker Studio Applications 397 | 2. Delete the SageMaker Studio User 398 | 3. Delete the SageMaker Studio Domain 399 | 4. Delete the Elastic File System created by SageMaker 400 | 3. Delete parameters from Parameter store 401 | 4. If you deployed the Kit from your local environment, run the following commands: 402 | 403 | ```sh 404 | cdk destroy "aik/filtering" 405 | cdk destroy "aik/sagemaker-emr" 406 | ``` 407 | 408 | 5. If you deployed the Kit using the Cloud-powered approach, delete the Kit deployment CloudFormation stacks you created (**RTB-Kit-MLDataPipeline** and **RTB-Kit-Inference**). This will also delete the Kit stacks as well. Finally, delete the CDK bootstrapping stack **CDKToolkit-rtbkit**. 409 | * Delete the stacks using AWS Management Console 410 | * Alternatively, use AWS CLI: 411 | 412 | ```sh 413 | aws cloudformation delete-stack --stack-name RTB-Kit-MLDataPipeline 414 | aws cloudformation delete-stack --stack-name RTB-Kit-Inference 415 | aws cloudformation delete-stack --stack-name CDKToolkit-rtbkit 416 | ``` 417 | -------------------------------------------------------------------------------- /assets/aik-high-level-view-HLD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/aik-high-level-view-HLD.png -------------------------------------------------------------------------------- /assets/aik-high-level-view-ML-Data-Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/aik-high-level-view-ML-Data-Pipeline.png -------------------------------------------------------------------------------- /assets/choose-credential-cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/choose-credential-cluster.png -------------------------------------------------------------------------------- /assets/clone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/clone.png -------------------------------------------------------------------------------- /assets/connect-cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/connect-cluster.png -------------------------------------------------------------------------------- /assets/create-cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/create-cluster.png -------------------------------------------------------------------------------- /assets/dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/dashboard.png -------------------------------------------------------------------------------- /assets/ecs-global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/ecs-global.png -------------------------------------------------------------------------------- /assets/ecs-inside.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/ecs-inside.png -------------------------------------------------------------------------------- /assets/ecs-service.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/ecs-service.png -------------------------------------------------------------------------------- /assets/ecs-tasks-selected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/ecs-tasks-selected.png -------------------------------------------------------------------------------- /assets/ecs-tasks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/ecs-tasks.png -------------------------------------------------------------------------------- /assets/project-cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/project-cluster.png -------------------------------------------------------------------------------- /assets/reuse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/reuse.png -------------------------------------------------------------------------------- /assets/sagemaker-studio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/sagemaker-studio.png -------------------------------------------------------------------------------- /assets/select-cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/select-cluster.png -------------------------------------------------------------------------------- /assets/select-template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/select-template.png -------------------------------------------------------------------------------- /assets/traffic-filtering-server-view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-rtb-intelligence-kit/61691a02040813af9bb286683062ff047bc4323f/assets/traffic-filtering-server-view.png -------------------------------------------------------------------------------- /bin/advertising-intelligence-kit.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | import "source-map-support/register" 3 | import * as cdk from "aws-cdk-lib" 4 | import { applyTags } from "../helpers/tags" 5 | import { AIKStage } from "../lib/aik-stage" 6 | 7 | const app = new cdk.App() 8 | 9 | // This stage is meant to be deployed to local development environments. 10 | // For production deployments, use CDK Pipelines. 11 | const aik = new AIKStage(app, "aik") 12 | applyTags(aik) 13 | 14 | // If you want to deploy this kit to accounts as part of CI/CD, use CDK Pipelines 15 | // to instantiate an instance of each stage with hard-coded properties. 16 | // 17 | // For example, in your pipeline stack for your prod stage: 18 | // 19 | // const prodStage = new AIKStage(app, 'aik-stage', { 20 | // environment: 'production', 21 | // { PRODUCTION CONFIG VALUES HERE }, 22 | // }); 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/advertising-intelligence-kit.ts", 3 | "context": { 4 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /helpers/tags.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from "aws-cdk-lib" 2 | import { IConstruct } from "constructs" 3 | 4 | /** 5 | * Applies default tags to the given construct. 6 | */ 7 | export const applyTags = (c: IConstruct) => { 8 | cdk.Tags.of(c).add("Project", "advertising-intelligence-kit") 9 | } 10 | -------------------------------------------------------------------------------- /jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | testEnvironment: "node", 3 | roots: ["/test"], 4 | testMatch: ["**/*.test.ts"], 5 | transform: { 6 | "^.+\\.tsx?$": "ts-jest" 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /lib/add-ingress.ts: -------------------------------------------------------------------------------- 1 | import * as ec2 from "aws-cdk-lib/aws-ec2" 2 | import { Construct } from "constructs" 3 | 4 | // Helper function to create the roles, since it's very repetitive 5 | export function addIngress(scope: Construct, 6 | name: string, 7 | toSecurityGroup: ec2.CfnSecurityGroup | string | ec2.SecurityGroup, 8 | fromSecurityGroup: ec2.CfnSecurityGroup | string | ec2.SecurityGroup, 9 | ipProtocol: string, 10 | reverse: boolean, 11 | fromPort?: number, 12 | toPort?: number) { 13 | 14 | if (fromPort === undefined) fromPort = 0 15 | if (toPort === undefined) toPort = 65535 16 | 17 | function getId(sg: ec2.CfnSecurityGroup | string | ec2.SecurityGroup): string { 18 | if (sg instanceof ec2.CfnSecurityGroup) return sg.ref 19 | if (sg instanceof ec2.SecurityGroup) return sg.securityGroupId 20 | return sg 21 | } 22 | 23 | // Delete the ingress rules before the security groups so we don't get stuck 24 | function depend(ing: ec2.CfnSecurityGroupIngress) { 25 | if (toSecurityGroup instanceof ec2.CfnSecurityGroup) { 26 | ing.addDependency(toSecurityGroup) 27 | } 28 | if (toSecurityGroup instanceof ec2.SecurityGroup) { 29 | ing.node.addDependency(toSecurityGroup) 30 | } 31 | if (fromSecurityGroup instanceof ec2.CfnSecurityGroup) { 32 | ing.addDependency(fromSecurityGroup) 33 | } 34 | if (fromSecurityGroup instanceof ec2.SecurityGroup) { 35 | ing.node.addDependency(fromSecurityGroup) 36 | } 37 | } 38 | 39 | const to = getId(toSecurityGroup) 40 | const from = getId(fromSecurityGroup) 41 | 42 | const in1 = new ec2.CfnSecurityGroupIngress(scope, name, { 43 | ipProtocol, fromPort, toPort, 44 | groupId: to, sourceSecurityGroupId: from, 45 | }) 46 | 47 | depend(in1) 48 | 49 | if (reverse) { 50 | // Create the same ingress role in reverse 51 | const in2 = new ec2.CfnSecurityGroupIngress(scope, name + "-rev", { 52 | ipProtocol, fromPort, toPort, 53 | sourceSecurityGroupId: to, groupId: from, 54 | }) 55 | 56 | depend(in2) 57 | } 58 | } -------------------------------------------------------------------------------- /lib/aik-stage.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from "aws-cdk-lib" 2 | import { Construct } from "constructs" 3 | import {FilteringApplicationStack} from "./filtering-application-stack" 4 | import { SagemakerEmrStack } from "./sagemaker-emr-stack" 5 | 6 | /** 7 | * Advertising Intelligence Kit deployment stage. 8 | * 9 | * This stage represents the entire app, which may be deployed to various 10 | * environments such as local development accounts, beta, gamma, prod, etc. 11 | * 12 | * This stage has the following stacks: 13 | * - Networking 14 | * - Filtering application 15 | * - Sagemaker Studio and Service Catalog Product 16 | * - EMR cluster (deployed from service catalog) 17 | */ 18 | export class AIKStage extends cdk.Stage { 19 | constructor(scope: Construct, id: string, props?: cdk.StageProps) { 20 | super(scope, id, props) 21 | 22 | // Create the SageMaker studio domain and the Service Catalog product 23 | // that allows analysts to create EMR clusters from their studio instance. 24 | const sagemakerEmr = new SagemakerEmrStack(this, "sagemaker-emr", { 25 | description: "(SO9127) Advertising Intelligence Kit - SageMaker and EMR stack (uksb-1tf424ncp)" 26 | }) 27 | 28 | new FilteringApplicationStack(this, "filtering", { 29 | description: "Advertising Intelligence Kit - Filtering application stack", 30 | vpc : sagemakerEmr.vpc, 31 | trainingBucket: sagemakerEmr.trainingBucket 32 | }) 33 | 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /lib/emr-product-stack.ts: -------------------------------------------------------------------------------- 1 | import { Construct } from "constructs" 2 | import * as sc from "@aws-cdk/aws-servicecatalog-alpha" 3 | import * as cdk from "aws-cdk-lib" 4 | import * as s3 from "aws-cdk-lib/aws-s3" 5 | import * as emr from "aws-cdk-lib/aws-emr" 6 | import * as ec2 from "aws-cdk-lib/aws-ec2" 7 | import * as iam from "aws-cdk-lib/aws-iam" 8 | import { addIngress } from "./add-ingress" 9 | import { Fn } from "aws-cdk-lib" 10 | 11 | /** 12 | * Configurable properties for the EMR stack 13 | * 14 | * We can't just use object references to refer to values in the main stack, 15 | * since this stack is deployed to Service Catalog as a template. 16 | * 17 | */ 18 | export interface EmrStackProps extends cdk.StackProps { 19 | ec2SubnetIdExportName: string, 20 | ec2VpcIdExportName: string, 21 | sageMakerSGIdExportName: string, 22 | emrBucketExportName: string, 23 | emrEc2RoleNameExportName: string, 24 | emrServiceRoleNameExportName: string 25 | emrSecurityConfigurationNameExportName: string 26 | } 27 | 28 | /** 29 | * This stack will be configured as the stack created by the ServiceCatalog 30 | * product. Normally we would reference a CloudFormation template URL when we 31 | * create a product, but the `ProductStack` class allows us to define it in CDK. 32 | * 33 | * Adapted from the template in this blog post: 34 | * 35 | * https://aws.amazon.com/blogs/machine-learning/part-1-create-and-manage-amazon-emr-clusters-from-sagemaker-studio-to-run-interactive-spark-and-ml-workloads/ 36 | */ 37 | export class EmrStack extends sc.ProductStack { 38 | constructor(scope: Construct, id: string, props: EmrStackProps) { 39 | super(scope, id) 40 | 41 | // Parameters - normally we would not use CloudFormation parameters in 42 | // a CDK app, but they are required for ServiceCatalog to allow users 43 | // to configure products. 44 | 45 | const cfnParams = [ 46 | { 47 | name: "SageMakerProjectName", 48 | type: "String", 49 | description: "Name of the project", 50 | }, 51 | { 52 | name: "SageMakerProjectId", 53 | type: "String", 54 | description: "Service generated Id of the project", 55 | }, 56 | { 57 | name: "EmrClusterName", 58 | type: "String", 59 | description: "EMR cluster Name", 60 | }, 61 | { 62 | name: "MainInstanceType", 63 | type: "String", 64 | description: "Instance type of the EMR main node", 65 | default: "m5.xlarge", 66 | allowedValues: [ 67 | "m5.xlarge", 68 | "m5.2xlarge", 69 | "m5.4xlarge", 70 | ] 71 | }, 72 | { 73 | name: "CoreInstanceType", 74 | type: "String", 75 | description: "Instance type of the EMR core nodes", 76 | default: "m5.xlarge", 77 | allowedValues: [ 78 | "m5.xlarge", 79 | "m5.2xlarge", 80 | "m5.4xlarge", 81 | "m3.medium", 82 | "m3.large", 83 | "m3.xlarge", 84 | "m3.2xlarge", 85 | ] 86 | }, 87 | { 88 | name: "CoreInstanceCount", 89 | type: "String", 90 | description: "Number of core instances in the EMR cluster", 91 | default: "2", 92 | allowedValues: ["2", "5", "10"] 93 | }, 94 | { 95 | name: "EmrReleaseVersion", 96 | type: "String", 97 | description: "The release version of EMR to launch", 98 | default: "emr-6.4.0", 99 | allowedValues: ["emr-6.4.0"] 100 | }, 101 | { 102 | name: "AutoTerminationIdleTimout", 103 | type: "String", 104 | description: "Specifies the amount of idle time in seconds after which the cluster automatically terminates. You can specify a minimum of 60 seconds and a maximum of 604800 seconds (seven days).", 105 | default: "3600", 106 | allowedPattern: ["(6[0-9]|[7-9][0-9]|[1-9][0-9]{2,4}|[1-5][0-9]{5}|60[0-3][0-9]{3}|604[0-7][0-9]{2}|604800)"] 107 | } 108 | ] 109 | 110 | const emrSecurityConfigurationName = Fn.importValue(props.emrSecurityConfigurationNameExportName) 111 | const jobFlowRoleName = Fn.importValue(props.emrEc2RoleNameExportName) 112 | const emrServiceRoleName = Fn.importValue(props.emrServiceRoleNameExportName) 113 | 114 | // NOTE: There is a bug in the SageMaker UI when choosing a template for the cluster, 115 | // which prevents us from using "Number" parameter types. 116 | 117 | const cfnParamMap = new Map() 118 | for (const p of cfnParams) { 119 | 120 | // Create the parameter and associate it with this stack 121 | const cfnp = new cdk.CfnParameter(this, p.name, { 122 | type: p.type, 123 | description: p.description, 124 | default: p.default, 125 | allowedValues: p.allowedValues 126 | }) 127 | 128 | // Add the parameter to a map so we can look it up later 129 | cfnParamMap.set(p.name, cfnp) 130 | } 131 | 132 | // This bucket was created by the parent stack and holds bootstrap scripts 133 | const sourceBucket = s3.Bucket.fromBucketAttributes(this, "source-bucket", { 134 | bucketName: Fn.importValue(props.emrBucketExportName) 135 | }) 136 | 137 | // VPC reference 138 | 139 | // Security groups - we're using the L1 CfnSecurityGroup since it allows 140 | // us to only pass the VPC Id, otherwise we would need to import the 141 | // VPC, but we can't since this stack is deployed by Service Catalog 142 | 143 | // EMR Main SG 144 | const mainSG = new ec2.CfnSecurityGroup(this, "main-sg", { 145 | groupDescription: "SageMaker EMR Cluster Main", 146 | vpcId: cdk.Fn.importValue(props.ec2VpcIdExportName), 147 | }) 148 | 149 | // EMR Core SG 150 | const coreSG = new ec2.CfnSecurityGroup(this, "core-sg", { 151 | groupDescription: "SageMaker EMR Cluster Core", 152 | vpcId: cdk.Fn.importValue(props.ec2VpcIdExportName), 153 | }) 154 | 155 | // EMR Service SG 156 | const svcSG = new ec2.CfnSecurityGroup(this, "svc-sg", { 157 | groupDescription: "SageMaker EMR Cluster Service", 158 | vpcId: cdk.Fn.importValue(props.ec2VpcIdExportName), 159 | }) 160 | 161 | // Ingress rules 162 | 163 | // TODO: An EMR L2 construct could hide all of these ingress rules, which 164 | // need to be exactly correct or you don't find out until runtime, which 165 | // makes for a lot of time consuming trial and error when using the L1. 166 | 167 | const sageMakerSGId = cdk.Fn.importValue(props.sageMakerSGIdExportName) 168 | 169 | // These 3 get added automatically if missing, which makes it 170 | // impossible to delete the stack if we don't add them explicitly 171 | addIngress(this, "main-core-8443", mainSG, coreSG, "tcp", true, 8443, 8443) 172 | addIngress(this, "main-svc-8443", mainSG, svcSG, "tcp", true, 8443, 8443) 173 | addIngress(this, "core-svc-8443", coreSG, svcSG, "tcp", true, 8443, 8443) 174 | 175 | addIngress(this, "main-main-icmp", mainSG, mainSG, "icmp", true, -1, -1) 176 | addIngress(this, "main-core-icmp", mainSG, coreSG, "icmp", true, -1, -1) 177 | addIngress(this, "main-main-tcp", mainSG, mainSG, "tcp", true) 178 | addIngress(this, "main-core-tcp", mainSG, coreSG, "tcp", true) 179 | addIngress(this, "main-main-udp", mainSG, mainSG, "udp", true) 180 | addIngress(this, "main-core-udp", mainSG, coreSG, "udp", true) 181 | 182 | addIngress(this, "main-livy", mainSG, coreSG, "tcp", true, 8998, 8998) 183 | addIngress(this, "sm-livy-main", mainSG, sageMakerSGId, "tcp", true, 8998, 8998) 184 | addIngress(this, "sm-livy-core", coreSG, sageMakerSGId, "tcp", true, 8998, 8998) 185 | addIngress(this, "main-hive", mainSG, coreSG, "tcp", true, 10000, 10000) 186 | addIngress(this, "main-svc", mainSG, svcSG, "tcp", true) 187 | addIngress(this, "scv-main-9443", svcSG, mainSG, "tcp", true, 9443, 9443) 188 | 189 | addIngress(this, "main-kdc", coreSG, sageMakerSGId, "tcp", false, 88, 88) 190 | addIngress(this, "main-kdcadmin", coreSG, sageMakerSGId, "tcp", false, 749, 749) 191 | addIngress(this, "main-kdcinit", coreSG, sageMakerSGId, "tcp", false, 464, 464) 192 | 193 | 194 | const jobFlowInstanceProfile = new iam.CfnInstanceProfile(this, 195 | "job-flow-profile", { 196 | roles: [jobFlowRoleName], 197 | path: "/", 198 | }) 199 | 200 | 201 | // EMR Cluster - there is no L2 support for this yet 202 | const emrCluster = new emr.CfnCluster(this, "cluster", { 203 | name: cfnParamMap.get("EmrClusterName")?.valueAsString || "", 204 | applications: [ 205 | { 206 | name: "Spark" 207 | }, 208 | { 209 | name: "Hive" 210 | }, 211 | { 212 | name: "Livy" 213 | } 214 | ], 215 | bootstrapActions: [ 216 | { 217 | name: "Dummy bootstrap action", 218 | scriptBootstrapAction: { 219 | args: ["dummy", "parameter"], 220 | path: cdk.Fn.sub("s3://${SampleDataBucket}/installpylibs.sh", { 221 | "SampleDataBucket": sourceBucket.bucketName 222 | }) 223 | } 224 | } 225 | ], 226 | autoScalingRole: "EMR_AutoScaling_DefaultRole", 227 | configurations: [ 228 | { 229 | classification: "livy-conf", 230 | configurationProperties: { "livy.server.session.timeout": "2h" }, 231 | } 232 | ], 233 | ebsRootVolumeSize: 100, 234 | instances: { 235 | coreInstanceGroup: { 236 | instanceCount: 2, // We will manually hack this later to get around the SageMaker UI bug 237 | //instanceCount: cfnParamMap.get("CoreInstanceCount")?.valueAsNumber || 2, 238 | instanceType: cfnParamMap.get("CoreInstanceType")?.valueAsString || "m5.xlarge", 239 | ebsConfiguration: { 240 | ebsBlockDeviceConfigs: [ 241 | { 242 | volumeSpecification: { 243 | sizeInGb: 320, 244 | volumeType: "gp2", 245 | }, 246 | }, 247 | ], 248 | ebsOptimized: true, 249 | }, 250 | market: "ON_DEMAND", 251 | name: "coreNode", 252 | }, 253 | masterInstanceGroup: { 254 | instanceCount: 1, 255 | instanceType: cfnParamMap.get("CoreInstanceType")?.valueAsString || "m5.xlarge", 256 | ebsConfiguration: { 257 | ebsBlockDeviceConfigs: [ 258 | { 259 | volumeSpecification: { 260 | sizeInGb: 320, 261 | volumeType: "gp2", 262 | }, 263 | }, 264 | ], 265 | ebsOptimized: true, 266 | }, 267 | market: "ON_DEMAND", 268 | name: "mainNode", 269 | }, 270 | terminationProtected: false, 271 | ec2SubnetId: cdk.Fn.importValue(props.ec2SubnetIdExportName), 272 | emrManagedMasterSecurityGroup: mainSG.ref, 273 | emrManagedSlaveSecurityGroup: coreSG.ref, 274 | serviceAccessSecurityGroup: svcSG.ref, 275 | }, 276 | jobFlowRole: jobFlowInstanceProfile.ref, 277 | serviceRole: emrServiceRoleName, 278 | logUri: cdk.Fn.sub("s3://${SampleDataBucket}/logging/", { 279 | "SampleDataBucket": sourceBucket.bucketName 280 | }), 281 | releaseLabel: cfnParamMap.get("EmrReleaseVersion")?.valueAsString || "", 282 | visibleToAllUsers: true, 283 | securityConfiguration: emrSecurityConfigurationName, 284 | autoTerminationPolicy: { 285 | idleTimeout: Number(cfnParamMap.get("AutoTerminationIdleTimout")?.valueAsString) || 3600, 286 | }, 287 | steps: [ 288 | { 289 | actionOnFailure: "CONTINUE", 290 | hadoopJarStep: { 291 | args: [cdk.Fn.sub("s3://${SampleDataBucket}/configurekdc.sh", { 292 | "SampleDataBucket": sourceBucket.bucketName 293 | })], 294 | jar: cdk.Fn.sub("s3://${AWS::Region}.elasticmapreduce/libs/script-runner/script-runner.jar", {}), 295 | mainClass: "" 296 | }, 297 | name: "run any bash or java job in spark", 298 | } 299 | ] 300 | }) 301 | 302 | // Use escape hatches to hack the core instance count 303 | emrCluster.addOverride("Properties.Instances.CoreInstanceGroup.InstanceCount", 304 | cfnParamMap.get("CoreInstanceCount")?.valueAsString || "2") 305 | 306 | emrCluster.node.addDependency(jobFlowInstanceProfile) 307 | 308 | // Note: It would be great if the EMR team deprecated the master/slave terminology. 309 | // Trying to avoid using those terms where we have a choice, preferring main and core. 310 | 311 | // TODO: Cleanup bucket function: Since we can't use assets with ServiceCatalog products, 312 | // and the new S3 L2 auto delete functionality deploys a Lambda function. 313 | // We need to add the bucket cleanup function from the original templates. 314 | // Whether to delete the bucket or not should be configurable. 315 | 316 | new cdk.CfnOutput(this, "emr-main-dns-name-output", { 317 | description: "DNS Name of the EMR Master Node", 318 | value: emrCluster.attrMasterPublicDns, 319 | }) 320 | } 321 | } -------------------------------------------------------------------------------- /lib/filtering-application-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from "aws-cdk-lib" 2 | import { RemovalPolicy } from "aws-cdk-lib" 3 | import { Construct } from "constructs" 4 | import * as ecs from "aws-cdk-lib/aws-ecs" 5 | import * as ec2 from "aws-cdk-lib/aws-ec2" 6 | import * as iam from "aws-cdk-lib/aws-iam" 7 | import * as path from "path" 8 | import * as ssm from "aws-cdk-lib/aws-ssm" 9 | import * as logs from "aws-cdk-lib/aws-logs" 10 | import * as s3 from "aws-cdk-lib/aws-s3" 11 | import monitoringConfig from "./monitoring-config.json" 12 | import * as cloudwatch from "aws-cdk-lib/aws-cloudwatch" 13 | import { NagSuppressions } from "cdk-nag" 14 | import * as ecra from "aws-cdk-lib/aws-ecr-assets" 15 | 16 | 17 | 18 | export interface FilteringStackProps extends cdk.StackProps { 19 | /** 20 | * A reference to the VPC where the MLflow application will be deployed 21 | */ 22 | readonly vpc: ec2.IVpc; 23 | readonly trainingBucket: s3.IBucket; 24 | } 25 | 26 | /** 27 | * Creates the filtering application, which runs on ECS and handles inference. 28 | */ 29 | export class FilteringApplicationStack extends cdk.Stack { 30 | constructor(scope: Construct, id: string, props: FilteringStackProps) { 31 | super(scope, id) 32 | const vpc = props.vpc 33 | const trainingBucket = props.trainingBucket 34 | 35 | 36 | 37 | /** 38 | * ============================================================================ 39 | * ========= IAM permission associated to the ad and fitering server ========== 40 | * ============================================================================ 41 | */ 42 | 43 | const filteringApplicationRole = new iam.Role( this, "FilteringApplicationRole", { 44 | assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 45 | managedPolicies: [ 46 | iam.ManagedPolicy.fromAwsManagedPolicyName("CloudWatchAgentServerPolicy") 47 | ] 48 | } 49 | ) 50 | 51 | trainingBucket.grantRead(filteringApplicationRole) 52 | 53 | 54 | 55 | 56 | 57 | /** 58 | * ============================================================================ 59 | * ========= SSM Parameters used by the ad and fitering server ========== 60 | * ============================================================================ 61 | */ 62 | 63 | const monitoringConfigSSMParameter = new ssm.StringParameter(this, "MonitoringConfig", { 64 | description: "The configuration for CloudWatch agent", 65 | parameterName: "/aik/monitoring/config", 66 | stringValue: JSON.stringify(monitoringConfig), 67 | tier: ssm.ParameterTier.STANDARD, 68 | } 69 | ) 70 | monitoringConfigSSMParameter.grantRead(filteringApplicationRole) 71 | 72 | const currentRegion = new ssm.StringParameter(this, "SSMParameterCurrentRegion", { 73 | description: "curent region where the stack is deployed", 74 | stringValue: cdk.Stack.of(this).region, 75 | parameterName: "/aik/current-region", 76 | tier: ssm.ParameterTier.STANDARD 77 | }) 78 | currentRegion.grantRead(filteringApplicationRole) 79 | 80 | filteringApplicationRole.addToPolicy( new iam.PolicyStatement({ 81 | actions: ["ssm:GetParameter"], 82 | resources:[ 83 | `arn:aws:ssm:${cdk.Stack.of(this).region}:${cdk.Stack.of(this).account}:parameter/aik/xgboost/path`, 84 | `arn:aws:ssm:${cdk.Stack.of(this).region}:${cdk.Stack.of(this).account}:parameter/aik/pipelineModelArtifactPath`, 85 | `arn:aws:ssm:${cdk.Stack.of(this).region}:${cdk.Stack.of(this).account}:parameter/aik/pipelineModelArtifactSchemaPath`, 86 | `arn:aws:ssm:${cdk.Stack.of(this).region}:${cdk.Stack.of(this).account}:parameter/aik/inference_data`, 87 | 88 | ] 89 | })) 90 | 91 | NagSuppressions.addResourceSuppressions(filteringApplicationRole, [ 92 | { 93 | id: "AwsSolutions-IAM4", 94 | reason: "CDK controls the policy, which is configured the way we need it to be (grantRead)" 95 | }], true) 96 | NagSuppressions.addResourceSuppressions(filteringApplicationRole, [ 97 | { 98 | id: "AwsSolutions-IAM5", 99 | reason: "CDK controls the policy, which is configured the way we need it to be (grantRead)" 100 | }], true) 101 | /** 102 | * ================================================================ 103 | * ========= TASK DEFINITION FOR AD SERVER and FILTERING ========== 104 | * ================================================================ 105 | */ 106 | 107 | // Create an ECS cluster 108 | const cluster = new ecs.Cluster(this, "Cluster", { 109 | vpc, 110 | clusterName: "advertising-server", 111 | containerInsights: true 112 | }) 113 | 114 | const taskDefinition = new ecs.FargateTaskDefinition(this, "TaskDef", { 115 | family: "adserver-application", 116 | cpu: 4096, 117 | memoryLimitMiB: 8192, 118 | taskRole: filteringApplicationRole 119 | }) 120 | const imageServer = new ecs.AssetImage(path.join(__dirname, "../source/traffic-filtering-app/"), { 121 | file: "Dockerfile.Server", 122 | platform: ecra.Platform.LINUX_AMD64 123 | }) 124 | const imageClient = new ecs.AssetImage(path.join(__dirname, "../source/traffic-filtering-app/"), { 125 | file: "Dockerfile.Client", 126 | platform: ecra.Platform.LINUX_AMD64 127 | }) 128 | 129 | const advertisingServerLogGroup = new logs.LogGroup(this, "advertisingServerLogGroup", { 130 | logGroupName: "/aik/advertising-server", 131 | removalPolicy: RemovalPolicy.DESTROY 132 | }) 133 | const filteringServerLogGroup = new logs.LogGroup(this, "filteringServerLogGroup", { 134 | logGroupName: "/aik/filtering-server", 135 | removalPolicy: RemovalPolicy.DESTROY 136 | }) 137 | const monitoringLogGroup = new logs.LogGroup(this, "monitoringLogGroup", { 138 | logGroupName: "/aik/monitoring", 139 | removalPolicy: RemovalPolicy.DESTROY 140 | }) 141 | 142 | /** 143 | * ================================================================= 144 | * ========= Container defintion for the filtering server ========== 145 | * ================================================================= 146 | */ 147 | 148 | 149 | taskDefinition.addContainer("filtering-server", { 150 | containerName: "filtering-server", 151 | image: imageServer, 152 | memoryLimitMiB: 512, 153 | cpu: 1024, 154 | essential: true, 155 | user: "8000", 156 | portMappings: [ 157 | { containerPort: 8080 }, 158 | ], 159 | secrets: { 160 | AWS_REGION: ecs.Secret.fromSsmParameter(currentRegion) 161 | }, 162 | logging: ecs.LogDriver.awsLogs({ 163 | streamPrefix: "filtering-server", 164 | logGroup: filteringServerLogGroup 165 | }) 166 | }) 167 | 168 | NagSuppressions.addResourceSuppressions(cluster, [ 169 | { 170 | id: "AwsSolutions-ECS7", 171 | reason: "Seems to be a false positive as were are configuring logs above" 172 | } 173 | ], true) 174 | 175 | /** 176 | * ========================================================= 177 | * ========= Container defintion for the adserver ========== 178 | * ========================================================= 179 | */ 180 | taskDefinition.addContainer("advertising-server", { 181 | containerName: "advertising-server", 182 | image: imageClient, 183 | command: [ 184 | "localhost", "1", "500000" 185 | ], 186 | memoryLimitMiB: 512, 187 | cpu: 1024, 188 | essential: false, 189 | secrets: { 190 | AWS_REGION: ecs.Secret.fromSsmParameter(currentRegion) 191 | }, 192 | logging: ecs.LogDriver.awsLogs({ 193 | streamPrefix: "advertising-server", 194 | logGroup: advertisingServerLogGroup 195 | }), 196 | user: "8000" 197 | }) 198 | 199 | /** 200 | * ======================================================= 201 | * ========= Container defintion for monitoring ========== 202 | * ======================================================= 203 | */ 204 | taskDefinition.addContainer("cloudwatch", { 205 | containerName: "cloudwatch", 206 | image: ecs.ContainerImage.fromRegistry("public.ecr.aws/cloudwatch-agent/cloudwatch-agent:latest"), 207 | memoryLimitMiB: 256, 208 | cpu: 256, 209 | essential: true, 210 | secrets: { 211 | CW_CONFIG_CONTENT: ecs.Secret.fromSsmParameter(monitoringConfigSSMParameter) 212 | }, 213 | logging: ecs.LogDriver.awsLogs({ 214 | streamPrefix: "cloudwatch-agent", 215 | logGroup: monitoringLogGroup 216 | }) 217 | }) 218 | 219 | 220 | // Instantiate an Amazon ECS Service 221 | new ecs.FargateService(this, "FilteringService", { 222 | cluster, 223 | taskDefinition, 224 | serviceName: "aik-application", 225 | maxHealthyPercent: 100, 226 | minHealthyPercent: 0, 227 | vpcSubnets: { 228 | subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS 229 | }, 230 | desiredCount: 1, 231 | 232 | }) 233 | 234 | /** 235 | * ====================================== 236 | * ========= Metric definition ========== 237 | * ====================================== 238 | */ 239 | const clientSideLatency = new cloudwatch.Metric({ 240 | metricName: "adserver_client_adserver_latency_ms", 241 | namespace: "aik", 242 | unit: cloudwatch.Unit.MILLISECONDS, 243 | statistic: "p99", 244 | period: cdk.Duration.seconds(10), 245 | label: "Advertising server (client)", 246 | dimensionsMap: { 247 | metric_type: "timing" 248 | }, 249 | 250 | 251 | 252 | }) 253 | 254 | const clientSideMetricMs = new cloudwatch.MathExpression({ 255 | expression: "mLatencyClientSide/1000", 256 | usingMetrics: { 257 | "mLatencyClientSide":clientSideLatency 258 | }, 259 | label: "Advertising server (client)", 260 | color: cloudwatch.Color.ORANGE 261 | }) 262 | const serverSideLatency = new cloudwatch.Metric({ 263 | metricName: "filtering_server_filtering_latency", 264 | namespace: "aik", 265 | statistic: "p99", 266 | period: cdk.Duration.seconds(10), 267 | unit: cloudwatch.Unit.MILLISECONDS, 268 | label: "Filtering server (server)", 269 | dimensionsMap: { 270 | metric_type: "timing" 271 | } 272 | 273 | }) 274 | const serverSideMetricMs = new cloudwatch.MathExpression({ 275 | expression: "mLatencyServerSide/1000", 276 | usingMetrics: { 277 | "mLatencyServerSide":serverSideLatency 278 | }, 279 | label: "Filtering server (server)", 280 | color: cloudwatch.Color.BLUE 281 | }) 282 | const totalCountOver10s = new cloudwatch.Metric({ 283 | metricName: "filtering_server_filtering_count", 284 | namespace: "aik", 285 | statistic: "sum", 286 | period: cdk.Duration.seconds(10), 287 | label: "Total number of bid requests over 10 seconds", 288 | dimensionsMap: { 289 | metric_type: "counter" 290 | } 291 | }) 292 | const throughputOver10s = new cloudwatch.MathExpression({ 293 | expression: "filtering_server_filtering_count_10s/10", 294 | usingMetrics: { 295 | "filtering_server_filtering_count_10s":totalCountOver10s 296 | }, 297 | label: "Average throughput over 10s", 298 | color: cloudwatch.Color.BLUE 299 | }) 300 | const totalTransaction = new cloudwatch.Metric({ 301 | metricName: "filtering_server_filtering_count", 302 | namespace: "aik", 303 | statistic: "sum", 304 | period: cdk.Duration.minutes(30), 305 | label: "Total number of bid requests", 306 | dimensionsMap: { 307 | metric_type: "counter" 308 | } 309 | 310 | }) 311 | const likelihoodToBid = new cloudwatch.Metric({ 312 | metricName: "adserver_client_likelihood_to_bid", 313 | namespace: "aik", 314 | statistic: "average", 315 | period: cdk.Duration.seconds(10), 316 | label: "Likelihood to bid", 317 | dimensionsMap: { 318 | metric_type: "gauge" 319 | } 320 | 321 | }) 322 | const dashboard = new cloudwatch.Dashboard(this, "FilteringDashboard", /* all optional props */ { 323 | dashboardName: "Monitoring-Dashboard", 324 | periodOverride: cloudwatch.PeriodOverride.INHERIT, 325 | }) 326 | 327 | const latencyWidget = new cloudwatch.GraphWidget({ 328 | 329 | view: cloudwatch.GraphWidgetView.TIME_SERIES, 330 | 331 | stacked: false, 332 | liveData: true, 333 | rightYAxis: { 334 | showUnits: false 335 | }, 336 | leftYAxis: { 337 | showUnits: false, 338 | label: "Latency in ms", 339 | }, 340 | title: "Filtering Latency (ms)", 341 | legendPosition: cloudwatch.LegendPosition.BOTTOM, 342 | right:[ 343 | clientSideMetricMs, 344 | serverSideMetricMs 345 | ], 346 | period: cdk.Duration.seconds(10) 347 | }) 348 | 349 | const throughputWidget = new cloudwatch.GraphWidget({ 350 | 351 | view: cloudwatch.GraphWidgetView.TIME_SERIES, 352 | 353 | stacked: false, 354 | liveData: true, 355 | rightYAxis: { 356 | showUnits: false 357 | }, 358 | leftYAxis: { 359 | showUnits: false, 360 | label: "query per second", 361 | }, 362 | title: "Filtering request per second", 363 | legendPosition: cloudwatch.LegendPosition.BOTTOM, 364 | right:[ 365 | throughputOver10s 366 | ], 367 | period: cdk.Duration.seconds(10) 368 | }) 369 | 370 | const bidRequestWidget = new cloudwatch.SingleValueWidget({ 371 | title: "Number of bid requests submitted", 372 | metrics:[totalTransaction] 373 | }) 374 | const likelihoodWidget = new cloudwatch.GraphWidget({ 375 | 376 | view: cloudwatch.GraphWidgetView.TIME_SERIES, 377 | 378 | stacked: false, 379 | liveData: true, 380 | rightYAxis: { 381 | showUnits: false 382 | }, 383 | leftYAxis: { 384 | showUnits: false 385 | }, 386 | title: "Likelihood analysis", 387 | legendPosition: cloudwatch.LegendPosition.BOTTOM, 388 | right:[ 389 | likelihoodToBid 390 | ], 391 | period: cdk.Duration.seconds(10) 392 | }) 393 | const row = new cloudwatch.Row(latencyWidget,throughputWidget,likelihoodWidget,bidRequestWidget) 394 | dashboard.addWidgets(row) 395 | 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /lib/launch-constraint.ts: -------------------------------------------------------------------------------- 1 | import * as iam from "aws-cdk-lib/aws-iam" 2 | import { NagSuppressions } from "cdk-nag" 3 | import { Construct } from "constructs" 4 | 5 | /** 6 | * Create the launch constraint roles that will be used by Service Catalog 7 | * when a user provisions the EMR product. 8 | * 9 | * @param construct 10 | */ 11 | export function createLaunchConstraint(construct: Construct, boundary: iam.ManagedPolicy): iam.Role { 12 | 13 | // Launch constraint - this is the role that is assumed when a Studio user 14 | // clicks on the button to create a new EMR cluster 15 | const constraint = new iam.Role(construct, "launch-constraint", { 16 | assumedBy: new iam.ServicePrincipal("servicecatalog.amazonaws.com"), 17 | managedPolicies: [ 18 | iam.ManagedPolicy.fromAwsManagedPolicyName("AWSServiceCatalogAdminFullAccess"), 19 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonEMRFullAccessPolicy_v2"), 20 | ], 21 | permissionsBoundary: boundary 22 | }) 23 | 24 | constraint.addToPolicy(new iam.PolicyStatement({ 25 | effect: iam.Effect.ALLOW, 26 | actions: ["s3:*"], 27 | resources: ["*"] 28 | })) 29 | 30 | constraint.addToPolicy(new iam.PolicyStatement({ 31 | effect: iam.Effect.ALLOW, 32 | actions: ["sns:Publish"], 33 | resources: ["*"] 34 | })) 35 | 36 | constraint.addToPolicy(new iam.PolicyStatement({ 37 | effect: iam.Effect.ALLOW, 38 | actions: [ 39 | "ec2:CreateSecurityGroup", 40 | "ec2:RevokeSecurityGroupEgress", 41 | "ec2:DeleteSecurityGroup", 42 | "ec2:createTags", 43 | "ec2:AuthorizeSecurityGroupEgress", 44 | "ec2:AuthorizeSecurityGroupIngress", 45 | "ec2:RevokeSecurityGroupIngress" 46 | ], 47 | resources: ["*"] 48 | })) 49 | 50 | constraint.addToPolicy(new iam.PolicyStatement({ 51 | effect: iam.Effect.ALLOW, 52 | actions: [ 53 | "lambda:CreateFunction", 54 | "lambda:InvokeFunction", 55 | "lambda:DeleteFunction", 56 | "lambda:GetFunction"], 57 | resources: ["*"] // TODO - Limit with a tag? 58 | })) 59 | 60 | constraint.addToPolicy(new iam.PolicyStatement({ 61 | effect: iam.Effect.ALLOW, 62 | actions: ["elasticmapreduce:RunJobFlow"], 63 | resources: ["*"] 64 | })) 65 | 66 | constraint.addToPolicy(new iam.PolicyStatement({ 67 | effect: iam.Effect.ALLOW, 68 | actions: [ 69 | "iam:CreateRole", 70 | "iam:DetachRolePolicy", 71 | "iam:AttachRolePolicy", 72 | "iam:DeleteRolePolicy", 73 | "iam:DeleteRole", 74 | "iam:PutRolePolicy", 75 | "iam:PassRole", 76 | "iam:CreateInstanceProfile", 77 | "iam:RemoveRoleFromInstanceProfile", 78 | "iam:DeleteInstanceProfile", 79 | "iam:AddRoleToInstanceProfile" 80 | ], 81 | resources: ["*"] 82 | })) 83 | 84 | constraint.addToPolicy(new iam.PolicyStatement({ 85 | effect: iam.Effect.ALLOW, 86 | actions: [ 87 | "cloudformation:CreateStack", 88 | "cloudformation:DeleteStack", 89 | "cloudformation:DescribeStackEvents", 90 | "cloudformation:DescribeStacks", 91 | "cloudformation:GetTemplateSummary", 92 | "cloudformation:SetStackPolicy", 93 | "cloudformation:ValidateTemplate", 94 | "cloudformation:UpdateStack", 95 | ], 96 | resources: ["*"] 97 | })) 98 | 99 | NagSuppressions.addResourceSuppressions(constraint, [ 100 | { 101 | id: "AwsSolutions-IAM4", 102 | reason: "Provisioning this product requires full access and the managed policy is likely better than '*'" 103 | }, 104 | { 105 | id: "AwsSolutions-IAM5", 106 | reason: "The resources have not been created yet so we can't refer to them here" 107 | } 108 | ], true) 109 | 110 | return constraint 111 | } -------------------------------------------------------------------------------- /lib/monitoring-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "agent": { 3 | "metrics_collection_interval": 5, 4 | "debug": true 5 | }, 6 | "metrics": { 7 | "force_flush_interval": 10, 8 | "namespace": "aik", 9 | "metrics_collected": { 10 | "statsd": { 11 | "metrics_collection_interval":5, 12 | "metrics_aggregation_interval":5 13 | } 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /lib/permission-boundary.ts: -------------------------------------------------------------------------------- 1 | import { Aws, Names } from "aws-cdk-lib" 2 | import * as iam from "aws-cdk-lib/aws-iam" 3 | import { NagSuppressions } from "cdk-nag" 4 | import { Construct } from "constructs" 5 | 6 | /** 7 | * Create a permissions boundary that must be applied to all IAM roles and users created by this deployment and by Service Catalog. 8 | * This permissions boundary is used to avoid unintended permission escalation. 9 | * 10 | * @param construct The parent construct that is the scope of the created permissions boundary. 11 | */ 12 | export function createPermissionsBoundary(construct: Construct): iam.ManagedPolicy { 13 | const boundaryPolicyName = `PermissionBoundary${Names.uniqueId(construct)}` 14 | const boundaryPolicyArn = `arn:aws:iam::${Aws.ACCOUNT_ID}:policy/${boundaryPolicyName}` 15 | 16 | const boundary = new iam.ManagedPolicy(construct, "global-permissions-boundary", { 17 | managedPolicyName: boundaryPolicyName, 18 | statements: [ 19 | // disable access to cost and billing 20 | new iam.PolicyStatement({ 21 | effect: iam.Effect.DENY, 22 | actions: [ 23 | "account:*", 24 | "aws-portal:*", 25 | "savingsplans:*", 26 | "cur:*", 27 | "ce:*"], 28 | resources: ["*"], 29 | }), 30 | 31 | // disable permissions boundary policy edit 32 | new iam.PolicyStatement({ 33 | effect: iam.Effect.DENY, 34 | actions: [ 35 | "iam:DeletePolicy", 36 | "iam:DeletePolicyVersion", 37 | "iam:CreatePolicyVersion", 38 | "iam:SetDefaultPolicyVersion"], 39 | resources: [boundaryPolicyArn], 40 | }), 41 | 42 | // disable removal of permissions boundary from a user/role 43 | new iam.PolicyStatement({ 44 | effect: iam.Effect.DENY, 45 | actions: [ 46 | "iam:DeleteUserPermissionsBoundary", 47 | "iam:DeleteRolePermissionsBoundary"], 48 | resources: [ 49 | `arn:aws:iam::${Aws.ACCOUNT_ID}:user/*`, 50 | `arn:aws:iam::${Aws.ACCOUNT_ID}:role/*` 51 | ], 52 | conditions: { 53 | "StringEquals": { 54 | "iam:PermissionsBoundary": boundaryPolicyArn 55 | } 56 | } 57 | }), 58 | 59 | // disable assigning a different permissions boundary 60 | new iam.PolicyStatement({ 61 | effect: iam.Effect.DENY, 62 | actions: [ 63 | "iam:PutUserPermissionsBoundary", 64 | "iam:PutRolePermissionsBoundary"], 65 | resources: [ 66 | `arn:aws:iam::${Aws.ACCOUNT_ID}:user/*`, 67 | `arn:aws:iam::${Aws.ACCOUNT_ID}:role/*` 68 | ], 69 | conditions: { 70 | "StringNotEquals": { 71 | "iam:PermissionsBoundary": boundaryPolicyArn 72 | } 73 | } 74 | }), 75 | 76 | // disable creating users/roles without the required permissions boundary 77 | new iam.PolicyStatement({ 78 | effect: iam.Effect.DENY, 79 | actions: [ 80 | "iam:CreateUser", 81 | "iam:CreateRole"], 82 | resources: [ 83 | `arn:aws:iam::${Aws.ACCOUNT_ID}:user/*`, 84 | `arn:aws:iam::${Aws.ACCOUNT_ID}:role/*` 85 | ], 86 | conditions: { 87 | "StringNotEquals": { 88 | "iam:PermissionsBoundary": boundaryPolicyArn 89 | } 90 | } 91 | }), 92 | 93 | // only allow passing this account's roles to select services 94 | new iam.PolicyStatement({ 95 | effect: iam.Effect.DENY, 96 | actions: ["iam:PassRole"], 97 | resources: [`arn:aws:iam::${Aws.ACCOUNT_ID}:role/*`], 98 | conditions: { 99 | "StringNotEquals": { 100 | "iam:PassedToService": [ 101 | "ec2.amazonaws.com", 102 | "elasticmapreduce.amazonaws.com", 103 | "application-autoscaling.amazonaws.com", 104 | "servicecatalog.amazonaws.com", 105 | "sagemaker.amazonaws.com" 106 | ] 107 | } 108 | } 109 | }), 110 | 111 | // default - full access 112 | new iam.PolicyStatement({ 113 | effect: iam.Effect.ALLOW, 114 | actions: ["*"], 115 | resources: ["*"], 116 | }) 117 | ], 118 | }) 119 | 120 | NagSuppressions.addResourceSuppressions(boundary, [ 121 | { 122 | id: "AwsSolutions-IAM5", 123 | reason: "We need to grant the role passing permission, but we don't know the role ARNs at this point since it's a permission boundary" 124 | } 125 | ], false) 126 | 127 | return boundary 128 | } 129 | -------------------------------------------------------------------------------- /lib/product-construct.ts: -------------------------------------------------------------------------------- 1 | import { Construct } from "constructs" 2 | import * as cdk from "aws-cdk-lib" 3 | import * as sc from "@aws-cdk/aws-servicecatalog-alpha" 4 | import { EmrStack, EmrStackProps } from "./emr-product-stack" 5 | import * as ec2 from "aws-cdk-lib/aws-ec2" 6 | import * as s3 from "aws-cdk-lib/aws-s3" 7 | import * as sagemaker from "aws-cdk-lib/aws-sagemaker" 8 | import { addIngress } from "./add-ingress" 9 | import { IVpc } from "aws-cdk-lib/aws-ec2" 10 | import { BucketDeployment, Source } from "aws-cdk-lib/aws-s3-deployment" 11 | import { CfnOutput } from "aws-cdk-lib" 12 | import { NagSuppressions } from "cdk-nag" 13 | import { createLaunchConstraint } from "./launch-constraint" 14 | import { createPermissionsBoundary } from "./permission-boundary" 15 | import { createSagemakerExecutionRole } from "./sagemaker-execution" 16 | import * as ssm from "aws-cdk-lib/aws-ssm" 17 | import * as iam from "aws-cdk-lib/aws-iam" 18 | import * as cr from "aws-cdk-lib/custom-resources" 19 | import * as lambda from "aws-cdk-lib/aws-lambda" 20 | import * as emr from "aws-cdk-lib/aws-emr" 21 | import * as kms from "aws-cdk-lib/aws-kms" 22 | 23 | 24 | /** 25 | * This construct creates a VPC, a SageMaker domain, and a Service Catalog product. 26 | * 27 | * The service catalog product template is created in `emr-product-stack.ts`. 28 | */ 29 | export class ProductConstruct extends Construct { 30 | 31 | public vpc: IVpc 32 | public trainingBucket: s3.IBucket 33 | 34 | constructor(scope: Construct, id: string) { 35 | super(scope, id) 36 | 37 | const self = this 38 | 39 | // the global permissions boundary applied to all user/roles that will be created 40 | const permissionsBoundary = createPermissionsBoundary(this) 41 | 42 | // Create the VPC 43 | this.vpc = new ec2.Vpc(this, "vpc", {}) 44 | this.vpc.addFlowLog("flow-log-cw", { 45 | trafficType: ec2.FlowLogTrafficType.REJECT 46 | }) 47 | 48 | NagSuppressions.addResourceSuppressions(this.vpc, [ 49 | {id: "AwsSolutions-EC23", reason: "cdk-nag can't read Fn::GetAtt"} 50 | ], true) 51 | 52 | // Create a bucket to hold access logs for the other buckets 53 | const logBucket = new s3.Bucket(this, "access-logs", { 54 | encryption: s3.BucketEncryption.S3_MANAGED, 55 | autoDeleteObjects: true, 56 | objectOwnership: s3.ObjectOwnership.BUCKET_OWNER_ENFORCED, 57 | removalPolicy: cdk.RemovalPolicy.DESTROY, 58 | blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL, 59 | enforceSSL: true, 60 | versioned: true 61 | }) 62 | 63 | // Suppress nag error since this is the log bucket itself. 64 | NagSuppressions.addResourceSuppressions(logBucket, [ 65 | { id: "AwsSolutions-S1", reason: "This is the log bucket" }, 66 | ]) 67 | 68 | // Create a bucket to be used to hold training data and trained models 69 | this.trainingBucket = new s3.Bucket(this, "training-models", { 70 | encryption: s3.BucketEncryption.S3_MANAGED, 71 | autoDeleteObjects: true, 72 | removalPolicy: cdk.RemovalPolicy.DESTROY, 73 | serverAccessLogsBucket: logBucket, 74 | serverAccessLogsPrefix: "training-models", 75 | blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL, 76 | enforceSSL: true, 77 | versioned: true 78 | }) 79 | 80 | new ssm.StringParameter(this, "Parameter", { 81 | allowedPattern: ".*", 82 | description: "The bucket which holds data and ML models", 83 | parameterName: "/aik/data-bucket", 84 | stringValue: this.trainingBucket.bucketName, 85 | tier: ssm.ParameterTier.ADVANCED, 86 | }) 87 | 88 | 89 | // Add VPC endpoints 90 | 91 | const addEndpoint = function (s: string) { 92 | self.vpc.addInterfaceEndpoint("ep-" + s.split(".").join("-"), { 93 | privateDnsEnabled: true, 94 | service: new ec2.InterfaceVpcEndpointService( 95 | cdk.Fn.sub("com.amazonaws.${AWS::Region}." + s)) 96 | }) 97 | } 98 | 99 | addEndpoint("sagemaker.api") 100 | addEndpoint("sagemaker.runtime") 101 | addEndpoint("sts") 102 | addEndpoint("monitoring") 103 | addEndpoint("logs") 104 | addEndpoint("ecr.dkr") 105 | addEndpoint("ecr.api") 106 | 107 | const ec2SubnetIdExportName = "subnet-id" 108 | const ec2VpcIdExportName = "vpc-id" 109 | 110 | new cdk.CfnOutput(this, "vpcidout", { 111 | description: "VPC Id output export", 112 | value: this.vpc.vpcId, 113 | exportName: ec2VpcIdExportName, 114 | }) 115 | 116 | new cdk.CfnOutput(this, "subnetidout", { 117 | description: "Subnet Id output export", 118 | value: this.vpc.privateSubnets[0].subnetId, 119 | exportName: ec2SubnetIdExportName 120 | }) 121 | 122 | const sageMakerSGIdExportName = "sg-id" 123 | 124 | // Create the S3 bucket to hold the EMR bootstrap scripts 125 | const emrBootstrapBucket = new s3.Bucket(this, "emr-bootstrap", { 126 | encryption: s3.BucketEncryption.S3_MANAGED, 127 | removalPolicy: cdk.RemovalPolicy.DESTROY, 128 | autoDeleteObjects: true, 129 | serverAccessLogsPrefix: "emr-bootstrap", 130 | serverAccessLogsBucket: logBucket, 131 | blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL, 132 | enforceSSL: true, 133 | versioned: true 134 | }) 135 | 136 | // Deploy the EMR shell scripts to the bucket 137 | new BucketDeployment(this, "emr-boostrap-deploy", { 138 | destinationBucket: emrBootstrapBucket, 139 | sources: [Source.asset("source/emr-bootstrap")] 140 | }) 141 | 142 | // Service catalog portfolio 143 | const portfolio = new sc.Portfolio(this, "sagemaker-emr-portfolio", { 144 | displayName: "SageMaker EMR Product Portfolio", 145 | providerName: "AWS", 146 | }) 147 | 148 | const emrBucketExportName = "sagemaker-emr-bootstrap-bucket" 149 | 150 | // Export the name of the bucket 151 | new CfnOutput(this, "emr-bootstrap-bucket-out", { 152 | value: emrBootstrapBucket.bucketName, 153 | exportName: emrBucketExportName, 154 | description: "The name of the bucket where we put EMR bootstrap scripts" 155 | }) 156 | 157 | 158 | 159 | const jobFlowRole = new iam.Role(this, "job-flow", { 160 | assumedBy: new iam.ServicePrincipal("ec2.amazonaws.com"), 161 | managedPolicies: [ 162 | iam.ManagedPolicy.fromAwsManagedPolicyName( 163 | "service-role/AmazonElasticMapReduceforEC2Role"), 164 | iam.ManagedPolicy.fromAwsManagedPolicyName( 165 | "AmazonSSMManagedInstanceCore" 166 | ) 167 | ], 168 | permissionsBoundary: permissionsBoundary 169 | }) 170 | jobFlowRole.addToPolicy(new iam.PolicyStatement({ 171 | effect: iam.Effect.ALLOW, 172 | actions: [ 173 | "ssm:GetParameters", 174 | "ssm:GetParameter", 175 | ], 176 | resources: [ 177 | cdk.Fn.sub("arn:aws:ssm:${AWS::Region}:${AWS::AccountId}:*") 178 | ] 179 | })) 180 | NagSuppressions.addResourceSuppressions(jobFlowRole, [ 181 | { 182 | id: "AwsSolutions-IAM4", 183 | reason: "We don't know the name of the resource that will be created later by the cloud formation stack so we have to use managed policy" 184 | }], true) 185 | NagSuppressions.addResourceSuppressions(jobFlowRole, [ 186 | { 187 | id: "AwsSolutions-IAM5", 188 | reason: "The name of the SSM parameters are not know" 189 | }], true) 190 | 191 | emrBootstrapBucket.grantRead(jobFlowRole) 192 | 193 | const emrEc2RoleNameExportName = "emr-ec2-role-arn" 194 | new cdk.CfnOutput(this, "emr-ec2-role", { 195 | description: "The ec2 role of the EMR cluster used service catalog", 196 | value: jobFlowRole.roleName, 197 | exportName: emrEc2RoleNameExportName 198 | }) 199 | 200 | 201 | const serviceRole = new iam.CfnRole(this, "service-role", { 202 | assumeRolePolicyDocument: { 203 | Statement: [{ 204 | Action: [ 205 | "sts:AssumeRole" 206 | ], 207 | Effect: "Allow", 208 | Principal: { 209 | Service: [ 210 | "elasticmapreduce.amazonaws.com" 211 | ] 212 | } 213 | }], 214 | Version: "2012-10-17", 215 | }, 216 | managedPolicyArns: [ 217 | "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole" 218 | ], 219 | path: "/", 220 | permissionsBoundary: permissionsBoundary.managedPolicyArn 221 | }) 222 | NagSuppressions.addResourceSuppressions(serviceRole, [ 223 | { 224 | id: "AwsSolutions-IAM4", 225 | reason: "We don't know the name of the resource that will be created later by the cloud formation stack so we have to use managed policy" 226 | }], true) 227 | const emrServiceRoleNameExportName = "emr-service-role-arn" 228 | new cdk.CfnOutput(this, "emr-service-role", { 229 | description: "The service role of the EMR cluster used service catalog", 230 | value: serviceRole.attrArn, 231 | exportName: emrServiceRoleNameExportName 232 | }) 233 | 234 | const emrSecurityConfigurationName = "aik-emr-security-configuration" 235 | const ebsEncryptionKey = new kms.Key(this,'ebs-ebcryption-key', { 236 | alias: "aik/ebs", 237 | removalPolicy: cdk.RemovalPolicy.DESTROY, 238 | pendingWindow: cdk.Duration.days(7), 239 | enabled: true, 240 | enableKeyRotation: true, 241 | admins: [new iam.AccountPrincipal(cdk.Stack.of(this).account)], 242 | policy: new iam.PolicyDocument({ 243 | statements:[ 244 | new iam.PolicyStatement({ 245 | principals: [ 246 | new iam.ArnPrincipal(jobFlowRole.roleArn), 247 | new iam.ArnPrincipal(serviceRole.attrArn) 248 | ], 249 | actions: [ 250 | "kms:Encrypt", 251 | "kms:Decrypt", 252 | "kms:ReEncrypt*", 253 | "kms:GenerateDataKey*", 254 | "kms:DescribeKey" 255 | ], 256 | resources: ["*"] 257 | }), 258 | new iam.PolicyStatement({ 259 | principals: [ 260 | new iam.ArnPrincipal(jobFlowRole.roleArn), 261 | new iam.ArnPrincipal(serviceRole.attrArn) 262 | ], 263 | actions: [ 264 | "kms:CreateGrant", 265 | "kms:ListGrants", 266 | "kms:RevokeGrant" 267 | ], 268 | resources: ["*"], 269 | conditions: { 270 | "Bool": { 271 | "kms:GrantIsForAWSResource": true 272 | } 273 | } 274 | }) 275 | ] 276 | }) 277 | }) 278 | 279 | 280 | 281 | 282 | const securityConfiguration = { 283 | "EncryptionConfiguration": { 284 | "EnableInTransitEncryption": false, 285 | "EnableAtRestEncryption": true, 286 | "AtRestEncryptionConfiguration": { 287 | "LocalDiskEncryptionConfiguration": { 288 | "EnableEbsEncryption" : true, 289 | "EncryptionKeyProviderType": "AwsKms", 290 | "AwsKmsKey": ebsEncryptionKey.keyArn 291 | } 292 | } 293 | } 294 | } 295 | new emr.CfnSecurityConfiguration(this, 'MyCfnSecurityConfiguration', { 296 | securityConfiguration: securityConfiguration, 297 | 298 | // the properties below are optional 299 | name: emrSecurityConfigurationName, 300 | }); 301 | 302 | const emrSecurityConfigurationNameExportName = "emr-security-configuration-name" 303 | new cdk.CfnOutput(this, "emr-security-configuration-name", { 304 | description: "Name of the security configuration for the EMR cluster", 305 | value: emrSecurityConfigurationName, 306 | exportName: emrSecurityConfigurationNameExportName 307 | }) 308 | 309 | // Service catalog product 310 | const emrStackProps: EmrStackProps = { 311 | ec2SubnetIdExportName, 312 | ec2VpcIdExportName, 313 | sageMakerSGIdExportName, 314 | emrBucketExportName, 315 | emrEc2RoleNameExportName, 316 | emrServiceRoleNameExportName, 317 | emrSecurityConfigurationNameExportName 318 | } 319 | 320 | // TODO: EMR cluster auto shutdown after inactivity 321 | 322 | const emrStack = new EmrStack(this, "EmrProduct", emrStackProps) 323 | const template = sc.CloudFormationTemplate.fromProductStack(emrStack) 324 | const product = new sc.CloudFormationProduct(this, "sagemaker-emr-product", { 325 | productName: "SageMaker EMR Product", 326 | owner: "AWS", 327 | productVersions: [ 328 | { 329 | productVersionName: "v1", 330 | cloudFormationTemplate: template, 331 | }, 332 | ], 333 | }) 334 | 335 | // This tag is what makes the template visible from SageMaker Studio 336 | cdk.Tags.of(product).add("sagemaker:studio-visibility:emr", "true") 337 | 338 | // Associate the product with the portfolio 339 | portfolio.addProduct(product) 340 | 341 | // Create the constraint role 342 | const constraint = createLaunchConstraint(this, permissionsBoundary) 343 | 344 | portfolio.setLaunchRole(product, constraint) 345 | 346 | const vpceSG = new ec2.SecurityGroup(this, id + "vpc-ep-sg", { 347 | description: "Allow TLS for VPC endpoint", 348 | vpc: this.vpc, 349 | }) 350 | 351 | const smSG = new ec2.SecurityGroup(this, id + "sm-sg", { 352 | vpc: this.vpc 353 | }) 354 | 355 | addIngress(this, "sm-sm", smSG, smSG, "-1", false) 356 | addIngress(this, "sm-smtcp", smSG, smSG, "tcp", false) 357 | addIngress(this, "sm-vpce", vpceSG, smSG, "-1", false) 358 | 359 | new cdk.CfnOutput(this, id + "smsgidout", { 360 | description: "SageMaker security group id output export", 361 | value: smSG.securityGroupId, 362 | exportName: sageMakerSGIdExportName 363 | }) 364 | 365 | const smExecRole = createSagemakerExecutionRole(this, this.trainingBucket,permissionsBoundary) 366 | 367 | // Allow SageMaker to read and write SystemsManager ParameterStore 368 | smExecRole.addToPolicy(new iam.PolicyStatement({ 369 | effect: iam.Effect.ALLOW, 370 | actions: [ 371 | "ssm:PutParameter", 372 | "ssm:GetParameters", 373 | "ssm:GetParameter", 374 | ], 375 | resources: [ 376 | cdk.Fn.sub("arn:aws:ssm:${AWS::Region}:${AWS::AccountId}:*") 377 | ] 378 | })) 379 | 380 | // Allow SageMaker to access the training data bucket 381 | this.trainingBucket.grantReadWrite(smExecRole) 382 | 383 | // Principal association 384 | portfolio.giveAccessToRole(smExecRole) 385 | 386 | // SageMaker domain 387 | const domain = new sagemaker.CfnDomain(this, "domain", { 388 | appNetworkAccessType: "VpcOnly", 389 | authMode: "IAM", 390 | defaultUserSettings: { 391 | executionRole: smExecRole.roleArn, 392 | securityGroups: [smSG.securityGroupId], 393 | }, 394 | domainName: "CDKSample", 395 | vpcId: this.vpc.vpcId, 396 | subnetIds: [this.vpc.privateSubnets[0].subnetId] 397 | }) 398 | 399 | new cdk.CfnOutput(this, "efs-out", { 400 | value: domain.attrHomeEfsFileSystemId, 401 | description: "File system id to be deleted before the stack can be deleted" 402 | }) 403 | 404 | // Custom resource to clean up automatically created security groups 405 | 406 | // Create the lambda handler for the resource 407 | const onEvent = new lambda.Function(this, "sagemaker-sg-cleanup-fn", { 408 | runtime: lambda.Runtime.NODEJS_16_X, 409 | code: lambda.Code.fromAsset("./build/source/sagemaker-sg-cleanup"), 410 | handler: "handler.handler", 411 | memorySize: 1536, 412 | timeout: cdk.Duration.minutes(5), 413 | description: "SageMaker security group cleanup", 414 | environment: { 415 | "VPC_ID": this.vpc.vpcId, 416 | "EFS_ID": domain.attrHomeEfsFileSystemId 417 | } 418 | }) 419 | 420 | NagSuppressions.addResourceSuppressions(onEvent, [ 421 | { 422 | id: "AwsSolutions-IAM4", 423 | reason: "CDK controls the policy, which is configured the way we need it to be" 424 | }, 425 | { 426 | id: "AwsSolutions-L1", 427 | reason: "This Lambda function is only using when destroying the stack" 428 | }, 429 | ], true) 430 | 431 | // Grant the lambda permissions to describe and delete vpc, efs 432 | onEvent.addToRolePolicy(new iam.PolicyStatement({ 433 | effect: iam.Effect.ALLOW, 434 | resources: ["*"], 435 | actions: [ 436 | "ec2:DescribeSecurityGroups", 437 | "ec2:DescribeSecurityGroupRules", 438 | "ec2:DeleteSecurityGroup", 439 | "ec2:DeleteNetworkInterface", 440 | "ec2:RevokeSecurityGroupIngress", 441 | "ec2:RevokeSecurityGroupEgress", 442 | "elasticfilesystem:DescribeMountTargets", 443 | "elasticfilesystem:DeleteMountTarget", 444 | "elasticfilesystem:DeleteFileSystem", 445 | ] 446 | })) 447 | 448 | NagSuppressions.addResourceSuppressions(onEvent, [ 449 | { 450 | id: "AwsSolutions-IAM5", 451 | reason: "We need a '*' here because we are deleting resources that we did not create, so we don't control the names" 452 | }], true) 453 | 454 | 455 | // Create a provider 456 | const provider = new cr.Provider(this, "sagemaker-sg-cleanup-pr", { 457 | onEventHandler: onEvent 458 | }) 459 | 460 | // Create the custom resource 461 | const customResource = new cdk.CustomResource(this, "sagemaker-sg-cleanup-cr", { 462 | serviceToken: provider.serviceToken 463 | }) 464 | 465 | // Add a dependency on the domain so the lambda deletes first 466 | customResource.node.addDependency(domain) 467 | 468 | // Sagemaker user profile 469 | new sagemaker.CfnUserProfile(this, "user-profile", { 470 | domainId: domain.attrDomainId, 471 | userProfileName: "cdk-studio-user", 472 | userSettings: { 473 | executionRole: smExecRole.roleArn 474 | } 475 | }) 476 | 477 | } 478 | } 479 | -------------------------------------------------------------------------------- /lib/sagemaker-emr-stack.ts: -------------------------------------------------------------------------------- 1 | import {Stack, StackProps} from "aws-cdk-lib" 2 | import {IVpc} from "aws-cdk-lib/aws-ec2" 3 | import {Construct} from "constructs" 4 | import {ProductConstruct} from "./product-construct" 5 | import {IBucket} from "aws-cdk-lib/aws-s3"; 6 | 7 | /** 8 | * This stack deploys a VPC, a SageMaker domain, and a ServiceCatalog product. 9 | * 10 | * The Service Catalog product allows SageMaker studio users to deploy an EMR 11 | * cluster with sample data to demonstrate interacting with the cluster from 12 | * SageMaker studio. 13 | */ 14 | export class SagemakerEmrStack extends Stack { 15 | 16 | public vpc: IVpc 17 | public trainingBucket: IBucket 18 | 19 | constructor(scope: Construct, id: string, props?: StackProps) { 20 | super(scope, id, props) 21 | 22 | const product = new ProductConstruct(this, "sagemaker-emr-product") 23 | 24 | this.vpc = product.vpc 25 | this.trainingBucket = product.trainingBucket 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /lib/sagemaker-execution.ts: -------------------------------------------------------------------------------- 1 | import * as iam from "aws-cdk-lib/aws-iam" 2 | import * as s3 from "aws-cdk-lib/aws-s3" 3 | import { NagSuppressions } from "cdk-nag" 4 | import { Construct } from "constructs" 5 | import * as cdk from "aws-cdk-lib" 6 | 7 | /** 8 | * Create the role that SageMaker will use for user actions. 9 | * 10 | * @param construct 11 | * @returns 12 | */ 13 | export function createSagemakerExecutionRole(construct:Construct, 14 | trainingBucket: s3.IBucket, 15 | boundary: iam.ManagedPolicy): iam.Role { 16 | 17 | const smExecRole = new iam.Role(construct, "sm-exec", { 18 | assumedBy: new iam.ServicePrincipal("sagemaker.amazonaws.com"), 19 | managedPolicies: [ 20 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonSageMakerFullAccess"), 21 | ], 22 | permissionsBoundary: boundary 23 | }) 24 | 25 | trainingBucket.grantReadWrite(smExecRole) 26 | 27 | smExecRole.addToPolicy(new iam.PolicyStatement({ 28 | effect: iam.Effect.ALLOW, 29 | actions: [ 30 | "ssm:PutParameter", 31 | "ssm:GetParameters", 32 | "ssm:GetParameter", 33 | ], 34 | resources: [ 35 | cdk.Fn.sub("arn:aws:ssm:${AWS::Region}:${AWS::AccountId}:*") 36 | ] 37 | })) 38 | 39 | smExecRole.addToPolicy(new iam.PolicyStatement({ 40 | effect: iam.Effect.ALLOW, 41 | actions: [ 42 | "elasticmapreduce:ListInstances", 43 | "elasticmapreduce:DescribeCluster", 44 | "elasticmapreduce:DescribeSecurityConfiguration", 45 | "elasticmapreduce:CreatePersistentAppUI", 46 | "elasticmapreduce:DescribePersistentAppUI", 47 | "elasticmapreduce:GetPersistentAppUIPresignedURL", 48 | "elasticmapreduce:GetOnClusterAppUIPresignedURL", 49 | "elasticmapreduce:ListClusters", 50 | "iam:CreateServiceLinkedRole", 51 | "iam:GetRole", 52 | ], 53 | resources: ["*"] 54 | })) 55 | 56 | smExecRole.addToPolicy(new iam.PolicyStatement({ 57 | effect: iam.Effect.ALLOW, 58 | actions: ["iam:PassRole"], 59 | resources: ["*"], 60 | conditions: { 61 | "StringEquals": {"iam:PassedToService": "sagemaker.amazonaws.com"} 62 | } 63 | })) 64 | 65 | smExecRole.addToPolicy(new iam.PolicyStatement({ 66 | effect: iam.Effect.ALLOW, 67 | actions: [ 68 | "elasticmapreduce:DescribeCluster", 69 | "elasticmapreduce:ListInstanceGroups"], 70 | resources: [cdk.Fn.sub("arn:${AWS::Partition}:elasticmapreduce:*:*:cluster/*")] 71 | })) 72 | 73 | smExecRole.addToPolicy(new iam.PolicyStatement({ 74 | effect: iam.Effect.ALLOW, 75 | actions: [ 76 | "elasticmapreduce:ListClusters"], 77 | resources: ["*"] 78 | })) 79 | 80 | NagSuppressions.addResourceSuppressions(smExecRole, [ 81 | { 82 | id: "AwsSolutions-IAM4", 83 | reason: "Domain users require full access and the managed policy is likely better than '*'" 84 | }, 85 | { 86 | id: "AwsSolutions-IAM5", 87 | reason: "The resources have not been created yet so we can't refer to them here" 88 | } 89 | ], true) 90 | 91 | return smExecRole 92 | } -------------------------------------------------------------------------------- /lib/suppress-nag.ts: -------------------------------------------------------------------------------- 1 | import { Stack } from "aws-cdk-lib" 2 | import { NagSuppressions } from "cdk-nag" 3 | 4 | /** 5 | * Suppress cdk-nag warnings. 6 | * 7 | * @param stack 8 | */ 9 | export function suppressNag(stack:Stack) { 10 | 11 | NagSuppressions.addResourceSuppressionsByPath(stack, 12 | "/sagemaker-emr/Custom::CDKBucketDeployment8693BB64968944B69AAFB0CC9EB8756C/ServiceRole/DefaultPolicy/Resource", 13 | [ 14 | { 15 | id: "AwsSolutions-IAM4", 16 | reason: "Controlled by CDK L2 Construct" 17 | }, 18 | { 19 | id: "AwsSolutions-IAM5", 20 | reason: "Controlled by CDK L2 Construct" 21 | } 22 | ]) 23 | 24 | NagSuppressions.addResourceSuppressionsByPath(stack, 25 | "/sagemaker-emr/Custom::CDKBucketDeployment8693BB64968944B69AAFB0CC9EB8756C/ServiceRole/Resource", 26 | [ 27 | { 28 | id: "AwsSolutions-IAM4", 29 | reason: "Controlled by CDK L2 Construct" 30 | }, 31 | { 32 | id: "AwsSolutions-IAM5", 33 | reason: "Controlled by CDK L2 Construct" 34 | } 35 | ]) 36 | 37 | NagSuppressions.addResourceSuppressionsByPath(stack, 38 | "/sagemaker-emr/sagemaker-emr-product/sagemaker-sg-cleanup-pr/framework-onEvent/ServiceRole/Resource", 39 | [ 40 | { 41 | id: "AwsSolutions-IAM4", 42 | reason: "Controlled by CDK L2 Construct" 43 | }, 44 | { 45 | id: "AwsSolutions-IAM5", 46 | reason: "Controlled by CDK L2 Construct" 47 | } 48 | ]) 49 | 50 | NagSuppressions.addResourceSuppressionsByPath(stack, 51 | "/sagemaker-emr/sagemaker-emr-product/sagemaker-sg-cleanup-pr/framework-onEvent/ServiceRole/DefaultPolicy/Resource", 52 | [ 53 | { 54 | id: "AwsSolutions-IAM5", 55 | reason: "Controlled by CDK L2 Construct", 56 | appliesTo: ["Resource:::*"] 57 | } 58 | ]) 59 | 60 | NagSuppressions.addResourceSuppressionsByPath(stack, 61 | "/sagemaker-emr/sagemaker-emr-product/sagemaker-sg-cleanup-pr/framework-onEvent/Resource", 62 | [ 63 | { 64 | id: "AwsSolutions-L1", 65 | reason: "Controlled by CDK L2 Construct" 66 | } 67 | ]) 68 | 69 | NagSuppressions.addResourceSuppressionsByPath(stack, 70 | "/sagemaker-emr/Custom::CDKBucketDeployment8693BB64968944B69AAFB0CC9EB8756C/Resource", 71 | [ 72 | { 73 | id: "AwsSolutions-L1", 74 | reason: "Controlled by CDK L2 Construct" 75 | } 76 | ]) 77 | } 78 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "advertising-intelligence-kit", 3 | "version": "0.1.0", 4 | "bin": { 5 | "advertising-intelligence-kit": "bin/advertising-intelligence-kit.js" 6 | }, 7 | "scripts": { 8 | "build": "eslint && tsc && cdk synth", 9 | "watch": "tsc -w", 10 | "test": "jest", 11 | "cdk": "cdk", 12 | "nag": "cdk synth --app='npx ts-node test/nag.ts'" 13 | }, 14 | "devDependencies": { 15 | "@aws-cdk/aws-servicecatalog-alpha": "^2.22.0-alpha.0", 16 | "@aws-sdk/client-ec2": "^3.552.0", 17 | "@types/fs-extra": "^11.0.4", 18 | "@types/node": "^20.12.7", 19 | "@typescript-eslint/eslint-plugin": "^5.35.1", 20 | "@typescript-eslint/parser": "^5.35.1", 21 | "aws-cdk": "^2.137.0", 22 | "aws-sdk": "^2.1599.0", 23 | "cdk-nag": "2.28.90", 24 | "constructs": "^10.3.0", 25 | "eslint": "^8.23.0", 26 | "jest": "^29.7.0", 27 | "ts-node": "^10.0.4", 28 | "typescript": "^5.4.5", 29 | "@types/babel__traverse": "7.18.5" 30 | }, 31 | "dependencies": { 32 | "amazon-s3-uri": "^0.1.1", 33 | "aws-cdk-lib": "^2.137.0", 34 | "constructs": "^10.3.0", 35 | "fs-extra": "^11.0.4", 36 | "prettier": "^2.7.1", 37 | "source-map-support": "^0.5.16", 38 | "ts-jest": "^29.1.2" 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /source/emr-bootstrap/Readme.md: -------------------------------------------------------------------------------- 1 | # Explanation 2 | 3 | This folder contains scripts used by EMR to run **Steps** or **Bootstrap actions**. 4 | 5 | The files are referred to in the `emr-product-stack.ts`. You can find the result of script execution in the Amazon EMR console, under **Steps** and **Bootstrap actions** tabs. 6 | -------------------------------------------------------------------------------- /source/emr-bootstrap/configurekdc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #Add a principal to the KDC for the master node, using the master node's returned host name 3 | sudo kadmin.local -q "ktadd -k /etc/krb5.keytab host/`hostname -f`" 4 | -------------------------------------------------------------------------------- /source/emr-bootstrap/installpylibs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sudo yum install -y python3-devel 3 | sudo yum install -y libtiff-devel libjpeg-devel libzip-devel freetype-devel lcms2-devel libwebp-devel tcl-devel tk-devel 4 | 5 | #install python modules 6 | sudo /usr/bin/python3 -m pip install -U cython==0.29.24 7 | sudo /usr/bin/python3 -m pip install -U setuptools==58.1.0 8 | sudo /usr/bin/python3 -m pip install -U numpy==1.21.2 9 | sudo /usr/bin/python3 -m pip install -U matplotlib==3.4.3 10 | sudo /usr/bin/python3 -m pip install -U requests==2.26.0 11 | sudo /usr/bin/python3 -m pip install -U boto3==1.18.63 12 | sudo /usr/bin/python3 -m pip install -U pandas==1.2.5 13 | 14 | # required to parse useragent 15 | sudo /usr/bin/python3 -m pip install -U woothee==1.10.1 16 | # remaining for mleap 17 | sudo /usr/bin/python3 -m pip install -U pybind11 18 | sudo /usr/bin/python3 -m pip install -U pythran 19 | sudo /usr/bin/python3 -m pip install -U scipy 20 | sudo /usr/bin/python3 -m pip install -U mleap==0.17 21 | 22 | 23 | # intall mleap on Master node 24 | if grep isMaster /mnt/var/lib/info/instance.json | grep false; 25 | then 26 | echo "This is not master node, do nothing,exiting" 27 | exit 0 28 | fi 29 | echo "This is master, continuing to execute script" 30 | 31 | 32 | -------------------------------------------------------------------------------- /source/notebooks/0_store_configuration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Configuration Management" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This notebook is designed to be run with `Python 3 (Data Science)` kernel." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Input & Output Configuration" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "In the following notebooks, we will use an S3 bucket to store raw data, processed data, feataures and trained models. Therefore we retrieve the configuration, which we build and stored earlier in the Parameter Store." 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Includes" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import os\n", 45 | "import boto3\n", 46 | "\n", 47 | "session = boto3.Session()\n", 48 | "ssm = session.client('ssm')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "bucket = ssm.get_parameter(Name=\"/aik/data-bucket\")[\"Parameter\"][\"Value\"]\n", 58 | "bucket" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Now as we have our base bucket, we can prepare the various prefixes which will be used for data processing and the model training later on." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# download url for the example data set\n", 75 | "download_url = \"https://www.kaggle.com/lastsummer/ipinyou/download\"\n", 76 | "\n", 77 | "# destination where we store the raw data\n", 78 | "raw_data = \"s3://\" + bucket + \"/raw/ipinyou-data\"\n", 79 | "# taking a subset of the rawdata to speed up processing and training during development\n", 80 | "bid_source = \"s3://\" + bucket + \"/raw/ipinyou-data/training1st/bid.20130311.txt.bz2\"\n", 81 | "imp_source = \"s3://\" + bucket + \"/raw/ipinyou-data/training1st/imp.20130311.txt.bz2\"\n", 82 | "\n", 83 | "# output destinations for the data processing \n", 84 | "output_train = \"s3://\" + bucket + \"/processed/sample/train\"\n", 85 | "output_test = \"s3://\" + bucket + \"/processed/sample/test\"\n", 86 | "output_verify = \"s3://\" + bucket + \"/processed/sample/valid\"\n", 87 | "output_transformed= \"s3://\" + bucket + \"/processed/sample/transformed\"\n", 88 | "pipelineModelArtifactPath = \"s3://\" + bucket + \"/pipeline-model/model.zip\"\n", 89 | "inference_data = \"s3://\" + bucket + \"/pipeline-model/inference-data/\"\n", 90 | "inference_schema = \"s3://\" + bucket + \"/pipeline-model/pipeline-schema.json\"\n", 91 | "binary_model = \"s3://\" + bucket + \"/binary-model/xgboost.bin\"" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "## Store Configuration data for consumption in following notebooks." 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "## Store Parameter in the System Manger ParameterStore" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "A good alternative way of storing parameters is the AWS Systems Manager Parameter Store" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "ssm.put_parameter(Name=\"/aik/download_url\", Value=download_url, Type=\"String\", Overwrite=True)\n", 122 | "ssm.put_parameter(Name=\"/aik/raw_data\", Value=raw_data, Type=\"String\", Overwrite=True)\n", 123 | "ssm.put_parameter(Name=\"/aik/bid_source\", Value=bid_source, Type=\"String\", Overwrite=True)\n", 124 | "ssm.put_parameter(Name=\"/aik/imp_source\", Value=imp_source, Type=\"String\", Overwrite=True)\n", 125 | "ssm.put_parameter(Name=\"/aik/output_train\", Value=output_train, Type=\"String\", Overwrite=True)\n", 126 | "ssm.put_parameter(Name=\"/aik/output_test\", Value=output_test, Type=\"String\", Overwrite=True)\n", 127 | "ssm.put_parameter(Name=\"/aik/output_verify\", Value=output_verify, Type=\"String\", Overwrite=True)\n", 128 | "ssm.put_parameter(Name=\"/aik/output_transformed\", Value= output_transformed, Type=\"String\", Overwrite=True)\n", 129 | "ssm.put_parameter(Name=\"/aik/pipelineModelArtifactPath\", Value= pipelineModelArtifactPath, Type=\"String\", Overwrite=True)\n", 130 | "ssm.put_parameter(Name=\"/aik/inference_data\", Value=inference_data, Type=\"String\", Overwrite=True)\n", 131 | "ssm.put_parameter(Name=\"/aik/xgboost/path\", Value=binary_model, Type=\"String\", Overwrite=True)\n", 132 | "ssm.put_parameter(Name=\"/aik/pipelineModelArtifactSchemaPath\", Value=inference_schema, Type=\"String\", Overwrite=True)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Read Parameter from the System Manager Parameter Store" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "bucket = ssm.get_parameter(Name=\"/aik/data-bucket\")[\"Parameter\"][\"Value\"]\n", 149 | "download_url = ssm.get_parameter(Name=\"/aik/download_url\")[\"Parameter\"][\"Value\"]\n", 150 | "raw_data = ssm.get_parameter(Name=\"/aik/raw_data\")[\"Parameter\"][\"Value\"]\n", 151 | "bid_source = ssm.get_parameter(Name=\"/aik/bid_source\")[\"Parameter\"][\"Value\"]\n", 152 | "imp_source = ssm.get_parameter(Name=\"/aik/bid_source\")[\"Parameter\"][\"Value\"]\n", 153 | "output_train = ssm.get_parameter(Name=\"/aik/output_train\")[\"Parameter\"][\"Value\"]\n", 154 | "output_test = ssm.get_parameter(Name=\"/aik/output_test\")[\"Parameter\"][\"Value\"]\n", 155 | "output_verify = ssm.get_parameter(Name=\"/aik/output_verify\")[\"Parameter\"][\"Value\"] \n", 156 | "output_transformed = ssm.get_parameter(Name=\"/aik/output_transformed\")[\"Parameter\"][\"Value\"] \n", 157 | "pipelineModelArtifactPath = ssm.get_parameter(Name=\"/aik/pipelineModelArtifactPath\")[\"Parameter\"][\"Value\"] \n", 158 | "inference_data = ssm.get_parameter(Name=\"/aik/inference_data\")[\"Parameter\"][\"Value\"]\n", 159 | "binary_model = ssm.get_parameter(Name=\"/aik/xgboost/path\")[\"Parameter\"][\"Value\"]\n", 160 | "inference_schema= ssm.get_parameter(Name=\"/aik/pipelineModelArtifactSchemaPath\")[\"Parameter\"][\"Value\"]" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Print current configuration" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "print(f'bucket={bucket}')\n", 177 | "print(f'download_url={download_url}')\n", 178 | "print(f'raw_data={raw_data}')\n", 179 | "print(f'bid_source={bid_source}')\n", 180 | "print(f'imp_source={imp_source}')\n", 181 | "print(f'output_train={output_train}')\n", 182 | "print(f'output_verify={output_verify}')\n", 183 | "print(f'output_test={output_test}')\n", 184 | "print(f'output_transformed={output_transformed}')\n", 185 | "print(f'pipelineModelArtifactPath={pipelineModelArtifactPath}')" 186 | ] 187 | } 188 | ], 189 | "metadata": { 190 | "instance_type": "ml.t3.medium", 191 | "kernelspec": { 192 | "display_name": "Python 3 (Data Science)", 193 | "language": "python", 194 | "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:eu-west-1:470317259841:image/datascience-1.0" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.7.10" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 4 211 | } 212 | -------------------------------------------------------------------------------- /source/notebooks/1_download_ipinyou_data_tos3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Setting Up the Configuration for the data storage in S3" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This notebook is designed to be run with `Python 3 (Data Science)` kernel." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "import boto3\n", 25 | "\n", 26 | "session = boto3.Session() \n", 27 | "ssm = session.client('ssm')\n", 28 | "\n", 29 | "download_url = ssm.get_parameter(Name=\"/aik/download_url\")[\"Parameter\"][\"Value\"]\n", 30 | "raw_data = ssm.get_parameter(Name=\"/aik/raw_data\")[\"Parameter\"][\"Value\"]" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "We are using the opendatasets library to easily download the dataset from kaggle" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "%pip install opendatasets==0.1.20" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import opendatasets as od" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "You will need an kaggle account. Script is asking for username(not email) and a kaggle key. Refer to https://github.com/Kaggle/kaggle-api if you need to create a kaggle key." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "od.download(download_url)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Copy data from local storage to the S3 bucket" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "import os \n", 88 | "cwd = os.getcwd()\n", 89 | "cwd\n", 90 | "source = f\"{cwd}/ipinyou/ipinyou.contest.dataset\"\n", 91 | "source" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "!aws s3 cp --recursive $source $raw_data" 101 | ] 102 | } 103 | ], 104 | "metadata": { 105 | "instance_type": "ml.t3.medium", 106 | "kernelspec": { 107 | "display_name": "Python 3 (Data Science)", 108 | "language": "python", 109 | "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:eu-west-1:470317259841:image/datascience-1.0" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.7.10" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 4 126 | } 127 | -------------------------------------------------------------------------------- /source/sagemaker-sg-cleanup/delete-sg.ts: -------------------------------------------------------------------------------- 1 | import * as aws from "aws-sdk" 2 | 3 | /** 4 | * Delete the security groups that Sagemaker creates for a domain to connect to EFS. 5 | * 6 | * Ideally these would be deleted by the CloudFormation resource. 7 | * 8 | * @param vpcId 9 | */ 10 | export async function deleteSagemakerSecurityGroups(vpcId: string) { 11 | console.log("About to delete NFS security groups for ", vpcId) 12 | console.log("Region", process.env.AWS_REGION) 13 | const ec2 = new aws.EC2({ 14 | region: process.env.AWS_DEFAULT_REGION || process.env.AWS_REGION 15 | }) 16 | const sgResult = await ec2.describeSecurityGroups({ 17 | Filters: [ 18 | { 19 | Name: "vpc-id", 20 | Values: [vpcId] 21 | } 22 | ] 23 | }).promise() 24 | console.log({sgResult}) 25 | const groupsToDelete = [] 26 | if (sgResult.SecurityGroups) { 27 | for (const sg of sgResult.SecurityGroups) { 28 | console.log(sg) 29 | if (sg.VpcId === vpcId && sg.GroupName && sg.GroupId && ( 30 | sg.GroupName.indexOf("outbound-nfs") > -1 || 31 | sg.GroupName.indexOf("inbound-nfs") > -1)) { 32 | 33 | console.log(`Found sg ${sg.GroupId}`) 34 | groupsToDelete.push(sg) 35 | 36 | const ruleResult = await ec2.describeSecurityGroupRules({ 37 | Filters: [ 38 | { 39 | Name: "group-id", 40 | Values: [sg.GroupId] 41 | } 42 | ] 43 | }).promise() 44 | 45 | console.log(ruleResult.SecurityGroupRules) 46 | 47 | if (ruleResult.SecurityGroupRules) { 48 | for (const r of ruleResult.SecurityGroupRules) { 49 | if (r.IsEgress && r.SecurityGroupRuleId) { 50 | await ec2.revokeSecurityGroupEgress({ 51 | GroupId: sg.GroupId, 52 | SecurityGroupRuleIds: [r.SecurityGroupRuleId] 53 | }).promise() 54 | } else if (r.SecurityGroupRuleId) { 55 | await ec2.revokeSecurityGroupIngress({ 56 | GroupId: sg.GroupId, 57 | SecurityGroupRuleIds: [r.SecurityGroupRuleId] 58 | }).promise() 59 | } 60 | } 61 | } 62 | } 63 | } 64 | 65 | // Now that the rules are all deleted, delete the groups 66 | for (const g of groupsToDelete) { 67 | await ec2.deleteSecurityGroup({ 68 | GroupId: g.GroupId 69 | }).promise() 70 | } 71 | } 72 | } -------------------------------------------------------------------------------- /source/sagemaker-sg-cleanup/handler.ts: -------------------------------------------------------------------------------- 1 | // import * as ec2 from "@aws-sdk/client-ec2" 2 | import * as aws from "aws-sdk" 3 | import { deleteSagemakerSecurityGroups } from "./delete-sg" 4 | 5 | /** 6 | * Handle requests from 7 | * @param evt 8 | */ 9 | exports.handler = async (evt: any): Promise => { 10 | console.log(evt) 11 | 12 | const requestType = evt.RequestType 13 | 14 | const vpcId = process.env["VPC_ID"] 15 | console.info({ vpcId }) 16 | 17 | const efsId = process.env["EFS_ID"] 18 | console.info({ efsId }) 19 | 20 | if (!vpcId) { 21 | throw new Error("VPC_ID not defined") 22 | } 23 | 24 | if (!efsId) { 25 | throw new Error("EFS_ID not defined") 26 | } 27 | 28 | if (requestType === "Create" || requestType === "Update") { 29 | 30 | // We don't do anything for create or update 31 | return { "PhysicalResourceId": "N/A" } 32 | 33 | } else if (requestType == "Delete") { 34 | 35 | // Delete the file system created for the SageMaker domain 36 | try { 37 | console.log(`About to delete file system ${efsId}`) 38 | const efs = new aws.EFS() 39 | const mountTargets = await efs.describeMountTargets({ 40 | FileSystemId: efsId 41 | }).promise() 42 | if (mountTargets.MountTargets) { 43 | for (const t of mountTargets.MountTargets) { 44 | await efs.deleteMountTarget({ 45 | MountTargetId: t.MountTargetId 46 | }).promise() 47 | 48 | // The DeleteMountTarget call returns while the mount target state 49 | // is still deleting. You can check the mount target deletion by 50 | // calling the DescribeMountTargets operation, which returns a list of 51 | // mount target descriptions for the given file system. 52 | 53 | let remainingTargets 54 | let deleting = true 55 | const maxTries = 10 56 | let currentTry = 0 57 | console.log("About to wait for mount target to be deleted") 58 | do { 59 | currentTry += 1 60 | await new Promise(r => setTimeout(r, 5000)) 61 | remainingTargets = await efs.describeMountTargets({ 62 | FileSystemId: efsId 63 | }).promise() 64 | console.log({ remainingTargets }) 65 | let sawIt = false 66 | if (remainingTargets.MountTargets) { 67 | for (const stillThere of remainingTargets.MountTargets) { 68 | if (stillThere.MountTargetId == t.MountTargetId) { 69 | sawIt = true 70 | } 71 | } 72 | } 73 | if (!sawIt) deleting = false 74 | } while (deleting && currentTry < maxTries) 75 | } 76 | } 77 | await efs.deleteFileSystem({ 78 | FileSystemId: efsId 79 | }).promise() 80 | } catch (ex) { 81 | console.error(ex) 82 | } 83 | 84 | // SageMaker automatically adds a security group to allow access to NFS. 85 | // https://awscli.amazonaws.com/v2/documentation/api/latest/reference/sagemaker/create-domain.html 86 | // This makes it impossible to delete the stack... 87 | // Adding these rules manually does not prevent SageMaker from creating them 88 | // This custom resource handler deletes the security groups 89 | // Names: 90 | // security-group-for-outbound-nfs-d-lhtklclyyvsw 91 | // security-group-for-inbound-nfs-d-lhtklclyyvsw 92 | 93 | try { 94 | console.log(`About to delete NFS security groups from VPC ${vpcId}`) 95 | await deleteSagemakerSecurityGroups(vpcId) 96 | } catch (ex) { 97 | console.error(ex) 98 | } 99 | } 100 | } 101 | 102 | 103 | 104 | // /** 105 | // * For reference, here's the JS SDK V3 code. It's not available by default on Lambda, 106 | // * and we'd have to do some bundling to package it up, so we went back to v2. 107 | // */ 108 | // async function v3() { 109 | // (async () => { 110 | // try { 111 | // const client = new ec2.EC2Client({}) 112 | // const command = new ec2.DescribeSecurityGroupsCommand({ 113 | // Filters: [ 114 | // { 115 | // Name: "vpc-id", 116 | // Values: [vpcId] 117 | // } 118 | // ] 119 | // }) 120 | // const results = await client.send(command) 121 | // console.log(results) 122 | 123 | // if (results.SecurityGroups) { 124 | // for (const result of results.SecurityGroups) { 125 | // console.log(result) 126 | // const groupName = result.GroupName || "" 127 | // if (groupName.indexOf("outbound-nfs") > -1 || 128 | // groupName.indexOf("inbound-nfs") > -1) { 129 | // console.log(`Deleting ${groupName}`) 130 | // await client.send(new ec2.DeleteSecurityGroupCommand({ 131 | // GroupId: result.GroupId 132 | // })) 133 | // } 134 | // } 135 | // } 136 | 137 | // } catch (err) { 138 | // console.error(err) 139 | // } 140 | // })() 141 | // } 142 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/Dockerfile.Client: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | # 5 | # Build stage 6 | # 7 | FROM public.ecr.aws/docker/library/maven:3.8.1-jdk-11-slim AS build 8 | COPY src /home/app/src 9 | COPY pom.xml /home/app 10 | RUN mvn -f /home/app/pom.xml clean package 11 | 12 | # 13 | # Package stages 14 | # 15 | FROM public.ecr.aws/amazoncorretto/amazoncorretto:11 16 | RUN mkdir -p /home/app 17 | RUN chown 8000 /home/app 18 | WORKDIR /home/app 19 | USER 8000 20 | COPY --from=build /home/app/target/traffic-filtering-app-1.0-SNAPSHOT-jar-with-dependencies.jar /usr/local/lib/traffic-filtering-app.jar 21 | RUN mkdir -p model 22 | RUN mkdir -p .tmp 23 | ENTRYPOINT ["java","-cp","/usr/local/lib/traffic-filtering-app.jar","-Xmx8g","com.aik.perfclient.MultiThreadedClient"] -------------------------------------------------------------------------------- /source/traffic-filtering-app/Dockerfile.Server: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | # 5 | # Build stage 6 | # 7 | FROM public.ecr.aws/docker/library/maven:3.8.1-jdk-11-slim AS build 8 | COPY src /home/app/src 9 | COPY pom.xml /home/app 10 | RUN mvn -f /home/app/pom.xml clean package 11 | 12 | # 13 | # Package stage 14 | # 15 | FROM public.ecr.aws/amazoncorretto/amazoncorretto:11 16 | RUN yum update -y ; yum install -y gcc 17 | RUN mkdir -p /home/app 18 | RUN chown 8000 /home/app 19 | WORKDIR /home/app 20 | USER 8000 21 | COPY --from=build /home/app/target/traffic-filtering-app-1.0-SNAPSHOT-jar-with-dependencies.jar /usr/local/lib/traffic-filtering-app.jar 22 | RUN mkdir -p .tmp 23 | 24 | 25 | 26 | EXPOSE 8080 27 | ENTRYPOINT ["java","-cp","/usr/local/lib/traffic-filtering-app.jar","com.aik.prediction.InferenceServer"] -------------------------------------------------------------------------------- /source/traffic-filtering-app/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.aik 8 | traffic-filtering-app 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 13 | software.amazon.awssdk 14 | bom 15 | 2.17.101 16 | pom 17 | import 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | net.alchim31.maven 27 | scala-maven-plugin 28 | 4.6.1 29 | 30 | 31 | 32 | 33 | 34 | net.alchim31.maven 35 | scala-maven-plugin 36 | 37 | 38 | scala-compile-first 39 | process-resources 40 | 41 | add-source 42 | compile 43 | 44 | 45 | 46 | 47 | 48 | maven-assembly-plugin 49 | 3.3.0 50 | 51 | 52 | jar-with-dependencies 53 | 54 | 55 | 56 | 57 | package 58 | 59 | single 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | org.scala-lang 70 | scala-library 71 | 2.12.15 72 | 73 | 74 | ml.combust.mleap 75 | mleap-runtime_2.12 76 | 0.19.0 77 | 78 | 79 | ml.combust.mleap 80 | mleap-spark_2.12 81 | 0.19.0 82 | 83 | 84 | org.apache.spark 85 | spark-sql_2.12 86 | 3.1.2 87 | provided 88 | 89 | 90 | com.github.plokhotnyuk.jsoniter-scala 91 | jsoniter-scala-core_2.12 92 | 2.13.18 93 | 94 | 95 | com.github.plokhotnyuk.jsoniter-scala 96 | jsoniter-scala-macros_2.12 97 | 2.13.18 98 | provided 99 | 100 | 101 | 102 | org.apache.logging.log4j 103 | log4j-api 104 | 2.17.0 105 | 106 | 107 | org.apache.logging.log4j 108 | log4j-core 109 | 2.17.1 110 | 111 | 112 | org.apache.thrift 113 | libthrift 114 | 0.15.0 115 | 116 | 117 | com.timgroup 118 | java-statsd-client 119 | 3.1.0 120 | 121 | 122 | org.apache.commons 123 | commons-math3 124 | 3.6.1 125 | 126 | 127 | com.googlecode.json-simple 128 | json-simple 129 | 1.1.1 130 | 131 | 132 | software.amazon.awssdk 133 | s3 134 | 135 | 136 | software.amazon.awssdk 137 | ssm 138 | 139 | 140 | ml.dmlc 141 | xgboost4j-spark_2.12 142 | 1.6.0 143 | 144 | 145 | ml.dmlc 146 | xgboost4j_2.12 147 | 1.6.0 148 | 149 | 150 | org.apache.commons 151 | commons-compress 152 | 1.21 153 | 154 | 155 | com.google.guava 156 | guava 157 | 31.0.1-jre 158 | 159 | 160 | is.tagomor.woothee 161 | woothee-java 162 | 1.11.0 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 11 171 | 11 172 | 173 | 174 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/filterapi/BidResponse.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Thrift Compiler (0.15.0) 3 | * 4 | * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING 5 | * @generated 6 | */ 7 | package com.aik.filterapi; 8 | 9 | @SuppressWarnings({"cast", "rawtypes", "serial", "unchecked", "unused"}) 10 | /** 11 | * Raw data required for filtering a bid request 12 | */ 13 | @javax.annotation.Generated(value = "Autogenerated by Thrift Compiler (0.15.0)", date = "2022-03-08") 14 | public class BidResponse implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { 15 | private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("BidResponse"); 16 | 17 | private static final org.apache.thrift.protocol.TField LIKELIHOOD_TO_BID_FIELD_DESC = new org.apache.thrift.protocol.TField("likelihoodToBid", org.apache.thrift.protocol.TType.DOUBLE, (short)1); 18 | 19 | private static final org.apache.thrift.scheme.SchemeFactory STANDARD_SCHEME_FACTORY = new BidResponseStandardSchemeFactory(); 20 | private static final org.apache.thrift.scheme.SchemeFactory TUPLE_SCHEME_FACTORY = new BidResponseTupleSchemeFactory(); 21 | 22 | public double likelihoodToBid; // required 23 | 24 | /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ 25 | public enum _Fields implements org.apache.thrift.TFieldIdEnum { 26 | LIKELIHOOD_TO_BID((short)1, "likelihoodToBid"); 27 | 28 | private static final java.util.Map byName = new java.util.HashMap(); 29 | 30 | static { 31 | for (_Fields field : java.util.EnumSet.allOf(_Fields.class)) { 32 | byName.put(field.getFieldName(), field); 33 | } 34 | } 35 | 36 | /** 37 | * Find the _Fields constant that matches fieldId, or null if its not found. 38 | */ 39 | @org.apache.thrift.annotation.Nullable 40 | public static _Fields findByThriftId(int fieldId) { 41 | switch(fieldId) { 42 | case 1: // LIKELIHOOD_TO_BID 43 | return LIKELIHOOD_TO_BID; 44 | default: 45 | return null; 46 | } 47 | } 48 | 49 | /** 50 | * Find the _Fields constant that matches fieldId, throwing an exception 51 | * if it is not found. 52 | */ 53 | public static _Fields findByThriftIdOrThrow(int fieldId) { 54 | _Fields fields = findByThriftId(fieldId); 55 | if (fields == null) throw new java.lang.IllegalArgumentException("Field " + fieldId + " doesn't exist!"); 56 | return fields; 57 | } 58 | 59 | /** 60 | * Find the _Fields constant that matches name, or null if its not found. 61 | */ 62 | @org.apache.thrift.annotation.Nullable 63 | public static _Fields findByName(java.lang.String name) { 64 | return byName.get(name); 65 | } 66 | 67 | private final short _thriftId; 68 | private final java.lang.String _fieldName; 69 | 70 | _Fields(short thriftId, java.lang.String fieldName) { 71 | _thriftId = thriftId; 72 | _fieldName = fieldName; 73 | } 74 | 75 | public short getThriftFieldId() { 76 | return _thriftId; 77 | } 78 | 79 | public java.lang.String getFieldName() { 80 | return _fieldName; 81 | } 82 | } 83 | 84 | // isset id assignments 85 | private static final int __LIKELIHOODTOBID_ISSET_ID = 0; 86 | private byte __isset_bitfield = 0; 87 | public static final java.util.Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; 88 | static { 89 | java.util.Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new java.util.EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); 90 | tmpMap.put(_Fields.LIKELIHOOD_TO_BID, new org.apache.thrift.meta_data.FieldMetaData("likelihoodToBid", org.apache.thrift.TFieldRequirementType.DEFAULT, 91 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE))); 92 | metaDataMap = java.util.Collections.unmodifiableMap(tmpMap); 93 | org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(BidResponse.class, metaDataMap); 94 | } 95 | 96 | public BidResponse() { 97 | } 98 | 99 | public BidResponse( 100 | double likelihoodToBid) 101 | { 102 | this(); 103 | this.likelihoodToBid = likelihoodToBid; 104 | setLikelihoodToBidIsSet(true); 105 | } 106 | 107 | /** 108 | * Performs a deep copy on other. 109 | */ 110 | public BidResponse(BidResponse other) { 111 | __isset_bitfield = other.__isset_bitfield; 112 | this.likelihoodToBid = other.likelihoodToBid; 113 | } 114 | 115 | public BidResponse deepCopy() { 116 | return new BidResponse(this); 117 | } 118 | 119 | @Override 120 | public void clear() { 121 | setLikelihoodToBidIsSet(false); 122 | this.likelihoodToBid = 0.0; 123 | } 124 | 125 | public double getLikelihoodToBid() { 126 | return this.likelihoodToBid; 127 | } 128 | 129 | public BidResponse setLikelihoodToBid(double likelihoodToBid) { 130 | this.likelihoodToBid = likelihoodToBid; 131 | setLikelihoodToBidIsSet(true); 132 | return this; 133 | } 134 | 135 | public void unsetLikelihoodToBid() { 136 | __isset_bitfield = org.apache.thrift.EncodingUtils.clearBit(__isset_bitfield, __LIKELIHOODTOBID_ISSET_ID); 137 | } 138 | 139 | /** Returns true if field likelihoodToBid is set (has been assigned a value) and false otherwise */ 140 | public boolean isSetLikelihoodToBid() { 141 | return org.apache.thrift.EncodingUtils.testBit(__isset_bitfield, __LIKELIHOODTOBID_ISSET_ID); 142 | } 143 | 144 | public void setLikelihoodToBidIsSet(boolean value) { 145 | __isset_bitfield = org.apache.thrift.EncodingUtils.setBit(__isset_bitfield, __LIKELIHOODTOBID_ISSET_ID, value); 146 | } 147 | 148 | public void setFieldValue(_Fields field, @org.apache.thrift.annotation.Nullable java.lang.Object value) { 149 | switch (field) { 150 | case LIKELIHOOD_TO_BID: 151 | if (value == null) { 152 | unsetLikelihoodToBid(); 153 | } else { 154 | setLikelihoodToBid((java.lang.Double)value); 155 | } 156 | break; 157 | 158 | } 159 | } 160 | 161 | @org.apache.thrift.annotation.Nullable 162 | public java.lang.Object getFieldValue(_Fields field) { 163 | switch (field) { 164 | case LIKELIHOOD_TO_BID: 165 | return getLikelihoodToBid(); 166 | 167 | } 168 | throw new java.lang.IllegalStateException(); 169 | } 170 | 171 | /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ 172 | public boolean isSet(_Fields field) { 173 | if (field == null) { 174 | throw new java.lang.IllegalArgumentException(); 175 | } 176 | 177 | switch (field) { 178 | case LIKELIHOOD_TO_BID: 179 | return isSetLikelihoodToBid(); 180 | } 181 | throw new java.lang.IllegalStateException(); 182 | } 183 | 184 | @Override 185 | public boolean equals(java.lang.Object that) { 186 | if (that instanceof BidResponse) 187 | return this.equals((BidResponse)that); 188 | return false; 189 | } 190 | 191 | public boolean equals(BidResponse that) { 192 | if (that == null) 193 | return false; 194 | if (this == that) 195 | return true; 196 | 197 | boolean this_present_likelihoodToBid = true; 198 | boolean that_present_likelihoodToBid = true; 199 | if (this_present_likelihoodToBid || that_present_likelihoodToBid) { 200 | if (!(this_present_likelihoodToBid && that_present_likelihoodToBid)) 201 | return false; 202 | if (this.likelihoodToBid != that.likelihoodToBid) 203 | return false; 204 | } 205 | 206 | return true; 207 | } 208 | 209 | @Override 210 | public int hashCode() { 211 | int hashCode = 1; 212 | 213 | hashCode = hashCode * 8191 + org.apache.thrift.TBaseHelper.hashCode(likelihoodToBid); 214 | 215 | return hashCode; 216 | } 217 | 218 | @Override 219 | public int compareTo(BidResponse other) { 220 | if (!getClass().equals(other.getClass())) { 221 | return getClass().getName().compareTo(other.getClass().getName()); 222 | } 223 | 224 | int lastComparison = 0; 225 | 226 | lastComparison = java.lang.Boolean.compare(isSetLikelihoodToBid(), other.isSetLikelihoodToBid()); 227 | if (lastComparison != 0) { 228 | return lastComparison; 229 | } 230 | if (isSetLikelihoodToBid()) { 231 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.likelihoodToBid, other.likelihoodToBid); 232 | if (lastComparison != 0) { 233 | return lastComparison; 234 | } 235 | } 236 | return 0; 237 | } 238 | 239 | @org.apache.thrift.annotation.Nullable 240 | public _Fields fieldForId(int fieldId) { 241 | return _Fields.findByThriftId(fieldId); 242 | } 243 | 244 | public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { 245 | scheme(iprot).read(iprot, this); 246 | } 247 | 248 | public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { 249 | scheme(oprot).write(oprot, this); 250 | } 251 | 252 | @Override 253 | public java.lang.String toString() { 254 | java.lang.StringBuilder sb = new java.lang.StringBuilder("BidResponse("); 255 | boolean first = true; 256 | 257 | sb.append("likelihoodToBid:"); 258 | sb.append(this.likelihoodToBid); 259 | first = false; 260 | sb.append(")"); 261 | return sb.toString(); 262 | } 263 | 264 | public void validate() throws org.apache.thrift.TException { 265 | // check for required fields 266 | // check for sub-struct validity 267 | } 268 | 269 | private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { 270 | try { 271 | write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); 272 | } catch (org.apache.thrift.TException te) { 273 | throw new java.io.IOException(te); 274 | } 275 | } 276 | 277 | private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, java.lang.ClassNotFoundException { 278 | try { 279 | // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. 280 | __isset_bitfield = 0; 281 | read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); 282 | } catch (org.apache.thrift.TException te) { 283 | throw new java.io.IOException(te); 284 | } 285 | } 286 | 287 | private static class BidResponseStandardSchemeFactory implements org.apache.thrift.scheme.SchemeFactory { 288 | public BidResponseStandardScheme getScheme() { 289 | return new BidResponseStandardScheme(); 290 | } 291 | } 292 | 293 | private static class BidResponseStandardScheme extends org.apache.thrift.scheme.StandardScheme { 294 | 295 | public void read(org.apache.thrift.protocol.TProtocol iprot, BidResponse struct) throws org.apache.thrift.TException { 296 | org.apache.thrift.protocol.TField schemeField; 297 | iprot.readStructBegin(); 298 | while (true) 299 | { 300 | schemeField = iprot.readFieldBegin(); 301 | if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { 302 | break; 303 | } 304 | switch (schemeField.id) { 305 | case 1: // LIKELIHOOD_TO_BID 306 | if (schemeField.type == org.apache.thrift.protocol.TType.DOUBLE) { 307 | struct.likelihoodToBid = iprot.readDouble(); 308 | struct.setLikelihoodToBidIsSet(true); 309 | } else { 310 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); 311 | } 312 | break; 313 | default: 314 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); 315 | } 316 | iprot.readFieldEnd(); 317 | } 318 | iprot.readStructEnd(); 319 | 320 | // check for required fields of primitive type, which can't be checked in the validate method 321 | struct.validate(); 322 | } 323 | 324 | public void write(org.apache.thrift.protocol.TProtocol oprot, BidResponse struct) throws org.apache.thrift.TException { 325 | struct.validate(); 326 | 327 | oprot.writeStructBegin(STRUCT_DESC); 328 | oprot.writeFieldBegin(LIKELIHOOD_TO_BID_FIELD_DESC); 329 | oprot.writeDouble(struct.likelihoodToBid); 330 | oprot.writeFieldEnd(); 331 | oprot.writeFieldStop(); 332 | oprot.writeStructEnd(); 333 | } 334 | 335 | } 336 | 337 | private static class BidResponseTupleSchemeFactory implements org.apache.thrift.scheme.SchemeFactory { 338 | public BidResponseTupleScheme getScheme() { 339 | return new BidResponseTupleScheme(); 340 | } 341 | } 342 | 343 | private static class BidResponseTupleScheme extends org.apache.thrift.scheme.TupleScheme { 344 | 345 | @Override 346 | public void write(org.apache.thrift.protocol.TProtocol prot, BidResponse struct) throws org.apache.thrift.TException { 347 | org.apache.thrift.protocol.TTupleProtocol oprot = (org.apache.thrift.protocol.TTupleProtocol) prot; 348 | java.util.BitSet optionals = new java.util.BitSet(); 349 | if (struct.isSetLikelihoodToBid()) { 350 | optionals.set(0); 351 | } 352 | oprot.writeBitSet(optionals, 1); 353 | if (struct.isSetLikelihoodToBid()) { 354 | oprot.writeDouble(struct.likelihoodToBid); 355 | } 356 | } 357 | 358 | @Override 359 | public void read(org.apache.thrift.protocol.TProtocol prot, BidResponse struct) throws org.apache.thrift.TException { 360 | org.apache.thrift.protocol.TTupleProtocol iprot = (org.apache.thrift.protocol.TTupleProtocol) prot; 361 | java.util.BitSet incoming = iprot.readBitSet(1); 362 | if (incoming.get(0)) { 363 | struct.likelihoodToBid = iprot.readDouble(); 364 | struct.setLikelihoodToBidIsSet(true); 365 | } 366 | } 367 | } 368 | 369 | private static S scheme(org.apache.thrift.protocol.TProtocol proto) { 370 | return (org.apache.thrift.scheme.StandardScheme.class.equals(proto.getScheme()) ? STANDARD_SCHEME_FACTORY : TUPLE_SCHEME_FACTORY).getScheme(); 371 | } 372 | } 373 | 374 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/perfclient/FilteringResult.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.aik.perfclient; 5 | 6 | public class FilteringResult { 7 | private double executionTime ; 8 | private double likelihood; 9 | 10 | public FilteringResult(double executionTime, double likelihood) { 11 | this.executionTime = executionTime; 12 | this.likelihood = likelihood; 13 | } 14 | 15 | public double getExecutionTime() { 16 | return executionTime; 17 | } 18 | 19 | public double getLikelihood() { 20 | return likelihood; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/perfclient/MultiThreadedClient.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | 5 | package com.aik.perfclient; 6 | 7 | 8 | import com.aik.filterapi.BidRequest; 9 | import com.aik.filterapi.BidRequestFilter; 10 | import com.aik.filterapi.BidResponse; 11 | import com.google.common.math.Quantiles; 12 | import com.timgroup.statsd.NonBlockingStatsDClient; 13 | import com.timgroup.statsd.StatsDClient; 14 | import is.tagomor.woothee.Classifier; 15 | import org.apache.logging.log4j.LogManager; 16 | import org.apache.logging.log4j.Logger; 17 | import org.apache.thrift.TException; 18 | import org.apache.thrift.protocol.TBinaryProtocol; 19 | import org.apache.thrift.protocol.TProtocol; 20 | import org.apache.thrift.transport.TSocket; 21 | import org.apache.thrift.transport.TTransport; 22 | import org.json.simple.JSONObject; 23 | import org.json.simple.parser.JSONParser; 24 | import org.json.simple.parser.ParseException; 25 | import software.amazon.awssdk.services.s3.S3Client; 26 | import software.amazon.awssdk.services.s3.model.*; 27 | import software.amazon.awssdk.services.ssm.SsmClient; 28 | import software.amazon.awssdk.services.ssm.model.GetParameterRequest; 29 | import software.amazon.awssdk.services.ssm.model.GetParameterResponse; 30 | import software.amazon.awssdk.services.ssm.model.SsmException; 31 | 32 | import java.io.BufferedReader; 33 | import java.io.FileReader; 34 | import java.io.IOException; 35 | import java.nio.file.Paths; 36 | import java.time.Duration; 37 | import java.time.Instant; 38 | import java.util.*; 39 | import java.util.concurrent.*; 40 | import java.util.regex.Matcher; 41 | import java.util.regex.Pattern; 42 | import java.util.stream.Collectors; 43 | import java.util.stream.IntStream; 44 | 45 | 46 | public class MultiThreadedClient { 47 | private static final Logger logger = LogManager.getLogger(MultiThreadedClient.class.getName()); 48 | private static final StatsDClient statsd = new NonBlockingStatsDClient("adserver_client", "localhost", 8125); 49 | private static final Random randDeviceType = new Random(); 50 | private static final int maxDeviceType = 5; 51 | 52 | 53 | public static void main(String[] args) { 54 | logger.warn("Starting client"); 55 | logger.warn("Sleep 30 s"); 56 | 57 | try { 58 | TimeUnit.SECONDS.sleep(30); 59 | } catch (InterruptedException e) { 60 | logger.error("Failed while waiting"); 61 | logger.catching(e); 62 | } 63 | logger.info("Start downloading test data"); 64 | try { 65 | int nbThread = 1; 66 | int nbTest = 100000; 67 | String inputRequestUri = "" ; 68 | if (args.length > 0) { 69 | nbThread = Integer.parseInt(args[1]); 70 | nbTest = Integer.parseInt(args[2]); 71 | inputRequestUri = getValueFromSsmParameter("/aik/inference_data"); 72 | } 73 | 74 | logger.warn("nbThread " + nbThread ) ; 75 | logger.warn("nbTest " + nbTest ) ; 76 | 77 | // download file from s3 78 | AbstractMap.SimpleEntry s3URIParsed = MultiThreadedClient.parseS3Uri(inputRequestUri); 79 | MultiThreadedClient.downloadTestFile(s3URIParsed.getKey(),s3URIParsed.getValue(),"./.tmp/test.json") ; 80 | ArrayList dataset = MultiThreadedClient.loadData("./.tmp/test.json") ; 81 | 82 | perform(nbThread,nbTest,dataset); 83 | 84 | 85 | logger.info("Ending client"); 86 | } catch (TException x) { 87 | logger.error("Exception while opening TCP socket"); 88 | logger.catching(x); 89 | } 90 | } 91 | 92 | private static FilteringResult executeInPromise(BidRequestFilter.Client client, BidRequest bidRequest) { 93 | FilteringResult filteringResult = null; 94 | try { 95 | filteringResult = MultiThreadedClient.performOne(client, bidRequest); 96 | } catch (TException e) { 97 | e.printStackTrace(); 98 | } 99 | 100 | return filteringResult; 101 | } 102 | 103 | private static void perform(int nbThread,int nbTest, ArrayList bidRequests) throws TException { 104 | ExecutorService executorService = Executors.newFixedThreadPool(nbThread); 105 | List>> callables = new ArrayList<>(); 106 | logger.warn("starting load test"); 107 | Instant start = Instant.now(); 108 | 109 | for (int curentThreadIdx = 0; curentThreadIdx < nbThread; curentThreadIdx++) { 110 | final int threadId = curentThreadIdx ; 111 | Callable> callable = () -> { 112 | TTransport transport; 113 | 114 | transport = new TSocket("localhost", 9090); 115 | transport.open(); 116 | TProtocol protocol = new TBinaryProtocol(transport); 117 | BidRequestFilter.Client client = new BidRequestFilter.Client(protocol); 118 | logger.info(" nbTest " + nbTest + " nb Thread " + nbThread + " bid request size " + bidRequests.size()) ; 119 | IntStream idxStream = new Random().ints(nbTest / nbThread, threadId*(nbTest / nbThread), bidRequests.size()); 120 | List filteringResults = idxStream.mapToObj(i -> { 121 | //logger.warn("par Thread " + Thread.currentThread().getId()); 122 | logger.trace("random idx " + i); 123 | return MultiThreadedClient.executeInPromise(client, bidRequests.get(i)); 124 | }).collect(Collectors.toList()); 125 | transport.close(); 126 | return filteringResults; 127 | }; 128 | callables.add(callable) ; 129 | } 130 | List>> result = null ; 131 | try { 132 | result = executorService.invokeAll(callables) ; 133 | } catch (InterruptedException e) { 134 | e.printStackTrace(); 135 | } 136 | 137 | logger.warn("current Thread " + Thread.currentThread().getId()); 138 | 139 | assert result != null; 140 | List dataset = result.stream().map(futureExec -> { 141 | List filteringResults = null; 142 | try { 143 | filteringResults = futureExec.get(); 144 | } catch (InterruptedException | ExecutionException e) { 145 | e.printStackTrace(); 146 | } 147 | return filteringResults; 148 | }).filter(Objects::nonNull).flatMap(List::stream).collect(Collectors.toList()); 149 | 150 | logger.warn("starting load end"); 151 | 152 | 153 | List datasetExecutionTime = dataset.stream().map(FilteringResult::getExecutionTime).collect(Collectors.toList()); 154 | List datasetLikelihood = dataset.stream().map(FilteringResult::getLikelihood).collect(Collectors.toList()); 155 | 156 | //AWSXRay.endSegment(); 157 | Duration totalExecutionDuration = Duration.between(start, Instant.now()); 158 | float totalExecutionTime = (float) totalExecutionDuration.toMillis() / 1000.f; 159 | logger.warn("ending load test (s)" + totalExecutionTime); 160 | logger.warn("observed throughput (qps)" + (float) nbTest / totalExecutionTime); 161 | DoubleSummaryStatistics execTimeStatistics = datasetExecutionTime.stream().collect(Collectors.summarizingDouble(Double::doubleValue)); 162 | DoubleSummaryStatistics likelihoodStatistics = datasetLikelihood.stream().collect(Collectors.summarizingDouble(Double::doubleValue)); 163 | 164 | double p99ExecutionTime = Quantiles.percentiles().index(99).compute(datasetExecutionTime); 165 | double p95ExecutionTime = Quantiles.percentiles().index(95).compute(datasetExecutionTime); 166 | logger.warn("------------ execution time -------------"); 167 | logger.warn("Mean execution time (ms) = " + execTimeStatistics.getAverage()); 168 | logger.warn("Max execution time (ms) = " + execTimeStatistics.getMax()); 169 | logger.warn("Min execution time (ms) = " + execTimeStatistics.getMin()); 170 | logger.warn("p99 execution time (ms) = " + p99ExecutionTime); 171 | logger.warn("p95 execution time (ms) = " + p95ExecutionTime); 172 | logger.warn("------------ likelihood -------------"); 173 | logger.warn("Mean likelihood time (ms) = " + likelihoodStatistics.getAverage()); 174 | logger.warn("Max likelihood time (ms) = " + likelihoodStatistics.getMax()); 175 | logger.warn("Min likelihood time (ms) = " + likelihoodStatistics.getMin()); 176 | 177 | executorService.shutdown(); 178 | } 179 | 180 | private static FilteringResult performOne(BidRequestFilter.Client client, BidRequest bidRequest) throws TException { 181 | logger.info("start bid request filtering"); 182 | Duration filteringExecutionTime = null; 183 | double likelihoodToBid = 0.0 ; 184 | 185 | try { 186 | logger.info("filter bid request"); 187 | logger.info("bid request " + bidRequest) ; 188 | Instant start = Instant.now(); 189 | BidResponse response = client.filter(bidRequest); 190 | filteringExecutionTime = Duration.between(start, Instant.now()); 191 | likelihoodToBid = response.likelihoodToBid ; 192 | logger.info("bid request successfully filtered"); 193 | statsd.recordGaugeValue("adserver_latency_h", ((float) filteringExecutionTime.getNano()) / 1000000.f); 194 | statsd.recordExecutionTime("adserver_latency_ms", filteringExecutionTime.getNano() / 1000); 195 | statsd.recordGaugeValue("likelihood_to_bid", response.likelihoodToBid); 196 | logger.info("stats recorded successfully"); 197 | logger.info("Advertiser Index " + bidRequest.advertiserId + " likelihood to bid " + response.likelihoodToBid); 198 | 199 | } catch (org.apache.thrift.TException io) { 200 | logger.error("Exception while filtering bid request"); 201 | logger.catching(io); 202 | } 203 | 204 | logger.info("end bid request filtering"); 205 | assert filteringExecutionTime != null; 206 | return new FilteringResult ((double) filteringExecutionTime.getNano() / 1000000.f,likelihoodToBid); 207 | } 208 | 209 | public static AbstractMap.SimpleEntry parseS3Uri(String s3URI) { 210 | logger.info("connection start"); 211 | final String regex = "^s3://([^/]+)/(.+)$" ; 212 | 213 | final Pattern pattern = Pattern.compile(regex, Pattern.MULTILINE); 214 | final Matcher matcher = pattern.matcher(s3URI); 215 | 216 | AbstractMap.SimpleEntry parsedURI = null ; 217 | 218 | while (matcher.find()) { 219 | System.out.println("Full match: " + matcher.group(0)); 220 | 221 | parsedURI = new AbstractMap.SimpleEntry<>(matcher.group(1), matcher.group(2) ); 222 | 223 | } 224 | 225 | return parsedURI ; 226 | 227 | 228 | } 229 | static private void downloadTestFile(String bucketName,String prefix, String path) { 230 | 231 | try (S3Client s3 = S3Client.builder().build()){ 232 | String keyName = ""; 233 | ListObjectsRequest listObjects = ListObjectsRequest 234 | .builder() 235 | .bucket(bucketName) 236 | .prefix(prefix) 237 | .build(); 238 | 239 | ListObjectsResponse res = s3.listObjects(listObjects); 240 | List objects = res.contents(); 241 | 242 | for (S3Object firstObject : objects) { 243 | keyName = firstObject.key(); 244 | } 245 | GetObjectRequest objectRequest = GetObjectRequest 246 | .builder() 247 | .key(keyName) 248 | .bucket(bucketName) 249 | .build(); 250 | 251 | s3.getObject(objectRequest, Paths.get(path)); 252 | } catch (S3Exception e) { 253 | logger.catching(e); 254 | System.exit(1); 255 | } 256 | } 257 | 258 | 259 | static private ArrayList loadData(String path) { 260 | 261 | Object obj ; 262 | logger.warn("start loading data in memory"); 263 | ArrayList bidRequestList = new ArrayList<>(); 264 | int maxObject = 20000; 265 | int nbObject = 0; 266 | try { 267 | try (BufferedReader br = new BufferedReader(new FileReader(path))) { //creates a buffering character input stream 268 | String line; 269 | while ((line = br.readLine()) != null && nbObject <= maxObject) { 270 | obj = new JSONParser().parse(line); 271 | BidRequest bidRequest = new BidRequest(); 272 | 273 | JSONObject rawObj = (JSONObject) obj; 274 | bidRequest.bidId = rawObj.get("BidID").toString(); 275 | bidRequest.dayOfWeek = Integer.parseInt(rawObj.get("dow").toString()); 276 | bidRequest.hour = rawObj.get("hour").toString(); 277 | if (rawObj.get("AdvertiserID") == null) { 278 | bidRequest.advertiserId = ""; 279 | } else { 280 | bidRequest.advertiserId = rawObj.get("AdvertiserID").toString(); 281 | } 282 | bidRequest.domainId = rawObj.get("Domain").toString(); 283 | bidRequest.regionId = rawObj.get("RegionID").toString(); 284 | bidRequest.cityId = rawObj.get("CityID").toString(); 285 | if (rawObj.get("BiddingPrice") == null) { 286 | bidRequest.biddingPrice = 0; 287 | } else { 288 | bidRequest.biddingPrice = Long.parseLong(rawObj.get("BiddingPrice").toString()); 289 | } 290 | if (rawObj.get("PayingPrice") == null) { 291 | bidRequest.payingPrice = 0; 292 | } else { 293 | bidRequest.payingPrice = Long.parseLong(rawObj.get("PayingPrice").toString()); 294 | } 295 | if (rawObj.get("UserAgent") == null) { 296 | bidRequest.deviceTypeId = 6; 297 | } else { 298 | bidRequest.deviceTypeId = MultiThreadedClient 299 | .getDeviceTypeId(rawObj.get("UserAgent").toString()); 300 | } 301 | 302 | bidRequestList.add(bidRequest); 303 | nbObject = nbObject + 1; 304 | } 305 | } 306 | } catch (IOException | ParseException e) { 307 | e.printStackTrace(); 308 | } 309 | logger.warn("end loading data in memory " + bidRequestList.size()); 310 | return bidRequestList; 311 | } 312 | 313 | /** 314 | * return a string value from an SSM parameter 315 | * @param ssmParameterName the name of SSM parameter 316 | * @return the value stored 317 | */ 318 | private static String getValueFromSsmParameter(String ssmParameterName) { 319 | logger.traceEntry(); 320 | SsmClient ssmClient = SsmClient.builder() 321 | .build(); 322 | String valueFromSsmParameter = ""; 323 | try { 324 | logger.info("getting parameter for parameter name " + ssmParameterName); 325 | GetParameterRequest parameterRequest = GetParameterRequest.builder() 326 | .name(ssmParameterName) 327 | .build(); 328 | 329 | GetParameterResponse parameterResponse = ssmClient.getParameter(parameterRequest); 330 | valueFromSsmParameter = parameterResponse.parameter().value(); 331 | logger.info("successfully retrieved parameter " + valueFromSsmParameter); 332 | 333 | } catch (SsmException e) { 334 | logger.error("error while reading SSM parameter"); 335 | logger.catching(e); 336 | } finally { 337 | ssmClient.close(); 338 | } 339 | return logger.traceExit(valueFromSsmParameter); 340 | } 341 | 342 | static public int getDeviceTypeId(String userAgent) { 343 | 344 | 345 | Map r = Classifier.parse(userAgent); 346 | String category = r.get("category"); 347 | int deviceTypeId; 348 | switch (category) { 349 | case "smartphone": 350 | deviceTypeId = 0; 351 | break; 352 | case "mobilephone": 353 | deviceTypeId = 1; 354 | break; 355 | case "appliance": 356 | deviceTypeId = 2; 357 | break; 358 | case "pc": 359 | deviceTypeId = 3; 360 | break; 361 | case "crawler": 362 | deviceTypeId = 4; 363 | break; 364 | case "misc": 365 | deviceTypeId = 5; 366 | break; 367 | default: 368 | deviceTypeId = randDeviceType.nextInt(maxDeviceType+1); ; 369 | break; 370 | } 371 | return deviceTypeId; 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/prediction/BidRequestHandler.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.aik.prediction; 5 | 6 | import com.aik.filterapi.BidRequest; 7 | import com.aik.filterapi.BidRequestFilter; 8 | import com.aik.filterapi.BidResponse; 9 | import com.timgroup.statsd.NonBlockingStatsDClient; 10 | import com.timgroup.statsd.StatsDClient; 11 | import org.apache.logging.log4j.LogManager; 12 | import org.apache.logging.log4j.Logger; 13 | import software.amazon.awssdk.services.ssm.SsmClient; 14 | import software.amazon.awssdk.services.ssm.model.GetParameterRequest; 15 | import software.amazon.awssdk.services.ssm.model.GetParameterResponse; 16 | import software.amazon.awssdk.services.ssm.model.SsmException; 17 | 18 | import java.io.IOException; 19 | import java.io.InputStream; 20 | import java.time.Duration; 21 | import java.time.Instant; 22 | import java.util.*; 23 | import java.util.concurrent.Executors; 24 | import java.util.concurrent.ScheduledExecutorService; 25 | import java.util.concurrent.TimeUnit; 26 | 27 | 28 | public class BidRequestHandler implements BidRequestFilter.Iface { 29 | 30 | private static final Logger logger = LogManager.getLogger(BidRequestHandler.class.getName()); 31 | private static final StatsDClient statsd = new NonBlockingStatsDClient("filtering_server", "localhost", 8125); 32 | private static final DoubleSummaryStatistics totalStats = new DoubleSummaryStatistics(); 33 | private final BiddingFilter filter; 34 | private String filteringModelSsmParameterName; 35 | private String transformationModelSsmParameterName; 36 | private String transformationModelSchemaSsmParameterName; 37 | 38 | 39 | public BidRequestHandler() { 40 | filter = new BiddingFilter(); 41 | } 42 | 43 | public void init() { 44 | logger.traceEntry(); 45 | int metricsIntervalMs = 20000; 46 | 47 | filteringModelSsmParameterName = "/aik/xgboost/path" ; 48 | transformationModelSsmParameterName = "/aik/pipelineModelArtifactPath" ; 49 | transformationModelSchemaSsmParameterName = "/aik/pipelineModelArtifactSchemaPath" ; 50 | 51 | try (InputStream input = BidRequestHandler.class.getClassLoader().getResourceAsStream("config.properties")) { 52 | 53 | Properties prop = new Properties(); 54 | if (input == null) { 55 | logger.error("Sorry, unable to find config.properties"); 56 | return; 57 | } 58 | 59 | //load a properties file from class path, inside static method 60 | prop.load(input); 61 | //get the property value and print it out 62 | metricsIntervalMs = new Integer(prop.getProperty("aik.inference.server.metrics.interval.ms")); 63 | 64 | 65 | } catch (IOException ex) { 66 | ex.printStackTrace(); 67 | } 68 | 69 | this.loadConfig() ; 70 | ScheduledExecutorService executorService = Executors 71 | .newSingleThreadScheduledExecutor(); 72 | // schedule printing of the metrics 73 | executorService.scheduleAtFixedRate(() -> logger.warn("current execution average (ms): " + totalStats.getAverage()), 0, metricsIntervalMs, TimeUnit.MILLISECONDS); 74 | } 75 | 76 | 77 | private void loadConfig() { 78 | //load inference model using XGB library 79 | logger.info("Getting URI of the filtering model"); 80 | String filteringModelUri = getValueFromSsmParameter(filteringModelSsmParameterName) ; 81 | String transformationModelUri = getValueFromSsmParameter(transformationModelSsmParameterName) ; 82 | String transformationModelSchemaUri = getValueFromSsmParameter(transformationModelSchemaSsmParameterName) ; 83 | logger.info("Downloading bidding filter model "+ filteringModelUri); 84 | loadModel(transformationModelUri,transformationModelSchemaUri,filteringModelUri); 85 | logger.info("Loading in memory transformer model"); 86 | } 87 | 88 | 89 | private void loadModel(String s3URITransformationModel, 90 | String s3URITransformationModelSchema, 91 | String s3URIFilteringModel ) { 92 | 93 | logger.info("Downloading transformations model for feature processing from " + s3URITransformationModel); 94 | String modelLocation = Downloader.getTransformerModel(s3URITransformationModel); 95 | logger.info("Downloading schema for feature processing"); 96 | String schemaLocation = Downloader.getTransformerSchema(s3URITransformationModelSchema); 97 | 98 | // load model for feature transformation 99 | logger.info("Loading in memory transformer model"); 100 | Transform$.MODULE$.loadModel(modelLocation); 101 | logger.info("Loading in memory schema model"); 102 | Transform$.MODULE$.loadSchema(schemaLocation); 103 | 104 | //load inference model using XGB library 105 | logger.info("Downloading bidding filter model"); 106 | String modelBiddingFilterLocation = Downloader.getFilteringModel(s3URIFilteringModel); 107 | logger.info("Loading in memory transformer model"); 108 | 109 | filter.loadModel(modelBiddingFilterLocation); 110 | } 111 | 112 | 113 | /** 114 | * return a string value from an SSM parameter 115 | * @param ssmParameterName the name of SSM parameter 116 | * @return the value stored 117 | */ 118 | private String getValueFromSsmParameter(String ssmParameterName) { 119 | logger.traceEntry(); 120 | SsmClient ssmClient = SsmClient.builder() 121 | .build(); 122 | String valueFromSsmParameter = ""; 123 | try { 124 | logger.info("getting parameter for parameter name " + ssmParameterName); 125 | GetParameterRequest parameterRequest = GetParameterRequest.builder() 126 | .name(ssmParameterName) 127 | .build(); 128 | 129 | GetParameterResponse parameterResponse = ssmClient.getParameter(parameterRequest); 130 | valueFromSsmParameter = parameterResponse.parameter().value(); 131 | logger.info("successfully retrieved parameter " + valueFromSsmParameter); 132 | 133 | } catch (SsmException e) { 134 | logger.error("error while reading SSM parameter"); 135 | logger.catching(e); 136 | } finally { 137 | ssmClient.close(); 138 | } 139 | return logger.traceExit(valueFromSsmParameter); 140 | } 141 | 142 | 143 | 144 | 145 | 146 | public BidResponse filter(BidRequest request) throws org.apache.thrift.TException { 147 | logger.info("starting filtering a bid request"); 148 | 149 | Instant start = Instant.now(); 150 | BidResponse response = new BidResponse(); 151 | 152 | try { 153 | List transformedFeature = Transform$.MODULE$.transform(request); 154 | logger.info("nb featured : " + transformedFeature.size()); 155 | 156 | // Compute likelihood to bid for each TP 157 | double likelihood = filter.filter(transformedFeature); 158 | logger.trace("advertiser ID " + request.advertiserId + " likelihood to bid " + likelihood); 159 | 160 | response.likelihoodToBid = likelihood; 161 | Instant stop = Instant.now(); 162 | 163 | Duration totalDuration = Duration.between(start, stop); 164 | double totalExecutionTime = totalDuration.getNano() / 1000000.d; 165 | //logger.warn("par Thread " + Thread.currentThread().getId() + "execution time " + totalExecutionTime + " micro " + totalDuration.getNano()/1000 ); 166 | totalStats.accept(totalExecutionTime); 167 | statsd.incrementCounter("filtering_count"); 168 | statsd.recordExecutionTime("filtering_latency", totalDuration.getNano() / 1000); 169 | } 170 | catch (Exception e ){ 171 | logger.warn("An exception was caught " + e) ; 172 | e.printStackTrace(); 173 | } 174 | return response; 175 | } 176 | } 177 | 178 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/prediction/BiddingFilter.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.aik.prediction; 5 | 6 | import ml.dmlc.xgboost4j.java.Booster; 7 | import ml.dmlc.xgboost4j.java.DMatrix; 8 | import ml.dmlc.xgboost4j.java.XGBoost; 9 | import ml.dmlc.xgboost4j.java.XGBoostError; 10 | import org.apache.logging.log4j.LogManager; 11 | import org.apache.logging.log4j.Logger; 12 | 13 | import java.time.Duration; 14 | import java.time.Instant; 15 | import java.util.*; 16 | 17 | 18 | public class BiddingFilter { 19 | private static final Logger logger = LogManager.getLogger(BiddingFilter.class.getName()); 20 | private static Booster booster; 21 | 22 | final private static DoubleSummaryStatistics mainStats = new DoubleSummaryStatistics(); 23 | 24 | 25 | public void loadModel(String modelLocation) { 26 | logger.info("load model in memory"); 27 | long startTime = System.currentTimeMillis(); 28 | try { 29 | BiddingFilter.booster = XGBoost.loadModel(modelLocation); 30 | } catch (XGBoostError e) { 31 | logger.error("model location : ["+modelLocation+"]"); 32 | logger.error("error while loading filtering model " + modelLocation); 33 | logger.catching(e); 34 | } 35 | long endTime = System.currentTimeMillis() - startTime; 36 | logger.info("--- load model in: " + endTime + "ms"); 37 | } 38 | 39 | public Double filter( List bidRequest) { 40 | Instant mainStart = Instant.now() ; 41 | double likelihoodToBid = -1 ; 42 | float[] testInput = new float[bidRequest.size()]; 43 | float[][] predicts ; 44 | for (int i = 0, total = bidRequest.size(); i < total; i++) { 45 | testInput[i] = bidRequest.get(i).floatValue(); 46 | } 47 | logger.info("filtering input " + Arrays.toString(testInput)) ; 48 | try { 49 | //One row, X columns 50 | DMatrix testMatOneRow = new DMatrix(testInput, 1, testInput.length, Float.NaN); 51 | predicts = BiddingFilter.booster.predict(testMatOneRow); 52 | 53 | if (predicts.length > 0) { 54 | if (predicts[0].length > 0) { 55 | likelihoodToBid = predicts[0][0]; 56 | } 57 | } 58 | } catch (XGBoostError e) { 59 | logger.catching(e); 60 | } 61 | 62 | logger.info("likelihood to bid " + likelihoodToBid) ; 63 | Instant mainStop = Instant.now() ; 64 | Duration mainDuration = Duration.between(mainStart,mainStop) ; 65 | double mainTime = mainDuration.getNano()/1000000.d; 66 | mainStats.accept(mainTime); 67 | return likelihoodToBid; 68 | } 69 | 70 | 71 | } -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/prediction/Downloader.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.aik.prediction; 5 | 6 | import org.apache.logging.log4j.LogManager; 7 | import org.apache.logging.log4j.Logger; 8 | import software.amazon.awssdk.services.s3.S3Client; 9 | import software.amazon.awssdk.services.s3.model.GetObjectRequest; 10 | 11 | 12 | import java.io.File; 13 | import java.util.AbstractMap; 14 | import java.util.Random; 15 | import java.util.regex.Matcher; 16 | import java.util.regex.Pattern; 17 | 18 | public class Downloader { 19 | 20 | private static final Logger logger = LogManager.getLogger(Downloader.class.getName()); 21 | final private static String transformerFileNamePrefix = "transformer" ; 22 | final private static String schemaFileNamePrefix = "schema" ; 23 | final private static String filteringFileNamePrefix = "filtering" ; 24 | final private static String suffix = "model" ; 25 | final private static String tempDirectory = ".tmp" ; 26 | 27 | final private static Random generator = new Random() ; 28 | final private static int MaxRandomNumber = 100000000 ; 29 | 30 | public static AbstractMap.SimpleEntry parseS3Uri(String s3URI) { 31 | logger.info("connection start"); 32 | final String regex = "^s3://([^/]+)/(.+)$" ; 33 | 34 | final Pattern pattern = Pattern.compile(regex, Pattern.MULTILINE); 35 | final Matcher matcher = pattern.matcher(s3URI); 36 | 37 | AbstractMap.SimpleEntry parsedURI = null ; 38 | 39 | while (matcher.find()) { 40 | System.out.println("Full match: " + matcher.group(0)); 41 | 42 | parsedURI = new AbstractMap.SimpleEntry<>(matcher.group(1), matcher.group(2) ); 43 | 44 | } 45 | 46 | return parsedURI ; 47 | 48 | 49 | } 50 | 51 | public static File generateFileName(String prefix,String extension) { 52 | File downloadedFile ; 53 | int randomness = generator.nextInt(MaxRandomNumber); 54 | downloadedFile = new File (tempDirectory+"/"+prefix+"-" + randomness +"-" + Downloader.suffix+extension) ; 55 | return downloadedFile ; 56 | } 57 | 58 | public static String getTransformerModel(String s3URIModelTransformer) { 59 | 60 | AbstractMap.SimpleEntry s3URIParsed = Downloader.parseS3Uri(s3URIModelTransformer) ; 61 | File downloadedFile = Downloader.generateFileName(Downloader.transformerFileNamePrefix,".zip") ; 62 | 63 | String keyName = s3URIParsed.getValue() ; 64 | String bucketName = s3URIParsed.getKey() ; 65 | 66 | logger.info("key name for transformer model " + keyName ) ; 67 | logger.info("bucket name for transformer model" + bucketName) ; 68 | S3Client s3 = S3Client.builder() 69 | .build(); 70 | 71 | GetObjectRequest objectRequest = GetObjectRequest 72 | .builder() 73 | .key(keyName) 74 | .bucket(bucketName) 75 | .build(); 76 | 77 | s3.getObject(objectRequest, downloadedFile.toPath()); 78 | s3.close() ; 79 | logger.info("Model downloaded"); 80 | 81 | return downloadedFile.getAbsolutePath(); 82 | } 83 | 84 | public static String getTransformerSchema(String s3URISchema) { 85 | logger.info("connection start"); 86 | 87 | AbstractMap.SimpleEntry s3URIParsed = Downloader.parseS3Uri(s3URISchema) ; 88 | File downloadedFile = Downloader.generateFileName(Downloader.schemaFileNamePrefix,".json") ; 89 | String keyName = s3URIParsed.getValue() ; 90 | String bucketName = s3URIParsed.getKey() ; 91 | 92 | 93 | 94 | logger.info("key name for transformer model " + keyName ) ; 95 | logger.info("bucket name for transformer model" + bucketName) ; 96 | S3Client s3 = S3Client.builder() 97 | .build(); 98 | 99 | GetObjectRequest objectRequest = GetObjectRequest 100 | .builder() 101 | .key(keyName) 102 | .bucket(bucketName) 103 | .build(); 104 | 105 | s3.getObject(objectRequest, downloadedFile.toPath()); 106 | s3.close() ; 107 | logger.info("Model downloaded"); 108 | 109 | return downloadedFile.getAbsolutePath(); 110 | } 111 | 112 | public static String getFilteringModel(String s3URIFilteringModel) { 113 | AbstractMap.SimpleEntry s3URIParsed = Downloader.parseS3Uri(s3URIFilteringModel) ; 114 | File downloadedFile = Downloader.generateFileName(Downloader.filteringFileNamePrefix,".bin") ; 115 | 116 | 117 | String keyName = s3URIParsed.getValue() ; 118 | String bucketName = s3URIParsed.getKey() ; 119 | 120 | logger.info("key name for transformer model " + keyName ) ; 121 | logger.info("bucket name for transformer model" + bucketName) ; 122 | S3Client s3 = S3Client.builder() 123 | .build(); 124 | GetObjectRequest objectRequest = GetObjectRequest 125 | .builder() 126 | .key(keyName) 127 | .bucket(bucketName) 128 | .build(); 129 | 130 | s3.getObject(objectRequest, downloadedFile.toPath()); 131 | s3.close() ; 132 | logger.info("Model found"); 133 | logger.info(downloadedFile.getAbsolutePath()); 134 | 135 | return downloadedFile.getAbsolutePath(); 136 | } 137 | 138 | } -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/java/com/aik/prediction/InferenceServer.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.aik.prediction; 5 | 6 | import com.aik.filterapi.BidRequestFilter; 7 | import org.apache.logging.log4j.LogManager; 8 | import org.apache.logging.log4j.Logger; 9 | import org.apache.thrift.server.TServer; 10 | import org.apache.thrift.server.TThreadPoolServer; 11 | import org.apache.thrift.transport.TServerSocket; 12 | import org.apache.thrift.transport.TServerTransport; 13 | 14 | 15 | public class InferenceServer { 16 | 17 | private static final Logger logger = LogManager.getLogger(InferenceServer.class.getName()); 18 | 19 | public static BidRequestHandler handler; 20 | 21 | public static BidRequestFilter.Processor processor; 22 | 23 | public static void main(String[] args) { 24 | try { 25 | 26 | logger.info("Downloading transformer model"); 27 | 28 | handler = new BidRequestHandler(); 29 | handler.init(); 30 | processor = new BidRequestFilter.Processor<>(handler); 31 | 32 | Runnable simple = () -> simple(processor); 33 | 34 | new Thread(simple).start(); 35 | } catch (Exception x) { 36 | x.printStackTrace(); 37 | } 38 | } 39 | 40 | public static void simple(BidRequestFilter.Processor processor) { 41 | try { 42 | TServerTransport serverTransport = new TServerSocket(9090); 43 | 44 | // Use this for a multithreaded server 45 | TThreadPoolServer.Args pool = new TThreadPoolServer.Args(serverTransport).processor(processor); 46 | pool.minWorkerThreads(8); 47 | pool.minWorkerThreads(2); 48 | TServer server = new TThreadPoolServer(pool); 49 | 50 | logger.info("Starting the simple server..."); 51 | server.serve(); 52 | } catch (Exception e) { 53 | e.printStackTrace(); 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/resources/config.properties: -------------------------------------------------------------------------------- 1 | aik.inference.server.metrics.interval.ms=20000 2 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/scala/com/aik/prediction/Transform.scala: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.aik.prediction 5 | 6 | 7 | import com.aik.filterapi._ 8 | import com.github.plokhotnyuk.jsoniter_scala.core._ 9 | import com.github.plokhotnyuk.jsoniter_scala.macros._ 10 | import ml.combust.bundle.BundleFile 11 | import ml.combust.mleap.core.types._ 12 | import ml.combust.mleap.runtime.MleapSupport._ 13 | import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row} 14 | import resource.managed 15 | 16 | import scala.collection.JavaConverters._ 17 | import scala.collection.immutable.ListMap 18 | import scala.io.Source 19 | //https://github.com/combust/mleap/blob/master/mleap-spark-base/src/main/scala/org/apache/spark/sql/mleap/TypeConverters.scala\n 20 | 21 | 22 | object Transform { 23 | 24 | var mleapPipeline: ml.combust.mleap.runtime.frame.Transformer = null 25 | var schema: ml.combust.mleap.core.types.StructType = null 26 | 27 | def loadModel(location: String): Unit = { 28 | println(s"starting loading from location $location") 29 | // TO DO: test loading artifact from an unzipped folder 30 | val bundle = (for (bundleFile <- managed(BundleFile(s"jar:file:$location"))) yield { 31 | bundleFile.loadMleapBundle().get 32 | }).opt.get 33 | mleapPipeline = bundle.root 34 | println(mleapPipeline.getClass) 35 | } 36 | 37 | def loadSchema(location: String): Unit = { 38 | println(s"starting loading from location $location") 39 | val schemaFile = Source.fromFile(location) 40 | val schemaFileContents = schemaFile.getLines.mkString 41 | schemaFile.close() 42 | implicit val codec: JsonValueCodec[ListMap[String, String]] = JsonCodecMaker.make[ListMap[String, String]](CodecMakerConfig) 43 | val schemaFieldMap = readFromArray(schemaFileContents.getBytes("UTF-8")) 44 | //Reconstruct MLeap Schema from JSON Map 45 | schema = StructType( 46 | schemaFieldMap.toList.map { 47 | case (f, "DoubleType") => StructField(f, ScalarType.Double) 48 | case (f, "IntegerType") => StructField(f, ScalarType.Int) 49 | case (f, "LongType") => StructField(f, ScalarType(BasicType.Long)) 50 | case (f, "StringType") => StructField(f, ScalarType.String) 51 | case (f, "ArrayType(IntegerType,true)") => StructField(f, ListType.Int) 52 | } 53 | ).get 54 | println(schema) 55 | } 56 | 57 | def transform(request: BidRequest): java.util.List[java.lang.Double] = { 58 | val rowRequest = Seq(Row( 59 | request.bidId, 60 | request.dayOfWeek, 61 | request.hour, 62 | request.regionId, 63 | request.cityId, 64 | request.domainId, 65 | request.advertiserId, 66 | request.biddingPrice, 67 | request.payingPrice, 68 | request.userAgent 69 | )) 70 | val frame = DefaultLeapFrame(schema, rowRequest) 71 | val predictionLeapFrame = mleapPipeline.transform(frame).get 72 | val vectorizedLeapFrame = predictionLeapFrame.select("dow","hour","IndexAdvertiserID","IndexDomain","IndexRegionID","IndexCityID").get.dataset 73 | //val vectorizedData: List[java.lang.Double] = vectorizedLeapFrame.apply(0).getTensor(0).toDense.rawValuesIterator.toList 74 | val vectorizedData: List[java.lang.Double] = List(vectorizedLeapFrame.head(0)).map(item => new java.lang.Double(item.toString.toDouble)) ::: List(new java.lang.Double(request.deviceTypeId.toDouble)) 75 | vectorizedData.asJava 76 | 77 | 78 | 79 | } 80 | 81 | 82 | } -------------------------------------------------------------------------------- /source/traffic-filtering-app/src/main/thrift/api.thrift: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | 5 | /** 6 | * The available types in Thrift are: 7 | * 8 | * bool Boolean, one byte 9 | * i8 (byte) Signed 8-bit integer 10 | * i16 Signed 16-bit integer 11 | * i32 Signed 32-bit integer 12 | * i64 Signed 64-bit integer 13 | * double 64-bit floating point value 14 | * string String 15 | * binary Blob (byte array) 16 | * map Map from one type to another 17 | * list Ordered list of one type 18 | * set Set of unique elements of one type 19 | * 20 | */ 21 | 22 | 23 | 24 | /** 25 | * define java packages 26 | */ 27 | namespace java com.aik.filterapi 28 | 29 | 30 | /** 31 | * Raw data required for filtering a bid request 32 | */ 33 | struct BidRequest { 34 | 1: string bidId, 35 | 2: i32 dayOfWeek , 36 | 3: string hour, 37 | 4: string regionId, 38 | 5: string cityId, 39 | 6: string domainId, 40 | 7: string advertiserId, 41 | 8: i64 biddingPrice, 42 | 9: i64 payingPrice, 43 | 10: string userAgent 44 | 11: i32 deviceTypeId 45 | } 46 | 47 | /** 48 | * Raw data required for filtering a bid request 49 | */ 50 | struct BidResponse { 51 | 1: double likelihoodToBid 52 | } 53 | 54 | 55 | /** 56 | * Definition of the available service 57 | */ 58 | service BidRequestFilter { 59 | 60 | /** 61 | * This method asssociate for every TP in the BidRequest 62 | * an indicator specifying if the bid should be proposed to the TP 63 | */ 64 | BidResponse filter(1: BidRequest request) 65 | 66 | } 67 | 68 | 69 | -------------------------------------------------------------------------------- /test/nag.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | // This file exists because cdk-nag doesn't work on stages. 4 | // We need to synthesize the stacks at the top level to run nag checks. 5 | 6 | import "source-map-support/register" 7 | import * as cdk from "aws-cdk-lib" 8 | import { applyTags } from "../helpers/tags" 9 | import { Aspects } from "aws-cdk-lib" 10 | import { AwsSolutionsChecks, NagSuppressions } from "cdk-nag" 11 | import { SagemakerEmrStack } from "../lib/sagemaker-emr-stack" 12 | import { FilteringApplicationStack } from "../lib/filtering-application-stack" 13 | import { suppressNag } from "../lib/suppress-nag" 14 | 15 | /** 16 | * Create an app with our stacks directly instantiated. 17 | */ 18 | function main() { 19 | const app = new cdk.App() 20 | const sagemakerEmr = new SagemakerEmrStack(app, "sagemaker-emr", {}) 21 | suppressNag(sagemakerEmr) 22 | const filteringStack = new FilteringApplicationStack(app, "filtering", { 23 | trainingBucket: sagemakerEmr.trainingBucket, 24 | vpc : sagemakerEmr.vpc, 25 | }) 26 | NagSuppressions.addResourceSuppressionsByPath(filteringStack, 27 | "/filtering/TaskDef/ExecutionRole/DefaultPolicy/Resource", 28 | [{ 29 | id: "AwsSolutions-IAM5", 30 | reason: "False positive: policy is created as part of the intanciation of the construct", 31 | appliesTo: ["Resource::*"] 32 | }], 33 | true) 34 | applyTags(app) 35 | Aspects.of(app).add(new AwsSolutionsChecks({ 36 | verbose: true, 37 | logIgnores: false, 38 | reports: true 39 | })) 40 | NagSuppressions.addStackSuppressions(sagemakerEmr, [ 41 | {"id": "CdkNagValidationFailure", reason: "Nag can't parse Fn::"} 42 | ]) 43 | 44 | app.synth() 45 | } 46 | 47 | main() 48 | -------------------------------------------------------------------------------- /test/test-delete-sg.ts: -------------------------------------------------------------------------------- 1 | import { deleteSagemakerSecurityGroups } from "../source/sagemaker-sg-cleanup/delete-sg" 2 | 3 | (async function testDelete() { 4 | try { 5 | await deleteSagemakerSecurityGroups(process.argv[2]) 6 | } catch (ex) { 7 | console.error(ex) 8 | } 9 | })() -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2018", 4 | "module": "commonjs", 5 | "lib": ["es2018"], 6 | "declaration": true, 7 | "strict": true, 8 | "noImplicitAny": true, 9 | "strictNullChecks": true, 10 | "noImplicitThis": true, 11 | "alwaysStrict": true, 12 | "noUnusedLocals": false, 13 | "noUnusedParameters": false, 14 | "noImplicitReturns": true, 15 | "noFallthroughCasesInSwitch": false, 16 | "inlineSourceMap": true, 17 | "inlineSources": true, 18 | "experimentalDecorators": true, 19 | "strictPropertyInitialization": false, 20 | "resolveJsonModule": true, 21 | "esModuleInterop": true, 22 | "typeRoots": ["./node_modules/@types"], 23 | "outDir": "build" 24 | }, 25 | "exclude": ["cdk.out", "node_modules", "build"] 26 | } 27 | --------------------------------------------------------------------------------