├── images ├── sample.jpg └── architecture.png ├── .gitignore ├── CODE_OF_CONDUCT.md ├── cdk.json ├── src ├── main │ └── java │ │ └── com │ │ └── amazonaws │ │ └── services │ │ └── sample │ │ └── apigateway │ │ └── websocketratelimit │ │ ├── RateLimitApp.java │ │ └── RateLimitStack.java └── test │ └── java │ └── com │ └── amazonaws │ └── services │ └── sample │ └── apigateway │ └── websocketratelimit │ └── RateLimitTest.java ├── LICENSE ├── lambda ├── Tenant.js ├── SampleClientGet.js ├── SessionTTL.js ├── WebsocketDisconnect.js ├── SampleClient.html ├── Authorizer.js ├── Session.js ├── WebsocketEcho.js ├── WebsocketConnect.js ├── SQSEcho.js ├── Common.js └── SampleClient.js ├── pom.xml ├── CONTRIBUTING.md └── README.md /images/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/api-gateway-websocket-saas-rate-limiting-using-aws-lambda-authorizer/HEAD/images/sample.jpg -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/api-gateway-websocket-saas-rate-limiting-using-aws-lambda-authorizer/HEAD/images/architecture.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .classpath.txt 2 | target 3 | .classpath 4 | .project 5 | .idea 6 | .settings 7 | .vscode 8 | *.iml 9 | *.DS_Store 10 | dependency-reduced-pom.xml 11 | *.jar 12 | 13 | # CDK asset staging directory 14 | .cdk.staging 15 | cdk.out 16 | 17 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "mvn -e -q compile exec:java", 3 | "context": { 4 | "@aws-cdk/aws-apigateway:usagePlanKeyOrderInsensitiveId": true, 5 | "@aws-cdk/core:stackRelativeExports": "true", 6 | "@aws-cdk/aws-ecs-patterns:removeDefaultDesiredCount": true, 7 | "@aws-cdk/aws-rds:lowercaseDbIdentifier": true, 8 | "@aws-cdk/aws-lambda:recognizeVersionProps": true, 9 | "@aws-cdk/aws-cloudfront:defaultSecurityPolicyTLSv1.2_2021": true 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/services/sample/apigateway/websocketratelimit/RateLimitApp.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.amazonaws.services.sample.apigateway.websocketratelimit; 5 | 6 | import software.amazon.awscdk.App; 7 | import software.amazon.awscdk.StackProps; 8 | 9 | public class RateLimitApp { 10 | public static void main(final String[] args) { 11 | App app = new App(); 12 | 13 | new RateLimitStack(app, "APIGatewayWebSocketRateLimitStack", StackProps.builder() 14 | .stackName("APIGatewayWebSocketRateLimit") 15 | .build()); 16 | 17 | app.synth(); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/services/sample/apigateway/websocketratelimit/RateLimitTest.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.amazonaws.services.sample.apigateway.websocketratelimit; 5 | 6 | import org.junit.Assert; 7 | import org.junit.Test; 8 | import software.amazon.awscdk.App; 9 | import com.fasterxml.jackson.databind.JsonNode; 10 | import com.fasterxml.jackson.databind.ObjectMapper; 11 | import com.fasterxml.jackson.databind.SerializationFeature; 12 | 13 | import java.io.IOException; 14 | 15 | 16 | public class RateLimitTest { 17 | private final static ObjectMapper JSON = 18 | new ObjectMapper().configure(SerializationFeature.INDENT_OUTPUT, true); 19 | 20 | @Test 21 | public void testStack() throws IOException { 22 | App app = new App(); 23 | RateLimitStack stack = new RateLimitStack(app, "test"); 24 | 25 | // synthesize the stack to a CloudFormation template 26 | JsonNode actual = JSON.valueToTree(app.synth().getStackArtifact(stack.getArtifactId()).getTemplate()); 27 | Assert.assertNotNull(actual); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /lambda/Tenant.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const common = require("./Common.js"); 5 | 6 | // This handler is just a sample helper to fetch the current tenant ids from the database. 7 | // In a production system the tenant id would typically be known to the user and a list would not be 8 | // available as a public endpoint. 9 | exports.handler = async(event, context) => { 10 | //console.log('Received event:', JSON.stringify(event, null, 2)); 11 | 12 | try { 13 | if (event.requestContext.http.method == "GET") { 14 | event.queryStringParameters = { 15 | tenantId: "none" 16 | }; 17 | let dynamo = common.createDynamoDBClient(event); 18 | let body = await dynamo.scan({ "TableName": process.env.TenantTableName }).promise(); 19 | return {statusCode: 200, headers: {"Content-Type": "application/json"}, body: JSON.stringify(body)}; 20 | } 21 | } 22 | catch (err) { 23 | console.error(err); 24 | return { statusCode: 400, headers: { "Content-Type": "application/json" }, body: JSON.stringify(err.message) }; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /lambda/SampleClientGet.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const fs = require('fs') 5 | 6 | // This is a simple handler to return back the sample webpage with url values created by the cloudformation stack 7 | exports.handler = async(event) => { 8 | //console.log("Sample: " + JSON.stringify(event, null, 2)); 9 | 10 | try { 11 | let filename = event.queryStringParameters && event.queryStringParameters.page ? event.queryStringParameters.page : "SampleClient.html"; 12 | const data = fs.readFileSync("./" + filename, 'utf8').replace("{{WssUrl}}", process.env.WssUrl).replace("{{sessionUrl}}", process.env.SessionUrl).replace("{{tenantUrl}}", process.env.TenantUrl); 13 | let contentType = filename == "SampleClient.js" ? "text/javascript" : "text/html"; 14 | return { 15 | statusCode: 200, 16 | body: data, 17 | headers: { "Content-Type": contentType } 18 | }; 19 | } 20 | catch (err) { 21 | console.error(err); 22 | return { 23 | statusCode: 500, 24 | body: JSON.stringify('Error: ' + JSON.stringify(err)) 25 | }; 26 | } 27 | }; 28 | -------------------------------------------------------------------------------- /lambda/SessionTTL.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const AWS = require("aws-sdk"); 5 | const apig = new AWS.ApiGatewayManagementApi({ endpoint: process.env.ApiGatewayEndpoint }); 6 | 7 | // This handler is used to disconnect any remaining websocket connections for a given session when the time to live (TTL) expires 8 | exports.handler = async function(event, context) { 9 | //console.log(JSON.stringify(event)); 10 | for (let x = 0; x < event.Records.length; x++) { 11 | const record = event.Records[x]; 12 | if (record.userIdentity && record.userIdentity.principalId && record.userIdentity.type && record.userIdentity.principalId == "dynamodb.amazonaws.com" && record.userIdentity.type == "Service") { 13 | if (record.eventName == 'REMOVE' && record.dynamodb && record.dynamodb.OldImage && record.dynamodb.OldImage.connectionIds) { 14 | let connectionIds = record.dynamodb.OldImage.connectionIds.SS; 15 | for (let y = 0; y < connectionIds.length; y++) { 16 | const connectionId = connectionIds[y]; 17 | //console.log("SessionTTL Removing ConnectionId: " + connectionId); 18 | try { 19 | await apig.deleteConnection({ ConnectionId: connectionId }).promise(); 20 | } 21 | catch (err) { 22 | console.error(err); 23 | } 24 | }; 25 | } 26 | } 27 | } 28 | return { statusCode: 200 }; 29 | }; 30 | -------------------------------------------------------------------------------- /lambda/WebsocketDisconnect.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const common = require("./Common.js"); 5 | 6 | // This handler will remove the current connection from the sessions connectionId set 7 | // and decrement the total number of connections for this tenant 8 | exports.handler = async function(event, context) { 9 | //console.log('Received event:', JSON.stringify(event, null, 2)); 10 | 11 | if (event.requestContext.routeKey == '$disconnect') { 12 | try { 13 | let dynamo = common.createDynamoDBClient(event); 14 | let tenantId = common.getTenantId(event); 15 | let sessionId = common.getSessionId(event); 16 | let deleteConnectIdParams = { 17 | "TableName": process.env.SessionTableName, 18 | "Key": {tenantId: tenantId, sessionId: sessionId}, 19 | "UpdateExpression": "DELETE connectionIds :c", 20 | "ExpressionAttributeValues": { 21 | ":c": dynamo.createSet([event.requestContext.connectionId]) 22 | }, 23 | "ReturnValues": "NONE" 24 | }; 25 | let updateConnectCountParams = { 26 | "TableName": process.env.LimitTableName, 27 | "Key": { tenantId: tenantId, key: tenantId }, 28 | "UpdateExpression": "set itemCount = if_not_exists(itemCount, :zero) - :dec", 29 | "ExpressionAttributeValues": {":dec": 1, ":zero": 0}, 30 | "ReturnValues": "NONE" 31 | }; 32 | await dynamo.transactWrite({TransactItems: [{Update: deleteConnectIdParams}, {Update: updateConnectCountParams}]}).promise(); 33 | } catch (err) { 34 | console.error(err); 35 | return {statusCode: 1011}; // return server error code 36 | } 37 | } 38 | return { statusCode: 200 }; 39 | } 40 | -------------------------------------------------------------------------------- /lambda/SampleClient.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Demo Client 6 | 7 | 8 | 9 |
10 |
11 | Settings 12 | 13 | 14 | 15 | 18 | 19 | 22 | 23 | 24 | 25 | 26 | 27 | 31 | 32 |
Tenant 16 | 17 | Session 20 | 21 |
28 | 29 | 30 |
33 |
34 |
35 |
36 | Connection 37 |
38 | 39 | 40 | 41 | 44 | 45 | 46 | 47 | 50 | 51 |
Message 42 | 43 |
48 | 49 |
52 |
53 | 54 |
55 | Log 56 | 57 | 58 | 59 | 60 |
61 | 62 |
63 | 64 | 65 | -------------------------------------------------------------------------------- /lambda/Authorizer.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const common = require("./Common.js"); 5 | let tenantSettingsCache = {}; // Defined outside the function globally 6 | 7 | // This handler will check if the given tenant id and session id are valid. 8 | // The tenant and session id are given via the query string parameter from the client. 9 | // The session id is short lived and is removed either by the end user deleting it when done 10 | // or by a time to live (TTL) timeout from DynamoBD 11 | exports.handler = async function(event, context) { 12 | //console.log('Received event:', JSON.stringify(event, null, 2)); 13 | let dynamo = common.createDynamoDBClient(event); 14 | let tenantId = common.getTenantId(event); 15 | let sessionId = common.getSessionId(event); 16 | try { 17 | let response; 18 | // Check the local tenant cache to improve loading time and reduce calls to the DynamoBD database. 19 | if (tenantId in tenantSettingsCache) { 20 | response = tenantSettingsCache[tenantId]; 21 | } else { 22 | response = await dynamo.get({ "TableName": process.env.TenantTableName, "Key": { tenantId: tenantId } }).promise(); 23 | if (!response || !response.Item || !response.Item.tenantId) { 24 | console.log(tenantId + " tenant not found"); 25 | return common.generateDeny(event.methodArn, event); 26 | } 27 | tenantSettingsCache[tenantId] = response; 28 | } 29 | let tenantSettings = response.Item; 30 | 31 | // Check if session exists 32 | // A session Id is short lived and is removed from DynamoDB via TTL. A sessionId must be created prior to trying to connect a websocket 33 | response = await dynamo.get({ "TableName": process.env.SessionTableName, "Key": { tenantId: tenantId, sessionId: sessionId } }).promise(); 34 | if (!response || !response.Item) { 35 | console.log("Tenant: " + tenantId + " Session: " + sessionId + " not found: " + JSON.stringify(response, null, 2)); 36 | return common.generateDeny(event.methodArn, event); 37 | } 38 | 39 | return common.generateAllow(event.methodArn, event, tenantSettings); 40 | } 41 | catch (err) { 42 | console.error(err); 43 | return common.generateDeny(event.methodArn, event); 44 | } 45 | } -------------------------------------------------------------------------------- /lambda/Session.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const common = require("./Common.js"); 5 | 6 | // This handler is used to create a session id for a given tenant id. This session id is required when creating a websocket connection. 7 | // A session time-to-live (TTL) is set for each session based on the tenant settings. 8 | exports.handler = async(event, context) => { 9 | //console.log('Received event:', JSON.stringify(event, null, 2)); 10 | 11 | try { 12 | let tenantId = common.getTenantId(event); 13 | let sessionId = common.getSessionId(event); 14 | if (!tenantId || !sessionId) { 15 | return { statusCode: 400, headers: { "Content-Type": "application/json" }, body: JSON.stringify("Invalid request") }; 16 | } 17 | 18 | let dynamo = common.createDynamoDBClient(event); 19 | // Check for a valid tenantId 20 | let response = await dynamo.get({ "TableName": process.env.TenantTableName, "Key": { tenantId: tenantId } }).promise(); 21 | if (!response || !response.Item || !response.Item.tenantId) { 22 | return { statusCode: 400, headers: { "Content-Type": "application/json" }, body: JSON.stringify("Invalid request") }; 23 | } 24 | if (event.requestContext.http.method == "PUT") { 25 | let params = { 26 | "TableName": process.env.SessionTableName, 27 | "Key": { 28 | "tenantId": tenantId, 29 | "sessionId": sessionId, 30 | }, 31 | "UpdateExpression": "SET sessionTTL = :ttl", 32 | "ExpressionAttributeValues": { 33 | ":ttl": (Math.floor(+new Date() / 1000) + response.Item.sessionTTL) 34 | }, 35 | "ReturnValues": "UPDATED_NEW" 36 | }; 37 | let body = await dynamo.update(params).promise(); 38 | return {statusCode: 200, headers: {"Content-Type": "application/json"}, body: JSON.stringify(body)}; 39 | } else if (event.requestContext.http.method == "DELETE") { 40 | let params = { 41 | "TableName": process.env.SessionTableName, 42 | "Key": { tenantId: tenantId, sessionId: sessionId } 43 | }; 44 | let body = await dynamo.delete(params).promise(); 45 | return {statusCode: 200, headers: {"Content-Type": "application/json"}, body: JSON.stringify(body)}; 46 | } 47 | } 48 | catch (err) { 49 | console.error(err); 50 | return { statusCode: 400, headers: { "Content-Type": "application/json" }, body: JSON.stringify(err.message) }; 51 | } 52 | }; 53 | -------------------------------------------------------------------------------- /lambda/WebsocketEcho.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const AWS = require("aws-sdk"); 5 | const common = require("./Common.js"); 6 | const apig = new AWS.ApiGatewayManagementApi({ endpoint: process.env.ApiGatewayEndpoint }); 7 | 8 | /// This sample file is currently NOT used in this sample but is given as an example of executing a Lambda directly vs using SQS 9 | /// The blog post associated with this sample goes more in depth on why you may choose direct Lambda execution vs SQS 10 | 11 | exports.handler = async function(event, context) { 12 | //console.log('Received event:', JSON.stringify(event, null, 2)); 13 | const {body, requestContext: {connectionId, routeKey, requestId}} = event; 14 | if (routeKey == '$default') { 15 | try { 16 | let tenantId = common.getTenantId(event); 17 | let sessionId = common.getSessionId(event); 18 | let dynamo = common.createDynamoDBClient(event); 19 | 20 | // Update and check the total number of messages per minute per tenant 21 | let updateResponse = await common.incrementLimitTablePerMinute(dynamo, tenantId, tenantId, "minutemsg"); 22 | if (!updateResponse || updateResponse.Attributes.itemCount > event.requestContext.authorizer.messagesPerMinute) { 23 | console.log("Tenant: " + tenantId + " message rate limit hit"); 24 | await apig.postToConnection({ ConnectionId: connectionId, Data: common.createMessageThrottleResponse(connectionId, requestId) }).promise(); 25 | return {statusCode: 429}; 26 | } 27 | 28 | var updateParams = { 29 | "TableName": process.env.SessionTableName, 30 | "Key": {tenantId: tenantId, sessionId: sessionId}, 31 | "UpdateExpression": "set sessionTTL = :ttl", 32 | "ExpressionAttributeValues": { 33 | ":ttl": (Math.floor(+new Date() / 1000) + parseInt(event.requestContext.authorizer.sessionTTL)) 34 | }, 35 | "ReturnValues": "ALL_OLD" 36 | }; 37 | let results = await dynamo.update(updateParams).promise(); 38 | let connectionIds = results.Attributes.connectionIds.values; 39 | for (var x = 0; x < connectionIds.length; x++) { 40 | if (connectionIds[x] != connectionId) { 41 | await apig.postToConnection({ConnectionId: connectionIds[x], Data: `${body}`}).promise(); 42 | } 43 | } 44 | for (var x = 0; x < connectionIds.length; x++) { 45 | await apig.postToConnection({ ConnectionId: connectionIds[x], Data: common.createEchoResponse(tenantId, sessionId, connectionIds[x], body, undefined) }).promise(); 46 | } 47 | } catch (err) { 48 | console.error(err); 49 | return {statusCode: 1011}; // return server error code 50 | } 51 | } else { 52 | return {statusCode: 1008}; // return policy violation code 53 | } 54 | return {statusCode: 200}; 55 | } 56 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 4.0.0 5 | 6 | com.amazonaws 7 | api-gateway-websocket-saas-rate-limiting-using-aws-lambda-authorizer-cdk 8 | API Gateway Websocket SaaS Rate Limiting using AWS Lambda Authorizer 9 | 1.1.0 10 | 11 | 12 | UTF-8 13 | 14 | 15 | 16 | 17 | 18 | org.apache.maven.plugins 19 | maven-compiler-plugin 20 | 3.8.1 21 | 22 | 9 23 | 9 24 | 25 | 26 | 27 | 28 | org.codehaus.mojo 29 | exec-maven-plugin 30 | 3.0.0 31 | 32 | com.amazonaws.services.sample.apigateway.websocketratelimit.RateLimitApp 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | software.amazon.awscdk 41 | aws-cdk-lib 42 | [2.0.0,) 43 | 44 | 45 | 46 | software.constructs 47 | constructs 48 | [10.0.0,) 49 | 50 | 51 | 52 | software.amazon.awscdk 53 | apigatewayv2-alpha 54 | 2.65.0-alpha.0 55 | 56 | 57 | 58 | software.amazon.awscdk 59 | apigatewayv2-integrations-alpha 60 | 2.65.0-alpha.0 61 | 62 | 63 | 64 | software.amazon.awscdk 65 | apigatewayv2-integrations-alpha 66 | 2.65.0-alpha.0 67 | 68 | 69 | 70 | 71 | junit 72 | junit 73 | 4.13.1 74 | test 75 | 76 | 77 | org.hamcrest 78 | hamcrest-core 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /lambda/WebsocketConnect.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const common = require("./Common.js"); 5 | 6 | exports.handler = async function(event, context) { 7 | //console.log('Received event:', JSON.stringify(event, null, 2)); 8 | 9 | if (event.requestContext.routeKey == '$connect') { 10 | let dynamo = common.createDynamoDBClient(event); 11 | let tenantId = common.getTenantId(event); 12 | let sessionId = common.getSessionId(event); 13 | try { 14 | // Check if we are over the number of connections allowed per tenant 15 | let response = await dynamo.get({ "TableName": process.env.LimitTableName, "Key": { tenantId: tenantId, key: tenantId } }).promise(); 16 | if (response && response.Item && response.Item.itemCount && response.Item.itemCount >= event.requestContext.authorizer.tenantConnections) { 17 | console.log("Tenant " + tenantId + " over tenant total limit"); 18 | return { statusCode: 429 }; 19 | } 20 | 21 | // Check if we are over the number of connections allowed per tenant session 22 | response = await dynamo.get({ "TableName": process.env.SessionTableName, "Key": { tenantId: tenantId, sessionId: sessionId } }).promise(); 23 | if (response && response.Item && response.Item.connectionIds && response.Item.connectionIds.values.length >= event.requestContext.authorizer.connectionsPerSession) { 24 | console.log("Tenant: " + tenantId + " Session: " + sessionId + " over session total limit"); 25 | return { statusCode: 429 }; 26 | } 27 | 28 | // Update and check the total number of connections per minute per tenant 29 | let updateResponse = await common.incrementLimitTablePerMinute(dynamo, tenantId, tenantId, "minute"); 30 | if (!updateResponse || updateResponse.Attributes.itemCount > event.requestContext.authorizer.tenantPerMinute) { 31 | console.log("Tenant: " + tenantId + " over limit per minute"); 32 | return { statusCode: 429 }; 33 | } 34 | 35 | // Update and check the total number of connections per minute per tenant/session 36 | updateResponse = await common.incrementLimitTablePerMinute(dynamo, tenantId,tenantId + ":" + sessionId, "minute"); 37 | if (!updateResponse || updateResponse.Attributes.itemCount > event.requestContext.authorizer.sessionPerMinute) { 38 | console.log(tenantId + "-" + sessionId + " over session per minute limit"); 39 | return { statusCode: 429 }; 40 | } 41 | 42 | // Update the session and limit table with the current connection Ids and counts now that we have passed all other checks 43 | var updateConnectIdParams = { 44 | "TableName": process.env.SessionTableName, 45 | "Key": { tenantId: tenantId, sessionId: sessionId }, 46 | "UpdateExpression": "set sessionTTL = :ttl ADD connectionIds :c", 47 | "ExpressionAttributeValues": { 48 | ":ttl": (Math.floor(+new Date() / 1000) + parseInt(event.requestContext.authorizer.sessionTTL)), 49 | ":c": dynamo.createSet([event.requestContext.connectionId]) 50 | }, 51 | "ReturnValues": "NONE" 52 | }; 53 | var updateConnectCountParams = { 54 | "TableName": process.env.LimitTableName, 55 | "Key": { tenantId: tenantId, key: tenantId }, 56 | "UpdateExpression": "set itemCount = if_not_exists(itemCount, :zero) + :inc", 57 | "ExpressionAttributeValues": { ":inc": 1, ":zero": 0 }, 58 | "ReturnValues": "NONE" 59 | }; 60 | await dynamo.transactWrite({ TransactItems: [ { Update: updateConnectIdParams }, { Update: updateConnectCountParams } ] }).promise(); 61 | } 62 | catch (err) { 63 | console.error(err); 64 | return { statusCode: 1011 }; // return server error code 65 | } 66 | return { statusCode: 200 }; 67 | } 68 | 69 | return { statusCode: 200 }; 70 | }; -------------------------------------------------------------------------------- /lambda/SQSEcho.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const AWS = require("aws-sdk"); 5 | const common = require("./Common.js"); 6 | const apig = new AWS.ApiGatewayManagementApi({ endpoint: process.env.ApiGatewayEndpoint }); 7 | 8 | // This handler will iterate the event records from the SQS queue and send a response message back to each 9 | // connection associated with the session. This keeps all session connections in sync across multiple web browser windows/tabs. 10 | exports.handler = async (event, context) => { 11 | //console.log("Event: ", JSON.stringify(event, null, 2)); 12 | if (event.Records) { 13 | for (let r = 0; r < event.Records.length; r++) { 14 | let recordEvent = { 15 | requestContext: { 16 | authorizer: { 17 | tenantId: event.Records[r].messageAttributes.tenantId.stringValue, 18 | sessionId: event.Records[r].messageAttributes.sessionId.stringValue, 19 | messagesPerMinute: event.Records[r].messageAttributes.messagesPerMinute.stringValue, 20 | sessionTTL: event.Records[r].messageAttributes.sessionTTL.stringValue 21 | } 22 | }, 23 | connectionId: event.Records[r].messageAttributes.connectionId.stringValue, 24 | requestId: event.Records[r].messageAttributes.requestId.stringValue, 25 | body: event.Records[r].body 26 | } 27 | let queueName = event.Records[r].eventSourceARN.substring(event.Records[r].eventSourceARN.indexOf("tenant-"), event.Records[r].eventSourceARN.length); 28 | try { 29 | let body = recordEvent.body; 30 | let connectionId = recordEvent.connectionId; 31 | let requestId = recordEvent.requestId; 32 | let tenantId = common.getTenantId(recordEvent); 33 | let sessionId = common.getSessionId(recordEvent); 34 | let dynamo = common.createDynamoDBClient(recordEvent); 35 | // Update and check the total number of messages per minute per tenant 36 | let updateResponse = await common.incrementLimitTablePerMinute(dynamo, tenantId, tenantId, "minutemsg"); 37 | if (!updateResponse || updateResponse.Attributes.itemCount > recordEvent.requestContext.authorizer.messagesPerMinute) { 38 | console.log("Tenant: " + tenantId + " message rate limit hit"); 39 | await apig.postToConnection({ ConnectionId: connectionId, Data: common.createMessageThrottleResponse(connectionId, requestId) }).promise(); 40 | continue; 41 | } 42 | 43 | let updateParams = { 44 | "TableName": process.env.SessionTableName, 45 | "Key": { tenantId: tenantId, sessionId: sessionId }, 46 | "UpdateExpression": "set sessionTTL = :ttl", 47 | "ExpressionAttributeValues": { 48 | ":ttl": (Math.floor(+new Date() / 1000) + parseInt(recordEvent.requestContext.authorizer.sessionTTL)) 49 | }, 50 | "ReturnValues": "ALL_OLD" 51 | }; 52 | let results = await dynamo.update(updateParams).promise(); 53 | let connectionIds = results.Attributes.connectionIds.values; 54 | for (let x = 0; x < connectionIds.length; x++) { 55 | if (connectionIds[x] != connectionId) { 56 | await apig.postToConnection({ ConnectionId: connectionIds[x], Data: `${body}` }).promise(); 57 | } 58 | } 59 | for (let x = 0; x < connectionIds.length; x++) { 60 | await apig.postToConnection({ ConnectionId: connectionIds[x], Data: common.createEchoResponse(tenantId, sessionId, connectionIds[x], body, queueName) }).promise(); 61 | } 62 | } 63 | catch (err) { 64 | console.error(err); 65 | return { statusCode: 1011 }; // return server error code 66 | } 67 | } 68 | } 69 | const response = { 70 | statusCode: 200 71 | }; 72 | return response; 73 | }; 74 | -------------------------------------------------------------------------------- /lambda/Common.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const AWS = require("aws-sdk"); 5 | 6 | exports.secondsPerMinute = 60; 7 | 8 | // Encapsulating the two ways in which we can access a tenant Id depending on if the system has already been authorized 9 | exports.getTenantId = function(event) { 10 | if (event.requestContext && event.requestContext.authorizer && event.requestContext.authorizer.tenantId) { 11 | return event.requestContext.authorizer.tenantId; 12 | } else if (event.queryStringParameters && event.queryStringParameters.tenantId) { 13 | return event.queryStringParameters.tenantId; 14 | } 15 | return undefined; 16 | } 17 | 18 | // Encapsulating the two ways in which we can access a session Id depending on if the system has already been authorized 19 | exports.getSessionId = function(event) { 20 | if (event.requestContext && event.requestContext.authorizer && event.requestContext.authorizer.sessionId) { 21 | return event.requestContext.authorizer.sessionId; 22 | } else if (event.queryStringParameters && event.queryStringParameters.sessionId) { 23 | return event.queryStringParameters.sessionId; 24 | } 25 | return undefined; 26 | } 27 | 28 | // During the creation of the DynamoBD connection the tenant Id is added as the transitive tag key 29 | // to make sure we can only access data for this specific tenant 30 | exports.createDynamoDBClient = function(event) { 31 | var credentials = new AWS.ChainableTemporaryCredentials({ 32 | params: { 33 | RoleArn: process.env.RoleArn, 34 | Tags: [{ 35 | Key: "tenantId", 36 | Value: exports.getTenantId(event) 37 | }], 38 | TransitiveTagKeys: [ 39 | "tenantId" 40 | ] 41 | }, 42 | Credentials: { 43 | AccessKeyId: AWS.config.credentials.AccessKeyId, 44 | SecretAccessKey: AWS.config.credentials.SecretAccessKey, 45 | SessionToken: AWS.config.credentials.SessionToken 46 | } 47 | }); 48 | return new AWS.DynamoDB.DocumentClient(new AWS.Config({ 49 | credentials: credentials 50 | })); 51 | } 52 | 53 | // Update the limit table by incrementing the itemCount field by 1 for the specified key/current min combo 54 | // this function returns a promise value of the update command 55 | exports.incrementLimitTablePerMinute = function(dynamo, tenantId, keyStart, keyMid) { 56 | var epoch = exports.seconds_since_epoch(); 57 | let currentMin = (Math.trunc(epoch / exports.secondsPerMinute) * exports.secondsPerMinute); 58 | let key = keyStart + ":" + keyMid + ":" + currentMin; 59 | var updateParams = { 60 | "TableName": process.env.LimitTableName, 61 | "Key": { tenantId: tenantId, key: key }, 62 | "UpdateExpression": "set itemCount = if_not_exists(itemCount, :zero) + :inc, itemTTL = :ttl", 63 | "ExpressionAttributeValues": { ":ttl": currentMin + exports.secondsPerMinute + 1, ":inc": 1, ":zero": 0 }, 64 | "ReturnValues": "UPDATED_NEW" 65 | }; 66 | return dynamo.update(updateParams).promise(); 67 | } 68 | 69 | exports.seconds_since_epoch = function() { 70 | return Math.floor(Date.now() / 1000); 71 | } 72 | 73 | exports.createMessageThrottleResponse = function(connectionId, requestId) { 74 | return JSON.stringify({ message: "Too Many Requests", connectionId: connectionId, requestId: requestId }); 75 | } 76 | 77 | exports.createEchoResponse = function(tenantId, sessionId, connectionId, message, queue) { 78 | let response = { 79 | message: JSON.parse(message), 80 | tenantId: tenantId, 81 | sessionId: sessionId, 82 | connectionId: connectionId, 83 | queue: queue 84 | }; 85 | if (queue) { 86 | response.queue = queue; 87 | } 88 | return JSON.stringify(response); 89 | } 90 | 91 | // A policy is generated with an effect (Allow/Deny) and the context is filled 92 | // with the tenant and session information 93 | let generatePolicy = function(effect, resource, event, tenantSettings) { 94 | // Required output: 95 | let authResponse = {}; 96 | authResponse.principalId = 'anonymous'; 97 | if (effect && resource) { 98 | authResponse.policyDocument = { 99 | Version: '2012-10-17', // default version 100 | Statement: [{ 101 | Action: 'execute-api:Invoke', // default action 102 | Effect: effect, 103 | Resource: resource 104 | }] 105 | }; 106 | } 107 | authResponse.context = { 108 | tenantId: exports.getTenantId(event), 109 | sessionId: exports.getSessionId(event), 110 | sessionPerMinute: tenantSettings ? tenantSettings.sessionPerMinute : -1, 111 | tenantPerMinute: tenantSettings ? tenantSettings.tenantPerMinute : -1, 112 | tenantConnections: tenantSettings ? tenantSettings.tenantConnections : -1, 113 | connectionsPerSession: tenantSettings ? tenantSettings.connectionsPerSession : -1, 114 | sessionTTL: tenantSettings ? tenantSettings.sessionTTL : -1, 115 | messagesPerMinute: tenantSettings ? tenantSettings.messagesPerMinute : -1 116 | }; 117 | return authResponse; 118 | } 119 | exports.generateAllow = function(resource, event, tenantSettings) { return generatePolicy('Allow', resource, event, tenantSettings); } 120 | exports.generateDeny = function(resource, event) { return generatePolicy('Deny', resource, event, null); } -------------------------------------------------------------------------------- /lambda/SampleClient.js: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | const connectButton = document.querySelector("#connect"); 5 | const createSessionButton = document.querySelector("#createSession"); 6 | const deleteSessionButton = document.querySelector("#deleteSession"); 7 | const clearLogButton = document.querySelector("#clearLog"); 8 | const sendPooledButton = document.querySelector("#sendPooled"); 9 | const sendSiloButton = document.querySelector("#sendSilo"); 10 | const tenantId = document.querySelector("#tenantId"); 11 | const sessionId = document.querySelector("#sessionId"); 12 | const message = document.querySelector("#message"); 13 | const chatLog = document.querySelector("#chatLog"); 14 | const sessionUrl = "{{sessionUrl}}"; 15 | const tenantUrl = "{{tenantUrl}}"; 16 | const websocketUrl = "{{WssUrl}}"; 17 | let connection; 18 | let msgNumber = 1; 19 | const messageNumberTemplate = "{{msgNumber}}"; 20 | 21 | function addCommunication(input, event) { 22 | console.log(input, event); 23 | let li = document.createElement('li'); 24 | li.innerText = new Date().toLocaleTimeString() + ": " + input; 25 | chatLog.prepend(li); 26 | } 27 | 28 | function createUUID() { 29 | return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) { 30 | var r = Math.random() * 16 | 0, 31 | v = c == 'x' ? r : (r & 0x3 | 0x8); 32 | return v.toString(16); 33 | }); 34 | } 35 | clearLogButton.addEventListener("click", () => { 36 | chatLog.innerHTML = ""; 37 | }); 38 | 39 | window.onload = function() { 40 | sessionId.value = getCookie("sessionuuid"); 41 | loadTenants(); 42 | }; 43 | 44 | function getCookie(cname) { 45 | let name = cname + "="; 46 | let decodedCookie = decodeURIComponent(document.cookie); 47 | let ca = decodedCookie.split(';'); 48 | for (let i = 0; i < ca.length; i++) { 49 | let c = ca[i]; 50 | while (c.charAt(0) == ' ') { 51 | c = c.substring(1); 52 | } 53 | if (c.indexOf(name) == 0) { 54 | return c.substring(name.length, c.length); 55 | } 56 | } 57 | return ""; 58 | } 59 | 60 | function setCookie(cname, cvalue, exdays) { 61 | const d = new Date(); 62 | d.setTime(d.getTime() + (exdays * 24 * 60 * 60 * 1000)); 63 | let expires = "expires=" + d.toUTCString(); 64 | document.cookie = cname + "=" + cvalue + ";" + expires + ";path=/"; 65 | } 66 | 67 | function loadTenants() { 68 | tenantId.innerHTML = ""; 69 | var xhr = new XMLHttpRequest(); 70 | addCommunication("Loading tenants", undefined); 71 | xhr.onerror = function() { 72 | addCommunication("Loading tenants error " + xhr.status + ": " + xhr.statusText, xhr); 73 | }; 74 | xhr.onabort = function() { 75 | addCommunication("Load tenants aborted", xhr); 76 | }; 77 | xhr.onload = function(e) { 78 | if (xhr.status != 200) { // analyze HTTP status of the response 79 | addCommunication("Loading tenants error " + xhr.status + ": " + xhr.statusText, xhr); 80 | } 81 | else { // show the result 82 | addCommunication("Tenants loaded", xhr); 83 | let items = JSON.parse(xhr.response); 84 | for (let x = 0; x < items.Items.length; x++) { 85 | console.log("Item: " + JSON.stringify(items.Items[x])); 86 | var opt = document.createElement('option'); 87 | opt.value = items.Items[x].tenantId; 88 | opt.innerHTML = escape(items.Items[x].tenantId); 89 | tenantId.appendChild(opt); 90 | } 91 | setupSiloPoolButtons(); 92 | } 93 | }; 94 | 95 | xhr.open("GET", tenantUrl); 96 | xhr.send(); 97 | } 98 | 99 | createSessionButton.addEventListener("click", () => { 100 | var xhr = new XMLHttpRequest(); 101 | let uuid = createUUID(); 102 | addCommunication("Creating session " + uuid, undefined); 103 | xhr.onerror = function() { 104 | addCommunication("Session create error " + xhr.status + ": " + xhr.statusText, xhr); 105 | }; 106 | xhr.onabort = function() { 107 | addCommunication("Session create aborted " + uuid, xhr); 108 | }; 109 | xhr.onload = function(e) { 110 | if (xhr.status != 200) { // analyze HTTP status of the response 111 | addCommunication("Session create error " + xhr.status + ": " + xhr.statusText, xhr); 112 | } 113 | else { // show the result 114 | addCommunication("Created session " + uuid, xhr); 115 | setCookie("sessionuuid", uuid, 1); 116 | } 117 | }; 118 | 119 | xhr.open("PUT", sessionUrl + "?tenantId=" + tenantId.value + "&sessionId=" + uuid); 120 | sessionId.value = uuid; 121 | xhr.send(); 122 | }); 123 | 124 | deleteSessionButton.addEventListener("click", () => { 125 | var xhr = new XMLHttpRequest(); 126 | let uuid = sessionId.value; 127 | addCommunication("Deleting session " + uuid, undefined); 128 | xhr.onerror = function() { 129 | addCommunication("Session delete error " + xhr.status + ": " + xhr.statusText, xhr); 130 | }; 131 | xhr.onabort = function() { 132 | addCommunication("Session delete aborted " + uuid, xhr); 133 | }; 134 | xhr.onload = function(e) { 135 | if (xhr.status != 200) { // analyze HTTP status of the response 136 | addCommunication("Session delete error " + xhr.status + ": " + xhr.statusText, xhr); 137 | } 138 | else { // show the result 139 | addCommunication("Deleted session " + uuid, xhr); 140 | setCookie("sessionuuid", "", 1); 141 | sessionId.value = ""; 142 | } 143 | }; 144 | 145 | xhr.open("DELETE", sessionUrl + "?tenantId=" + tenantId.value + "&sessionId=" + uuid); 146 | xhr.send(); 147 | }); 148 | 149 | connectButton.addEventListener("click", () => { 150 | if (connection) { 151 | addCommunication("Disconnecting", undefined); 152 | connection.close(); 153 | } 154 | else { 155 | addCommunication("Connecting", undefined); 156 | connection = new WebSocket(websocketUrl + "?tenantId=" + tenantId.value + "&sessionId=" + sessionId.value); 157 | connectButton.innerHTML = "Disconnect"; 158 | connection.onopen = (event) => { 159 | addCommunication("Connected", event); 160 | }; 161 | 162 | connection.onclose = (event) => { 163 | addCommunication("Disconnected", event); 164 | connectButton.innerHTML = "Connect"; 165 | connection = undefined; 166 | }; 167 | 168 | connection.onerror = (event) => { 169 | addCommunication("Connection error. See console for details.", event); 170 | }; 171 | 172 | connection.onmessage = (event) => { 173 | let msg = JSON.parse(event.data); 174 | if (msg.message == "Too Many Requests") { 175 | addCommunication("THROTTLED!: " + event.data, event); 176 | } else { 177 | let starter = msg.tenantId ? "Recv: " : "Sent: "; 178 | addCommunication(starter + event.data, event); 179 | } 180 | }; 181 | } 182 | }); 183 | 184 | function sendMesssage(silo) { 185 | let data = {}; 186 | data.message = message.value; 187 | if (data.message.includes(messageNumberTemplate)) { 188 | msgNumber = msgNumber + 1; 189 | data.message = data.message.replace(messageNumberTemplate, "" + msgNumber); 190 | } 191 | data.action = silo ? "SiloSQS" : "PooledSQS"; 192 | let sendData = JSON.stringify(data); 193 | connection.send(sendData); 194 | addCommunication("Sent: " + sendData); 195 | } 196 | 197 | function setupSiloPoolButtons() { 198 | sendPooledButton.disabled = false; 199 | sendSiloButton.disabled = false; 200 | if (tenantId.value == "31a2e8c6-1826-11ec-9621-0242ac130002") { 201 | sendPooledButton.disabled = true; 202 | } else { 203 | sendSiloButton.disabled = true; 204 | } 205 | } 206 | 207 | sendPooledButton.addEventListener("click", () => { 208 | sendMesssage(false); 209 | }); 210 | 211 | sendSiloButton.addEventListener("click", () => { 212 | sendMesssage(true); 213 | }); 214 | 215 | tenantId.addEventListener("change", () => { 216 | setupSiloPoolButtons(); 217 | }); 218 | 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # API Gateway WebSocket SaaS Multi-Tenant Rate Limiting 2 | 3 | When creating a SaaS multi-tenant systems which require websocket connections we need a way to rate limit those connections on a per tenant basis. 4 | With Amazon API Gateway you have the option to use usage plans with HTTP connections however they are not available for websockets. 5 | To enable rate limiting we can use a API Gateway Lambda Authorizer to validate a connection and control access. 6 | Using a Lambda Authorizer we can implement code to allow the system to valid connection rates and throttle inbound connections on a per tenant basis. 7 | This sample also demonstrates pool and silo modes for handling the message traffic per tenant. 8 | 9 | ## Architecture 10 | Architecture 11 | 12 | 1. The client sends an HTTP PUT request to the Amazon API Gateway HTTP endpoint to create a session for a tenant. If required, this call can be authenticated, however, that is outside the scope of this sample. 13 | 2. A Lambda function will create a session and store it in a DynamoDB table with a TTL (Time To Live) value specified. Amazon DynamoDB Streams are used to remove all session connections if no communication is sent or received over a specific period of time. Each call to the database layer is restricted by conditional keys using sts:TransitiveTagKeys so each tenant can only access its specific rows based on the tenant ID. 14 | 3. Once a session is created the client will initiate a websocket connection to the Amazon AWS API Gateway WebSocket endpoint. A session can be used multiple times to create connections from multiple web browser windows. The session is used to keep these different connections in sync. 15 | 4. A Lambda function is used as the Authorizer for the WebSocket connection. The authorizer will do the following: 16 |
    17 |
  1. Validate the tenant exists.
  2. 18 |
  3. Validate the session exists.
  4. 19 |
  5. Add the tenant ID, session ID, connection ID, and tenant settings to the authorization context.
  6. 20 |
