├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── architecture └── arch_aoss_rag.png ├── bin └── rag-with-amazon-bedrock-and-opensearch.ts ├── cdk.json ├── jest.config.js ├── knowledgebase └── .gitkeep ├── lambda ├── aoss-trigger │ └── app.py ├── aoss-update │ ├── Dockerfile │ ├── app.py │ └── requirements.txt ├── app-client-create-trigger │ └── app.py ├── call-back-url-init │ └── app.py ├── call-back-url-update │ └── app.py └── pdf-processor │ ├── Dockerfile │ ├── lambda_function.py │ └── requirements.txt ├── lib ├── aoss-update-stack.ts ├── base-infra-stack.ts ├── opensearch-stack.ts ├── rag-app-stack.ts └── test-compute-stack.ts ├── package-lock.json ├── package.json ├── rag-app ├── Dockerfile ├── __init__.py ├── aoss_chat_bedrock.py ├── app.py ├── app_init.py ├── helper_functions.py ├── images │ ├── ai-icon.png │ └── user-icon.png ├── requirements.txt └── run_app.sh ├── screenshots ├── aoss_api_dash.png ├── aoss_dashboard.png ├── app_screenshot.png ├── cog_login_page.png └── invalid_cert.png ├── scripts ├── api-key-secret-manager-upload │ ├── api-key-secret-manager-upload.py │ └── requirements.txt └── self-signed-cert-utility │ ├── default_cert_params.json │ ├── requirements.txt │ └── self-signed-cert-utility.py ├── test └── rag-with-amazon-bedrock-and-opensearch.test.ts └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | *.js 2 | !jest.config.js 3 | *.d.ts 4 | node_modules 5 | 6 | # CDK asset staging directory 7 | .cdk.staging 8 | cdk.out 9 | 10 | # ssl script 11 | .ssl 12 | 13 | .DS_Store 14 | 15 | .venv 16 | venv 17 | env 18 | .env 19 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAG with Amazon Bedrock and OpenSearch 2 | 3 | Opinionated sample on how to configure and deploy [RAG (Retrieval Augmented Generation)](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) application. 4 | 5 | It is comprised of a few core pieces: 6 | 7 | * [Amazon Bedrock](https://aws.amazon.com/bedrock/) as the managed service providing easy API based access to [foundation models (FMs)](https://aws.amazon.com/what-is/foundation-models/). 8 | 9 | * [Amazon OpenSearch Service](https://aws.amazon.com/opensearch-service/). This is an open-source alternative to using [Amazon Kendra](https://aws.amazon.com/kendra/). 10 | 11 | * [LangChain](https://www.langchain.com/) as a [Large Language Model (LLM)](https://www.elastic.co/what-is/large-language-models) application framework. It has also been used to update the OpenSearch index when new documents get added to the knowledgebase S3 bucket. 12 | 13 | * [Amazon Elastic Container Service (ECS)](https://aws.amazon.com/ecs/) to run the RAG Application. 14 | 15 | * [Streamlit](https://streamlit.io/) for the frontent user interface of the RAG Application. 16 | 17 | * [Application Load Balancer](https://aws.amazon.com/elasticloadbalancing/application-load-balancer/) to route HTTPS traffic to the ECS service (which is running the RAG App). 18 | 19 | * [Amazon Cognito](https://aws.amazon.com/cognito/) for secure user authentication. 20 | 21 | This sample is inspired by [another sample](https://github.com/aws-samples/rag-with-amazon-bedrock-and-pgvector) that demonstrates a similar functionality with PGVector (instead of OpenSearch). 22 | 23 | ## Architecture 24 | 25 | ![Architecture](./architecture/arch_aoss_rag.png) 26 | 27 | ## Short note on vector data stores 28 | 29 | [Vector database](https://en.wikipedia.org/wiki/Vector_database) is an essential component of any RAG application. The LLM framework uses the vector data store to search for information based on the question that comes from the user. 30 | 31 | Typical assumption (*and a strong constraint on this sample project*) is that a knowledgebase would comprise of PDF documents stored somewhere. Ideally, a true knowledgebase would encompass a lot more - would scrape websites, wiki pages and so on. But to limit the scope of this sample, the knowledgebase is an [S3](https://aws.amazon.com/s3/) bucket containing a bunch of PDF documents. 32 | 33 | A popular choice for vector database in an AWS based RAG app is Amazon Kendra. It does [optical character recognition (OCR)](https://en.wikipedia.org/wiki/Optical_character_recognition) for PDFs under the hood. It is a fully managed search service with seemless integration with AWS Services like S3. Additionally, Amazon Bedrock also has a vector database offering in the form of ["Knowledgebases"](https://aws.amazon.com/bedrock/knowledge-bases/). 34 | 35 | NOTE - "Bedrock Knowledgebases" is another vector store offering; and **it should not** be confused with the term "knowledgebase" and/or "knowledgebase bucket" which refers to the S3 bucket containing PDF documents in this project. 36 | 37 | However, the purpose of this sample was to show how to set up an open-source vector database, and since Kendra and Bedrock Knowledgebases are not open source, this sample focuses on OpenSearch. Unlike Kendra, OpenSearch cannot directly query PDF documents, so we need to extract the text, and then feed the text to OpenSearch. 38 | 39 | ## OpenSearch orchestration 40 | 41 | The expectation is that PDF files will land in the knowledgebase S3 bucket - either by manually uploading it via the console, or programmatically via the [AWS CLI](https://aws.amazon.com/cli/) or by running `cdk deploy BaseInfraStack`. NOTE - the last option (*`cdk deploy`*) requires that you put the PDF files in the ["knowledgebase"](./knowledgebase/) directory of this project. The [S3 Bucket Deployment](https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_s3_deployment.BucketDeployment.html) construct will then upload these files to the knowledgebase bucket. 42 | 43 | Once the files land in the knowledgebase S3 bucket, [S3 Event Notifications](https://docs.aws.amazon.com/AmazonS3/latest/userguide/EventNotifications.html) initiate a [lambda](https://docs.aws.amazon.com/lambda/latest/dg/welcome.html) function to extract text from the PDF file(s), and upload the converted text files into the "processed text S3 Bucket". The code/logic for this conversion [lambda function](./lambda/pdf-processor/lambda_function.py) is in the [lambda/pdf-processor](./lambda/pdf-processor/) directory. The function uses the [pypdf](https://github.com/py-pdf/pypdf) Python Library to achieve the text extraction. 44 | 45 | After the processed text files land in the "processed text S3 bucket", another S3 Event Notification triggers another lambda function ([aoss-trigger](./lambda/aoss-trigger/app.py)) that extract the necessary information about the file and pushes it off to an [Amazon SQS](https://aws.amazon.com/sqs/) queue. 46 | 47 | That message push in the SQS, initiates another lambda function ([aoss-update](./lambda/aoss-update/app.py)) that finally updates the vector database with the contents of the processed text file to be indexed (*which will enable it to be searched by the RAG app*). It uses the [S3FileLoader](https://python.langchain.com/docs/integrations/document_loaders/aws_s3_file) component from LangChain to extract document contents to feed OpenSearch. 48 | 49 | ### Short note on Embeddings 50 | 51 | [Embeddings](https://www.elastic.co/what-is/vector-embedding) are a way to convert words and sentences into numbers that capture their meaning and relationships. In the context of RAG, these "vector embeddings" aid in ["similarity search"](https://en.wikipedia.org/wiki/Similarity_search) capabilities. Adding documents to an OpenSearch index also requires creation/provisioning of embeddings. This project/sample has utilized [OpenAI's Embeddings](https://platform.openai.com/docs/guides/embeddings). So, if you wish to build/run this app in your own AWS environment, you would need to create an account with OpenAI and need their [API Key](https://help.openai.com/en/articles/4936850-where-do-i-find-my-api-key). 52 | 53 | **OpenAI has its own pricing on its API usage** so be mindful of that. You can find that out on their [pricing page](https://openai.com/pricing). You should be able to get going with the free credits, but if you keep this app running long enough, it will start accruing additional charges. 54 | 55 | Some other options to obtain embeddings - 56 | * [HuggingFace](https://huggingface.co/blog/getting-started-with-embeddings) 57 | * [Amazon Titan](https://aws.amazon.com/about-aws/whats-new/2023/09/amazon-titan-embeddings-generally-available/) 58 | 59 | NOTE - If you wish to use alternative embeddings, you will need to change the code in the [rag-app](./rag-app/) and the [aoss-update lambda function](./lambda/aoss-update/) accordingly. 60 | 61 | ## Deploying the app 62 | 63 | This project is divided into a few sub-stacks, so deploying it also requires a few additional steps. It uses [AWS CDK](https://aws.amazon.com/cdk/) for [Infrastructure as Code (IaC)](https://en.wikipedia.org/wiki/Infrastructure_as_code). 64 | 65 | ### Pre-requisites 66 | 67 | * Since this is a [TypeScript](https://www.typescriptlang.org/) CDK project, you should have [npm](https://www.npmjs.com/) installed (which is the package manager for TypeScript/JavaScript). 68 | * You can find installation instructions for npm [here](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm). 69 | 70 | * Install [AWS CLI](https://aws.amazon.com/cli/) on your computer (*if not already done so*). 71 | * `pip install awscli`. This means need to have python installed on your computer (if it is not already installed.) 72 | * You need to also configure and authenticate your AWS CLI to be able to interact with AWS programmatically. Detailed instructions of how you could do that are provided [here](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html) 73 | 74 | * You need to have [docker](https://www.docker.com/) installed on your computer. 75 | * You can check out these options for building and running docker containers on your local machine: 76 | * [Docker desktop](https://www.docker.com/products/docker-desktop/). Most popular container management app. Note - it does require a license if the organization you work at is bigger than a certain threshold. 77 | * [Rancher desktop](https://rancherdesktop.io/). It is a popular open source container management tool. 78 | * [Finch](https://github.com/runfinch/finch). Another open-source tool for container management.Note - currently it only supports MacOS machines. 79 | 80 | * Have an API Key from [OpenAI](https://openai.com/). This key is needed for programmatic access to use their embeddings for OpenSearch. You need to create an account with OpenAI (*if you already don't have one already*). Details to find/create an API Key can be found [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-api-key). 81 | 82 | ### Create a self-signed SSL certificate 83 | 84 | * Set the `IAM_SELF_SIGNED_SERVER_CERT_NAME` environment variable. This is the name of the self-signed server certificate that will be created ([via IAM](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_server-certs.html)) as part of the deployment. 85 | 86 | ``` 87 | export IAM_SELF_SIGNED_SERVER_CERT_NAME= 88 | ``` 89 | 90 | * Run the [self-signed-cert-utility.py](./scripts/self-signed-cert-utility/self-signed-cert-utility.py) script in the [scripts](./scripts/) directory to create a self-signed certificate, and upload its contents to AWS via `boto3` API calls. 91 | 92 | This is needed because the Application Load Balancer requires [SSL certificates](https://www.cloudflare.com/en-gb/learning/ssl/what-is-an-ssl-certificate/) to have a functioning [HTTPS listener](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/create-https-listener.html). 93 | 94 | 95 | # switch to the self-signed-cert-utility directory 96 | cd scripts/self-signed-cert-utility 97 | 98 | # create a python3 virtual environment (highly recommended) 99 | python3 -m virtualenv .certenv 100 | 101 | # activate the virtual environment 102 | source .certenv/bin/activate 103 | # for a different shell like fish, just add a `.fish` at the end of the previous command 104 | 105 | # install requirements 106 | pip install -r requirements.txt 107 | 108 | # run the script 109 | python self-signed-cert-utility.py 110 | # optionally specify a `--profile` if you're not using the default AWS profile 111 | 112 | # deactivate virtual environment 113 | deactivate 114 | 115 | # return to the root directory of the project 116 | cd - 117 | 118 | 119 | If the script runs successfully, you should see a a JSON like object printed out in the log output with parameters like `ServerCertificateName`, `ServerCertificateId`, `Arn` etc. Moreover, the `HTTPStatusCode` should have the value `200`. 120 | 121 | The parameters encoded in the certificate are in a JSON file. By default it expects a file named ["default_cert_parameters.json"](./scripts/self-signed-cert-utility/default_cert_params.json) unless otherwise specified. You may change the values of the default JSON file if you wish to. If you wish to use your own config file (*instead of the default*), you can do so by specifying the `--config-file` parameter. 122 | 123 | You can also specify a custom domain for the certificate by setting the `APP_DOMAIN` environment variable. 124 | 125 | NOTE - an alternative would be to use the [AWS Certificates Manager](https://aws.amazon.com/certificate-manager/) but it requires additional steps (*in the form of creating and registering your own domain, involve [Route53 hosted zones](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/hosted-zones-working-with.html) etc*). And since the focus of this sample is to show deployment of a RAG app, and not registering domains etc. it does not get into configuring that bit. 126 | 127 | 128 | ### Define the domain name for the Cognito hosted UI [Optional] 129 | 130 | Set the `COGNITO_DOMAIN_NAME` environment variable. This will be the domain of the [Cognito hosted UI](https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-app-integration.html) which will be used to "log-in" and/or "sign-up" into the app. 131 | ``` 132 | export COGNITO_DOMAIN_NAME= 133 | ``` 134 | 135 | The default value is defined in the [base-infra-stack.ts](./lib/base-infra-stack.ts#L260). 136 | 137 | ### Install dependencies (if not already done) 138 | 139 | ``` 140 | npm install 141 | ``` 142 | 143 | ### Bootstrap CDK environment (if not already done) 144 | 145 | Bootstrapping provisions resources in your environment such as an Amazon Simple Storage Service (Amazon S3) bucket for storing files and AWS Identity and Access Management (IAM) roles that grant permissions needed to perform deployments. These resources get provisioned in an AWS CloudFormation stack, called the bootstrap stack. It is usually named CDKToolkit. Like any AWS CloudFormation stack, it will appear in the AWS CloudFormation console of your environment once it has been deployed. More details can be found [here](https://docs.aws.amazon.com/cdk/v2/guide/bootstrapping.html). 146 | 147 | ``` 148 | npx cdk bootstrap 149 | 150 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 151 | ``` 152 | 153 | NOTE - you only need to do this once per account. If there are other CDK projects deployed in your AWS account, you won't need to do this. 154 | 155 | ### Set environment variable (if you are on an M1/M2 Mac) 156 | 157 | Depending on the architecture of your computer, you may need to set this environment variable for the docker container. This is because docker containers are dependent on the architecture of the host machine that is building/running them. 158 | 159 | **If your machine runs on the [x86](https://en.wikipedia.org/wiki/X86) architecture, you can ignore this step.** 160 | 161 | ``` 162 | export DOCKER_CONTAINER_PLATFORM_ARCH=arm 163 | ``` 164 | 165 | 166 | ### Deploy the BaseInfraStack 167 | 168 | ``` 169 | npx cdk deploy BaseInfraStack 170 | 171 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 172 | ``` 173 | 174 | This will deploy the base infrastructure - consisting of a VPC, Application Load Balancer for the app, S3 buckets (for knowledgebase, and the processed text), Lambda functions to process the PDF documents, some SQS queues for decoupling, a Secret credential for the OpenAI API key, Cognito user pool and some more bits and pieces of the cloud infrastructure. The CDK code for this is in the [lib](./lib) directory within the [base-infra-stack.ts](./lib/base-infra-stack.ts) file. 175 | 176 | ### Upload the OpenAI API key to Secrets Manager 177 | 178 | The secret was created after the deployment of the `BaseInfraStack` but the value inside it is not valid. You can either enter your OpenAI API key via the AWS Secrets Manager console; Or you could use the [api-key-secret-manager-upload.py](./scripts/api-key-secret-manager-upload/api-key-secret-manager-upload.py) script to do that for you. 179 | 180 | [AWS Secrets Manager](https://aws.amazon.com/secrets-manager/) is the recommended way to store credentials in AWS, as it provides API based access to credentials for databases etc. Since OpenAI (*the provider we are using the vector emebeddings from*) is an external service and has its own API keys, we need to manually upload that key to Secrets Manager so that the app infrastructure can access it securely. 181 | 182 | 183 | # switch to the api-key-secret-manager-upload directory 184 | cd scripts/api-key-secret-manager-upload 185 | 186 | # create a python3 virtual environment (highly recommended) 187 | python3 -m virtualenv .keyenv 188 | 189 | # activate the virtual environment 190 | source .keyenv/bin/activate 191 | # for a different shell like fish, just add a `.fish` at the end of the previous command 192 | 193 | # install requirements 194 | pip install -r requirements.txt 195 | 196 | # run the script; optionally specify a `--profile` if you're not using the default AWS profile 197 | python api-key-secret-manager-upload.py -s openAiApiKey 198 | 199 | 2024-01-14 19:42:59,341 INFO [__main__]:[MainThread] AWS Profile being used: default 200 | 2024-01-14 19:42:59,421 INFO [__main__]:[MainThread] Updating Secret: openAiApiKey 201 | Please enter the API Key: 202 | 2024-01-14 19:44:02,221 INFO [__main__]:[MainThread] Successfully updated secret value 203 | 2024-01-14 19:44:02,221 INFO [__main__]:[MainThread] Total time elapsed: 62.88090920448303 seconds 204 | # deactivate virtual environment 205 | deactivate 206 | 207 | # return to the root directory of the project 208 | cd - 209 | 210 | 211 | The script will prompt you to enter you OpenAI API key. It uses the [getpass](https://docs.python.org/3/library/getpass.html) Python library so that you don't have to enter it in plain text. 212 | 213 | NOTE - that the instructions specify `-s openAiApiKey`. It is the same name as defined in the [base-infra-stack.ts](./lib/base-infra-stack.ts?plain=1#L124). If you change the value there, you will need to change the value whilst running the script too. 214 | 215 | ### Deploy the TestCompute Stack 216 | 217 | ``` 218 | npx cdk deploy TestComputeStack 219 | 220 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 221 | ``` 222 | 223 | This will deploy an EC2 instance that you may use to troubleshoot OpenSearch connectivity/make API calls etc. and/or any other test/dev computing you might need to do. 224 | 225 | 226 | ### Deploy the OpenSearch stack 227 | 228 | ``` 229 | npx cdk deploy OpenSearchStack 230 | 231 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 232 | ``` 233 | 234 | This will deploy an [Amazon managed OpenSearch serverless collection](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-overview.html#serverless-start) - specialized to do Vector Searches. The reason to use this is so that you don't have to worry about managing the OpenSearch cluster. Additionally, it will create and attach some [network security policies](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-network.html), [encryption security policies](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-encryption.html) and [data access policies](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-data-access.html) to the collection. 235 | 236 | Note that this Serverless connection has ["AllowFromPublic" set to True](./lib/opensearch-stack.ts#L41). This enables you to easily access the OpenSearch Dashboard from the Amazon OpenSearch Service console. You do this by navigating to the Collections section in the console, select the collection that got created (as part of deploying the OpenSearch stack), and clicking on the Dashboard URL. 237 | 238 | ![aoss-dash](./screenshots/aoss_dashboard.png) 239 | 240 | If you don't set the `OPENSEARCH_COLLECTION_NAME` environment variable to something, by default the name of the collection will be "rag-collection". You can also change the default value of the collection name [here](./lib/opensearch-stack.ts#L22). 241 | 242 | ### Deploy the OpenSearch Update Stack 243 | 244 | ``` 245 | npx cdk deploy aossUpdateStack 246 | 247 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 248 | ``` 249 | 250 | This will deploy a [Lambda function](./lambda/aoss-update/) that will update the OpenSearch index whenever a new document lands in the processed text bucket. 251 | 252 | 253 | ### Deploy the RAG App Stack 254 | 255 | ``` 256 | npx cdk deploy ragStack 257 | 258 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 259 | ``` 260 | 261 | This will deploy the ECS Fargate service running the code for the RAG Application. It will also add this service as a target to the Application Load Balancer defined in the "BaseInfraStack". The CDK infrastructure code for this stack is in the [rag-app-stack.ts](./lib/rag-app-stack.ts) file in the [lib](./lib/) directory. 262 | 263 | This app leverages LangChain for interacting with Bedrock and OpenSearch; and Streamlit for the frontend user interface. The application code is in the [rag-app](./rag-app/) directory. 264 | 265 | 266 | 267 | ### Add some PDF documents to the knowledgebase S3 Bucket 268 | 269 | After deploying the "RagStack", the necessary infrastructure to have the app running is complete. That being said, you need to still populate the knowledgebase S3 bucket. 270 | 271 | You can either do it manually by going into the console and uploading some files. 272 | 273 | Or you can add some PDF files to the [knowledgebase](./knowledgebase/) directory in this project, and then run `npx cdk deploy BaseInfraStack`. The reason to prefer this option would be that you can then track knowledgebase documents in your source control (*if that is a requirement for your use case*). 274 | 275 | This should upload the document(s) to the Knowledgebase S3 bucket via the S3 Bucket Deployment construct. 276 | 277 | After the upload to the S3 knowledgebase bucket is complete, it will trigger the [pdf-processor](./lambda/pdf-processor/) lambda function to extract the text from the PDF and upload it to the "processed text s3 bucket". 278 | 279 | Upload to the processed text bucket will trigger the [aoss-update](./lambda/aoss-update/) lambda function to then add that document to the OpenSearch index. 280 | You can verify that it has been added to the vector store by making API calls to the OpenSearch endpoint, or making API calls to OpenSearch via the dashboard as shown below: 281 | 282 | ![aoss-api-dash](./screenshots/aoss_api_dash.png) 283 | 284 | ### Testing the RAG App 285 | 286 | After adding document(s) to the knowledgebase, you can now test the app. 287 | 288 | If you log into the AWS console, and find the Application Load Balancer (*under the EC2 section*) page, and select the load balancer that was created as part of the "BaseInfraStack", it should have a "DNS name". If you copy that name, and type `https://` in your browser, it should direct you to the app. 289 | 290 | Note - since we are using a self-signed SSL certificate (*via IAM Server Certificates*), you might see this warning on your browser (showing Chrome below): 291 | 292 | ![InvalidCert](./screenshots/invalid_cert.png) 293 | 294 | If you see that warning, click on advanced, and proceed to the URL, it should then direct you to the Login UI (*server via Cognito*): 295 | 296 | ![InvalidCert](./screenshots/cog_login_page.png) 297 | 298 | * You can either click sign-up and create a new user from this console (*note - you will have to verify the email you sign up with by entering the code that gets sent to that email*) 299 | 300 | * Alternatively you could create a user in the AWS Console (by navigating to the cognito service) 301 | 302 | * There is a programmatic way to create your user via the SDK; or you could use this [open-source helper utility](https://github.com/aws-samples/cognito-user-token-helper). 303 | 304 | Once you've successfully signed in (*or signed up*), you will see the UI, and you can start asking questions based on the document(s) you've uploaded. The example document used is the [2021 Amazon letter to the shareholder](https://www.aboutamazon.com/news/company-news/2021-letter-to-shareholders), and the question asked was "What is AWS?": 305 | ![AppScreen](./screenshots/app_screenshot.png) 306 | 307 | ## Miscellaneous notes / technical hiccups / recommendations 308 | 309 | * The **frontend user interface (UI)** was built using streamlit, and inspired by another [open source project](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples). 310 | 311 | * **Cognito Callback URL hiccup** - 312 | When creating an application load balancer via Infrastructure as Code (IaC), the DNS name is generated with some random characters (*that can be both UPPER CASE and lower case*). When configuring this with the Cognito User Pool Integration (*app client*), the DNS name is used for the Callback URL. The problem here is that Cognito does not like UPPER CASE characters, and whilst deploying this solution via IaC, there isn't much you can do about converting the DNS name to lower case (*because it is actually a token, and not the actual string value of the DNA name*). There is an [open Github issue](https://github.com/aws/aws-cdk/issues/11171) on this. 313 | 314 | So, in order to fix this, the project has Eventbridge triggers in place, that check for when the App integration client is created, a lambda function is invoked that pushes a message to an SQS queue, which invokes another Lambda function that updates the app client via the [update_user_pool_client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cognito-idp/client/update_user_pool_client.html) boto3 API call with the lower case DNS name in the callback URL. 315 | 316 | The code for the lambda function is in the [lambda/call-back-url-init](./lambda/call-back-url-init/) directory. 317 | 318 | * If you were to **deploy a solution like this in a production environment**, you would need to create and register your own domain to host the app. The recommendation would be to use **AWS Certificates Manager** to generate the self signed certificate, and link that with a **Route53 hosted zone**. More details can be found in [AWS Documentation](https://docs.aws.amazon.com/acm/latest/userguide/dns-validation.html). 319 | 320 | * While streamlit is good for quickly deploying UIs, it may not be best suited for production if the intent is to add more functionality to the app (*i.e. extending it beyong the RAG Chatbot app*). It may be worth looking at [AWS Amplify](https://aws.amazon.com/amplify). Decoupling the frontend from the backend could also introduce the possibility of running the backend as a Lambda Function with API Gateway. 321 | 322 | * Alternate vector embedding providers like HuggingFace and/or Amazon Titan would require some code changes (*specifically in the Lambda function(s) that update the OpenSearch index via LangChain, and the ECS application running the RAG app*). 323 | 324 | * The model used in this sample is [Anthropic's Claude V1 Instant](https://aws.amazon.com/bedrock/claude/). You can change the model by providing an environment variable `FOUNDATION_MODEL_ID` to the rag app in the [rag-app-stack.ts](./lib/rag-app-stack.ts?plain=1#L143). You can find the different model IDs on the [AWS Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html) 325 | 326 | * Key concepts / techniques covered in this sample - 327 | * OpenSearch as an open-source Vector database option for RAG applications 328 | * Using LangChain to serve a RAG application, update the OpenSearch index 329 | * Application Load Balancer (ALB) + ECS Fargate Service to serve an app 330 | * Using self signed certificates to configure the HTTPS listener for the ALB 331 | * Integrating a Cognito Login UI with the ALB 332 | 333 | ## Generic CDK instructions 334 | 335 | This is a blank project for CDK development with TypeScript. 336 | 337 | The `cdk.json` file tells the CDK Toolkit how to execute your app. 338 | 339 | ## Useful commands 340 | 341 | * `npm run build` compile typescript to js 342 | * `npm run watch` watch for changes and compile 343 | * `npm run test` perform the jest unit tests 344 | * `cdk deploy` deploy this stack to your default AWS account/region 345 | * `cdk diff` compare deployed stack with current state 346 | * `cdk synth` emits the synthesized CloudFormation template 347 | -------------------------------------------------------------------------------- /architecture/arch_aoss_rag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/architecture/arch_aoss_rag.png -------------------------------------------------------------------------------- /bin/rag-with-amazon-bedrock-and-opensearch.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | import 'source-map-support/register'; 3 | import * as cdk from 'aws-cdk-lib'; 4 | import { RagAppStack } from '../lib/rag-app-stack'; 5 | import { BaseInfraStack } from '../lib/base-infra-stack'; 6 | import { TestComputeStack } from '../lib/test-compute-stack'; 7 | import { OpenSearchStack } from '../lib/opensearch-stack'; 8 | import { OpenSearchUpdateStack } from '../lib/aoss-update-stack'; 9 | 10 | const app = new cdk.App(); 11 | 12 | // contains vpc, 13 | const baseInfra = new BaseInfraStack(app, 'BaseInfraStack', { 14 | }); 15 | 16 | 17 | // for a test EC2 instance to play around with (optional) 18 | const testComputeStack = new TestComputeStack(app, 'TestComputeStack', { 19 | vpc: baseInfra.vpc, 20 | ec2SG: baseInfra.ec2SecGroup, 21 | }); 22 | 23 | // OpenSearch Serverless Creation. TODO: fix the stackname to be consistent 24 | const opensearchStack = new OpenSearchStack(app, 'OpenSearchStack', { 25 | testComputeHostRole: testComputeStack.hostRole, 26 | lambdaRole: baseInfra.aossUpdateLambdaRole, 27 | ecsTaskRole: baseInfra.ecsTaskRole 28 | }); 29 | 30 | // lambda function to update the aoss index upon new document landing 31 | const aossUpdateStack = new OpenSearchUpdateStack(app, 'aossUpdateStack', { 32 | processedBucket: baseInfra.processedBucket, 33 | indexName: baseInfra.aossIndexName, 34 | apiKeySecret: baseInfra.apiKeySecret, 35 | triggerQueue: baseInfra.aossQueue, 36 | lambdaRole: baseInfra.aossUpdateLambdaRole, 37 | aossHost: opensearchStack.serverlessCollection.attrId 38 | }); 39 | 40 | // ecs service 41 | const ragApp = new RagAppStack(app, 'ragStack', { 42 | vpc: baseInfra.vpc, 43 | indexName: baseInfra.aossIndexName, 44 | apiKeySecret: baseInfra.apiKeySecret, 45 | taskSecGroup: baseInfra.ecsTaskSecGroup, 46 | aossHost: opensearchStack.serverlessCollection.attrId, 47 | elbTargetGroup: baseInfra.appTargetGroup, 48 | taskRole: baseInfra.ecsTaskRole 49 | }); 50 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/rag-with-amazon-bedrock-and-opensearch.ts", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "**/*.d.ts", 11 | "**/*.js", 12 | "tsconfig.json", 13 | "package*.json", 14 | "yarn.lock", 15 | "node_modules", 16 | "test" 17 | ] 18 | }, 19 | "context": { 20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 21 | "@aws-cdk/core:checkSecretUsage": true, 22 | "@aws-cdk/core:target-partitions": [ 23 | "aws", 24 | "aws-cn" 25 | ], 26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 29 | "@aws-cdk/aws-iam:minimizePolicies": true, 30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 35 | "@aws-cdk/core:enablePartitionLiterals": true, 36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 41 | "@aws-cdk/aws-route53-patters:useCertificate": true, 42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 48 | "@aws-cdk/aws-redshift:columnId": true, 49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 52 | "@aws-cdk/aws-kms:aliasNameRef": true, 53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 55 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 56 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 57 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 58 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true, 59 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true, 60 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true, 61 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true, 62 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | testEnvironment: 'node', 3 | roots: ['/test'], 4 | testMatch: ['**/*.test.ts'], 5 | transform: { 6 | '^.+\\.tsx?$': 'ts-jest' 7 | } 8 | }; 9 | -------------------------------------------------------------------------------- /knowledgebase/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/knowledgebase/.gitkeep -------------------------------------------------------------------------------- /lambda/aoss-trigger/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | QUEUE_URL_ENV_VAR = "AOSS_UPDATE_QUEUE" 11 | BUCKET_ENV_VAR = "BUCKET_NAME" 12 | 13 | MALFORMED_EVENT_LOG_MSG = "Malformed event. Skipping this record." 14 | 15 | 16 | class MalformedEvent(Exception): 17 | """Raised if a malformed event received""" 18 | 19 | 20 | class MissingEnvironmentVariable(Exception): 21 | """Raised if a required environment variable is missing""" 22 | 23 | 24 | def _silence_noisy_loggers(): 25 | """Silence chatty libraries for better logging""" 26 | for logger in ['boto3', 'botocore', 27 | 'botocore.vendored.requests.packages.urllib3']: 28 | logging.getLogger(logger).setLevel(logging.WARNING) 29 | 30 | 31 | def _configure_logger(): 32 | """Configure python logger""" 33 | level = logging.INFO 34 | verbose = os.environ.get("VERBOSE", "") 35 | if verbose.lower() == "true": 36 | print("Will set the logging output to DEBUG") 37 | level = logging.DEBUG 38 | 39 | if len(logging.getLogger().handlers) > 0: 40 | # The Lambda environment pre-configures a handler logging to stderr. 41 | # If a handler is already configured, `.basicConfig` does not execute. 42 | # Thus we set the level directly. 43 | logging.getLogger().setLevel(level) 44 | else: 45 | logging.basicConfig(level=level) 46 | 47 | 48 | def _check_missing_field(validation_dict, extraction_key): 49 | """Check if a field exists in a dictionary 50 | 51 | :param validation_dict: Dictionary 52 | :param extraction_key: String 53 | 54 | :raises: MalformedEvent 55 | """ 56 | extracted_value = validation_dict.get(extraction_key) 57 | 58 | if not extracted_value: 59 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 60 | raise MalformedEvent 61 | 62 | 63 | def _validate_field(validation_dict, extraction_key, expected_value): 64 | """Validate the passed in field 65 | 66 | :param validation_dict: Dictionary 67 | :param extraction_key: String 68 | :param expected_value: String 69 | 70 | :raises: ValueError 71 | """ 72 | extracted_value = validation_dict.get(extraction_key) 73 | 74 | _check_missing_field(validation_dict, extraction_key) 75 | 76 | if extracted_value != expected_value: 77 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 78 | raise ValueError 79 | 80 | 81 | def _record_validation(record, source_bucket_name): 82 | """Validate record 83 | 84 | :param record: Dictionary 85 | :param source_bucket_name: String 86 | 87 | :rtype: Boolean 88 | """ 89 | # validate eventSource 90 | _validate_field(record, "eventSource", "aws:s3") 91 | 92 | # validate eventSource 93 | _check_missing_field(record, "eventName") 94 | if not record["eventName"].startswith("ObjectCreated"): 95 | LOGGER.warning( 96 | "Found a non ObjectCreated event, ignoring this record") 97 | return False 98 | 99 | # check for 's3' in response elements 100 | _check_missing_field(record, "s3") 101 | 102 | s3_data = record["s3"] 103 | # validate s3 data 104 | _check_missing_field(s3_data, "bucket") 105 | _validate_field(s3_data["bucket"], "name", source_bucket_name) 106 | 107 | # check for object 108 | _check_missing_field(s3_data, "object") 109 | # check for key 110 | _check_missing_field(s3_data["object"], "key") 111 | 112 | return True 113 | 114 | 115 | def _send_message_to_sqs(client, queue_url, message_dict): 116 | """Send message to SQS Queue 117 | 118 | :param client: Boto3 client object (SQS) 119 | :param queue_url: String 120 | :param message_dict: Dictionary 121 | 122 | :raises: Exception 123 | """ 124 | LOGGER.info(f"Attempting to send message to: {queue_url}") 125 | resp = client.send_message( 126 | QueueUrl=queue_url, 127 | MessageBody=json.dumps(message_dict) 128 | ) 129 | 130 | _check_missing_field(resp, "ResponseMetadata") 131 | resp_metadata = resp["ResponseMetadata"] 132 | 133 | _check_missing_field(resp_metadata, "HTTPStatusCode") 134 | status_code = resp_metadata["HTTPStatusCode"] 135 | 136 | if status_code == 200: 137 | LOGGER.info("Successfully pushed message") 138 | else: 139 | raise Exception("Unable to push message") 140 | 141 | 142 | def lambda_handler(event, context): 143 | """What executes when the program is run""" 144 | 145 | # configure python logger 146 | _configure_logger() 147 | # silence chatty libraries 148 | _silence_noisy_loggers() 149 | 150 | # check for DESTINATION_BUCKET_NAME env var 151 | bucket_name = os.environ.get(BUCKET_ENV_VAR) 152 | if not bucket_name: 153 | raise MissingEnvironmentVariable(BUCKET_ENV_VAR) 154 | LOGGER.info(f"destination bucket: {bucket_name}") 155 | 156 | queue_url = os.environ.get(QUEUE_URL_ENV_VAR) 157 | if not queue_url: 158 | raise MissingEnvironmentVariable( 159 | f"{QUEUE_URL_ENV_VAR} environment variable is required") 160 | 161 | # check for 'records' field in the event 162 | _check_missing_field(event, "Records") 163 | records = event["Records"] 164 | 165 | if not isinstance(records, list): 166 | raise Exception("'Records' is not a list") 167 | LOGGER.info("Extracted 'Records' from the event") 168 | 169 | sqs_client = boto3.client("sqs") 170 | for record in records: 171 | try: 172 | valid_record = _record_validation(record, bucket_name) 173 | except KeyError: 174 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 175 | continue 176 | except ValueError: 177 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 178 | continue 179 | if not valid_record: 180 | LOGGER.warning( 181 | "record could not be validated. Skipping this one.") 182 | continue 183 | 184 | _send_message_to_sqs( 185 | sqs_client, 186 | queue_url, 187 | { 188 | "bucket": bucket_name, 189 | "file": record["s3"]["object"]["key"] 190 | } 191 | ) 192 | sqs_client.close() 193 | -------------------------------------------------------------------------------- /lambda/aoss-update/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | COPY requirements.txt requirements.txt 4 | 5 | RUN yum update -y && pip3 install -r requirements.txt 6 | 7 | COPY app.py app.py 8 | 9 | CMD [ "app.lambda_handler"] 10 | -------------------------------------------------------------------------------- /lambda/aoss-update/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from botocore.exceptions import ClientError 7 | from langchain_community.document_loaders import S3FileLoader 8 | from langchain_community.vectorstores import OpenSearchVectorSearch 9 | from langchain_openai import OpenAIEmbeddings 10 | from opensearchpy import RequestsHttpConnection, AWSV4SignerAuth 11 | 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | SQS_QUEUE_ENV_VAR = "QUEUE_URL" 16 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 17 | 18 | AOSS_INDEX_NAME_ENV_VAR = "AOSS_INDEX_NAME" 19 | AOSS_ID_ENV_VAR = "AOSS_ID" 20 | AOSS_AWS_REGION_ENV_VAR = "AOSS_AWS_REGION" 21 | AOSS_SVC_NAME = "aoss" 22 | 23 | DEFAULT_TIMEOUT_AOSS = 100 24 | DEFAULT_AOSS_ENGINE = "faiss" 25 | 26 | 27 | class MalformedEvent(Exception): 28 | """Raised if a malformed event received""" 29 | 30 | 31 | class MissingEnvironmentVariable(Exception): 32 | """Raised if a required environment variable is missing""" 33 | 34 | 35 | def _silence_noisy_loggers(): 36 | """Silence chatty libraries for better logging""" 37 | for logger in ['boto3', 'botocore', 38 | 'botocore.vendored.requests.packages.urllib3']: 39 | logging.getLogger(logger).setLevel(logging.WARNING) 40 | 41 | 42 | def _configure_logger(): 43 | """Configure python logger for lambda function""" 44 | default_log_args = { 45 | "level": logging.DEBUG if os.environ.get("VERBOSE", False) else logging.INFO, 46 | "format": "%(asctime)s [%(levelname)s] %(name)s - %(message)s", 47 | "datefmt": "%d-%b-%y %H:%M", 48 | "force": True, 49 | } 50 | logging.basicConfig(**default_log_args) 51 | 52 | 53 | def _check_missing_field(validation_dict, extraction_key): 54 | """Check if a field exists in a dictionary 55 | 56 | :param validation_dict: Dictionary 57 | :param extraction_key: String 58 | 59 | :raises: KeyError 60 | """ 61 | extracted_value = validation_dict.get(extraction_key) 62 | 63 | if not extracted_value: 64 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 65 | raise KeyError 66 | 67 | 68 | def _get_message_body(event): 69 | """Extract message body from the event 70 | 71 | :param event: Dictionary 72 | 73 | :raises: MalformedEvent 74 | 75 | :rtype: Dictionary 76 | """ 77 | body = "" 78 | test_event = event.get("test_event", "") 79 | if test_event.lower() == "true": 80 | LOGGER.info("processing test event (and not from SQS)") 81 | LOGGER.debug("Test body: %s", event) 82 | return event 83 | else: 84 | LOGGER.info("Attempting to extract message body from SQS") 85 | 86 | _check_missing_field(event, "Records") 87 | records = event["Records"] 88 | 89 | first_record = records[0] 90 | 91 | try: 92 | body = first_record.get("body") 93 | except AttributeError: 94 | raise MalformedEvent("First record is not a proper dict") 95 | 96 | if not body: 97 | raise MalformedEvent("Missing 'body' in the record") 98 | 99 | try: 100 | return json.loads(body) 101 | except json.decoder.JSONDecodeError: 102 | raise MalformedEvent("'body' is not valid JSON") 103 | 104 | 105 | def _get_sqs_message_attributes(event): 106 | """Extract receiptHandle from message 107 | 108 | :param event: Dictionary 109 | 110 | :raises: MalformedEvent 111 | 112 | :rtype: Dictionary 113 | """ 114 | LOGGER.info("Attempting to extract receiptHandle from SQS") 115 | records = event.get("Records") 116 | if not records: 117 | LOGGER.warning("No receiptHandle found, probably not an SQS message") 118 | return 119 | try: 120 | first_record = records[0] 121 | except IndexError: 122 | raise MalformedEvent("Records seem to be empty") 123 | 124 | _check_missing_field(first_record, "receiptHandle") 125 | receipt_handle = first_record["receiptHandle"] 126 | 127 | _check_missing_field(first_record, "messageId") 128 | message_id = first_record["messageId"] 129 | 130 | return { 131 | "message_id": message_id, 132 | "receipt_handle": receipt_handle 133 | } 134 | 135 | 136 | def get_secret_from_name(secret_name, kv=True): 137 | """Return secret from secret name 138 | 139 | :param secret_name: String 140 | :param kv: Boolean (weather it is json or not) 141 | 142 | :raises: botocore.exceptions.ClientError 143 | 144 | :rtype: Dictionary 145 | """ 146 | session = boto3.session.Session() 147 | 148 | # Initializing Secret Manager's client 149 | client = session.client( 150 | service_name='secretsmanager', 151 | region_name=os.environ.get("AWS_REGION", session.region_name) 152 | ) 153 | LOGGER.info(f"Attempting to get secret value for: {secret_name}") 154 | try: 155 | get_secret_value_response = client.get_secret_value( 156 | SecretId=secret_name) 157 | except ClientError as e: 158 | # For a list of exceptions thrown, see 159 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 160 | LOGGER.error("Unable to fetch details from Secrets Manager") 161 | raise e 162 | 163 | _check_missing_field( 164 | get_secret_value_response, "SecretString") 165 | 166 | if kv: 167 | return json.loads( 168 | get_secret_value_response["SecretString"]) 169 | else: 170 | return get_secret_value_response["SecretString"] 171 | 172 | 173 | def lambda_handler(event, context): 174 | """What executes when the program is run""" 175 | 176 | # configure python logger for Lambda 177 | _configure_logger() 178 | # silence chatty libraries for better logging 179 | _silence_noisy_loggers() 180 | 181 | msg_attr = _get_sqs_message_attributes(event) 182 | 183 | if msg_attr: 184 | 185 | # Because messages remain in the queue 186 | LOGGER.info( 187 | f"Deleting message {msg_attr['message_id']} from sqs") 188 | sqs_client = boto3.client("sqs") 189 | queue_url = os.environ.get(SQS_QUEUE_ENV_VAR) 190 | if not queue_url: 191 | raise MissingEnvironmentVariable( 192 | f"{SQS_QUEUE_ENV_VAR} environment variable is required") 193 | 194 | deletion_resp = sqs_client.delete_message( 195 | QueueUrl=queue_url, 196 | ReceiptHandle=msg_attr["receipt_handle"]) 197 | 198 | sqs_client.close() 199 | 200 | resp_metadata = deletion_resp.get("ResponseMetadata") 201 | if not resp_metadata: 202 | raise Exception( 203 | "No response metadata from deletion call") 204 | status_code = resp_metadata.get("HTTPStatusCode") 205 | 206 | if status_code == 200: 207 | LOGGER.info(f"Successfully deleted message") 208 | else: 209 | raise Exception("Unable to delete message") 210 | 211 | body = _get_message_body(event) 212 | 213 | _check_missing_field(body, "bucket") 214 | _check_missing_field(body, "file") 215 | 216 | aoss_id = os.environ.get(AOSS_ID_ENV_VAR) 217 | if not aoss_id: 218 | raise MissingEnvironmentVariable( 219 | f"{AOSS_ID_ENV_VAR} environment variable is required") 220 | 221 | aoss_region = os.environ.get(AOSS_AWS_REGION_ENV_VAR) 222 | if not aoss_region: 223 | raise MissingEnvironmentVariable( 224 | f"{AOSS_AWS_REGION_ENV_VAR} environment variable is required") 225 | 226 | index_name = os.environ.get(AOSS_INDEX_NAME_ENV_VAR) 227 | if not index_name: 228 | raise MissingEnvironmentVariable( 229 | f"{AOSS_INDEX_NAME_ENV_VAR} environment variable is required") 230 | 231 | openai_secret = os.environ.get(API_KEY_SECRET_ENV_VAR) 232 | if not openai_secret: 233 | raise MissingEnvironmentVariable( 234 | f"{API_KEY_SECRET_ENV_VAR} environment variable is required") 235 | os.environ["OPENAI_API_KEY"] = get_secret_from_name( 236 | openai_secret, kv=False) 237 | LOGGER.info("Fetching OpenAI embeddings") 238 | embeddings = OpenAIEmbeddings() 239 | 240 | LOGGER.info("Initializing S3FileLoader") 241 | loader = S3FileLoader(body['bucket'], body['file']) 242 | 243 | LOGGER.info( 244 | f"Loading document: {body['file']} from bucket: {body['bucket']}") 245 | docs = loader.load() 246 | 247 | LOGGER.info("Setting up auth for OpenSearch Serverless") 248 | auth = AWSV4SignerAuth( 249 | boto3.Session().get_credentials(), 250 | aoss_region, 251 | AOSS_SVC_NAME 252 | ) 253 | 254 | LOGGER.info("Adding new document to the vector store") 255 | docsearch = OpenSearchVectorSearch.from_documents( 256 | docs, 257 | embeddings, 258 | opensearch_url=f"{aoss_id}.{aoss_region}.{AOSS_SVC_NAME}.amazonaws.com:443", 259 | http_auth=auth, 260 | timeout=DEFAULT_TIMEOUT_AOSS, 261 | use_ssl=True, 262 | verify_certs=True, 263 | connection_class = RequestsHttpConnection, 264 | index_name=index_name, 265 | engine=DEFAULT_AOSS_ENGINE, 266 | ) 267 | -------------------------------------------------------------------------------- /lambda/aoss-update/requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | langchain 3 | langchain-community 4 | langchain-openai 5 | opensearch-py 6 | unstructured 7 | # nltk 8 | -------------------------------------------------------------------------------- /lambda/app-client-create-trigger/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | QUEUE_URL_ENV_VAR = "TRIGGER_QUEUE" 11 | 12 | 13 | class MalformedEvent(Exception): 14 | """Raised if a malformed event received""" 15 | 16 | 17 | class MissingEnvironmentVariable(Exception): 18 | """Raised if a required environment variable is missing""" 19 | 20 | 21 | def _silence_noisy_loggers(): 22 | """Silence chatty libraries for better logging""" 23 | for logger in ['boto3', 'botocore', 24 | 'botocore.vendored.requests.packages.urllib3']: 25 | logging.getLogger(logger).setLevel(logging.WARNING) 26 | 27 | 28 | def _check_missing_field(validation_dict, extraction_key): 29 | """Check if a field exists in a dictionary 30 | 31 | :param validation_dict: Dictionary 32 | :param extraction_key: String 33 | 34 | :raises: MalformedEvent 35 | """ 36 | extracted_value = validation_dict.get(extraction_key) 37 | 38 | if not extracted_value: 39 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 40 | raise MalformedEvent 41 | 42 | 43 | def _validate_field(validation_dict, extraction_key, expected_value): 44 | """Validate the passed in field 45 | 46 | :param validation_dict: Dictionary 47 | :param extraction_key: String 48 | :param expected_value: String 49 | 50 | :raises: ValueError 51 | """ 52 | extracted_value = validation_dict.get(extraction_key) 53 | 54 | _check_missing_field(validation_dict, extraction_key) 55 | 56 | if extracted_value != expected_value: 57 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 58 | raise ValueError 59 | 60 | 61 | def _extract_valid_event(event): 62 | """Validate incoming event and extract necessary attributes 63 | 64 | :param event: Dictionary 65 | 66 | :raises: MalformedEvent 67 | :raises: ValueError 68 | 69 | :rtype: Dictionary 70 | """ 71 | 72 | _validate_field(event, "source", "aws.cognito-idp") 73 | 74 | _check_missing_field(event, "detail") 75 | event_detail = event["detail"] 76 | 77 | _validate_field( 78 | event_detail, 79 | "sourceIPAddress", 80 | "cloudformation.amazonaws.com" 81 | ) 82 | 83 | _validate_field( 84 | event_detail, 85 | "eventSource", 86 | "cognito-idp.amazonaws.com" 87 | ) 88 | _validate_field(event_detail, "eventName", "CreateUserPoolClient") 89 | 90 | _check_missing_field(event_detail, "responseElements") 91 | _check_missing_field(event_detail["responseElements"], "userPoolClient") 92 | 93 | return event_detail["responseElements"]["userPoolClient"] 94 | 95 | 96 | def _configure_logger(): 97 | """Configure python logger""" 98 | level = logging.INFO 99 | verbose = os.environ.get("VERBOSE", "") 100 | if verbose.lower() == "true": 101 | print("Will set the logging output to DEBUG") 102 | level = logging.DEBUG 103 | 104 | if len(logging.getLogger().handlers) > 0: 105 | # The Lambda environment pre-configures a handler logging to stderr. 106 | # If a handler is already configured, `.basicConfig` does not execute. 107 | # Thus we set the level directly. 108 | logging.getLogger().setLevel(level) 109 | else: 110 | logging.basicConfig(level=level) 111 | 112 | 113 | def _send_message_to_sqs(client, queue_url, message_dict): 114 | """Send message to SQS Queue 115 | 116 | :param client: Boto3 client object (SQS) 117 | :param queue_url: String 118 | :param message_dict: Dictionary 119 | 120 | :raises: Exception 121 | """ 122 | LOGGER.info(f"Attempting to send message to: {queue_url}") 123 | resp = client.send_message( 124 | QueueUrl=queue_url, 125 | MessageBody=json.dumps(message_dict) 126 | ) 127 | 128 | _check_missing_field(resp, "ResponseMetadata") 129 | resp_metadata = resp["ResponseMetadata"] 130 | 131 | _check_missing_field(resp_metadata, "HTTPStatusCode") 132 | status_code = resp_metadata["HTTPStatusCode"] 133 | 134 | if status_code == 200: 135 | LOGGER.info("Successfully pushed message") 136 | else: 137 | raise Exception("Unable to push message") 138 | 139 | 140 | def lambda_handler(event, context): 141 | """What executes when the program is run""" 142 | 143 | # configure python logger 144 | _configure_logger() 145 | # silence chatty libraries 146 | _silence_noisy_loggers() 147 | 148 | client_details = _extract_valid_event(event) 149 | LOGGER.info("Extracted user pool details") 150 | 151 | sqs_client = boto3.client("sqs") 152 | rds_ddl_queue_url = os.environ.get(QUEUE_URL_ENV_VAR) 153 | if not rds_ddl_queue_url: 154 | raise MissingEnvironmentVariable( 155 | f"{QUEUE_URL_ENV_VAR} environment variable is required") 156 | 157 | # send message to Triggering Queue 158 | _send_message_to_sqs( 159 | sqs_client, 160 | rds_ddl_queue_url, 161 | client_details) 162 | 163 | sqs_client.close() 164 | -------------------------------------------------------------------------------- /lambda/call-back-url-init/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | USER_POOL_ENV_VAR = "USER_POOL_ID" 11 | APP_CLIENT_ENV_VAR = "APP_CLIENT_ID" 12 | ALB_DNS_ENV_VAR = "ALB_DNS_NAME" 13 | SQS_QUEUE_ENV_VAR = "SQS_QUEUE_URL" 14 | 15 | 16 | class MalformedEvent(Exception): 17 | """Raised if a malformed event received""" 18 | 19 | 20 | class MissingEnvironmentVariable(Exception): 21 | """Raised if a required environment variable is missing""" 22 | 23 | 24 | def _silence_noisy_loggers(): 25 | """Silence chatty libraries for better logging""" 26 | for logger in ['boto3', 'botocore', 27 | 'botocore.vendored.requests.packages.urllib3']: 28 | logging.getLogger(logger).setLevel(logging.WARNING) 29 | 30 | 31 | def _check_missing_field(validation_dict, extraction_key): 32 | """Check if a field exists in a dictionary 33 | 34 | :param validation_dict: Dictionary 35 | :param extraction_key: String 36 | 37 | :raises: MalformedEvent 38 | """ 39 | extracted_value = validation_dict.get(extraction_key) 40 | 41 | if not extracted_value: 42 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 43 | raise MalformedEvent 44 | 45 | 46 | def _validate_field(validation_dict, extraction_key, expected_value): 47 | """Validate the passed in field 48 | 49 | :param validation_dict: Dictionary 50 | :param extraction_key: String 51 | :param expected_value: String 52 | 53 | :raises: ValueError 54 | """ 55 | extracted_value = validation_dict.get(extraction_key) 56 | 57 | _check_missing_field(validation_dict, extraction_key) 58 | 59 | if extracted_value != expected_value: 60 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 61 | raise ValueError 62 | 63 | 64 | def _configure_logger(): 65 | """Configure python logger""" 66 | level = logging.INFO 67 | verbose = os.environ.get("VERBOSE", "") 68 | if verbose.lower() == "true": 69 | print("Will set the logging output to DEBUG") 70 | level = logging.DEBUG 71 | 72 | if len(logging.getLogger().handlers) > 0: 73 | # The Lambda environment pre-configures a handler logging to stderr. 74 | # If a handler is already configured, `.basicConfig` does not execute. 75 | # Thus we set the level directly. 76 | logging.getLogger().setLevel(level) 77 | else: 78 | logging.basicConfig(level=level) 79 | 80 | 81 | def _get_message_body(event): 82 | """Extract message body from the event 83 | 84 | :param event: Dictionary 85 | 86 | :raises: MalformedEvent 87 | 88 | :rtype: Dictionary 89 | """ 90 | body = "" 91 | test_event = event.get("test_event", "") 92 | if test_event.lower() == "true": 93 | LOGGER.info("processing test event (and not from SQS)") 94 | LOGGER.debug("Test body: %s", event) 95 | return event 96 | else: 97 | LOGGER.info("Attempting to extract message body from SQS") 98 | 99 | _check_missing_field(event, "Records") 100 | records = event["Records"] 101 | 102 | first_record = records[0] 103 | 104 | try: 105 | body = first_record.get("body") 106 | except AttributeError: 107 | raise MalformedEvent("First record is not a proper dict") 108 | 109 | if not body: 110 | raise MalformedEvent("Missing 'body' in the record") 111 | 112 | try: 113 | return json.loads(body) 114 | except json.decoder.JSONDecodeError: 115 | raise MalformedEvent("'body' is not valid JSON") 116 | 117 | 118 | def _get_sqs_message_attributes(event): 119 | """Extract receiptHandle from message 120 | 121 | :param event: Dictionary 122 | 123 | :raises: MalformedEvent 124 | 125 | :rtype: Dictionary 126 | """ 127 | LOGGER.info("Attempting to extract receiptHandle from SQS") 128 | records = event.get("Records") 129 | if not records: 130 | LOGGER.warning("No receiptHandle found, probably not an SQS message") 131 | return 132 | try: 133 | first_record = records[0] 134 | except IndexError: 135 | raise MalformedEvent("Records seem to be empty") 136 | 137 | _check_missing_field(first_record, "receiptHandle") 138 | receipt_handle = first_record["receiptHandle"] 139 | 140 | _check_missing_field(first_record, "messageId") 141 | message_id = first_record["messageId"] 142 | 143 | return { 144 | "message_id": message_id, 145 | "receipt_handle": receipt_handle 146 | } 147 | 148 | 149 | def lambda_handler(event, context): 150 | """What executes when the program is run""" 151 | 152 | # configure python logger 153 | _configure_logger() 154 | # silence chatty libraries 155 | _silence_noisy_loggers() 156 | 157 | msg_attr = _get_sqs_message_attributes(event) 158 | 159 | if msg_attr: 160 | 161 | # Because messages remain in the queue 162 | LOGGER.info( 163 | f"Deleting message {msg_attr['message_id']} from sqs") 164 | sqs_client = boto3.client("sqs") 165 | queue_url = os.environ.get(SQS_QUEUE_ENV_VAR) 166 | if not queue_url: 167 | raise MissingEnvironmentVariable( 168 | f"{SQS_QUEUE_ENV_VAR} environment variable is required") 169 | 170 | deletion_resp = sqs_client.delete_message( 171 | QueueUrl=queue_url, 172 | ReceiptHandle=msg_attr["receipt_handle"]) 173 | 174 | sqs_client.close() 175 | 176 | resp_metadata = deletion_resp.get("ResponseMetadata") 177 | if not resp_metadata: 178 | raise Exception( 179 | "No response metadata from deletion call") 180 | status_code = resp_metadata.get("HTTPStatusCode") 181 | 182 | if status_code == 200: 183 | LOGGER.info(f"Successfully deleted message") 184 | else: 185 | raise Exception("Unable to delete message") 186 | 187 | client_details = _get_message_body(event) 188 | LOGGER.info("Extracted user pool details") 189 | 190 | user_pool_id = os.environ.get(USER_POOL_ENV_VAR) 191 | if not user_pool_id: 192 | raise MissingEnvironmentVariable( 193 | f"{USER_POOL_ENV_VAR} environment variable is required") 194 | 195 | app_client_id = os.environ.get(APP_CLIENT_ENV_VAR) 196 | if not app_client_id: 197 | raise MissingEnvironmentVariable( 198 | f"{APP_CLIENT_ENV_VAR} environment variable is required") 199 | 200 | alb_dns = os.environ.get(ALB_DNS_ENV_VAR) 201 | if not alb_dns: 202 | raise MissingEnvironmentVariable( 203 | f"{ALB_DNS_ENV_VAR} environment variable is required") 204 | 205 | _validate_field(client_details, "userPoolId", user_pool_id) 206 | _validate_field(client_details, "clientId", app_client_id) 207 | 208 | expected_callback_url = f"https://{alb_dns}/oauth2/idpresponse" 209 | lowered_callback_url = expected_callback_url.lower() 210 | 211 | _check_missing_field(client_details, "callbackURLs") 212 | callback_urls = client_details["callbackURLs"] 213 | if len(callback_urls) != 1: 214 | LOGGER.warning("Unexpected number of callback URLs") 215 | else: 216 | if callback_urls[0] != expected_callback_url: 217 | LOGGER.warning( 218 | "Looks like the callback URL is not " 219 | "associated with the correct load balancer. Please verify.") 220 | 221 | cog_client = boto3.client("cognito-idp") 222 | 223 | LOGGER.info("Updating the user pool client URL") 224 | resp = cog_client.update_user_pool_client( 225 | UserPoolId=user_pool_id, 226 | ClientId=app_client_id, 227 | ExplicitAuthFlows=client_details["explicitAuthFlows"], 228 | SupportedIdentityProviders=client_details["supportedIdentityProviders"], 229 | CallbackURLs=[lowered_callback_url], 230 | AllowedOAuthFlows=client_details["allowedOAuthFlows"], 231 | AllowedOAuthScopes=client_details["allowedOAuthScopes"], 232 | AllowedOAuthFlowsUserPoolClient=client_details["allowedOAuthFlowsUserPoolClient"], 233 | EnableTokenRevocation=client_details["enableTokenRevocation"], 234 | EnablePropagateAdditionalUserContextData=client_details["enablePropagateAdditionalUserContextData"], 235 | AuthSessionValidity=client_details["authSessionValidity"] 236 | ) 237 | _check_missing_field(resp, "ResponseMetadata") 238 | resp_metadata = resp["ResponseMetadata"] 239 | 240 | _check_missing_field(resp_metadata, "HTTPStatusCode") 241 | status_code = resp_metadata["HTTPStatusCode"] 242 | 243 | if status_code == 200: 244 | LOGGER.info("Successfully updated callback URL") 245 | else: 246 | raise Exception("Unable to update user pool client") 247 | 248 | cog_client.close() 249 | -------------------------------------------------------------------------------- /lambda/call-back-url-update/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import boto3 5 | 6 | 7 | LOGGER = logging.getLogger() 8 | 9 | USER_POOL_ENV_VAR = "USER_POOL_ID" 10 | APP_CLIENT_ENV_VAR = "APP_CLIENT_ID" 11 | ALB_DNS_ENV_VAR = "ALB_DNS_NAME" 12 | 13 | 14 | class MalformedEvent(Exception): 15 | """Raised if a malformed event received""" 16 | 17 | 18 | class MissingEnvironmentVariable(Exception): 19 | """Raised if a required environment variable is missing""" 20 | 21 | 22 | def _silence_noisy_loggers(): 23 | """Silence chatty libraries for better logging""" 24 | for logger in ['boto3', 'botocore', 25 | 'botocore.vendored.requests.packages.urllib3']: 26 | logging.getLogger(logger).setLevel(logging.WARNING) 27 | 28 | 29 | def _check_missing_field(validation_dict, extraction_key): 30 | """Check if a field exists in a dictionary 31 | 32 | :param validation_dict: Dictionary 33 | :param extraction_key: String 34 | 35 | :raises: MalformedEvent 36 | """ 37 | extracted_value = validation_dict.get(extraction_key) 38 | 39 | if not extracted_value: 40 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 41 | raise MalformedEvent 42 | 43 | 44 | def _validate_field(validation_dict, extraction_key, expected_value): 45 | """Validate the passed in field 46 | 47 | :param validation_dict: Dictionary 48 | :param extraction_key: String 49 | :param expected_value: String 50 | 51 | :raises: ValueError 52 | """ 53 | extracted_value = validation_dict.get(extraction_key) 54 | 55 | _check_missing_field(validation_dict, extraction_key) 56 | 57 | if extracted_value != expected_value: 58 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 59 | raise ValueError 60 | 61 | 62 | def _extract_valid_event(event): 63 | """Validate incoming event and extract necessary attributes 64 | 65 | :param event: Dictionary 66 | 67 | :raises: MalformedEvent 68 | :raises: ValueError 69 | 70 | :rtype: Dictionary 71 | """ 72 | 73 | _validate_field(event, "source", "aws.cognito-idp") 74 | 75 | _check_missing_field(event, "detail") 76 | event_detail = event["detail"] 77 | 78 | _validate_field( 79 | event_detail, 80 | "sourceIPAddress", 81 | "cloudformation.amazonaws.com" 82 | ) 83 | 84 | _validate_field( 85 | event_detail, 86 | "eventSource", 87 | "cognito-idp.amazonaws.com" 88 | ) 89 | 90 | _validate_field(event_detail, "eventName", "UpdateUserPoolClient") 91 | 92 | _check_missing_field(event_detail, "responseElements") 93 | _check_missing_field(event_detail["responseElements"], "userPoolClient") 94 | 95 | return event_detail["responseElements"]["userPoolClient"] 96 | 97 | 98 | def _configure_logger(): 99 | """Configure python logger""" 100 | level = logging.INFO 101 | verbose = os.environ.get("VERBOSE", "") 102 | if verbose.lower() == "true": 103 | print("Will set the logging output to DEBUG") 104 | level = logging.DEBUG 105 | 106 | if len(logging.getLogger().handlers) > 0: 107 | # The Lambda environment pre-configures a handler logging to stderr. 108 | # If a handler is already configured, `.basicConfig` does not execute. 109 | # Thus we set the level directly. 110 | logging.getLogger().setLevel(level) 111 | else: 112 | logging.basicConfig(level=level) 113 | 114 | 115 | def lambda_handler(event, context): 116 | """What executes when the program is run""" 117 | 118 | # configure python logger 119 | _configure_logger() 120 | # silence chatty libraries 121 | _silence_noisy_loggers() 122 | 123 | user_pool_id = os.environ.get(USER_POOL_ENV_VAR) 124 | if not user_pool_id: 125 | raise MissingEnvironmentVariable( 126 | f"{USER_POOL_ENV_VAR} environment variable is required") 127 | 128 | app_client_id = os.environ.get(APP_CLIENT_ENV_VAR) 129 | if not app_client_id: 130 | raise MissingEnvironmentVariable( 131 | f"{APP_CLIENT_ENV_VAR} environment variable is required") 132 | 133 | alb_dns = os.environ.get(ALB_DNS_ENV_VAR) 134 | if not alb_dns: 135 | raise MissingEnvironmentVariable( 136 | f"{ALB_DNS_ENV_VAR} environment variable is required") 137 | 138 | client_details = _extract_valid_event(event) 139 | LOGGER.info("Extracted user pool details") 140 | 141 | expected_callback_url = f"https://{alb_dns}/oauth2/idpresponse" 142 | lowered_callback_url = expected_callback_url.lower() 143 | 144 | _validate_field(client_details, "userPoolId", user_pool_id) 145 | _validate_field(client_details, "clientId", app_client_id) 146 | 147 | _check_missing_field(client_details, "callbackURLs") 148 | callback_urls = client_details["callbackURLs"] 149 | if len(callback_urls) != 1: 150 | LOGGER.warning("Unexpected number of callback URLs") 151 | else: 152 | if callback_urls[0] != expected_callback_url: 153 | LOGGER.warning( 154 | "Looks like the callback URL is not " 155 | "associated with the correct load balancer. Please verify.") 156 | 157 | cog_client = boto3.client("cognito-idp") 158 | 159 | LOGGER.info("Updating the user pool client URL") 160 | resp = cog_client.update_user_pool_client( 161 | UserPoolId=user_pool_id, 162 | ClientId=app_client_id, 163 | ExplicitAuthFlows=client_details["explicitAuthFlows"], 164 | SupportedIdentityProviders=client_details["supportedIdentityProviders"], 165 | CallbackURLs=[lowered_callback_url], 166 | AllowedOAuthFlows=client_details["allowedOAuthFlows"], 167 | AllowedOAuthScopes=client_details["allowedOAuthScopes"], 168 | AllowedOAuthFlowsUserPoolClient=client_details["allowedOAuthFlowsUserPoolClient"], 169 | EnableTokenRevocation=client_details["enableTokenRevocation"], 170 | EnablePropagateAdditionalUserContextData=client_details["enablePropagateAdditionalUserContextData"], 171 | AuthSessionValidity=client_details["authSessionValidity"] 172 | ) 173 | _check_missing_field(resp, "ResponseMetadata") 174 | resp_metadata = resp["ResponseMetadata"] 175 | 176 | _check_missing_field(resp_metadata, "HTTPStatusCode") 177 | status_code = resp_metadata["HTTPStatusCode"] 178 | 179 | if status_code == 200: 180 | LOGGER.info("Successfully updated callback URL") 181 | else: 182 | raise Exception("Unable to update user pool client") 183 | 184 | cog_client.close() 185 | -------------------------------------------------------------------------------- /lambda/pdf-processor/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | COPY requirements.txt requirements.txt 4 | 5 | RUN pip3 install -r requirements.txt 6 | 7 | COPY lambda_function.py lambda_function.py 8 | 9 | CMD [ "lambda_function.lambda_handler"] 10 | -------------------------------------------------------------------------------- /lambda/pdf-processor/lambda_function.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from pypdf import PdfReader 7 | from pypdf.errors import PdfReadError 8 | 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | SOURCE_BUCKET_ENV_VAR = "SOURCE_BUCKET_NAME" 13 | DEST_BUCKET_ENV_VAR = "DESTINATION_BUCKET_NAME" 14 | 15 | MALFORMED_EVENT_LOG_MSG = "Malformed event. Skipping this record." 16 | 17 | 18 | class MissingEnvironmentVariable(Exception): 19 | """Raised if a required environment variable is missing""" 20 | 21 | 22 | def _silence_noisy_loggers(): 23 | """Silence chatty libraries for better logging""" 24 | for logger in ['boto3', 'botocore', 25 | 'botocore.vendored.requests.packages.urllib3']: 26 | logging.getLogger(logger).setLevel(logging.WARNING) 27 | 28 | 29 | def _configure_logger(): 30 | """Configure python logger for lambda function""" 31 | default_log_args = { 32 | "level": logging.DEBUG if os.environ.get("VERBOSE", False) else logging.INFO, 33 | "format": "%(asctime)s [%(levelname)s] %(name)s - %(message)s", 34 | "datefmt": "%d-%b-%y %H:%M", 35 | "force": True, 36 | } 37 | logging.basicConfig(**default_log_args) 38 | 39 | 40 | def _check_missing_field(validation_dict, extraction_key): 41 | """Check if a field exists in a dictionary 42 | 43 | :param validation_dict: Dictionary 44 | :param extraction_key: String 45 | 46 | :raises: KeyError 47 | """ 48 | extracted_value = validation_dict.get(extraction_key) 49 | 50 | if not extracted_value: 51 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 52 | raise KeyError 53 | 54 | 55 | def _validate_field(validation_dict, extraction_key, expected_value): 56 | """Validate the passed in field 57 | 58 | :param validation_dict: Dictionary 59 | :param extraction_key: String 60 | :param expected_value: String 61 | 62 | :raises: ValueError 63 | """ 64 | extracted_value = validation_dict.get(extraction_key) 65 | _check_missing_field(validation_dict, extraction_key) 66 | 67 | if extracted_value != expected_value: 68 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 69 | raise ValueError 70 | 71 | 72 | def _record_validation(record, source_bucket_name): 73 | """Validate record 74 | 75 | :param record: Dictionary 76 | :param source_bucket_name: String 77 | 78 | :rtype: Boolean 79 | """ 80 | # validate eventSource 81 | _validate_field(record, "eventSource", "aws:s3") 82 | 83 | # validate eventSource 84 | _check_missing_field(record, "eventName") 85 | if not record["eventName"].startswith("ObjectCreated"): 86 | LOGGER.warning("Found a non ObjectCreated event, ignoring this record") 87 | return False 88 | 89 | # check for 's3' in response elements 90 | _check_missing_field(record, "s3") 91 | 92 | s3_data = record["s3"] 93 | # validate s3 data 94 | _check_missing_field(s3_data, "bucket") 95 | _validate_field(s3_data["bucket"], "name", source_bucket_name) 96 | 97 | # check for object 98 | _check_missing_field(s3_data, "object") 99 | # check for key 100 | _check_missing_field(s3_data["object"], "key") 101 | 102 | return True 103 | 104 | 105 | def _get_source_file_contents(client, bucket, filename): 106 | """Fetch the contents of the file 107 | 108 | :param client: boto3 Client Object (S3) 109 | :param bucket: String 110 | :param filename: String 111 | 112 | :rtype Bytes 113 | """ 114 | resp = client.get_object(Bucket=bucket, Key=filename) 115 | 116 | _check_missing_field(resp, "ResponseMetadata") 117 | 118 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 119 | 120 | _check_missing_field(resp, "Body") 121 | 122 | return resp["Body"].read() 123 | 124 | 125 | def lambda_handler(event, context): 126 | """What executes when the program is run""" 127 | 128 | # configure python logger for Lambda 129 | _configure_logger() 130 | # silence chatty libraries for better logging 131 | _silence_noisy_loggers() 132 | 133 | # check for SOURCE_BUCKET_NAME env var 134 | source_bucket_name = os.environ.get(SOURCE_BUCKET_ENV_VAR) 135 | if not source_bucket_name: 136 | raise MissingEnvironmentVariable(SOURCE_BUCKET_ENV_VAR) 137 | LOGGER.info(f"source bucket: {source_bucket_name}") 138 | 139 | # check for DESTINATION_BUCKET_NAME env var 140 | dest_bucket_name = os.environ.get(DEST_BUCKET_ENV_VAR) 141 | if not dest_bucket_name: 142 | raise MissingEnvironmentVariable(DEST_BUCKET_ENV_VAR) 143 | LOGGER.info(f"destination bucket: {dest_bucket_name}") 144 | 145 | # check for 'records' field in the event 146 | _check_missing_field(event, "Records") 147 | records = event["Records"] 148 | 149 | if not isinstance(records, list): 150 | raise Exception("'Records' is not a list") 151 | LOGGER.info("Extracted 'Records' from the event") 152 | 153 | for record in records: 154 | try: 155 | valid_record = _record_validation(record, source_bucket_name) 156 | except KeyError: 157 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 158 | continue 159 | except ValueError: 160 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 161 | continue 162 | 163 | if not valid_record: 164 | LOGGER.warning("record could not be validated. Skipping this one.") 165 | continue 166 | file_name = record["s3"]["object"]["key"] 167 | LOGGER.info( 168 | f"Valid record found. Will attempt to process the file: {file_name}") 169 | 170 | s3_client = boto3.client("s3") 171 | file_contents = _get_source_file_contents( 172 | s3_client, source_bucket_name, file_name) 173 | 174 | try: 175 | pdf = PdfReader(BytesIO(file_contents)) 176 | except PdfReadError as err: 177 | LOGGER.error(err) 178 | LOGGER.warning( 179 | f"{file_name} is invalid and/or corrupt, skipping.") 180 | continue 181 | 182 | LOGGER.info("Extracting text from pdf..") 183 | text_file_contents = "" 184 | for page in pdf.pages: 185 | text_file_contents = f"{text_file_contents}\n{page.extract_text()}" 186 | 187 | s3_resource = boto3.resource("s3") 188 | LOGGER.debug("Writing file to S3") 189 | s3_resource.Bucket(dest_bucket_name).put_object( 190 | Key=file_name.replace(".pdf", ".txt"), 191 | Body=text_file_contents.encode("utf-8")) 192 | LOGGER.info("Successfully converted pdf to txt, and uploaded to s3") 193 | 194 | LOGGER.debug("Closing s3 boto3 client") 195 | s3_client.close() 196 | -------------------------------------------------------------------------------- /lambda/pdf-processor/requirements.txt: -------------------------------------------------------------------------------- 1 | pypdf 2 | -------------------------------------------------------------------------------- /lib/aoss-update-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as lambda from "aws-cdk-lib/aws-lambda"; 4 | import * as s3 from "aws-cdk-lib/aws-s3"; 5 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 6 | import * as iam from 'aws-cdk-lib/aws-iam'; 7 | import * as sqs from 'aws-cdk-lib/aws-sqs'; 8 | import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; 9 | import path = require("path"); 10 | 11 | 12 | export interface OpenSearchUpdateStackProps extends cdk.StackProps { 13 | processedBucket: s3.Bucket; 14 | indexName: string; 15 | apiKeySecret: secretsmanager.Secret; 16 | triggerQueue: sqs.Queue; 17 | aossHost: string; 18 | lambdaRole: iam.Role; 19 | } 20 | 21 | export class OpenSearchUpdateStack extends cdk.Stack { 22 | 23 | constructor(scope: Construct, id: string, props: OpenSearchUpdateStackProps) { 24 | super(scope, id, props); 25 | 26 | // capturing architecture for docker container (arm or x86) 27 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 28 | 29 | // Docker assets for lambda function 30 | const dockerfile = path.join(__dirname, "../lambda/aoss-update/"); 31 | 32 | // create a Lambda function to update the vector store everytime a new document is added to the processed bucket 33 | const aossUpdateFn = new lambda.Function(this, "aossUpdate", { 34 | code: lambda.Code.fromAssetImage(dockerfile), 35 | handler: lambda.Handler.FROM_IMAGE, 36 | runtime: lambda.Runtime.FROM_IMAGE, 37 | timeout: cdk.Duration.minutes(3), 38 | role: props.lambdaRole, 39 | memorySize: 512, 40 | architecture: dockerPlatform == "arm" ? lambda.Architecture.ARM_64 : lambda.Architecture.X86_64, 41 | environment: { 42 | "API_KEY_SECRET_NAME": props.apiKeySecret.secretName, 43 | "AOSS_ID": props.aossHost, 44 | "AOSS_INDEX_NAME": props.indexName, 45 | "QUEUE_URL": props.triggerQueue.queueUrl, 46 | "AOSS_AWS_REGION": `${this.region}`, 47 | // S3FileLoader (LangChain) under the hood 48 | "NLTK_DATA": "/tmp" 49 | } 50 | }); 51 | // grant lambda function permissions to read processed bucket 52 | props.processedBucket.grantRead(aossUpdateFn); 53 | // grant lambda function permissions to ready the api key secret 54 | props.apiKeySecret.grantRead(aossUpdateFn); 55 | // create SQS event source 56 | const eventSource = new SqsEventSource(props.triggerQueue); 57 | // trigger Lambda function upon message in SQS queue 58 | aossUpdateFn.addEventSource(eventSource); 59 | 60 | } 61 | } -------------------------------------------------------------------------------- /lib/base-infra-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import * as s3 from "aws-cdk-lib/aws-s3"; 5 | import * as iam from "aws-cdk-lib/aws-iam"; 6 | import * as s3deploy from 'aws-cdk-lib/aws-s3-deployment'; 7 | import * as lambda from 'aws-cdk-lib/aws-lambda'; 8 | import * as s3notif from 'aws-cdk-lib/aws-s3-notifications'; 9 | import * as sqs from 'aws-cdk-lib/aws-sqs'; 10 | import * as events from 'aws-cdk-lib/aws-events'; 11 | import * as targets from 'aws-cdk-lib/aws-events-targets'; 12 | import * as cognito from 'aws-cdk-lib/aws-cognito'; 13 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 14 | import * as elbv2 from "aws-cdk-lib/aws-elasticloadbalancingv2"; 15 | import * as elbv2_actions from "aws-cdk-lib/aws-elasticloadbalancingv2-actions"; 16 | import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; 17 | 18 | import path = require("path"); 19 | 20 | export class BaseInfraStack extends cdk.Stack { 21 | readonly vpc: ec2.Vpc; 22 | readonly lambdaSG: ec2.SecurityGroup; 23 | readonly ecsTaskSecGroup: ec2.SecurityGroup; 24 | readonly knowledgeBaseBucket: s3.Bucket; 25 | readonly processedBucket: s3.Bucket; 26 | readonly aossQueue: sqs.Queue; 27 | readonly apiKeySecret: secretsmanager.Secret; 28 | readonly appTargetGroup: elbv2.ApplicationTargetGroup; 29 | readonly ec2SecGroup: ec2.SecurityGroup; 30 | readonly aossUpdateLambdaRole: iam.Role; 31 | readonly aossIndexName: string; 32 | readonly ecsTaskRole: iam.Role; 33 | 34 | constructor(scope: Construct, id: string, props?: cdk.StackProps) { 35 | super(scope, id, props); 36 | 37 | /* 38 | capturing region env var to know which region to deploy this infrastructure 39 | 40 | NOTE - the AWS profile that is used to deploy should have the same default region 41 | */ 42 | let validRegions: string[] = ['us-east-1', 'us-west-2']; 43 | const regionPrefix = process.env.CDK_DEFAULT_REGION || 'us-east-1'; 44 | console.log(`CDK_DEFAULT_REGION: ${regionPrefix}`); 45 | // throw error if unsupported CDK_DEFAULT_REGION specified 46 | if (!(validRegions.includes(regionPrefix))) { 47 | throw new Error('Unsupported CDK_DEFAULT_REGION specified') 48 | }; 49 | 50 | const indexName = process.env.AOSS_INDEX_NAME || 'rag-oai-index'; 51 | console.log(`AOSS_INDEX_NAME: ${indexName}`); 52 | this.aossIndexName = indexName; 53 | 54 | // create VPC to deploy the infrastructure in 55 | const vpc = new ec2.Vpc(this, "InfraNetwork", { 56 | ipAddresses: ec2.IpAddresses.cidr('10.80.0.0/20'), 57 | availabilityZones: [`${regionPrefix}a`, `${regionPrefix}b`, `${regionPrefix}c`], 58 | subnetConfiguration: [ 59 | { 60 | name: "public", 61 | subnetType: ec2.SubnetType.PUBLIC, 62 | }, 63 | { 64 | name: "private", 65 | subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, 66 | } 67 | ], 68 | }); 69 | this.vpc = vpc; 70 | 71 | // create bucket for knowledgeBase 72 | const docsBucket = new s3.Bucket(this, `knowledgeBase`, {}); 73 | this.knowledgeBaseBucket = docsBucket; 74 | // use s3 bucket deploy to upload documents from local repo to the knowledgebase bucket 75 | new s3deploy.BucketDeployment(this, 'knowledgeBaseBucketDeploy', { 76 | sources: [s3deploy.Source.asset(path.join(__dirname, "../knowledgebase"))], 77 | destinationBucket: docsBucket 78 | }); 79 | 80 | // create bucket for processed text (from PDF to txt) 81 | const processedTextBucket = new s3.Bucket(this, `processedText`, {}); 82 | this.processedBucket = processedTextBucket; 83 | 84 | // capturing architecture for docker container (arm or x86) 85 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 86 | 87 | // Docker assets for lambda function 88 | const dockerfile = path.join(__dirname, "../lambda/pdf-processor/"); 89 | // create a Lambda function to process knowledgebase pdf documents 90 | const lambdaFn = new lambda.Function(this, "pdfProcessorFn", { 91 | code: lambda.Code.fromAssetImage(dockerfile), 92 | handler: lambda.Handler.FROM_IMAGE, 93 | runtime: lambda.Runtime.FROM_IMAGE, 94 | timeout: cdk.Duration.minutes(15), 95 | memorySize: 512, 96 | architecture: dockerPlatform == "arm" ? lambda.Architecture.ARM_64 : lambda.Architecture.X86_64, 97 | environment: { 98 | "SOURCE_BUCKET_NAME": docsBucket.bucketName, 99 | "DESTINATION_BUCKET_NAME": processedTextBucket.bucketName 100 | } 101 | }); 102 | // grant lambda function permissions to read knowledgebase bucket 103 | docsBucket.grantRead(lambdaFn); 104 | // grant lambda function permissions to write to the processed text bucket 105 | processedTextBucket.grantWrite(lambdaFn); 106 | 107 | // create a new S3 notification that triggers the pdf processor lambda function 108 | const kbNotification = new s3notif.LambdaDestination(lambdaFn); 109 | // assign notification for the s3 event type 110 | docsBucket.addEventNotification(s3.EventType.OBJECT_CREATED, kbNotification); 111 | 112 | 113 | // Create security group for test ec2 instance (will be removed later) 114 | const ec2SecGroupName = "ec2-security-group"; 115 | const ec2SecurityGroup = new ec2.SecurityGroup(this, ec2SecGroupName, { 116 | securityGroupName: ec2SecGroupName, 117 | vpc: vpc, 118 | // for internet access 119 | allowAllOutbound: true 120 | }); 121 | this.ec2SecGroup = ec2SecurityGroup; 122 | 123 | // to store the API KEY for OpenAI embeddings 124 | const oaiSecret = 'openAiApiKey'; 125 | const openAiApiKey = new secretsmanager.Secret(this, oaiSecret, { 126 | secretName: oaiSecret 127 | }); 128 | this.apiKeySecret = openAiApiKey; 129 | 130 | // Queue for triggering opensearch update 131 | const aossUpdateQueue = new sqs.Queue(this, 'aossUpdateQueue', { 132 | queueName: "AOSS_Update_Queue", 133 | visibilityTimeout: cdk.Duration.minutes(5) 134 | }); 135 | this.aossQueue = aossUpdateQueue; 136 | 137 | // create a Lambda function to send message to SQS for vector store updates 138 | const aossTriggerFn = new lambda.Function(this, "aossTrigger", { 139 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/aoss-trigger")), 140 | runtime: lambda.Runtime.PYTHON_3_11, 141 | handler: "app.lambda_handler", 142 | timeout: cdk.Duration.minutes(2), 143 | environment: { 144 | "AOSS_UPDATE_QUEUE": aossUpdateQueue.queueUrl, 145 | "BUCKET_NAME": processedTextBucket.bucketName 146 | } 147 | }); 148 | // create a new S3 notification that triggers the opensearch trigger lambda function 149 | const processedBucketNotif = new s3notif.LambdaDestination(aossTriggerFn); 150 | // assign notification for the s3 event type 151 | processedTextBucket.addEventNotification(s3.EventType.OBJECT_CREATED, processedBucketNotif); 152 | // give permission to the function to be able to send messages to the queues 153 | aossUpdateQueue.grantSendMessages(aossTriggerFn); 154 | 155 | // lambda basic execution policy statement 156 | const lambdaBasicExecPolicy = new iam.PolicyStatement({ 157 | effect: iam.Effect.ALLOW, 158 | actions: [ 159 | "logs:CreateLogGroup", 160 | "logs:CreateLogStream", 161 | "logs:PutLogEvents" 162 | ], 163 | resources: ["*"], 164 | }); 165 | 166 | // AOSS API Access 167 | const aossAPIAccess = new iam.PolicyStatement({ 168 | effect: iam.Effect.ALLOW, 169 | actions: [ 170 | "aoss:APIAccessAll" 171 | ], 172 | resources: [`arn:aws:aoss:${this.region}:${this.account}:collection/*`], 173 | }); 174 | 175 | // role for aoss update lambda function 176 | const aossUpdateRole = new iam.Role(this, 'aossUpdateRole', { 177 | assumedBy: new iam.CompositePrincipal( 178 | new iam.ServicePrincipal('lambda.amazonaws.com'), 179 | ), 180 | }); 181 | aossUpdateRole.attachInlinePolicy( 182 | new iam.Policy(this, "basicExecutionLambda", { 183 | statements: [lambdaBasicExecPolicy] 184 | }) 185 | ); 186 | aossUpdateRole.attachInlinePolicy( 187 | new iam.Policy(this, "aossAPIAccess", { 188 | statements: [aossAPIAccess] 189 | }) 190 | ); 191 | this.aossUpdateLambdaRole = aossUpdateRole; 192 | 193 | // This IAM Role is used by tasks 194 | const ragTaskRole = new iam.Role(this, "RagTaskRole", { 195 | assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 196 | inlinePolicies: { 197 | aossAccessPolicy: new iam.PolicyDocument({ 198 | statements: [ 199 | new iam.PolicyStatement({ 200 | effect: iam.Effect.ALLOW, 201 | resources: [`arn:aws:aoss:${this.region}:${this.account}:collection/*`], 202 | actions: [ 203 | "aoss:APIAccessAll" 204 | ], 205 | }), 206 | ], 207 | }), 208 | bedrockPolicy: new iam.PolicyDocument({ 209 | statements: [ 210 | new iam.PolicyStatement({ 211 | effect: iam.Effect.ALLOW, 212 | resources: ["*"], 213 | actions: [ 214 | "bedrock:InvokeModel", 215 | ], 216 | }), 217 | ], 218 | }), 219 | }, 220 | }); 221 | this.ecsTaskRole = ragTaskRole; 222 | // grant permissions to ready the api key secret 223 | openAiApiKey.grantRead(ragTaskRole); 224 | 225 | 226 | // Security group for ECS tasks 227 | const ragAppSecGroup = new ec2.SecurityGroup(this, "ragAppSecGroup", { 228 | securityGroupName: "ecs-rag-sec-group", 229 | vpc: vpc, 230 | allowAllOutbound: true, 231 | }); 232 | // ragAppSecGroup.addIngressRule( 233 | // ec2.Peer.ipv4("0.0.0.0/0"), 234 | // ec2.Port.tcpRange(8500, 8600), 235 | // "Streamlit" 236 | // ); 237 | this.ecsTaskSecGroup = ragAppSecGroup; 238 | 239 | // Security group for ALB 240 | const albSecGroup = new ec2.SecurityGroup(this, "albSecGroup", { 241 | securityGroupName: "alb-sec-group", 242 | vpc: vpc, 243 | allowAllOutbound: true, 244 | }); 245 | 246 | // create load balancer 247 | const appLoadBalancer = new elbv2.ApplicationLoadBalancer(this, 'ragAppLb', { 248 | vpc: vpc, 249 | internetFacing: true, 250 | securityGroup: albSecGroup 251 | }); 252 | 253 | const certName = process.env.IAM_SELF_SIGNED_SERVER_CERT_NAME || ''; 254 | // throw error if IAM_SELF_SIGNED_SERVER_CERT_NAME is undefined 255 | if (certName === undefined || certName === '') { 256 | throw new Error('Please specify the "IAM_SELF_SIGNED_SERVER_CERT_NAME" env var') 257 | }; 258 | console.log(`self signed cert name: ${certName}`); 259 | 260 | const cognitoDomain = process.env.COGNITO_DOMAIN_NAME || 'rag-cog-aoss-dom'; 261 | console.log(`cognito domain name: ${cognitoDomain}`); 262 | 263 | // // create Target group for ECS service 264 | const ecsTargetGroup = new elbv2.ApplicationTargetGroup(this, 'default', { 265 | vpc: vpc, 266 | protocol: elbv2.ApplicationProtocol.HTTP, 267 | port: 8501 268 | }); 269 | this.appTargetGroup = ecsTargetGroup; 270 | 271 | // // Queue for triggering app client creation 272 | const appClientCreationQueue = new sqs.Queue(this, 'appClientCreateQueue', { 273 | queueName: "COG_APP_CLIENT_CREATE_QUEUE", 274 | visibilityTimeout: cdk.Duration.minutes(5) 275 | }); 276 | 277 | // // create a Lambda function to send message to SQS for vector store updates 278 | const appClientCreateTriggerFn = new lambda.Function(this, "appClientCreateTrigger", { 279 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/app-client-create-trigger")), 280 | runtime: lambda.Runtime.PYTHON_3_11, 281 | handler: "app.lambda_handler", 282 | timeout: cdk.Duration.minutes(2), 283 | environment: { 284 | "TRIGGER_QUEUE": appClientCreationQueue.queueUrl, 285 | } 286 | }); 287 | // give permission to the function to be able to send messages to the queues 288 | appClientCreationQueue.grantSendMessages(appClientCreateTriggerFn); 289 | 290 | // Trigger an event when there is a Cognito CreateUserPoolClient call recorded in CloudTrail 291 | const appClientCreateRule = new events.Rule(this, 'appClientCreateRule', { 292 | eventPattern: { 293 | source: ["aws.cognito-idp"], 294 | detail: { 295 | eventSource: ["cognito-idp.amazonaws.com"], 296 | eventName: ["CreateUserPoolClient"], 297 | sourceIPAddress: ["cloudformation.amazonaws.com"] 298 | } 299 | }, 300 | }); 301 | appClientCreateRule.node.addDependency(appClientCreationQueue); 302 | // Invoke the callBack update fn upon a matching event 303 | appClientCreateRule.addTarget(new targets.LambdaFunction(appClientCreateTriggerFn)); 304 | 305 | // create cognito user pool 306 | const userPool = new cognito.UserPool(this, "UserPool", { 307 | removalPolicy: cdk.RemovalPolicy.DESTROY, 308 | selfSignUpEnabled: true, 309 | signInAliases: { email: true}, 310 | autoVerify: { email: true } 311 | }); 312 | userPool.node.addDependency(appClientCreateRule); 313 | 314 | // create cognito user pool domain 315 | const userPoolDomain = new cognito.UserPoolDomain(this, 'upDomain', { 316 | userPool, 317 | cognitoDomain: { 318 | domainPrefix: cognitoDomain 319 | } 320 | }); 321 | 322 | // create and add Application Integration for the User Pool 323 | const client = userPool.addClient("WebClient", { 324 | userPoolClientName: "MyAppWebClient", 325 | idTokenValidity: cdk.Duration.days(1), 326 | accessTokenValidity: cdk.Duration.days(1), 327 | generateSecret: true, 328 | authFlows: { 329 | adminUserPassword: true, 330 | userPassword: true, 331 | userSrp: true 332 | }, 333 | oAuth: { 334 | flows: {authorizationCodeGrant: true}, 335 | scopes: [cognito.OAuthScope.OPENID], 336 | callbackUrls: [ `https://${appLoadBalancer.loadBalancerDnsName}/oauth2/idpresponse` ] 337 | }, 338 | supportedIdentityProviders: [cognito.UserPoolClientIdentityProvider.COGNITO] 339 | }); 340 | client.node.addDependency(appClientCreateRule); 341 | 342 | // add https listener to the load balancer 343 | const httpsListener = appLoadBalancer.addListener("httpsListener", { 344 | port: 443, 345 | open: true, 346 | certificates: [ 347 | { 348 | certificateArn: `arn:aws:iam::${this.account}:server-certificate/${certName}` 349 | }, 350 | ], 351 | defaultAction: new elbv2_actions.AuthenticateCognitoAction({ 352 | userPool: userPool, 353 | userPoolClient: client, 354 | userPoolDomain: userPoolDomain, 355 | next: elbv2.ListenerAction.forward([ecsTargetGroup]) 356 | }) 357 | }); 358 | /* 359 | 360 | create lambda function because ALB dns name is not lowercase, 361 | and cognito does not function as intended due to that 362 | 363 | Reference - https://github.com/aws/aws-cdk/issues/11171 364 | 365 | */ 366 | const callBackInitFn = new lambda.Function(this, "callBackInit", { 367 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/call-back-url-init")), 368 | runtime: lambda.Runtime.PYTHON_3_11, 369 | timeout: cdk.Duration.minutes(2), 370 | handler: "app.lambda_handler", 371 | environment:{ 372 | "USER_POOL_ID": userPool.userPoolId, 373 | "APP_CLIENT_ID": client.userPoolClientId, 374 | "ALB_DNS_NAME": appLoadBalancer.loadBalancerDnsName, 375 | "SQS_QUEUE_URL": appClientCreationQueue.queueUrl, 376 | }, 377 | }); 378 | callBackInitFn.role?.addManagedPolicy( 379 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonCognitoPowerUser") 380 | ); 381 | // create SQS event source 382 | const appClientCreateSqsEventSource = new SqsEventSource(appClientCreationQueue); 383 | // trigger Lambda function upon message in SQS queue 384 | callBackInitFn.addEventSource(appClientCreateSqsEventSource); 385 | 386 | const callBackUpdateFn = new lambda.Function(this, "callBackUpdate", { 387 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/call-back-url-update")), 388 | runtime: lambda.Runtime.PYTHON_3_11, 389 | timeout: cdk.Duration.minutes(2), 390 | handler: "app.lambda_handler", 391 | environment:{ 392 | "USER_POOL_ID": userPool.userPoolId, 393 | "APP_CLIENT_ID": client.userPoolClientId, 394 | "ALB_DNS_NAME": appLoadBalancer.loadBalancerDnsName 395 | }, 396 | }); 397 | callBackUpdateFn.role?.addManagedPolicy( 398 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonCognitoPowerUser") 399 | ); 400 | 401 | // Trigger an event when there is a Cognito CreateUserPoolClient call recorded in CloudTrail 402 | const appClientUpdateRule = new events.Rule(this, 'appClientUpdateRule', { 403 | eventPattern: { 404 | source: ["aws.cognito-idp"], 405 | detail: { 406 | eventSource: ["cognito-idp.amazonaws.com"], 407 | eventName: ["UpdateUserPoolClient"], 408 | sourceIPAddress: ["cloudformation.amazonaws.com"] 409 | } 410 | }, 411 | }); 412 | // Invoke the callBack update fn upon a matching event 413 | appClientUpdateRule.addTarget(new targets.LambdaFunction(callBackUpdateFn)); 414 | 415 | } 416 | } 417 | -------------------------------------------------------------------------------- /lib/opensearch-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import * as aoss from 'aws-cdk-lib/aws-opensearchserverless'; 5 | import * as iam from 'aws-cdk-lib/aws-iam'; 6 | 7 | 8 | export interface OpenSearchStackProps extends cdk.StackProps { 9 | testComputeHostRole: iam.Role; 10 | lambdaRole: iam.Role; 11 | ecsTaskRole: iam.Role; 12 | } 13 | 14 | export class OpenSearchStack extends cdk.Stack { 15 | readonly jumpHostSG: ec2.SecurityGroup; 16 | readonly collectionName: string; 17 | readonly serverlessCollection: aoss.CfnCollection; 18 | 19 | constructor(scope: Construct, id: string, props: OpenSearchStackProps) { 20 | super(scope, id, props); 21 | 22 | this.collectionName = process.env.OPENSEARCH_COLLECTION_NAME || 'rag-collection'; 23 | console.log(`Opensearch serverless collection name: ${this.collectionName}`); 24 | 25 | const networkSecurityPolicy = new aoss.CfnSecurityPolicy(this, 'aossNetworkSecPolicy', { 26 | policy: JSON.stringify([{ 27 | "Rules": [ 28 | { 29 | "Resource": [ 30 | `collection/${this.collectionName}` 31 | ], 32 | "ResourceType": "dashboard" 33 | }, 34 | { 35 | "Resource": [ 36 | `collection/${this.collectionName}` 37 | ], 38 | "ResourceType": "collection" 39 | } 40 | ], 41 | "AllowFromPublic": true 42 | }]), 43 | name: `${this.collectionName}-sec-policy`, 44 | type: "network" 45 | }); 46 | 47 | const encryptionSecPolicy = new aoss.CfnSecurityPolicy(this, 'aossEncryptionSecPolicy', { 48 | name: `${this.collectionName}-enc-sec-pol`, 49 | type: "encryption", 50 | policy: JSON.stringify({ 51 | "Rules": [ 52 | { 53 | "Resource": [ 54 | `collection/${this.collectionName}` 55 | ], 56 | "ResourceType": "collection" 57 | } 58 | ], 59 | "AWSOwnedKey": true 60 | }), 61 | }); 62 | 63 | const aossCollecton = new aoss.CfnCollection(this, 'serverlessCollectionRag', { 64 | name: this.collectionName, 65 | description: "Collection to power RAG searches", 66 | type: "VECTORSEARCH" 67 | }); 68 | this.serverlessCollection = aossCollecton; 69 | aossCollecton.addDependency(networkSecurityPolicy); 70 | aossCollecton.addDependency(encryptionSecPolicy); 71 | 72 | const dataAccessPolicy = new aoss.CfnAccessPolicy(this, 'dataAccessPolicy', { 73 | name: `${this.collectionName}-dap`, 74 | description: `Data access policy for: ${this.collectionName}`, 75 | type: "data", 76 | policy: JSON.stringify([ 77 | { 78 | "Rules": [ 79 | { 80 | "Resource": [ 81 | `collection/${this.collectionName}` 82 | ], 83 | "Permission": [ 84 | "aoss:CreateCollectionItems", 85 | "aoss:DeleteCollectionItems", 86 | "aoss:UpdateCollectionItems", 87 | "aoss:DescribeCollectionItems" 88 | ], 89 | "ResourceType": "collection" 90 | }, 91 | { 92 | "Resource": [ 93 | `index/${this.collectionName}/*` 94 | ], 95 | "Permission": [ 96 | "aoss:CreateIndex", 97 | "aoss:DeleteIndex", 98 | "aoss:UpdateIndex", 99 | "aoss:DescribeIndex", 100 | "aoss:ReadDocument", 101 | "aoss:WriteDocument" 102 | ], 103 | "ResourceType": "index" 104 | } 105 | ], 106 | "Principal": [ 107 | props.testComputeHostRole.roleArn, 108 | `arn:aws:iam::${this.account}:role/Admin`, 109 | props.lambdaRole.roleArn, 110 | props.ecsTaskRole.roleArn 111 | ], 112 | "Description": "data-access-rule" 113 | } 114 | ]), 115 | }); 116 | 117 | 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /lib/rag-app-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from "aws-cdk-lib/aws-ec2"; 4 | import * as ecs from "aws-cdk-lib/aws-ecs"; 5 | import * as iam from "aws-cdk-lib/aws-iam"; 6 | import * as logs from "aws-cdk-lib/aws-logs"; 7 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 8 | import * as ecr_assets from "aws-cdk-lib/aws-ecr-assets"; 9 | import * as elbv2 from "aws-cdk-lib/aws-elasticloadbalancingv2"; 10 | import path = require("path"); 11 | 12 | export interface RagAppStackProps extends cdk.StackProps { 13 | vpc: ec2.Vpc; 14 | indexName: string; 15 | apiKeySecret: secretsmanager.Secret; 16 | aossHost: string; 17 | taskSecGroup: ec2.SecurityGroup; 18 | elbTargetGroup: elbv2.ApplicationTargetGroup; 19 | taskRole: iam.Role; 20 | } 21 | 22 | export class RagAppStack extends cdk.Stack { 23 | 24 | constructor(scope: Construct, id: string, props: RagAppStackProps) { 25 | super(scope, id, props); 26 | 27 | // This is the ECS cluster that we use for running tasks at. 28 | const cluster = new ecs.Cluster(this, "ecsClusterRAG", { 29 | vpc: props.vpc, 30 | containerInsights: true, 31 | executeCommandConfiguration: { 32 | logging: ecs.ExecuteCommandLogging.DEFAULT, 33 | }, 34 | }); 35 | 36 | // This IAM Role is used by tasks 37 | const taskRole = new iam.Role(this, "TaskRole", { 38 | assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 39 | inlinePolicies: { 40 | taskRolePolicy: new iam.PolicyDocument({ 41 | statements: [ 42 | new iam.PolicyStatement({ 43 | effect: iam.Effect.ALLOW, 44 | resources: ["*"], 45 | actions: [ 46 | "cloudwatch:PutMetricData", 47 | "logs:CreateLogGroup", 48 | "logs:CreateLogStream", 49 | "logs:PutLogEvents", 50 | "logs:DescribeLogStreams", 51 | ], 52 | }), 53 | ], 54 | }), 55 | }, 56 | }); 57 | 58 | // // This IAM Role is used by tasks 59 | // const ragTaskRole = new iam.Role(this, "RagTaskRole", { 60 | // assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 61 | // inlinePolicies: { 62 | // aossAccessPolicy: new iam.PolicyDocument({ 63 | // statements: [ 64 | // new iam.PolicyStatement({ 65 | // effect: iam.Effect.ALLOW, 66 | // resources: [`arn:aws:aoss:${this.region}:${this.account}:collection/*`], 67 | // actions: [ 68 | // "aoss:APIAccessAll" 69 | // ], 70 | // }), 71 | // ], 72 | // }), 73 | // bedrockPolicy: new iam.PolicyDocument({ 74 | // statements: [ 75 | // new iam.PolicyStatement({ 76 | // effect: iam.Effect.ALLOW, 77 | // resources: ["*"], 78 | // actions: [ 79 | // "bedrock:InvokeModel", 80 | // ], 81 | // }), 82 | // ], 83 | // }), 84 | // }, 85 | // }); 86 | // // grant permissions to ready the api key secret 87 | // props.apiKeySecret.grantRead(ragTaskRole); 88 | 89 | // This IAM role is used to execute the tasks. It is used by task definition. 90 | const taskExecRole = new iam.Role(this, "TaskExecRole", { 91 | assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 92 | managedPolicies: [ 93 | iam.ManagedPolicy.fromAwsManagedPolicyName( 94 | "service-role/AmazonECSTaskExecutionRolePolicy" 95 | ), 96 | ], 97 | }); 98 | 99 | // We create Log Group in CloudWatch to follow task logs 100 | const taskLogGroup = new logs.LogGroup(this, "TaskLogGroup", { 101 | logGroupName: "/ragapp/", 102 | removalPolicy: cdk.RemovalPolicy.DESTROY, 103 | retention: logs.RetentionDays.THREE_DAYS, 104 | }); 105 | 106 | // We create a log driver for ecs 107 | const ragTaskLogDriver = new ecs.AwsLogDriver({ 108 | streamPrefix: "rag-app", 109 | logGroup: taskLogGroup, 110 | }); 111 | 112 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 113 | 114 | // We create the task definition. Task definition is used to create tasks by ECS. 115 | const ragTaskDef = new ecs.FargateTaskDefinition(this, "RagTaskDef", { 116 | family: "rag-app", 117 | memoryLimitMiB: 512, 118 | cpu: 256, 119 | taskRole: props.taskRole, 120 | executionRole: taskExecRole, 121 | runtimePlatform: { 122 | operatingSystemFamily: ecs.OperatingSystemFamily.LINUX, 123 | cpuArchitecture: dockerPlatform == "arm" ? ecs.CpuArchitecture.ARM64 : ecs.CpuArchitecture.X86_64 124 | } 125 | }); 126 | 127 | // We create a container image to be run by the tasks. 128 | const ragContainerImage = new ecs.AssetImage( path.join(__dirname, '../rag-app'), { 129 | platform: dockerPlatform == "arm" ? ecr_assets.Platform.LINUX_ARM64 : ecr_assets.Platform.LINUX_AMD64 130 | }); 131 | const containerName = "ragAppOpenSearch"; 132 | // We add this container image to our task definition that we created earlier. 133 | const ragContainer = ragTaskDef.addContainer("rag-container", { 134 | containerName: containerName, 135 | image: ragContainerImage, 136 | logging: ragTaskLogDriver, 137 | environment: { 138 | "AWS_REGION": `${this.region}`, 139 | "AOSS_INDEX_NAME": props.indexName, 140 | "AOSS_ID": props.aossHost, 141 | "API_KEY_SECRET_NAME": props.apiKeySecret.secretName, 142 | }, 143 | portMappings: [ 144 | { 145 | containerPort: 8501, 146 | hostPort: 8501, 147 | protocol: ecs.Protocol.TCP 148 | }, 149 | ] 150 | }); 151 | 152 | // define ECS fargate service to run the RAG app 153 | const ragAppService = new ecs.FargateService(this, "rag-app-service", { 154 | cluster, 155 | taskDefinition: ragTaskDef, 156 | desiredCount: 1,//vpc.availabilityZones.length, 157 | securityGroups: [props.taskSecGroup], 158 | minHealthyPercent: 0, 159 | }); 160 | // add fargate service as a target to the target group 161 | props.elbTargetGroup.addTarget(ragAppService); 162 | 163 | } 164 | } -------------------------------------------------------------------------------- /lib/test-compute-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import * as iam from 'aws-cdk-lib/aws-iam'; 5 | 6 | 7 | export interface TestComputeStackProps extends cdk.StackProps { 8 | vpc: ec2.Vpc; 9 | ec2SG: ec2.SecurityGroup; 10 | } 11 | 12 | 13 | export class TestComputeStack extends cdk.Stack { 14 | readonly jumpHostSG: ec2.SecurityGroup; 15 | readonly hostRole: iam.Role; 16 | 17 | constructor(scope: Construct, id: string, props: TestComputeStackProps) { 18 | super(scope, id, props); 19 | 20 | const userData = ec2.UserData.forLinux(); 21 | userData.addCommands( 22 | 'apt-get update -y', 23 | 'apt-get install -y git awscli ec2-instance-connect', 24 | 'apt install -y fish' 25 | ); 26 | 27 | const machineImage = ec2.MachineImage.fromSsmParameter( 28 | '/aws/service/canonical/ubuntu/server/focal/stable/current/amd64/hvm/ebs-gp2/ami-id', 29 | ); 30 | 31 | const jumpHostRole = new iam.Role(this, 'jumpHostRole', { 32 | assumedBy: new iam.CompositePrincipal( 33 | new iam.ServicePrincipal('ec2.amazonaws.com'), 34 | ), 35 | managedPolicies: [ 36 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSSMManagedInstanceCore'), 37 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonS3FullAccess'), 38 | iam.ManagedPolicy.fromAwsManagedPolicyName('AdministratorAccess'), 39 | ] 40 | }); 41 | this.hostRole = jumpHostRole; 42 | 43 | const instanceProf = new iam.CfnInstanceProfile(this, 'jumpHostInstanceProf', { 44 | roles: [jumpHostRole.roleName] 45 | }); 46 | 47 | // // to test locally (streamlit) 48 | // jumpHostSecurityGroup.addIngressRule( 49 | // ec2.Peer.anyIpv4(), 50 | // ec2.Port.tcp(8501), 51 | // 'Streamlit default port' 52 | // ); 53 | // this.jumpHostSG = jumpHostSecurityGroup; 54 | 55 | const ec2JumpHost = new ec2.Instance(this, 'ec2JumpHost', { 56 | vpc: props.vpc, 57 | instanceType: ec2.InstanceType.of(ec2.InstanceClass.T2, ec2.InstanceSize.MICRO), 58 | machineImage: machineImage, 59 | securityGroup: props.ec2SG, 60 | userData: userData, 61 | role: jumpHostRole, 62 | requireImdsv2: true, 63 | // for public access testing 64 | vpcSubnets: {subnetType: ec2.SubnetType.PUBLIC}, 65 | // for public access testing 66 | associatePublicIpAddress: true, 67 | blockDevices: [ 68 | { 69 | deviceName: '/dev/sda1', 70 | mappingEnabled: true, 71 | volume: ec2.BlockDeviceVolume.ebs(128, { 72 | deleteOnTermination: true, 73 | encrypted: true, 74 | volumeType: ec2.EbsDeviceVolumeType.GP2 75 | }) 76 | } 77 | ] 78 | }); 79 | 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "rag-with-amazon-bedrock-and-opensearch", 3 | "version": "0.1.0", 4 | "bin": { 5 | "rag-with-amazon-bedrock-and-opensearch": "bin/rag-with-amazon-bedrock-and-opensearch.js" 6 | }, 7 | "scripts": { 8 | "build": "tsc", 9 | "watch": "tsc -w", 10 | "test": "jest", 11 | "cdk": "cdk" 12 | }, 13 | "devDependencies": { 14 | "@types/jest": "^29.5.11", 15 | "@types/node": "20.10.8", 16 | "jest": "^29.7.0", 17 | "ts-jest": "^29.1.1", 18 | "aws-cdk": "2.121.1", 19 | "ts-node": "^10.9.2", 20 | "typescript": "~5.3.3" 21 | }, 22 | "dependencies": { 23 | "aws-cdk-lib": "2.121.1", 24 | "constructs": "^10.0.0", 25 | "source-map-support": "^0.5.21" 26 | } 27 | } -------------------------------------------------------------------------------- /rag-app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.2 2 | 3 | WORKDIR python-docker 4 | 5 | COPY requirements.txt requirements.txt 6 | 7 | RUN pip3 install -r requirements.txt && mkdir /root/.streamlit 8 | 9 | COPY . . 10 | 11 | CMD [ "./run_app.sh" ] 12 | -------------------------------------------------------------------------------- /rag-app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/rag-app/__init__.py -------------------------------------------------------------------------------- /rag-app/aoss_chat_bedrock.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import boto3 6 | from langchain.chains import ConversationalRetrievalChain 7 | from langchain.embeddings.openai import OpenAIEmbeddings 8 | from langchain.llms.bedrock import Bedrock 9 | from langchain.prompts import PromptTemplate 10 | from langchain_community.vectorstores import OpenSearchVectorSearch 11 | from opensearchpy import RequestsHttpConnection, AWSV4SignerAuth 12 | 13 | import helper_functions as hfn 14 | 15 | 16 | class MissingEnvironmentVariable(Exception): 17 | """Raised if a required environment variable is missing""" 18 | 19 | 20 | class bcolors: 21 | HEADER = '\033[95m' 22 | OKBLUE = '\033[94m' 23 | OKCYAN = '\033[96m' 24 | OKGREEN = '\033[92m' 25 | WARNING = '\033[93m' 26 | FAIL = '\033[91m' 27 | ENDC = '\033[0m' 28 | BOLD = '\033[1m' 29 | UNDERLINE = '\033[4m' 30 | 31 | 32 | MAX_HISTORY_LENGTH = 5 33 | 34 | AOSS_INDEX_NAME_ENV_VAR = "AOSS_INDEX_NAME" 35 | AOSS_ID_ENV_VAR = "AOSS_ID" 36 | # AOSS_AWS_REGION_ENV_VAR = "AOSS_AWS_REGION" 37 | AOSS_SVC_NAME = "aoss" 38 | 39 | DEFAULT_TIMEOUT_AOSS = 100 40 | # DEFAULT_AOSS_ENGINE = "faiss" # may not be needed 41 | 42 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 43 | 44 | DEFAULT_LOG_LEVEL = logging.INFO 45 | LOGGER = logging.getLogger(__name__) 46 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 47 | "[%(name)s]:[%(threadName)s] " \ 48 | "%(message)s" 49 | 50 | 51 | def build_chain(host, index_name): 52 | """Build conversational retrieval chain 53 | 54 | :param host: String 55 | :param index_name: String 56 | 57 | :rtype: ConversationalRetrievalChain 58 | """ 59 | region = os.environ["AWS_REGION"] 60 | 61 | llm = Bedrock( 62 | # credentials_profile_name=credentials_profile_name, 63 | region_name = region, 64 | model_kwargs={"max_tokens_to_sample":300,"temperature":1,"top_k":250,"top_p":0.999,"anthropic_version":"bedrock-2023-05-31"}, 65 | model_id=os.environ.get("FOUNDATION_MODEL_ID", "anthropic.claude-instant-v1") 66 | ) 67 | 68 | embeddings = OpenAIEmbeddings() 69 | docsearch = OpenSearchVectorSearch( 70 | f"https://{host}", 71 | index_name, 72 | embeddings, 73 | http_auth=AWSV4SignerAuth(boto3.Session().get_credentials(), region, AOSS_SVC_NAME), 74 | timeout=DEFAULT_TIMEOUT_AOSS, 75 | use_ssl=True, 76 | verify_certs=True, 77 | connection_class = RequestsHttpConnection, 78 | ) 79 | retriever = docsearch.as_retriever(search_kwargs={"k": 3}) 80 | # the "k" needs to be revisited 81 | 82 | prompt_template = """Human: This is a friendly conversation between a human and an AI. 83 | The AI is talkative and provides specific details from its context but limits it to 240 tokens. 84 | If the AI does not know the answer to a question, it truthfully says it 85 | does not know. 86 | 87 | Assistant: OK, got it, I'll be a talkative truthful AI assistant. 88 | 89 | Human: Here are a few documents in tags: 90 | 91 | {context} 92 | 93 | Based on the above documents, provide a detailed answer for, {question} 94 | Answer "don't know" if not present in the document. 95 | 96 | Assistant: 97 | """ 98 | PROMPT = PromptTemplate( 99 | template=prompt_template, input_variables=["context", "question"] 100 | ) 101 | 102 | condense_qa_template = """{chat_history} 103 | Human: 104 | Given the previous conversation and a follow up question below, rephrase the follow up question 105 | to be a standalone question. 106 | 107 | Follow Up Question: {question} 108 | Standalone Question: 109 | 110 | Assistant:""" 111 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) 112 | 113 | return ConversationalRetrievalChain.from_llm( 114 | llm=llm, 115 | retriever=retriever, 116 | condense_question_prompt=standalone_question_prompt, 117 | return_source_documents=True, 118 | combine_docs_chain_kwargs={"prompt": PROMPT}, 119 | verbose=True) 120 | 121 | 122 | def run_chain(chain, prompt: str, history=[]): 123 | return chain({"question": prompt, "chat_history": history}) 124 | 125 | 126 | if __name__ == "__main__": 127 | 128 | # logging configuration 129 | log_level = DEFAULT_LOG_LEVEL 130 | if os.environ.get("VERBOSE", "").lower() == "true": 131 | log_level = logging.DEBUG 132 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 133 | 134 | # open ai api key fetch 135 | openai_secret = os.environ.get(API_KEY_SECRET_ENV_VAR) 136 | if not openai_secret: 137 | raise MissingEnvironmentVariable(f"{API_KEY_SECRET_ENV_VAR} environment variable is required") 138 | os.environ["OPENAI_API_KEY"] = hfn.get_secret_from_name(openai_secret, kv=False) 139 | 140 | # serverless collection ID 141 | aoss_id = os.environ.get(AOSS_ID_ENV_VAR) 142 | if not aoss_id: 143 | raise MissingEnvironmentVariable( 144 | f"{AOSS_ID_ENV_VAR} environment variable is required") 145 | 146 | # opensearch index for RAG 147 | index_name = os.environ.get(AOSS_INDEX_NAME_ENV_VAR) 148 | if not index_name: 149 | raise MissingEnvironmentVariable( 150 | f"{AOSS_INDEX_NAME_ENV_VAR} environment variable is required") 151 | 152 | LOGGER.info("starting conversational retrieval chain now..") 153 | 154 | # langchain stuff 155 | chat_history = [] 156 | qa = build_chain( 157 | f"{aoss_id}.{os.environ['AWS_REGION']}.{AOSS_SVC_NAME}.amazonaws.com:443", 158 | index_name 159 | ) 160 | 161 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) 162 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 163 | print(">", end=" ", flush=True) 164 | 165 | for query in sys.stdin: 166 | if (query.strip().lower().startswith("new search:")): 167 | query = query.strip().lower().replace("new search:","") 168 | chat_history = [] 169 | elif (len(chat_history) == MAX_HISTORY_LENGTH): 170 | chat_history.pop(0) 171 | 172 | result = run_chain(qa, query, chat_history) 173 | 174 | chat_history.append((query, result["answer"])) 175 | 176 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) 177 | if 'source_documents' in result: 178 | print(bcolors.OKGREEN + 'Sources:') 179 | for d in result['source_documents']: 180 | print(d.metadata['source']) 181 | print(bcolors.ENDC) 182 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 183 | print(">", end=" ", flush=True) 184 | 185 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) 186 | -------------------------------------------------------------------------------- /rag-app/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import uuid 5 | 6 | import streamlit as st 7 | 8 | import aoss_chat_bedrock as bedrock_claude 9 | 10 | 11 | # TODO: clean up the way this app is written 12 | USER_ICON = "images/user-icon.png" 13 | AI_ICON = "images/ai-icon.png" 14 | MAX_HISTORY_LENGTH = 5 15 | 16 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 17 | 18 | AOSS_INDEX_NAME_ENV_VAR = "AOSS_INDEX_NAME" 19 | AOSS_ID_ENV_VAR = "AOSS_ID" 20 | AOSS_AWS_REGION_ENV_VAR = "AOSS_AWS_REGION" 21 | AOSS_SVC_NAME = "aoss" 22 | 23 | DEFAULT_LOG_LEVEL = logging.INFO 24 | LOGGER = logging.getLogger(__name__) 25 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 26 | "[%(name)s]:[%(threadName)s] " \ 27 | "%(message)s" 28 | 29 | 30 | class MissingEnvironmentVariable(Exception): 31 | """Raised if a required environment variable is missing""" 32 | 33 | 34 | # logging configuration 35 | log_level = DEFAULT_LOG_LEVEL 36 | if os.environ.get("VERBOSE", "").lower() == "true": 37 | log_level = logging.DEBUG 38 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 39 | 40 | os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] 41 | 42 | # serverless collection ID 43 | aoss_id = os.environ.get(AOSS_ID_ENV_VAR) 44 | if not aoss_id: 45 | raise MissingEnvironmentVariable( 46 | f"{AOSS_ID_ENV_VAR} environment variable is required") 47 | 48 | # opensearch index for RAG 49 | index_name = os.environ.get(AOSS_INDEX_NAME_ENV_VAR) 50 | if not index_name: 51 | raise MissingEnvironmentVariable( 52 | f"{AOSS_INDEX_NAME_ENV_VAR} environment variable is required") 53 | 54 | 55 | #function to read a properties file and create environment variables 56 | def read_properties_file(filename): 57 | import os 58 | import re 59 | with open(filename, 'r') as f: 60 | for line in f: 61 | m = re.match(r'^\s*(\w+)\s*=\s*(.*)\s*$', line) 62 | if m: 63 | os.environ[m.group(1)] = m.group(2) 64 | 65 | 66 | # Check if the user ID is already stored in the session state 67 | if 'user_id' in st.session_state: 68 | user_id = st.session_state['user_id'] 69 | 70 | # If the user ID is not yet stored in the session state, generate a random UUID 71 | else: 72 | user_id = str(uuid.uuid4()) 73 | st.session_state['user_id'] = user_id 74 | 75 | 76 | if 'llm_chain' not in st.session_state: 77 | if (len(sys.argv) > 1): 78 | if (sys.argv[1] == 'bedrock_claude'): 79 | st.session_state['llm_app'] = bedrock_claude 80 | st.session_state['llm_chain'] = bedrock_claude.build_chain( 81 | f"{aoss_id}.{os.environ['AWS_REGION']}.{AOSS_SVC_NAME}.amazonaws.com:443", 82 | index_name 83 | ) 84 | else: 85 | raise Exception("Unsupported LLM: ", sys.argv[1]) 86 | else: 87 | raise Exception("Usage: streamlit run app.py bedrock_claude") 88 | 89 | if 'chat_history' not in st.session_state: 90 | st.session_state['chat_history'] = [] 91 | 92 | if "chats" not in st.session_state: 93 | st.session_state.chats = [ 94 | { 95 | 'id': 0, 96 | 'question': '', 97 | 'answer': '' 98 | } 99 | ] 100 | 101 | if "questions" not in st.session_state: 102 | st.session_state.questions = [] 103 | 104 | if "answers" not in st.session_state: 105 | st.session_state.answers = [] 106 | 107 | if "input" not in st.session_state: 108 | st.session_state.input = "" 109 | 110 | 111 | st.markdown(""" 112 | 127 | """, unsafe_allow_html=True) 128 | 129 | 130 | def write_logo(): 131 | col1, col2, col3 = st.columns([5, 1, 5]) 132 | with col2: 133 | st.image(AI_ICON, use_column_width='always') 134 | 135 | 136 | def write_top_bar(): 137 | col1, col2, col3 = st.columns([1,10,2]) 138 | with col1: 139 | st.image(AI_ICON, use_column_width='always') 140 | with col2: 141 | selected_provider = sys.argv[1] 142 | provider = selected_provider.capitalize() 143 | header = f"An AI App powered by OpenSearch and {provider}!" 144 | st.write(f"

{header}

", unsafe_allow_html=True) 145 | with col3: 146 | clear = st.button("Clear Chat") 147 | return clear 148 | 149 | 150 | clear = write_top_bar() 151 | 152 | if clear: 153 | st.session_state.questions = [] 154 | st.session_state.answers = [] 155 | st.session_state.input = "" 156 | st.session_state["chat_history"] = [] 157 | 158 | 159 | def handle_input(): 160 | input = st.session_state.input 161 | question_with_id = { 162 | 'question': input, 163 | 'id': len(st.session_state.questions) 164 | } 165 | st.session_state.questions.append(question_with_id) 166 | 167 | chat_history = st.session_state["chat_history"] 168 | if len(chat_history) == MAX_HISTORY_LENGTH: 169 | chat_history = chat_history[:-1] 170 | 171 | llm_chain = st.session_state['llm_chain'] 172 | chain = st.session_state['llm_app'] 173 | result = chain.run_chain(llm_chain, input, chat_history) 174 | answer = result['answer'] 175 | chat_history.append((input, answer)) 176 | 177 | document_list = [] 178 | if 'source_documents' in result: 179 | for d in result['source_documents']: 180 | if not (d.metadata['source'] in document_list): 181 | document_list.append((d.metadata['source'])) 182 | 183 | st.session_state.answers.append({ 184 | 'answer': result, 185 | 'sources': document_list, 186 | 'id': len(st.session_state.questions) 187 | }) 188 | st.session_state.input = "" 189 | 190 | 191 | def write_user_message(md): 192 | col1, col2 = st.columns([1,12]) 193 | 194 | with col1: 195 | st.image(USER_ICON, use_column_width='always') 196 | with col2: 197 | st.warning(md['question']) 198 | 199 | 200 | def render_result(result): 201 | answer, sources = st.tabs(['Answer', 'Sources']) 202 | with answer: 203 | render_answer(result['answer']) 204 | with sources: 205 | if 'source_documents' in result: 206 | render_sources(result['source_documents']) 207 | else: 208 | render_sources([]) 209 | 210 | 211 | def render_answer(answer): 212 | col1, col2 = st.columns([1,12]) 213 | with col1: 214 | st.image(AI_ICON, use_column_width='always') 215 | with col2: 216 | st.info(answer['answer']) 217 | 218 | 219 | def render_sources(sources): 220 | col1, col2 = st.columns([1,12]) 221 | with col2: 222 | with st.expander("Sources"): 223 | for s in sources: 224 | st.write(s) 225 | 226 | 227 | #Each answer will have context of the question asked in order to associate the provided feedback with the respective question 228 | def write_chat_message(md, q): 229 | chat = st.container() 230 | with chat: 231 | render_answer(md['answer']) 232 | render_sources(md['sources']) 233 | 234 | 235 | with st.container(): 236 | for (q, a) in zip(st.session_state.questions, st.session_state.answers): 237 | write_user_message(q) 238 | write_chat_message(a, q) 239 | 240 | 241 | st.markdown('---') 242 | input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) 243 | -------------------------------------------------------------------------------- /rag-app/app_init.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import toml 5 | 6 | import helper_functions as hfn 7 | 8 | 9 | class MissingEnvironmentVariable(Exception): 10 | """Raised if a required environment variable is missing""" 11 | 12 | 13 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 14 | 15 | DEFAULT_LOG_LEVEL = logging.INFO 16 | LOGGER = logging.getLogger(__name__) 17 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 18 | "[%(name)s]:[%(threadName)s] " \ 19 | "%(message)s" 20 | 21 | 22 | if __name__ == "__main__": 23 | 24 | streamlit_secrets = {} 25 | 26 | # logging configuration 27 | log_level = DEFAULT_LOG_LEVEL 28 | if os.environ.get("VERBOSE", "").lower() == "true": 29 | log_level = logging.DEBUG 30 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 31 | 32 | # open ai api key fetch 33 | openai_secret = os.environ.get(API_KEY_SECRET_ENV_VAR) 34 | if not openai_secret: 35 | raise MissingEnvironmentVariable(f"{API_KEY_SECRET_ENV_VAR} environment variable is required") 36 | streamlit_secrets["OPENAI_API_KEY"] = hfn.get_secret_from_name(openai_secret, kv=False) 37 | 38 | LOGGER.info("Writing streamlit secrets") 39 | with open("/root/.streamlit/secrets.toml", "w") as file: 40 | toml.dump(streamlit_secrets, file) 41 | -------------------------------------------------------------------------------- /rag-app/helper_functions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from botocore.exceptions import ClientError 7 | 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | def _check_missing_field(validation_dict, extraction_key): 13 | """Check if a field exists in a dictionary 14 | 15 | :param validation_dict: Dictionary 16 | :param extraction_key: String 17 | 18 | :raises: KeyError 19 | """ 20 | extracted_value = validation_dict.get(extraction_key) 21 | 22 | if not extracted_value: 23 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 24 | raise KeyError 25 | 26 | 27 | def get_secret_from_name(secret_name, kv=True): 28 | """Return secret from secret name 29 | 30 | :param secret_name: String 31 | :param kv: Boolean (weather it is json or not) 32 | 33 | :raises: botocore.exceptions.ClientError 34 | 35 | :rtype: Dictionary 36 | """ 37 | session = boto3.session.Session() 38 | 39 | # Initializing Secret Manager's client 40 | client = session.client( 41 | service_name='secretsmanager', 42 | region_name=os.environ.get("AWS_REGION", session.region_name) 43 | ) 44 | LOGGER.info(f"Attempting to get secret value for: {secret_name}") 45 | try: 46 | get_secret_value_response = client.get_secret_value( 47 | SecretId=secret_name) 48 | except ClientError as e: 49 | # For a list of exceptions thrown, see 50 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 51 | LOGGER.error("Unable to fetch details from Secrets Manager") 52 | raise e 53 | 54 | _check_missing_field( 55 | get_secret_value_response, "SecretString") 56 | 57 | if kv: 58 | return json.loads( 59 | get_secret_value_response["SecretString"]) 60 | else: 61 | return get_secret_value_response["SecretString"] 62 | -------------------------------------------------------------------------------- /rag-app/images/ai-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/rag-app/images/ai-icon.png -------------------------------------------------------------------------------- /rag-app/images/user-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/rag-app/images/user-icon.png -------------------------------------------------------------------------------- /rag-app/requirements.txt: -------------------------------------------------------------------------------- 1 | langchain 2 | langchain-community 3 | langchain-openai 4 | boto3>=1.28.27 5 | opensearch-py 6 | unstructured 7 | openai 8 | anthropic 9 | streamlit 10 | tiktoken 11 | toml 12 | -------------------------------------------------------------------------------- /rag-app/run_app.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "initializing app.." 5 | python3 app_init.py 6 | 7 | echo "starting streamlit app" 8 | streamlit run app.py bedrock_claude 9 | -------------------------------------------------------------------------------- /screenshots/aoss_api_dash.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/screenshots/aoss_api_dash.png -------------------------------------------------------------------------------- /screenshots/aoss_dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/screenshots/aoss_dashboard.png -------------------------------------------------------------------------------- /screenshots/app_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/screenshots/app_screenshot.png -------------------------------------------------------------------------------- /screenshots/cog_login_page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/screenshots/cog_login_page.png -------------------------------------------------------------------------------- /screenshots/invalid_cert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-opensearch/cf84d54a5e2c2ea84f7116b58244df9a3edba5cb/screenshots/invalid_cert.png -------------------------------------------------------------------------------- /scripts/api-key-secret-manager-upload/api-key-secret-manager-upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import getpass 3 | import logging 4 | import time 5 | 6 | import boto3 7 | 8 | 9 | DEFAULT_LOG_LEVEL = logging.INFO 10 | LOGGER = logging.getLogger(__name__) 11 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 12 | "[%(name)s]:[%(threadName)s] " \ 13 | "%(message)s" 14 | 15 | AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" 16 | AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" 17 | AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" 18 | 19 | 20 | def _check_missing_field(validation_dict, extraction_key): 21 | """Check if a field exists in a dictionary 22 | 23 | :param validation_dict: Dictionary 24 | :param extraction_key: String 25 | 26 | :raises: Exception 27 | """ 28 | extracted_value = validation_dict.get(extraction_key) 29 | 30 | if not extracted_value: 31 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 32 | raise Exception 33 | 34 | 35 | def _validate_field(validation_dict, extraction_key, expected_value): 36 | """Validate the passed in field 37 | 38 | :param validation_dict: Dictionary 39 | :param extraction_key: String 40 | :param expected_value: String 41 | 42 | :raises: ValueError 43 | """ 44 | extracted_value = validation_dict.get(extraction_key) 45 | _check_missing_field(validation_dict, extraction_key) 46 | 47 | if extracted_value != expected_value: 48 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 49 | raise ValueError 50 | 51 | 52 | def _cli_args(): 53 | """Parse CLI Args 54 | 55 | :rtype: argparse.Namespace 56 | """ 57 | parser = argparse.ArgumentParser(description="api-key-secret-manager-upload") 58 | 59 | parser.add_argument("-s", 60 | "--secret-name", 61 | type=str, 62 | help="Secret Name", 63 | required=True 64 | ) 65 | parser.add_argument("-p", 66 | "--aws-profile", 67 | type=str, 68 | default="default", 69 | help="AWS profile to be used for the API calls") 70 | parser.add_argument("-v", 71 | "--verbose", 72 | action="store_true", 73 | help="debug log output") 74 | parser.add_argument("-e", 75 | "--env", 76 | action="store_true", 77 | help="Use environment variables for AWS credentials") 78 | return parser.parse_args() 79 | 80 | 81 | def _silence_noisy_loggers(): 82 | """Silence chatty libraries for better logging""" 83 | for logger in ['boto3', 'botocore', 84 | 'botocore.vendored.requests.packages.urllib3']: 85 | logging.getLogger(logger).setLevel(logging.WARNING) 86 | 87 | 88 | def main(): 89 | """What executes when the script is run""" 90 | start = time.time() # to capture elapsed time 91 | 92 | args = _cli_args() 93 | 94 | # logging configuration 95 | log_level = DEFAULT_LOG_LEVEL 96 | if args.verbose: 97 | log_level = logging.DEBUG 98 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 99 | # silence chatty libraries 100 | _silence_noisy_loggers() 101 | 102 | if args.env: 103 | LOGGER.info( 104 | "Attempting to fetch AWS credentials via environment variables") 105 | aws_access_key_id = os.environ.get(AWS_ACCESS_KEY_ID) 106 | aws_secret_access_key = os.environ.get(AWS_SECRET_ACCESS_KEY) 107 | aws_session_token = os.environ.get(AWS_SESSION_TOKEN) 108 | if not aws_secret_access_key or not aws_access_key_id or not aws_session_token: 109 | raise Exception( 110 | f"Missing one or more environment variables - " 111 | f"'{AWS_ACCESS_KEY_ID}', '{AWS_SECRET_ACCESS_KEY}', " 112 | f"'{AWS_SESSION_TOKEN}'" 113 | ) 114 | else: 115 | LOGGER.info(f"AWS Profile being used: {args.aws_profile}") 116 | boto3.setup_default_session(profile_name=args.aws_profile) 117 | 118 | sm_client = boto3.client("secretsmanager") 119 | 120 | LOGGER.info(f"Updating Secret: {args.secret_name}") 121 | 122 | resp = sm_client.update_secret( 123 | SecretId=args.secret_name, 124 | SecretString=getpass.getpass("Please enter the API Key: ") 125 | ) 126 | _check_missing_field(resp, "ResponseMetadata") 127 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 128 | LOGGER.info("Successfully updated secret value") 129 | 130 | LOGGER.debug("Closing secretsmanager boto3 client") 131 | sm_client.close() 132 | 133 | LOGGER.info(f"Total time elapsed: {time.time() - start} seconds") 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /scripts/api-key-secret-manager-upload/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/default_cert_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "country": "US", 3 | "state": "Pennsylvania", 4 | "locality": "Hatfield", 5 | "organization": "MyCo", 6 | "organizational_unit": "Infra", 7 | "email_address": "infra at myco dot com" 8 | } -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | pyopenssl 3 | validators 4 | -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/self-signed-cert-utility.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pathlib 6 | import time 7 | 8 | import boto3 9 | from OpenSSL import crypto 10 | import validators 11 | from validators import ValidationError 12 | 13 | 14 | DEFAULT_LOG_LEVEL = logging.INFO 15 | LOGGER = logging.getLogger(__name__) 16 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 17 | "[%(name)s]:[%(threadName)s] " \ 18 | "%(message)s" 19 | 20 | CERT_FILE = "cert.pem" 21 | KEY_FILE = "key.pem" 22 | 23 | DEFAULT_APP_DOMAIN = "aoss.rag" 24 | 25 | AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" 26 | AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" 27 | AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" 28 | 29 | 30 | def generate_ssl_keys(key_config): 31 | """Generate ssl keys 32 | 33 | Solution taken from: 34 | https://stackoverflow.com/questions/27164354/create-a-self-signed-x509-certificate-in-python 35 | 36 | :param data_dir_obj: DataDir 37 | :param key_config: Dictionary 38 | 39 | :rtype: Dictionary 40 | """ 41 | ssl_dir = f"{pathlib.Path.cwd()}/.ssl" 42 | if not os.path.isdir(ssl_dir): 43 | os.mkdir(ssl_dir) 44 | LOGGER.info("Created ssl dir: %s", ssl_dir) 45 | else: 46 | LOGGER.info(f"ssl dir: {ssl_dir} already exists") 47 | 48 | ssl_cert = f"{ssl_dir}/{os.environ.get('CERT_FILE', CERT_FILE)}" 49 | ssl_pem = f"{ssl_dir}/{os.environ.get('KEY_FILE', KEY_FILE)}" 50 | 51 | LOGGER.debug("Creating a key pair") 52 | ssl_key = crypto.PKey() 53 | ssl_key.generate_key(crypto.TYPE_RSA, 2048) 54 | 55 | LOGGER.debug("Creating a self-signed cert") 56 | cert = crypto.X509() 57 | cert.get_subject().countryName = key_config["country"] 58 | cert.get_subject().stateOrProvinceName = key_config["state"] 59 | cert.get_subject().localityName = key_config["locality"] 60 | cert.get_subject().organizationName = key_config["organization"] 61 | cert.get_subject().organizationalUnitName = \ 62 | key_config["organizational_unit"] 63 | domain = os.environ.get("APP_DOMAIN", DEFAULT_APP_DOMAIN) 64 | if not domain: 65 | raise Exception("Missing 'APP_DOMAIN' environement variable'") 66 | try: 67 | valid_domain = validators.domain(domain) 68 | if not valid_domain: 69 | LOGGER.error(f"'{domain}' could not be validated as a domain") 70 | raise Exception 71 | except ValidationError as ve: 72 | LOGGER.error(f"'{domain}' could not be validated as a domain") 73 | raise Exception 74 | 75 | cert.get_subject().commonName = domain 76 | 77 | cert.set_serial_number(1000) 78 | cert.gmtime_adj_notBefore(0) 79 | cert.gmtime_adj_notAfter(10*365*24*60*60) 80 | cert.set_issuer(cert.get_subject()) 81 | cert.set_pubkey(ssl_key) 82 | cert.sign(ssl_key, "sha1") 83 | 84 | cert_body = crypto.dump_certificate( 85 | crypto.FILETYPE_PEM, cert).decode("utf-8") 86 | LOGGER.info("Writing self-signed cert to: %s", ssl_cert) 87 | with open(ssl_cert, "wt") as cert_writer: 88 | cert_writer.write(cert_body) 89 | 90 | key_body = crypto.dump_privatekey( 91 | crypto.FILETYPE_PEM, ssl_key).decode("utf-8") 92 | LOGGER.info("Writing self-signed cert to: %s", ssl_pem) 93 | with open(ssl_pem, "wt") as key_writer: 94 | key_writer.write(key_body) 95 | return { 96 | "key": key_body, 97 | "cert": cert_body 98 | } 99 | 100 | 101 | def _validate_config_file_path(file_path): 102 | """Checks if passed in file path is valid or not 103 | 104 | :file_path: String 105 | 106 | :raises: FileNotFoundError 107 | """ 108 | LOGGER.info(f"Config file path: {file_path}") 109 | if not os.path.isfile(file_path): 110 | LOGGER.error("Config file provided is not found") 111 | raise FileNotFoundError 112 | else: 113 | LOGGER.debug("File path is valid") 114 | return 115 | 116 | 117 | def _parse_key_details_file(args): 118 | """Parse json file containing key details 119 | 120 | :param args: argparse.Namespace (CLI args) 121 | 122 | :rtype: List 123 | """ 124 | absolute_file_path = pathlib.Path(args.config_file).resolve() 125 | _validate_config_file_path(absolute_file_path) 126 | 127 | with open(absolute_file_path) as config_file: 128 | try: 129 | return json.load(config_file) 130 | except json.JSONDecodeError as e: 131 | LOGGER.error("Configuration file is not valid JSON") 132 | raise TypeError 133 | 134 | 135 | def _check_missing_field(validation_dict, extraction_key): 136 | """Check if a field exists in a dictionary 137 | 138 | :param validation_dict: Dictionary 139 | :param extraction_key: String 140 | 141 | :raises: KeyError 142 | """ 143 | extracted_value = validation_dict.get(extraction_key) 144 | 145 | if not extracted_value: 146 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 147 | raise KeyError 148 | 149 | 150 | def _validate_field(validation_dict, extraction_key, expected_value): 151 | """Validate the passed in field 152 | 153 | :param validation_dict: Dictionary 154 | :param extraction_key: String 155 | :param expected_value: String 156 | 157 | :raises: ValueError 158 | """ 159 | extracted_value = validation_dict.get(extraction_key) 160 | _check_missing_field(validation_dict, extraction_key) 161 | 162 | if extracted_value != expected_value: 163 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 164 | raise ValueError 165 | 166 | 167 | def _cli_args(): 168 | """Parse CLI Args 169 | 170 | :rtype: argparse.Namespace 171 | """ 172 | parser = argparse.ArgumentParser(description="self-signed-cert-utility") 173 | parser.add_argument("-p", 174 | "--aws-profile", 175 | type=str, 176 | default="default", 177 | help="AWS profile to be used for the API calls") 178 | parser.add_argument("-f", 179 | "--config-file", 180 | type=str, 181 | default="default_cert_params.json", 182 | help="path to configuration file") 183 | parser.add_argument("-v", 184 | "--verbose", 185 | action="store_true", 186 | help="debug log output") 187 | parser.add_argument("-e", 188 | "--env", 189 | action="store_true", 190 | help="Use environment variables for AWS credentials") 191 | return parser.parse_args() 192 | 193 | 194 | def _silence_noisy_loggers(): 195 | """Silence chatty libraries for better logging""" 196 | for logger in ['boto3', 'botocore', 197 | 'botocore.vendored.requests.packages.urllib3']: 198 | logging.getLogger(logger).setLevel(logging.WARNING) 199 | 200 | 201 | def main(): 202 | """What executes when the script is run""" 203 | start = time.time() # to capture elapsed time 204 | 205 | args = _cli_args() 206 | 207 | # logging configuration 208 | log_level = DEFAULT_LOG_LEVEL 209 | if args.verbose: 210 | log_level = logging.DEBUG 211 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 212 | # silence chatty libraries 213 | _silence_noisy_loggers() 214 | 215 | if args.env: 216 | LOGGER.info( 217 | "Attempting to fetch AWS credentials via environment variables") 218 | aws_access_key_id = os.environ.get(AWS_ACCESS_KEY_ID) 219 | aws_secret_access_key = os.environ.get(AWS_SECRET_ACCESS_KEY) 220 | aws_session_token = os.environ.get(AWS_SESSION_TOKEN) 221 | if not aws_secret_access_key or not aws_access_key_id or not aws_session_token: 222 | raise Exception( 223 | f"Missing one or more environment variables - " 224 | f"'{AWS_ACCESS_KEY_ID}', '{AWS_SECRET_ACCESS_KEY}', " 225 | f"'{AWS_SESSION_TOKEN}'" 226 | ) 227 | else: 228 | LOGGER.info(f"AWS Profile being used: {args.aws_profile}") 229 | boto3.setup_default_session(profile_name=args.aws_profile) 230 | 231 | cert_name = os.environ.get("IAM_SELF_SIGNED_SERVER_CERT_NAME") 232 | if not cert_name: 233 | LOGGER.error( 234 | "Need to export the 'IAM_SELF_SIGNED_SERVER_CERT_NAME' env var") 235 | raise Exception 236 | 237 | key_config = _parse_key_details_file(args) 238 | 239 | cert_files = generate_ssl_keys(key_config) 240 | 241 | iam_client = boto3.client("iam") 242 | 243 | resp = iam_client.upload_server_certificate( 244 | ServerCertificateName=cert_name, 245 | CertificateBody=cert_files["cert"], 246 | PrivateKey=cert_files["key"] 247 | ) 248 | _check_missing_field(resp, "ResponseMetadata") 249 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 250 | print(resp) 251 | 252 | LOGGER.debug("Closing iam client") 253 | iam_client.close() 254 | 255 | LOGGER.info(f"Total time elapsed: {time.time() - start} seconds") 256 | 257 | 258 | if __name__ == "__main__": 259 | main() -------------------------------------------------------------------------------- /test/rag-with-amazon-bedrock-and-opensearch.test.ts: -------------------------------------------------------------------------------- 1 | // import * as cdk from 'aws-cdk-lib'; 2 | // import { Template } from 'aws-cdk-lib/assertions'; 3 | // import * as RagWithAmazonBedrockAndOpensearch from '../lib/rag-with-amazon-bedrock-and-opensearch-stack'; 4 | 5 | // example test. To run these tests, uncomment this file along with the 6 | // example resource in lib/rag-with-amazon-bedrock-and-opensearch-stack.ts 7 | test('SQS Queue Created', () => { 8 | // const app = new cdk.App(); 9 | // // WHEN 10 | // const stack = new RagWithAmazonBedrockAndOpensearch.RagWithAmazonBedrockAndOpensearchStack(app, 'MyTestStack'); 11 | // // THEN 12 | // const template = Template.fromStack(stack); 13 | 14 | // template.hasResourceProperties('AWS::SQS::Queue', { 15 | // VisibilityTimeout: 300 16 | // }); 17 | }); 18 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "module": "commonjs", 5 | "lib": [ 6 | "es2020", 7 | "dom" 8 | ], 9 | "declaration": true, 10 | "strict": true, 11 | "noImplicitAny": true, 12 | "strictNullChecks": true, 13 | "noImplicitThis": true, 14 | "alwaysStrict": true, 15 | "noUnusedLocals": false, 16 | "noUnusedParameters": false, 17 | "noImplicitReturns": true, 18 | "noFallthroughCasesInSwitch": false, 19 | "inlineSourceMap": true, 20 | "inlineSources": true, 21 | "experimentalDecorators": true, 22 | "strictPropertyInitialization": false, 23 | "typeRoots": [ 24 | "./node_modules/@types" 25 | ] 26 | }, 27 | "exclude": [ 28 | "node_modules", 29 | "cdk.out" 30 | ] 31 | } 32 | --------------------------------------------------------------------------------