21 | 5. A Lambda function is used for the connect route which throttles inbound connections and returns a 429 response code if over a limit. The following checks and processing are done: 22 |
    23 |
  1. Over the total number of connections allowed for this tenant.
  2. 24 |
  3. Over the total number of connections allowed for this session.
  4. 25 |
  5. Over the total number of connections per minute allowed for the tenant.
  6. 26 |
  7. Over the total number of connections per minute allowed for the session.
  8. 27 |
  9. Add the connection ID to the sessions connection ID set and update the session Time to Live (TTL).
  10. 28 |
  11. Increment the total number of connections for the tenant.
  12. 29 |
30 | 6. Messages are processed via a Siloed or Pooled FIFO Queue depending on the API Gateway route. SQS FIFO queues are used to keep messages in order. If we send messages directly to the Lambda function there is the possibility a cold start could occur on the first message delaying its processing while a following message hits a warm Lambda function causing it to process faster and return an out of order reply. The tenant ID, session ID, connection ID and tenant settings are added to each message as message metadata. SQS FIFO queues use a combination of tenant ID and session ID for the SQS message group ID to keep messages in order. Each inbound message will update the DynamoDB session TTL to reset the session timeout. 31 |
    32 |
  1. Silo based messages are processed by the tenant’s corresponding SQS FIFO queue, which is named using the tenant ID. A Lambda function per tenant is used to read messages from the tenant’s SQS FIFO queue.
  2. 33 |
  3. Pool based messages are processed by a single pooled SQS FIFO queue. A Lambda function is used by all tenants to read messages from the pooled SQS FIFO queue.
  4. 34 |
35 | 7. A Lambda function is used during disconnect to do the following: 36 |
    37 |
  1. Remove the connection ID from the session connection ID set.
  2. 38 |
  3. Decrement the total number of connections for the tenant
  4. 39 |
40 | 8. Once all connections are closed, the client will send an HTTP DELETE request to the Amazon API Gateway HTTP endpoint to remove the session. 41 | 42 | ## Requirements 43 | 1. Apache Maven 3.8.1 44 | 2. AWS CDK 1.130.0 or later installed 45 | 46 | ## Setup 47 | 1. git clone this repository 48 | 2. In the root directory of the repository execute the command ```cdk deploy``` 49 | 3. Review the permissions and follow prompts 50 | 4. After deployment the CDK will list the outputs as follows: 51 | 1. APIGatewayWebSocketRateLimitStack.SampleClient 52 | 1. The URI points to the sample web page described below 53 | 2. APIGatewayWebSocketRateLimitStack.SessionURL 54 | 1. This URI points to the endpoint which is able to create sessions 55 | 3. APIGatewayWebSocketRateLimitStack.TenantURL 56 | 1. This URI is only exposed for demo purposes and is used to get a list of the current tenant Ids 57 | 4. APIGatewayWebSocketRateLimitStack.WebSocketURL 58 | 1. This URI is the websocket connection endpoint 59 | 60 | ## Sample Web Page 61 | Sample Web Page 62 | 63 | The sample can be used to test the various aspects of the system. The following steps are the happy path: 64 | 1. Open the web page given as the output **APIGatewayWebSocketRateLimitStack.SampleClient** from the CDK deployment 65 | 2. Wait for the tenant Ids to load 66 | 3. Click the **Create Session** button to create a new session 67 | 4. Click the **Connect** button 68 | 5. Once connected try both the **Send Silo** and **Send Pooled** buttons 69 | 1. The **Send Silo** button will send via the **SiloSQS** route which uses the siloed SQS FIFO queue execution model 70 | 2. The **Send Pooled** button will send via the **PooledSQS** route which uses the pooled SQS FIFO queue execution model 71 | 6. Click the **Disconnect** button to close the connection 72 | 7. Click the **Delete Session** button to remove the current session 73 | 74 | ## Cleanup 75 | 1. In the root directory of the repository execute the command ```cdk destroy``` 76 | 77 | ## Silo vs Pooled Message processing 78 | ### Silo 79 | SQS FIFO queues and siloed Lambdas per tenant are used in silo mode. The API gateway will use the authorization contexts tenantId to determine the queue name per tenant. Each SQS FIFO queue has a linked Lambda function to process messages which send an echo reply. 80 | ### Pooled 81 | A single SQS FIFO queue and siloed Lambdas per tenant are used in silo mode. The API gateway will use the authorization contexts tenantId to determine the queue name per tenant. Each SQS FIFO queue has a linked Lambda function to process messages which send an echo reply. 82 | 83 | ## DynamoDB Table Structures 84 | All tables access is restricted by a partition key condition to only allow access to rows for which the primary index matches the current tenantId. 85 | 86 | #### Tenant Table 87 | The tenant table is used to store the tenantIds and option details to allow each tenant to specify different rate limits. 88 | 89 | Fields 90 | 1. tenantId (String) (Partition Key) - The tenantId 91 | 2. connectionsPerSession (Number) - The max number of connections each session is allowed 92 | 3. tenantConnections (Number) - The max number of connections this tenant is allowed 93 | 4. sessionPerMinute (Number) - The max number of connections per minute for a session 94 | 5. tenantPerMinute (Number) - The max number of connections per minute for this tenant 95 | 6. sessionTTL (Number) - The session time to live value in seconds. This is used each time activity happens for a session to increase the time period before a session times out and connections are dropped. The TTL value is set as current time plus this value. 96 | 7. messagesPerMinute (Number) - The total number of messages per minute this tenant is allowed to process before throttling the tenant. 97 | 98 | #### Limit Table 99 | The limit table is used to store the current limit counts for each tenant and also the per minute counts. 100 | 101 | Fields 102 | 1. key (String) (Partition Key) - This key field can be one of three formats 103 | 1. tenantId - If the key is a single tenantId then it is tracking the total number of connections for this tenant 104 | 2. tenantId:minute:{epoch} - If the key is the tenantId:minute:{epoch} then it is tracking the current number of connections per minute for the tenant within the {epoch} value start time + 60 seconds. 105 | 3. tenantId:sessionId:minute:{epoch} - If the key is the tenantId:sessionId:minute:{epoch} then it is tracking the current number of connections per minute for the session within the {epoch} value start time + 60 seconds. 106 | 2. itemCount (Number) - The current value for the limit 107 | 3. itemTTL (Number) (TTL) - The time to live value for DynamoDB to remove this item. This is used for the per minute connection rates to remove expired rows. 108 | 109 | #### Session Table 110 | The session table keeps track of sessions per tenant and will expire sessions after a set amount of time 111 | 112 | Fields 113 | 1. tenantId (String) (Partition Key) - The tenantId 114 | 2. sessionId (String) (Sort Key) - The sessionId 115 | 3. connectionIds (Set [String]) - The current connectionIds for this session. This is used to keep track of the number of connections per session. It is also used to send reply messages to all connections on a specific session. 116 | 4. sessionTTL (Number) (TTL) - the time to live value for DynamoDB to remove this item. This value is used to removed expired sessions and disconnect any lingering connections associated. 117 | 118 | ## Security 119 | 120 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 121 | 122 | ## License 123 | 124 | This library is licensed under the MIT-0 License. See the LICENSE file. -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/services/sample/apigateway/websocketratelimit/RateLimitStack.java: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package com.amazonaws.services.sample.apigateway.websocketratelimit; 5 | 6 | import software.amazon.awscdk.*; 7 | import software.amazon.awscdk.customresources.AwsCustomResource; 8 | import software.amazon.awscdk.customresources.AwsCustomResourcePolicy; 9 | import software.amazon.awscdk.customresources.AwsSdkCall; 10 | import software.amazon.awscdk.customresources.PhysicalResourceId; 11 | import software.amazon.awscdk.services.apigatewayv2.*; 12 | import software.amazon.awscdk.services.apigatewayv2.alpha.*; 13 | import software.amazon.awscdk.services.apigatewayv2.alpha.HttpMethod; 14 | import software.amazon.awscdk.services.apigatewayv2.integrations.alpha.HttpLambdaIntegration; 15 | import software.amazon.awscdk.services.apigatewayv2.integrations.alpha.WebSocketLambdaIntegration; 16 | import software.amazon.awscdk.services.dynamodb.*; 17 | import software.amazon.awscdk.services.iam.*; 18 | import software.amazon.awscdk.services.lambda.Runtime; 19 | import software.amazon.awscdk.services.lambda.*; 20 | import software.amazon.awscdk.services.lambda.eventsources.DynamoEventSource; 21 | import software.amazon.awscdk.services.lambda.eventsources.SqsEventSource; 22 | import software.amazon.awscdk.services.sqs.DeduplicationScope; 23 | import software.amazon.awscdk.services.sqs.FifoThroughputLimit; 24 | import software.amazon.awscdk.services.sqs.Queue; 25 | import software.constructs.Construct; 26 | 27 | import java.util.List; 28 | import java.util.Map; 29 | 30 | public class RateLimitStack extends Stack { 31 | private Table tenantTable; 32 | private Table sessionTable; 33 | private Table limitTable; 34 | private Function sessionTTLLambda; 35 | private Function sampleClientFunction; 36 | private Function sessionFunction; 37 | private Function tenantFunction; 38 | private Function webSocketConnectFunction; 39 | private Function webSocketDisconnectFunction; 40 | private Function authorizerFunction; 41 | private WebSocketApi api; 42 | private HttpApi sessionApi; 43 | private CfnAuthorizer authorizer; 44 | private WebSocketStage stage; 45 | 46 | public RateLimitStack(final Construct scope, final String id) { 47 | this(scope, id, null); 48 | } 49 | 50 | public RateLimitStack(final Construct scope, final String id, final StackProps props) { 51 | super(scope, id, props); 52 | 53 | createTenantDynamoDBTable(); 54 | createSessionDynamoDBTable(); 55 | createLimitDynamoDBTable(); 56 | createSessionTTLLambda(); 57 | createSampleClientLambda(); 58 | createSessionLambda(); 59 | createTenantLambda(); 60 | createWebSocketConnectLambda(); 61 | createWebSocketDisconnectLambda(); 62 | createAuthorizerLambda(); 63 | createAPIGatewayWebSocket(); 64 | createQueueWebSocketAPIRoute(true); // silo 65 | createQueueWebSocketAPIRoute(false); // pooled 66 | createAuthorizer(); 67 | createConnectIntegration(); 68 | createStage(); 69 | createAPIGatewaySessionAndSample(); 70 | setupAPIGatewayLambdaFunctions(); 71 | addSampleTenantIds(); 72 | createOutputs(); 73 | } 74 | 75 | private void createTenantDynamoDBTable() { 76 | tenantTable = Table.Builder.create(this, "TenantTable") 77 | .removalPolicy(RemovalPolicy.DESTROY) 78 | .partitionKey(Attribute.builder() 79 | .name("tenantId") 80 | .type(AttributeType.STRING) 81 | .build()) 82 | .build(); 83 | 84 | EnableScalingProps esp = EnableScalingProps.builder().maxCapacity(10).minCapacity(1).build(); 85 | tenantTable.autoScaleReadCapacity(esp).scaleOnUtilization(UtilizationScalingProps.builder().targetUtilizationPercent(70).build()); 86 | tenantTable.autoScaleWriteCapacity(esp).scaleOnUtilization(UtilizationScalingProps.builder().targetUtilizationPercent(70).build()); 87 | } 88 | 89 | private void createSessionDynamoDBTable() { 90 | sessionTable = Table.Builder.create(this, "SessionTable") 91 | .removalPolicy(RemovalPolicy.DESTROY) 92 | .partitionKey(Attribute.builder() 93 | .name("tenantId") 94 | .type(AttributeType.STRING) 95 | .build()) 96 | .sortKey(Attribute.builder() 97 | .name("sessionId") 98 | .type(AttributeType.STRING) 99 | .build()) 100 | .stream(StreamViewType.NEW_AND_OLD_IMAGES) 101 | .timeToLiveAttribute("sessionTTL") 102 | .build(); 103 | 104 | EnableScalingProps esp = EnableScalingProps.builder().maxCapacity(10).minCapacity(1).build(); 105 | sessionTable.autoScaleReadCapacity(esp).scaleOnUtilization(UtilizationScalingProps.builder().targetUtilizationPercent(70).build()); 106 | sessionTable.autoScaleWriteCapacity(esp).scaleOnUtilization(UtilizationScalingProps.builder().targetUtilizationPercent(70).build()); 107 | } 108 | 109 | private void createLimitDynamoDBTable() { 110 | limitTable = Table.Builder.create(this, "LimitTable") 111 | .removalPolicy(RemovalPolicy.DESTROY) 112 | .partitionKey(Attribute.builder() 113 | .name("tenantId") 114 | .type(AttributeType.STRING) 115 | .build()) 116 | .sortKey(Attribute.builder() 117 | .name("key") 118 | .type(AttributeType.STRING) 119 | .build()) 120 | .timeToLiveAttribute("itemTTL") 121 | .build(); 122 | EnableScalingProps esp = EnableScalingProps.builder().maxCapacity(10).minCapacity(1).build(); 123 | limitTable.autoScaleReadCapacity(esp).scaleOnUtilization(UtilizationScalingProps.builder().targetUtilizationPercent(70).build()); 124 | limitTable.autoScaleWriteCapacity(esp).scaleOnUtilization(UtilizationScalingProps.builder().targetUtilizationPercent(70).build()); 125 | } 126 | 127 | private void createSessionTTLLambda() { 128 | sessionTTLLambda = Function.Builder.create(this, "SessionTTLLambda") 129 | .runtime(Runtime.NODEJS_14_X) 130 | .code(Code.fromAsset("lambda")) 131 | .handler("SessionTTL.handler") 132 | .events(List.of(DynamoEventSource.Builder.create(sessionTable).startingPosition(StartingPosition.LATEST).build())) 133 | .build(); 134 | } 135 | 136 | private void createSampleClientLambda() { 137 | sampleClientFunction = Function.Builder.create(this, "SampleClientHandler") 138 | .runtime(Runtime.NODEJS_14_X) 139 | .code(Code.fromAsset("lambda")) 140 | .handler("SampleClientGet.handler") 141 | .build(); 142 | } 143 | 144 | private void createSessionLambda() { 145 | sessionFunction = Function.Builder.create(this, "Session") 146 | .runtime(Runtime.NODEJS_14_X) 147 | .code(Code.fromAsset("lambda")) 148 | .handler("Session.handler") 149 | .build(); 150 | } 151 | 152 | private void createTenantLambda() { 153 | tenantFunction = Function.Builder.create(this, "Tenant") 154 | .runtime(Runtime.NODEJS_14_X) 155 | .code(Code.fromAsset("lambda")) 156 | .handler("Tenant.handler") 157 | .build(); 158 | } 159 | 160 | 161 | private void createWebSocketConnectLambda() { 162 | webSocketConnectFunction = Function.Builder.create(this, "WebSocketConnect") 163 | .runtime(Runtime.NODEJS_14_X) 164 | .code(Code.fromAsset("lambda")) 165 | .handler("WebSocketConnect.handler") 166 | .build(); 167 | } 168 | 169 | private void createWebSocketDisconnectLambda() { 170 | webSocketDisconnectFunction = Function.Builder.create(this, "WebSocketDisconnect") 171 | .runtime(Runtime.NODEJS_14_X) 172 | .code(Code.fromAsset("lambda")) 173 | .handler("WebSocketDisconnect.handler") 174 | .build(); 175 | sessionTable.grantReadWriteData(webSocketDisconnectFunction); 176 | } 177 | 178 | private void createAuthorizerLambda() { 179 | authorizerFunction = Function.Builder.create(this, "Authorizer") 180 | .runtime(Runtime.NODEJS_14_X) 181 | .code(Code.fromAsset("lambda")) 182 | .handler("Authorizer.handler") 183 | .build(); 184 | } 185 | 186 | private void createAPIGatewayWebSocket() { 187 | // Create a websocket API endpoint with routing to our echo lambda 188 | // We do not create the connect route at this point due to the authorizer not being enabled for the WebSocketRouteOptions 189 | // yet, we will instead use the low level Cfn style functions later. 190 | api = WebSocketApi.Builder.create(this, "WebSocketAPIGateway") 191 | .apiName("WebSocketRateLimitSample") 192 | .description("Rate limit websocket connections using a Lambda Authorizer.") 193 | .disconnectRouteOptions(WebSocketRouteOptions.builder().integration(new WebSocketLambdaIntegration("WebSocketAPIGatewayDisconnectRoute",webSocketDisconnectFunction)).build()) 194 | .build(); 195 | } 196 | 197 | 198 | private void createQueueWebSocketAPIRoute(boolean silo) { 199 | String nameExt = silo ? "Silo" : "Pooled"; 200 | Role apiGatewayWebSocketSQSRole = Role.Builder.create(this, "ApiGatewayWebSocket" + nameExt + "SQSRole") 201 | .assumedBy(ServicePrincipal.Builder.create("apigateway.amazonaws.com").build()) 202 | .inlinePolicies(Map.of("APIGateway" + nameExt + "SQSSendMessagePolicy", PolicyDocument.Builder.create() 203 | .statements(List.of(PolicyStatement.Builder.create() 204 | .effect(Effect.ALLOW) 205 | .actions(List.of("sqs:SendMessage")) 206 | .resources(List.of("arn:aws:sqs:" + getRegion() +":" + getAccount() + ":tenant-" + (silo ? "*" : nameExt) + ".fifo")) 207 | .build())) 208 | .build())) 209 | .build(); 210 | String requestTemplateItem = ""; 211 | requestTemplateItem += "Action=SendMessage"; 212 | requestTemplateItem += "&MessageGroupId=$context.authorizer.tenantId:$context.authorizer.sessionId"; 213 | requestTemplateItem += "&MessageDeduplicationId=$context.requestId"; 214 | requestTemplateItem += "&MessageAttribute.1.Name=tenantId&MessageAttribute.1.Value.StringValue=$context.authorizer.tenantId&MessageAttribute.1.Value.DataType=String"; 215 | requestTemplateItem += "&MessageAttribute.2.Name=sessionId&MessageAttribute.2.Value.StringValue=$context.authorizer.sessionId&MessageAttribute.2.Value.DataType=String"; 216 | requestTemplateItem += "&MessageAttribute.3.Name=connectionId&MessageAttribute.3.Value.StringValue=$context.connectionId&MessageAttribute.3.Value.DataType=String"; 217 | requestTemplateItem += "&MessageAttribute.4.Name=requestId&MessageAttribute.4.Value.StringValue=$context.requestId&MessageAttribute.4.Value.DataType=String"; 218 | requestTemplateItem += "&MessageAttribute.5.Name=sessionPerMinute&MessageAttribute.5.Value.StringValue=$context.authorizer.sessionPerMinute&MessageAttribute.5.Value.DataType=String"; 219 | requestTemplateItem += "&MessageAttribute.6.Name=tenantPerMinute&MessageAttribute.6.Value.StringValue=$context.authorizer.tenantPerMinute&MessageAttribute.6.Value.DataType=String"; 220 | requestTemplateItem += "&MessageAttribute.7.Name=connectionsPerSession&MessageAttribute.7.Value.StringValue=$context.authorizer.connectionsPerSession&MessageAttribute.7.Value.DataType=String"; 221 | requestTemplateItem += "&MessageAttribute.8.Name=sessionTTL&MessageAttribute.8.Value.StringValue=$context.authorizer.sessionTTL&MessageAttribute.8.Value.DataType=String"; 222 | requestTemplateItem += "&MessageAttribute.9.Name=tenantConnections&MessageAttribute.9.Value.StringValue=$context.authorizer.tenantConnections&MessageAttribute.9.Value.DataType=String"; 223 | requestTemplateItem += "&MessageAttribute.10.Name=messagesPerMinute&MessageAttribute.10.Value.StringValue=$context.authorizer.messagesPerMinute&MessageAttribute.10.Value.DataType=String"; 224 | requestTemplateItem += "&MessageBody=$input.json('$')"; 225 | CfnIntegration integration = CfnIntegration.Builder.create(this, nameExt + "Integration") 226 | .apiId(api.getApiId()) 227 | .connectionType("INTERNET") 228 | .integrationType("AWS") 229 | .credentialsArn(apiGatewayWebSocketSQSRole.getRoleArn()) 230 | .templateSelectionExpression("\\$default") 231 | .integrationMethod("POST") 232 | .integrationUri("arn:aws:apigateway:" + getRegion() + ":sqs:path/" + getAccount() + "/tenant-{queue}.fifo") 233 | .passthroughBehavior("NEVER") 234 | .requestParameters(Map.of( 235 | "integration.request.header.Content-Type", 236 | "'application/x-www-form-urlencoded'", 237 | "integration.request.path.queue", 238 | silo ? "context.authorizer.tenantId" : "'" + nameExt + "'")) 239 | .requestTemplates(Map.of( 240 | "$default", 241 | requestTemplateItem 242 | )) 243 | .build(); 244 | CfnRoute.Builder.create(this, nameExt + "Route") 245 | .apiId(api.getApiId()) 246 | .routeKey(nameExt + "SQS") 247 | .target("integrations/" + integration.getRef()) 248 | .build(); 249 | } 250 | 251 | private void createAuthorizer() { 252 | authorizer = CfnAuthorizer.Builder.create(this, "RateLimitAuthorizer") 253 | .identitySource(List.of("route.request.querystring.tenantId", "route.request.querystring.sessionId")) 254 | .authorizerType("REQUEST") 255 | .authorizerUri("arn:aws:apigateway:" + getRegion() + ":lambda:path/2015-03-31/functions/" + authorizerFunction.getFunctionArn() + "/invocations") 256 | .apiId(api.getApiId()) 257 | .name("RateLimitAuthorizer") 258 | .build(); 259 | } 260 | 261 | private void createConnectIntegration() { 262 | CfnIntegration integration = CfnIntegration.Builder.create(this, "ConnectLambdaIntegration") 263 | .integrationType("AWS_PROXY") 264 | .integrationMethod("POST") 265 | .integrationUri("arn:aws:apigateway:" + getRegion() + ":lambda:path/2015-03-31/functions/" + webSocketConnectFunction.getFunctionArn() + "/invocations") 266 | .apiId(api.getApiId()) 267 | .build(); 268 | 269 | CfnRoute.Builder.create(this, "ConnectRoute") 270 | .apiId(api.getApiId()) 271 | .routeKey("$connect") 272 | .authorizationType("CUSTOM") 273 | .target("integrations/" + integration.getRef()) 274 | .authorizerId(authorizer.getRef()) 275 | .build(); 276 | } 277 | 278 | private void createStage() { 279 | // Setup a production stage with auto deploy which will make sure we are ready to run as soon as the cloudformation stack completes 280 | stage = WebSocketStage.Builder.create(this, "EchoWebSocketAPIGatewayProd") 281 | .stageName("production") 282 | .webSocketApi(api) 283 | .autoDeploy(true) 284 | .build(); 285 | } 286 | 287 | private void createAPIGatewaySessionAndSample() { 288 | sessionApi = HttpApi.Builder.create(this, "SessionAndSampleAPIGateway") 289 | .apiName("WebSocketRateLimitSessionSample") 290 | .description("Creates and removes sessions and loads sample client") 291 | .createDefaultStage(false) 292 | .build(); 293 | HttpLambdaIntegration sessionLambdaIntegration = new HttpLambdaIntegration("SessionLambdaIntegration", sessionFunction); 294 | HttpLambdaIntegration tenantLambdaIntegration = new HttpLambdaIntegration("TenantLambdaIntegration", tenantFunction); 295 | HttpLambdaIntegration sampleClientLambdaIntegration = new HttpLambdaIntegration("SampleClientLambdaIntegration", sampleClientFunction); 296 | sessionApi.addRoutes(AddRoutesOptions.builder() 297 | .methods(List.of(HttpMethod.PUT)) 298 | .path("/session") 299 | .integration(sessionLambdaIntegration) 300 | .build()); 301 | sessionApi.addRoutes(AddRoutesOptions.builder() 302 | .methods(List.of(HttpMethod.DELETE)) 303 | .path("/session") 304 | .integration(sessionLambdaIntegration) 305 | .build()); 306 | sessionApi.addRoutes(AddRoutesOptions.builder() 307 | .methods(List.of(HttpMethod.GET)) 308 | .path("/tenant") 309 | .integration(tenantLambdaIntegration) 310 | .build()); 311 | sessionApi.addRoutes(AddRoutesOptions.builder() 312 | .methods(List.of(HttpMethod.GET)) 313 | .path("/SampleClient") 314 | .integration(sampleClientLambdaIntegration) 315 | .build()); 316 | sessionApi.addStage("SessionApiProductionStage", HttpStageOptions.builder() 317 | .autoDeploy(true) 318 | .stageName("production") 319 | .build()); 320 | } 321 | 322 | private void setupAPIGatewayLambdaFunctions() { 323 | // Update the lambdas to allow callbacks to this websocket endpoint and set environment variables to be able to reach various resources. 324 | // A role is created for each lambda to allow us to Assume the role with session tags to 325 | Role sessionTTLLambdaTableRole = Role.Builder.create(this, "SessionTTLLambdaTableRole").assumedBy(new SessionTagsPrincipal(sessionTTLLambda.getRole())).build(); 326 | setupWebSocketFunction(sessionTTLLambda, sessionTTLLambdaTableRole, null, true, true); 327 | Role webSocketConnectFunctionTableRole = Role.Builder.create(this, "WebSocketConnectFunctionTableRole").assumedBy(new SessionTagsPrincipal(webSocketConnectFunction.getRole())).build(); 328 | setupWebSocketFunction(webSocketConnectFunction, webSocketConnectFunctionTableRole, "/*/$connect", true); 329 | Role webSocketDisconnectFunctionTableRole = Role.Builder.create(this, "WebSocketDisconnectFunctionTableRole").assumedBy(new SessionTagsPrincipal(webSocketDisconnectFunction.getRole())).build(); 330 | setupWebSocketFunction(webSocketDisconnectFunction, webSocketDisconnectFunctionTableRole, "/*/$disconnect", true); 331 | Role authorizerFunctionTableRole = Role.Builder.create(this, "AuthorizerFunctionTableRole").assumedBy(new SessionTagsPrincipal(authorizerFunction.getRole())).build(); 332 | setupWebSocketFunction(authorizerFunction, authorizerFunctionTableRole, "/authorizers/" + authorizer.getRef(), false); 333 | Role sessionFunctionTableRole = Role.Builder.create(this, "SessionFunctionTableRole").assumedBy(new SessionTagsPrincipal(sessionFunction.getRole())).build(); 334 | setupWebSocketFunction(sessionFunction, sessionFunctionTableRole,null, false); 335 | Role tenantFunctionTableRole = Role.Builder.create(this, "TenantFunctionTableRole").assumedBy(new SessionTagsPrincipal(tenantFunction.getRole())).build(); 336 | setupWebSocketFunction(tenantFunction, tenantFunctionTableRole,null, false, false); 337 | 338 | sampleClientFunction.addEnvironment("WssUrl", stage.getUrl()); 339 | sampleClientFunction.addEnvironment("SessionUrl", sessionApi.getApiEndpoint() + "/production/session"); 340 | sampleClientFunction.addEnvironment("TenantUrl", sessionApi.getApiEndpoint() + "/production/tenant"); 341 | } 342 | 343 | private void setupWebSocketFunction(Function function, Role tableRole, String permissionEndpoint, boolean includePostPolicy) { 344 | setupWebSocketFunction(function, tableRole, permissionEndpoint, includePostPolicy, false); 345 | } 346 | 347 | private void setupWebSocketFunction(Function function, Role tableRole, String permissionEndpoint, boolean includePostPolicy, boolean includeDeletePolicy) { 348 | function.addEnvironment("ApiGatewayEndpoint", stage.getUrl().replace("wss://", "")); 349 | function.addEnvironment("TenantTableName", tenantTable.getTableName()); 350 | function.addEnvironment("SessionTableName", sessionTable.getTableName()); 351 | function.addEnvironment("LimitTableName", limitTable.getTableName()); 352 | if (includePostPolicy) { 353 | function.addToRolePolicy(PolicyStatement.Builder.create() 354 | .actions(List.of("execute-api:ManageConnections")) 355 | .resources(List.of(formatArn(ArnComponents.builder() 356 | .resource(api.getApiId()) 357 | .service("execute-api") 358 | .resourceName(stage.getStageName() + "/POST/*") 359 | .build()))) 360 | .build()); 361 | } 362 | if (includeDeletePolicy) { 363 | function.addToRolePolicy(PolicyStatement.Builder.create() 364 | .actions(List.of("execute-api:ManageConnections")) 365 | .resources(List.of(formatArn(ArnComponents.builder() 366 | .resource(api.getApiId()) 367 | .service("execute-api") 368 | .resourceName(stage.getStageName() + "/DELETE/*") 369 | .build()))) 370 | .build()); 371 | } 372 | if (permissionEndpoint != null) { 373 | function.addPermission("APIGatewayConnect", Permission.builder() 374 | .action("lambda:InvokeFunction") 375 | .principal(ServicePrincipal.Builder.create("apigateway.amazonaws.com").build()) 376 | .sourceArn("arn:aws:execute-api:" + getRegion() + ":" + getAccount() + ":" + api.getApiId() + permissionEndpoint) 377 | .build()); 378 | } 379 | // function.addToRolePolicy(PolicyStatement.Builder.create() 380 | // .effect(Effect.ALLOW) 381 | // .actions(List.of("sts:AssumeRole", "sts:TagSession")) 382 | // .resources(List.of(tableRole.getRoleArn())) 383 | // .build()); 384 | function.addEnvironment("RoleArn", tableRole.getRoleArn()); 385 | tableRole.grantAssumeRole(function.getRole()); 386 | tenantTable.grantReadData(tableRole).getPrincipalStatement().addCondition("ForAllValues:StringEquals", Map.of("dynamodb:LeadingKeys", List.of("${aws:PrincipalTag/tenantId}"))); 387 | sessionTable.grantReadWriteData(tableRole).getPrincipalStatement().addCondition("ForAllValues:StringEquals", Map.of("dynamodb:LeadingKeys", List.of("${aws:PrincipalTag/tenantId}"))); 388 | limitTable.grantReadWriteData(tableRole).getPrincipalStatement().addCondition("ForAllValues:StringLike", Map.of("dynamodb:LeadingKeys", List.of("${aws:PrincipalTag/tenantId}*"))); 389 | } 390 | 391 | private void addSampleTenantIds() { 392 | addSampleTenantId("a5a82459-3f18-4ecd-89a6-2d13af314751", "60", "5", "2", "10", "200", "60", 1); 393 | addSampleTenantId("9175b21a-332a-4a7a-a72d-9184ad7186c0", "120", "10", "5", "100", "300","600",2); 394 | addSampleTenantId("31a2e8c6-1826-11ec-9621-0242ac130002", "180", "30", "10", "1000", "300","6000",3); 395 | Function sqsEchoFunction = createSQSEchoLambda("Pooled"); 396 | createSQSFifoQueuePerTenant("Pooled", sqsEchoFunction); 397 | } 398 | 399 | private void addSampleTenantId(String tenantId, String tenantPerMinute, String sessionPerMinute, String connectionsPerSession, String tenantConnections, String sessionTTL, String messagesPerMinute, int index) { 400 | AwsSdkCall initializeData = AwsSdkCall.builder() 401 | .service("DynamoDB") 402 | .action("putItem") 403 | .physicalResourceId(PhysicalResourceId.of(tenantTable.getTableName() + "_initialization" + index)) 404 | .parameters(Map.ofEntries( 405 | Map.entry("TableName", tenantTable.getTableName()), 406 | Map.entry("Item", Map.ofEntries( 407 | Map.entry("tenantId", Map.of("S", tenantId)), 408 | Map.entry("tenantPerMinute", Map.of("N", tenantPerMinute)), 409 | Map.entry("sessionPerMinute", Map.of("N", sessionPerMinute)), 410 | Map.entry("connectionsPerSession", Map.of("N", connectionsPerSession)), 411 | Map.entry("tenantConnections", Map.of("N", tenantConnections)), 412 | Map.entry("sessionTTL", Map.of("N", sessionTTL)), 413 | Map.entry("messagesPerMinute", Map.of("N", messagesPerMinute)) 414 | )), 415 | Map.entry("ConditionExpression", "attribute_not_exists(tenantId)") 416 | )) 417 | .build(); 418 | 419 | AwsCustomResource tableInitializationResource = AwsCustomResource.Builder.create(this, "TenantSampleDataTableInitializationResource" + index) 420 | .policy(AwsCustomResourcePolicy.fromStatements(List.of( 421 | PolicyStatement.Builder.create() 422 | .effect(Effect.ALLOW) 423 | .actions(List.of("dynamodb:PutItem")) 424 | .resources(List.of(tenantTable.getTableArn())) 425 | .build() 426 | ))) 427 | .onCreate(initializeData) 428 | .build(); 429 | tableInitializationResource.getNode().addDependency(tenantTable); 430 | Function sqsEchoFunction = createSQSEchoLambda(tenantId); 431 | createSQSFifoQueuePerTenant(tenantId, sqsEchoFunction); 432 | } 433 | 434 | private Function createSQSEchoLambda(String tenantId) { 435 | Function function = Function.Builder.create(this, "SQSEcho" + tenantId) 436 | .runtime(Runtime.NODEJS_14_X) 437 | .code(Code.fromAsset("lambda")) 438 | .handler("SQSEcho.handler") 439 | .build(); 440 | Tags.of(function).add("tenantId", tenantId); 441 | Role lambdaTableRole = Role.Builder.create(this, "SQSEcho" + tenantId + "TableRole").assumedBy(new SessionTagsPrincipal(function.getRole())).build(); 442 | setupWebSocketFunction(function, lambdaTableRole, null, true); 443 | return function; 444 | } 445 | 446 | private void createSQSFifoQueuePerTenant(String tenantId, Function sqsEchoFunction) { 447 | Queue tenantQueue = Queue.Builder.create(this, "TenantQueue" + tenantId) 448 | .fifo(true) 449 | .fifoThroughputLimit(FifoThroughputLimit.PER_MESSAGE_GROUP_ID) 450 | .deduplicationScope(DeduplicationScope.MESSAGE_GROUP) 451 | .queueName("tenant-" + tenantId + ".fifo") 452 | .build(); 453 | Tags.of(tenantQueue).add("tenantId", tenantId); 454 | sqsEchoFunction.addEventSource(SqsEventSource.Builder.create(tenantQueue).enabled(true).build()); 455 | } 456 | 457 | private void createOutputs() { 458 | CfnOutput.Builder.create(this, "SessionURL") 459 | .exportName("SessionURL") 460 | .value(sessionApi.getApiEndpoint() + "/production/session") 461 | .build(); 462 | CfnOutput.Builder.create(this, "TenantURL") 463 | .exportName("TenantURL") 464 | .value(sessionApi.getApiEndpoint() + "/production/tenant") 465 | .build(); 466 | CfnOutput.Builder.create(this, "WebSocketURL") 467 | .exportName("WebSocketURL") 468 | .value(stage.getUrl()) 469 | .build(); 470 | CfnOutput.Builder.create(this, "SampleClient") 471 | .exportName("SampleClient") 472 | .value(sessionApi.getApiEndpoint() + "/production/SampleClient") 473 | .build(); 474 | } 475 | } 476 | --------------------------------------------------------------------------------