├── .gitignore ├── .npmignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── architecture ├── ddl_automation_arch.png └── rag_app_arch.png ├── bin └── rag-with-amazon-bedrock-and-rds-pgvector.ts ├── cdk.json ├── jest.config.js ├── knowledgebase └── .gitignore ├── lambda ├── app-client-create-trigger │ └── app.py ├── call-back-url-init │ └── app.py ├── call-back-url-update │ └── app.py ├── pdf-processor │ ├── Dockerfile │ ├── lambda_function.py │ └── requirements.txt ├── pgvector-trigger │ └── app.py ├── pgvector-update │ ├── Dockerfile │ ├── app.py │ └── requirements.txt ├── rds-ddl-change │ ├── Dockerfile │ ├── app.py │ └── requirements.txt ├── rds-ddl-init │ ├── Dockerfile │ ├── app.py │ └── requirements.txt └── rds-ddl-trigger │ └── app.py ├── lib ├── base-infra-stack.ts ├── ddl-source-rds-stack.ts ├── pgvector-update-stack.ts ├── rag-app-stack.ts ├── rds-ddl-automation-stack.ts ├── rds-stack.ts └── test-compute-stack.ts ├── package-lock.json ├── package.json ├── rag-app ├── Dockerfile ├── __init__.py ├── app.py ├── app_init.py ├── helper_functions.py ├── images │ ├── ai-icon.png │ └── user-icon.png ├── pgvec_chat_bedrock.py ├── requirements.txt └── run_app.sh ├── screenshots ├── app_screenshot.png ├── cog_login_page.png └── invalid_cert.png ├── scripts ├── api-key-secret-manager-upload │ ├── api-key-secret-manager-upload.py │ └── requirements.txt ├── rds-ddl-sql │ └── rds-ddl.sql └── self-signed-cert-utility │ ├── .gitignore │ ├── default_cert_params.json │ ├── requirements.txt │ └── self-signed-cert-utility.py ├── test └── rag-with-amazon-bedrock-and-rds-pgvector.test.ts └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | *.js 2 | !jest.config.js 3 | *.d.ts 4 | node_modules 5 | 6 | # CDK asset staging directory 7 | .cdk.staging 8 | cdk.out 9 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | *.ts 2 | !*.d.ts 3 | 4 | # CDK asset staging directory 5 | .cdk.staging 6 | cdk.out 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAG with Amazon Bedrock and PGVector on Amazon RDS 2 | 3 | Opinionated sample on how to configure and deploy [RAG (Retrieval Augmented Generation)](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) application. 4 | 5 | It is comprised of a few core pieces: 6 | 7 | * [Amazon Bedrock](https://aws.amazon.com/bedrock/) as the managed service providing easy API based access to [foundation models (FMs)](https://aws.amazon.com/what-is/foundation-models/). 8 | 9 | * [Amazon Relational Database Service (RDS)](https://aws.amazon.com/rds/) + [PGVector](https://github.com/pgvector/pgvector) as a Vector database. This is an open-source alternative to using [Amazon Kendra](https://aws.amazon.com/kendra/). 10 | 11 | * [LangChain](https://www.langchain.com/) as a [Large Language Model (LLM)](https://www.elastic.co/what-is/large-language-models) application framework. It has also been used to update PGVector when new documents get added to the knowledgebase S3 bucket. 12 | 13 | * [Amazon Elastic Container Service (ECS)](https://aws.amazon.com/ecs/) to run the RAG Application. 14 | 15 | * [Streamlit](https://streamlit.io/) for the frontent user interface of the RAG Application. 16 | 17 | * [Application Load Balancer](https://aws.amazon.com/elasticloadbalancing/application-load-balancer/) to route HTTPS traffic to the ECS service (which is running the RAG App). 18 | 19 | * [Amazon Cognito](https://aws.amazon.com/cognito/) for secure user authentication. 20 | 21 | ## Architecture 22 | 23 | ![Architecture](./architecture/rag_app_arch.png) 24 | 25 | ## Short note on vector data stores 26 | 27 | [Vector database](https://en.wikipedia.org/wiki/Vector_database) is an essential component of any RAG application. The LLM framework uses the vector data store to search for information based on the question that comes from the user. 28 | 29 | Typical assumption (*and a strong constraint on this sample project*) is that a knowledgebase would comprise of PDF documents stored somewhere. Ideally, a true knowledgebase would encompass a lot more - would scrape websites, wiki pages and so on. But to limit the scope of this sample, the knowledgebase is an [S3](https://aws.amazon.com/s3/) bucket containing a bunch of PDF documents. 30 | 31 | A popular choice for vector database in an AWS based RAG app is Amazon Kendra. It does [optical character recognition (OCR)](https://en.wikipedia.org/wiki/Optical_character_recognition) for PDFs under the hood. It is a fully managed search service with seemless integration with AWS Services like S3. Additionally, Amazon Bedrock also has a vector database offering in the form of ["Knowledgebases"](https://aws.amazon.com/bedrock/knowledge-bases/). 32 | 33 | NOTE - "Bedrock Knowledgebases" is another vector store offering; and **it should not** be confused with the term "knowledgebase" and/or "knowledgebase bucket" which refers to the S3 bucket containing PDF documents in this project. 34 | 35 | However, the purpose of this sample was to show how to set up an open-source vector database, and since Kendra and Bedrock Knowledgebases are not open source, this sample is focusing on PGVector (*running on Amazon RDS*). Unlike Kendra, PGVector cannot directly query PDF documents, so we need to extract the text, and then feed the text to PGVector. 36 | 37 | ## PGVector orchestration 38 | 39 | The expectation is that PDF files will land in the knowledgebase S3 bucket - either by manually uploading it via the console, or programmatically via the [AWS CLI](https://aws.amazon.com/cli/) or by running `cdk deploy BaseInfraStack`. NOTE - the last option (*`cdk deploy`*) requires that you put the PDF files in the ["knowledgebase"](./knowledgebase/) directory of this project. The [S3 Bucket Deployment](https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_s3_deployment.BucketDeployment.html) construct will then upload these files to the knowledgebase bucket. 40 | 41 | Once the files land in the knowledgebase S3 bucket, [S3 Event Notifications](https://docs.aws.amazon.com/AmazonS3/latest/userguide/EventNotifications.html) initiate a [lambda](https://docs.aws.amazon.com/lambda/latest/dg/welcome.html) function to extract text from the PDF file(s), and upload the converted text files into the "processed text S3 Bucket". The code/logic for this conversion [lambda function](./lambda/pdf-processor/lambda_function.py) is in the [lambda/pdf-processor](./lambda/pdf-processor/) directory. The function uses the [pypdf](https://github.com/py-pdf/pypdf) Python Library to achieve the text extraction. 42 | 43 | After the processed text files land in the "processed text S3 bucket", another S3 Event Notification triggers another lambda function ([pgvector-trigger](./lambda/pgvector-trigger/app.py)) that extract the necessary information about the file and pushes it off to an [Amazon SQS](https://aws.amazon.com/sqs/) queue. 44 | 45 | That message push in the SQS, initiates another lambda function ([pgvector-update](./lambda/pgvector-update/app.py)) that finally updates the vector database with the contents of the processed text file to be indexed (*which will enable it to be searched by the RAG app*). This lambda function uses LangChain to [add documents to PGVector](https://python.langchain.com/docs/integrations/vectorstores/pgvector#add-documents). Additionally it uses the [S3FileLoader](https://python.langchain.com/docs/integrations/document_loaders/aws_s3_file) component from LangChain to extract document contents to feed PGVector. 46 | 47 | ### Short note on Embeddings 48 | 49 | [Embeddings](https://www.elastic.co/what-is/vector-embedding) are a way to convert words and sentences into numbers that capture their meaning and relationships. In the context of RAG, these "vector embeddings" aid in ["similarity search"](https://en.wikipedia.org/wiki/Similarity_search) capabilities. Adding documents to PGVector also requires creation/provisioning of embeddings. This project/sample has utilized [OpenAI's Embeddings](https://platform.openai.com/docs/guides/embeddings). So, if you wish to build/run this app in your own AWS environment, you would need to create an account with OpenAI and need their [API Key](https://help.openai.com/en/articles/4936850-where-do-i-find-my-api-key). 50 | 51 | **OpenAI has its own pricing on its API usage** so be mindful of that. You can find that out on their [pricing page](https://openai.com/pricing). You should be able to get going with the free credits, but if you keep this app running long enough, it will start accruing additional charges. 52 | 53 | Some other options to obtain embeddings - 54 | * [HuggingFace](https://huggingface.co/blog/getting-started-with-embeddings) 55 | * [Amazon Titan](https://aws.amazon.com/about-aws/whats-new/2023/09/amazon-titan-embeddings-generally-available/) 56 | 57 | NOTE - If you wish to use alternative embeddings, you will need to change the code in the [rag-app](./rag-app/) and the [pgvector-update lambda function](./lambda/pgvector-update/) accordingly. 58 | 59 | ## Deploying the app 60 | 61 | This project is divided into a few sub-stacks, so deploying it also requires a few additional steps. It uses [AWS CDK](https://aws.amazon.com/cdk/) for [Infrastructure as Code (IaC)](https://en.wikipedia.org/wiki/Infrastructure_as_code). 62 | 63 | ### Pre-requisites 64 | 65 | * Since this is a [TypeScript](https://www.typescriptlang.org/) CDK project, you should have [npm](https://www.npmjs.com/) installed (which is the package manager for TypeScript/JavaScript). 66 | * You can find installation instructions for npm [here](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm). 67 | 68 | * Install [AWS CLI](https://aws.amazon.com/cli/) on your computer (*if not already done so*). 69 | * `pip install awscli`. This means need to have python installed on your computer (if it is not already installed.) 70 | * You need to also configure and authenticate your AWS CLI to be able to interact with AWS programmatically. Detailed instructions of how you could do that are provided [here](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html) 71 | 72 | * You need to have [docker](https://www.docker.com/) installed on your computer. 73 | * You can check out these options for building and running docker containers on your local machine: 74 | * [Docker desktop](https://www.docker.com/products/docker-desktop/). Most popular container management app. Note - it does require a license if the organization you work at is bigger than a certain threshold. 75 | * [Rancher desktop](https://rancherdesktop.io/). It is a popular open source container management tool. 76 | * [Finch](https://github.com/runfinch/finch). Another open-source tool for container management.Note - currently it only supports MacOS machines. 77 | 78 | * Have an API Key from [OpenAI](https://openai.com/). This key is needed for programmatic access to use their embeddings for PGVector. You need to create an account with OpenAI (*if you already don't have one already*). Details to find/create an API Key can be found [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-api-key). 79 | 80 | ### Create a self-signed SSL certificate 81 | 82 | * Set the `IAM_SELF_SIGNED_SERVER_CERT_NAME` environment variable. This is the name of the self-signed server certificate that will be created ([via IAM](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_server-certs.html)) as part of the deployment. 83 | 84 | ``` 85 | export IAM_SELF_SIGNED_SERVER_CERT_NAME= 86 | ``` 87 | 88 | * Run the [self-signed-cert-utility.py](./scripts/self-signed-cert-utility/self-signed-cert-utility.py) script in the [scripts](./scripts/) directory to create a self-signed certificate, and upload its contents to AWS via `boto3` API calls. 89 | 90 | This is needed because the Application Load Balancer requires [SSL certificates](https://www.cloudflare.com/en-gb/learning/ssl/what-is-an-ssl-certificate/) to have a functioning [HTTPS listener](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/create-https-listener.html). 91 | 92 | 93 | # switch to the self-signed-cert-utility directory 94 | cd scripts/self-signed-cert-utility 95 | 96 | # create a python3 virtual environment (highly recommended) 97 | python3 -m virtualenv .certenv 98 | 99 | # activate the virtual environment 100 | source .certenv/bin/activate 101 | # for a different shell like fish, just add a `.fish` at the end of the previous command 102 | 103 | # install requirements 104 | pip install -r requirements.txt 105 | 106 | # run the script 107 | python self-signed-cert-utility.py 108 | # optionally specify a `--profile` if you're not using the default AWS profile 109 | 110 | # deactivate virtual environment 111 | deactivate 112 | 113 | # return to the root directory of the project 114 | cd - 115 | 116 | 117 | If the script runs successfully, you should see a a JSON like object printed out in the log output with parameters like `ServerCertificateName`, `ServerCertificateId`, `Arn` etc. Moreover, the `HTTPStatusCode` should have the value `200`. 118 | 119 | The parameters encoded in the certificate are in a JSON file. By default it expects a file named ["default_cert_parameters.json"](./scripts/self-signed-cert-utility/default_cert_params.json) unless otherwise specified. You may change the values of the default JSON file if you wish to. If you wish to use your own config file (*instead of the default*), you can do so by specifying the `--config-file` parameter. 120 | 121 | You can also specify a custom domain for the certificate by setting the `APP_DOMAIN` environment variable. 122 | 123 | NOTE - an alternative would be to use the [AWS Certificates Manager](https://aws.amazon.com/certificate-manager/) but it requires additional steps (*in the form of creating and registering your own domain, involve [Route53 hosted zones](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/hosted-zones-working-with.html) etc*). And since the focus of this sample is to show deployment of a RAG app, and not registering domains etc. it does not get into configuring that bit. 124 | 125 | 126 | ### Define the domain name for the Cognito hosted UI 127 | 128 | Set the `COGNITO_DOMAIN_NAME` environment variable. This will be the domain of the [Cognito hosted UI](https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-app-integration.html) which will be used to "log-in" and/or "sign-up" into the app. 129 | ``` 130 | export COGNITO_DOMAIN_NAME= 131 | ``` 132 | 133 | ### Install dependencies (if not already done) 134 | 135 | ``` 136 | npm install 137 | ``` 138 | 139 | ### Bootstrap CDK environment (if not already done) 140 | 141 | Bootstrapping provisions resources in your environment such as an Amazon Simple Storage Service (Amazon S3) bucket for storing files and AWS Identity and Access Management (IAM) roles that grant permissions needed to perform deployments. These resources get provisioned in an AWS CloudFormation stack, called the bootstrap stack. It is usually named CDKToolkit. Like any AWS CloudFormation stack, it will appear in the AWS CloudFormation console of your environment once it has been deployed. More details can be found [here](https://docs.aws.amazon.com/cdk/v2/guide/bootstrapping.html). 142 | 143 | ``` 144 | npx cdk bootstrap 145 | 146 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 147 | ``` 148 | 149 | NOTE - you only need to do this once per account. If there are other CDK projects deployed in your AWS account, you won't need to do this. 150 | 151 | ### Set environment variable (if you are on an M1/M2 Mac) 152 | 153 | Depending on the architecture of your computer, you may need to set this environment variable for the docker container. This is because docker containers are dependent on the architecture of the host machine that is building/running them. 154 | 155 | **If your machine runs on the [x86](https://en.wikipedia.org/wiki/X86) architecture, you can ignore this step.** 156 | 157 | ``` 158 | export DOCKER_CONTAINER_PLATFORM_ARCH=arm 159 | ``` 160 | 161 | 162 | ### Deploy the BaseInfraStack 163 | 164 | ``` 165 | npx cdk deploy BaseInfraStack 166 | 167 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 168 | ``` 169 | 170 | This will deploy the base infrastructure - consisting of a VPC, Application Load Balancer for the app, S3 buckets (for knowledgebase, and the processed text), Lambda functions to process the PDF documents, some SQS queues for decoupling, a Secret credential for the OpenAI API key, Cognito user pool and some more bits and pieces of the cloud infrastructure. The CDK code for this is in the [lib](./lib) directory within the [base-infra-stack.ts](./lib/base-infra-stack.ts) file. 171 | 172 | ### Upload the OpenAI API key to Secrets Manager 173 | 174 | The secret was created after the deployment of the `BaseInfraStack` but the value inside it is not valid. You can either enter your OpenAI API key via the AWS Secrets Manager console; Or you could use the [api-key-secret-manager-upload.py](./scripts/api-key-secret-manager-upload/api-key-secret-manager-upload.py) script to do that for you. 175 | 176 | [AWS Secrets Manager](https://aws.amazon.com/secrets-manager/) is the recommended way to store credentials in AWS, as it provides API based access to credentials for databases etc. Since OpenAI (*the provider we are using the vector emebeddings from*) is an external service and has its own API keys, we need to manually upload that key to Secrets Manager so that the app infrastructure can access it securely. 177 | 178 | 179 | # switch to the api-key-secret-manager-upload directory 180 | cd scripts/api-key-secret-manager-upload 181 | 182 | # create a python3 virtual environment (highly recommended) 183 | python3 -m virtualenv .keyenv 184 | 185 | # activate the virtual environment 186 | source .keyenv/bin/activate 187 | # for a different shell like fish, just add a `.fish` at the end of the previous command 188 | 189 | # install requirements 190 | pip install -r requirements.txt 191 | 192 | # run the script; optionally specify a `--profile` if you're not using the default AWS profile 193 | python api-key-secret-manager-upload.py -s openAiApiKey 194 | 195 | 2024-01-14 19:42:59,341 INFO [__main__]:[MainThread] AWS Profile being used: default 196 | 2024-01-14 19:42:59,421 INFO [__main__]:[MainThread] Updating Secret: openAiApiKey 197 | Please enter the API Key: 198 | 2024-01-14 19:44:02,221 INFO [__main__]:[MainThread] Successfully updated secret value 199 | 2024-01-14 19:44:02,221 INFO [__main__]:[MainThread] Total time elapsed: 62.88090920448303 seconds 200 | # deactivate virtual environment 201 | deactivate 202 | 203 | # return to the root directory of the project 204 | cd - 205 | 206 | 207 | The script will prompt you to enter you OpenAI API key. It uses the [getpass](https://docs.python.org/3/library/getpass.html) Python library so that you don't have to enter it in plain text. 208 | 209 | NOTE - that the instructions specify `-s openAiApiKey`. It is the same name as defined in the [base-infra-stack.ts](./lib/base-infra-stack.ts?plain=1#L164). If you change the value there, you will need to change the value whilst running the script too. 210 | 211 | ### Deploy the RDS Stack 212 | 213 | ``` 214 | npx cdk deploy rdsStack 215 | 216 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 217 | ``` 218 | 219 | This will deploy an Amazon RDS instance in a private subnet of the VPC that was deployed as part of the "BaseInfraStack", and some security group (*and their associated egress and ingress rules*) The CDK code for the "rdsStack" is in the [rds-stack.ts](./lib/rds-stack.ts) file in the [lib](./lib/) directory. 220 | 221 | NOTE - you can also use [Amazon Aurora Serverless V2 (PostgreSQL)](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-serverless-v2.html) for PGVector. More details can be found [here](https://aws.amazon.com/about-aws/whats-new/2023/07/amazon-aurora-postgresql-pgvector-vector-storage-similarity-search/). 222 | 223 | The reason to not use Aurora Serverless was to also demonstrate [running Lambda functions in a VPC](https://docs.aws.amazon.com/lambda/latest/dg/configuration-vpc.html), and interacting with a database in a private subnet. 224 | 225 | ### Deploy the DDL Source Stack 226 | 227 | ``` 228 | npx cdk deploy ddlSourceStack 229 | 230 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 231 | ``` 232 | This will deploy an S3 Bucket that contains a DDL file (in this project's case, it is the [rds-ddl.sql](./scripts/rds-ddl-sql/rds-ddl.sql) in the [scripts](./scripts/) directory) via the S3 Bucket Deployment CDK construct. 233 | 234 | This is needed because the RDS instance that gets deployed does not have the PGVector extension (*it needs to be installed seprately*). The command is as simple as `CREATE EXTENSION vector;`. 235 | 236 | To avoid manually installing anything, the automation mechanism inspired from this [open-source project/sample](https://github.com/aws-samples/ddl-deployment-for-amazon-rds), was employed in this solution. The architecture of that automation pattern can be found [here](./architecture/ddl_automation_arch.png). 237 | 238 | Note that the sample used Aurora Serverless V1 (*instead of an RDS instance*). That version had access to the [RDS Data API](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/data-api.html). So this project's Lambda function(s) interacting with the database are running in a VPC as the database is in a private subnet without exposing any data API. Additionally, the lambda function(s) use the [psycopg2](https://pypi.org/project/psycopg2/) Python library to run queries against the database. 239 | 240 | The CDK code for the "ddlSourceStack" is in the [ddl-source-rds-stack.ts](./lib/ddl-source-rds-stack.ts) file in the [lib](./lib/) directory. 241 | 242 | ### Deploy the RDS DDL Automation Stack 243 | 244 | ``` 245 | npx cdk deploy ddlAutomationStack 246 | 247 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 248 | ``` 249 | This will deploy some lambda function(s) to orchestrate automatically initializing the RDS instance with the DDL SQL file (as part of the automation mechanism mentioned in the previous section). The architecture of that automation pattern can be found [here](./architecture/ddl_automation_arch.png). 250 | 251 | It features 2 lambda functions - 252 | * one that initializes the RDS Instance upon creation (*with the DDL SQL statements*). The code/logic for that function is in the [lambda/rds-ddl-init](./lambda/rds-ddl-init/) directory. 253 | * another that executes the DDL statements when there is a change to the DDL SQL file. The code/logic for that function is in the [lambda/rds-ddl-change](./lambda/rds-ddl-change/) directory 254 | 255 | The CDK code for the "ddlAutomationStack" is in the [rds-ddl-automation-stack.ts](./lib/rds-ddl-automation-stack.ts) file in the [lib](./lib/) directory. 256 | 257 | ### Deploy the PGVector Update Stack 258 | 259 | ``` 260 | npx cdk deploy PGVectorUpdateStack 261 | 262 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 263 | ``` 264 | 265 | This will deploy the lambda function which is responsible for updating the PGVector data store whenever there is a new document in the "processed text S3 Bucket". It leverages LangChain to do this operation. The code/logic for that lambda function is in the [lambda/pgvector-update](./lambda/pgvector-update/) directory. 266 | 267 | The CDK code for the "PGVectorUpdateStack" is in the [pgvector-update-stack.ts](./lib/pgvector-update-stack.ts) file in the [lib](./lib/) directory. 268 | 269 | ### Deploy the RAG App Stack 270 | 271 | ``` 272 | npx cdk deploy RagStack 273 | 274 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 275 | ``` 276 | 277 | This will deploy the ECS Fargate service running the code for the RAG Application. It will also add this service as a target to the Application Load Balancer defined in the "BaseInfraStack". The CDK infrastructure code for this stack is in the [rag-app-stack.ts](./lib/rag-app-stack.ts) file in the [lib](./lib/) directory. 278 | 279 | This app leverages LangChain for interacting with Bedrock and PGVector; and Streamlit for the frontend user interface. The application code is in the [rag-app](./rag-app/) directory. 280 | 281 | ### [Optional] Deploy the TestCompute Stack 282 | 283 | ``` 284 | npx cdk deploy TestComputeStack 285 | 286 | # You can optionally specify `--profile` at the end of that command if you wish to not use the default AWS profile. 287 | ``` 288 | 289 | This will deploy an EC2 instance capable of connecting to the RDS instance. 290 | 291 | This is optional. Since our RDS instance does not provide a query interface through the AWS console, it can be useful to have a Bastion EC2 host to interact with it, and maybe also test out some more LangChain functionality. Note - the EC2 instance is pretty bare bones, so you will need to install the necessary packages to be able to connect to the RDS instance and/or do anything else. 292 | 293 | ### Add some PDF documents to the knowledgebase S3 Bucket 294 | 295 | After deploying the "RagStack", the necessary infrastructure to have the app running is complete. That being said, you need to still populate the knowledgebase S3 bucket. 296 | 297 | You can either do it manually by going into the console and uploading some files. 298 | 299 | Or you can add some PDF files to the [knowledgebase](./knowledgebase/) directory in this project, and then run `npx cdk deploy BaseInfraStack`. The reason to prefer this option would be that you can then track knowledgebase documents in your source control (*if that is a requirement for your use case*). 300 | 301 | This should upload the document(s) to the Knowledgebase S3 bucket via the S3 Bucket Deployment construct. 302 | 303 | After the upload to the S3 knowledgebase bucket is complete, it will trigger the [pdf-processor](./lambda/pdf-processor/) lambda function to extract the text from the PDF and upload it to the "processed text s3 bucket". 304 | 305 | Upload to the processed text bucket will trigger the [pgvector-update](./lambda/pgvector-update/) lambda function to then add that document to the PGVector (*on RDS*) data store. 306 | You can verify that it has been added to the vector store by a couple of different ways: 307 | * Check the logs of the lambda function; if there are no errors, the document has probably successfully been indexed in PGVector. 308 | * However, to be certain, you can connect to the RDS instance via a Bastion Host(*the TestCompute Stack might come in handy here*): 309 | 310 | ``` 311 | # install postgresql (if not done already) 312 | sudo apt install postgresql 313 | 314 | # grab the connection details from the secrets manager, and then run 315 | psql -h -U postgres 316 | 317 | # enter the DB password when prompted 318 | Password for user postgres: 319 | 320 | # after connecting, run the following SQL query 321 | postgres=> SELECT * FROM langchain_pg_collection; 322 | 323 | # it should spit out a name, and a uuid ; 324 | # the name should be something like `pgvector-collection-..` 325 | # run the following SQL 326 | postgres=> SELECT uuid, COUNT(*) FROM langchain_pg_embedding WHERE collection_id = '' GROUP BY 1; 327 | 328 | # if the count is not 0 (i.e. a positve integer), that means the collection with the embedding has been addded successfully 329 | ``` 330 | 331 | ### Testing the RAG App 332 | 333 | After adding document(s) to the knowledgebase, you can now test the app. 334 | 335 | If you log into the AWS console, and find the Application Load Balancer (*under the EC2 section*) page, and select the load balancer that was created as part of the "BaseInfraStack", it should have a "DNS name". If you copy that name, and type `https://` in your browser, it should direct you to the app. 336 | 337 | Note - since we are using a self-signed SSL certificate (*via IAM Server Certificates*), you might see this warning on your browser (showing Chrome below): 338 | 339 | ![InvalidCert](./screenshots/invalid_cert.png) 340 | 341 | If you see that warning, click on advanced, and proceed to the URL, it should then direct you to the Login UI (*server via Cognito*): 342 | 343 | ![InvalidCert](./screenshots/cog_login_page.png) 344 | 345 | * You can either click sign-up and create a new user from this console (*note - you will have to verify the email you sign up with by entering the code that gets sent to that email*) 346 | 347 | * Alternatively you could create a user in the AWS Console (by navigating to the cognito service) 348 | 349 | * There is a programmatic way to create your user via the SDK; or you could use this [open-source helper utility](https://github.com/aws-samples/cognito-user-token-helper). 350 | 351 | Once you've successfully signed in (*or signed up*), you will see the UI, and you can start asking questions based on the document(s) you've uploaded. The example document used is the [2021 Amazon letter to the shareholder](https://www.aboutamazon.com/news/company-news/2021-letter-to-shareholders), and the question asked was "What is AWS?": 352 | ![AppScreen](./screenshots/app_screenshot.png) 353 | 354 | ## Miscellaneous notes / technical hiccups / recommendations 355 | 356 | * The **frontend user interface (UI)** was built using streamlit, and inspired by another [open source project](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples). 357 | 358 | * **Cognito Callback URL hiccup** - 359 | When creating an application load balancer via Infrastructure as Code (IaC), the DNS name is generated with some random characters (*that can be both UPPER CASE and lower case*). When configuring this with the Cognito User Pool Integration (*app client*), the DNS name is used for the Callback URL. The problem here is that Cognito does not like UPPER CASE characters, and whilst deploying this solution via IaC, there isn't much you can do about converting the DNS name to lower case (*because it is actually a token, and not the actual string value of the DNA name*). There is an [open Github issue](https://github.com/aws/aws-cdk/issues/11171) on this. 360 | 361 | So, in order to fix this, the project has Eventbridge triggers in place, that check for when the App integration client is created, a lambda function is invoked that pushes a message to an SQS queue, which invokes another Lambda function that updates the app client via the [update_user_pool_client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cognito-idp/client/update_user_pool_client.html) boto3 API call with the lower case DNS name in the callback URL. 362 | 363 | The code for the lambda function is in the [lambda/call-back-url-init](./lambda/call-back-url-init/) directory. 364 | 365 | * If you were to **deploy a solution like this in a production environment**, you would need to create and register your own domain to host the app. The recommendation would be to use **AWS Certificates Manager** to generate the self signed certificate, and link that with a **Route53 hosted zone**. More details can be found in [AWS Documentation](https://docs.aws.amazon.com/acm/latest/userguide/dns-validation.html). 366 | 367 | * While streamlit is good for quickly deploying UIs, it may not be best suited for production if the intent is to add more functionality to the app (*i.e. extending it beyong the RAG Chatbot app*). It may be worth looking at [AWS Amplify](https://aws.amazon.com/amplify). Decoupling the frontend from the backend could also introduce the possibility of running the backend as a Lambda Function with API Gateway. 368 | 369 | * Alternate vector embedding providers like HuggingFace and/or Amazon Titan would require some code changes (*specifically in the Lambda function(s) that update PGVector via LangChain, and the ECS application running the RAG app*). 370 | 371 | * The model used in this sample is [Anthropic's Claude V1 Instant](https://aws.amazon.com/bedrock/claude/). You can change the model by providing an environment variable `FOUNDATION_MODEL_ID` to the rag app in the [rag-app-stack.ts](./lib/rag-app-stack.ts?plain=1#L143). You can find the different model IDs on the [AWS Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html) 372 | 373 | * Key concepts / techniques covered in this sample - 374 | * PGVector as an Open source Vector database option for RAG applications 375 | * Running Lambda functions in VPCs to interact with RDS databases in private subnet(s) 376 | * Using LangChain to serve a RAG application, update PGVector 377 | * Application Load Balancer (ALB) + ECS Fargate Service to serve an app 378 | * Using self signed certificates to configure the HTTPS listener for the ALB 379 | * Integrating a Cognito Login UI with the ALB 380 | 381 | ## Generic CDK instructions 382 | 383 | This is a blank project for CDK development with TypeScript. 384 | 385 | The `cdk.json` file tells the CDK Toolkit how to execute your app. 386 | 387 | ## Useful commands 388 | 389 | * `npm run build` compile typescript to js 390 | * `npm run watch` watch for changes and compile 391 | * `npm run test` perform the jest unit tests 392 | * `cdk deploy` deploy this stack to your default AWS account/region 393 | * `cdk diff` compare deployed stack with current state 394 | * `cdk synth` emits the synthesized CloudFormation template 395 | -------------------------------------------------------------------------------- /architecture/ddl_automation_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/architecture/ddl_automation_arch.png -------------------------------------------------------------------------------- /architecture/rag_app_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/architecture/rag_app_arch.png -------------------------------------------------------------------------------- /bin/rag-with-amazon-bedrock-and-rds-pgvector.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | import 'source-map-support/register'; 3 | import * as cdk from 'aws-cdk-lib'; 4 | import { RagAppStack } from '../lib/rag-app-stack'; 5 | import { BaseInfraStack } from '../lib/base-infra-stack'; 6 | import { RDSStack } from '../lib/rds-stack'; 7 | import { DDLSourceRDSStack } from '../lib/ddl-source-rds-stack'; 8 | import { RdsDdlAutomationStack } from '../lib/rds-ddl-automation-stack'; 9 | import { TestComputeStack } from '../lib/test-compute-stack'; 10 | import { PGVectorUpdateStack } from '../lib/pgvector-update-stack'; 11 | 12 | 13 | const app = new cdk.App(); 14 | 15 | // contains vpc, 16 | const baseInfra = new BaseInfraStack(app, 'BaseInfraStack', { 17 | }); 18 | 19 | // contains the RDS instance and its associated security group 20 | const rds = new RDSStack(app, 'rdsStack', { 21 | vpc: baseInfra.vpc, 22 | sgLambda: baseInfra.lambdaSG, 23 | sgEc2: baseInfra.ec2SecGroup, 24 | ecsSecGroup: baseInfra.ecsTaskSecGroup, 25 | }); 26 | 27 | // contains s3 bucket containing the RDS DDL file 28 | const ddlSource = new DDLSourceRDSStack(app, 'ddlSourceStack', { 29 | rdsInstance: rds.dbInstance, 30 | }); 31 | 32 | // contains s3 bucket containing the RDS DDL file 33 | const ddlAutomation = new RdsDdlAutomationStack(app, 'ddlAutomationStack', { 34 | vpc: baseInfra.vpc, 35 | ddlTriggerQueue: baseInfra.rdsDdlTriggerQueue, 36 | dbName: rds.rdsDBName, 37 | ddlSourceS3Bucket: ddlSource.sourceS3Bucket, 38 | rdsInstance: rds.dbInstance, 39 | lambdaSG: baseInfra.lambdaSG, 40 | ddlSourceStackName: ddlSource.stackName, 41 | }); 42 | 43 | // vector store update stack 44 | const pgvectorUpdate = new PGVectorUpdateStack(app, 'PGVectorUpdateStack', { 45 | vpc: baseInfra.vpc, 46 | processedBucket: baseInfra.processedBucket, 47 | collectionName: baseInfra.pgvectorCollectionName, 48 | apiKeySecret: baseInfra.apiKeySecret, 49 | databaseCreds: rds.dbInstance.secret?.secretArn || "", 50 | triggerQueue: baseInfra.pgvectorQueue, 51 | dbInstance: rds.dbInstance, 52 | lambdaSG: baseInfra.lambdaSG, 53 | }); 54 | 55 | // for a test EC2 instance to play around with (optional) 56 | const testComputeStack = new TestComputeStack(app, 'TestComputeStack', { 57 | vpc: baseInfra.vpc, 58 | ec2SG: baseInfra.ec2SecGroup, 59 | }); 60 | 61 | // ECS service running the RAG App 62 | new RagAppStack(app, 'RagStack', { 63 | vpc: baseInfra.vpc, 64 | databaseCreds: rds.dbInstance.secret?.secretArn || "", 65 | collectionName: baseInfra.pgvectorCollectionName, 66 | apiKeySecret: baseInfra.apiKeySecret, 67 | dbInstance: rds.dbInstance, 68 | taskSecGroup: baseInfra.ecsTaskSecGroup, 69 | elbTargetGroup: baseInfra.appTargetGroup 70 | }); -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/rag-with-amazon-bedrock-and-rds-pgvector.ts", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "**/*.d.ts", 11 | "**/*.js", 12 | "tsconfig.json", 13 | "package*.json", 14 | "yarn.lock", 15 | "node_modules", 16 | "test" 17 | ] 18 | }, 19 | "context": { 20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 21 | "@aws-cdk/core:checkSecretUsage": true, 22 | "@aws-cdk/core:target-partitions": [ 23 | "aws", 24 | "aws-cn" 25 | ], 26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 29 | "@aws-cdk/aws-iam:minimizePolicies": true, 30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 35 | "@aws-cdk/core:enablePartitionLiterals": true, 36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 41 | "@aws-cdk/aws-route53-patters:useCertificate": true, 42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 48 | "@aws-cdk/aws-redshift:columnId": true, 49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 52 | "@aws-cdk/aws-kms:aliasNameRef": true, 53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 55 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 56 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 57 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 58 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | testEnvironment: 'node', 3 | roots: ['/test'], 4 | testMatch: ['**/*.test.ts'], 5 | transform: { 6 | '^.+\\.tsx?$': 'ts-jest' 7 | } 8 | }; 9 | -------------------------------------------------------------------------------- /knowledgebase/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /lambda/app-client-create-trigger/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | QUEUE_URL_ENV_VAR = "TRIGGER_QUEUE" 11 | 12 | 13 | class MalformedEvent(Exception): 14 | """Raised if a malformed event received""" 15 | 16 | 17 | class MissingEnvironmentVariable(Exception): 18 | """Raised if a required environment variable is missing""" 19 | 20 | 21 | def _silence_noisy_loggers(): 22 | """Silence chatty libraries for better logging""" 23 | for logger in ['boto3', 'botocore', 24 | 'botocore.vendored.requests.packages.urllib3']: 25 | logging.getLogger(logger).setLevel(logging.WARNING) 26 | 27 | 28 | def _check_missing_field(validation_dict, extraction_key): 29 | """Check if a field exists in a dictionary 30 | 31 | :param validation_dict: Dictionary 32 | :param extraction_key: String 33 | 34 | :raises: MalformedEvent 35 | """ 36 | extracted_value = validation_dict.get(extraction_key) 37 | 38 | if not extracted_value: 39 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 40 | raise MalformedEvent 41 | 42 | 43 | def _validate_field(validation_dict, extraction_key, expected_value): 44 | """Validate the passed in field 45 | 46 | :param validation_dict: Dictionary 47 | :param extraction_key: String 48 | :param expected_value: String 49 | 50 | :raises: ValueError 51 | """ 52 | extracted_value = validation_dict.get(extraction_key) 53 | 54 | _check_missing_field(validation_dict, extraction_key) 55 | 56 | if extracted_value != expected_value: 57 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 58 | raise ValueError 59 | 60 | 61 | def _extract_valid_event(event): 62 | """Validate incoming event and extract necessary attributes 63 | 64 | :param event: Dictionary 65 | 66 | :raises: MalformedEvent 67 | :raises: ValueError 68 | 69 | :rtype: Dictionary 70 | """ 71 | 72 | _validate_field(event, "source", "aws.cognito-idp") 73 | 74 | _check_missing_field(event, "detail") 75 | event_detail = event["detail"] 76 | 77 | _validate_field( 78 | event_detail, 79 | "sourceIPAddress", 80 | "cloudformation.amazonaws.com" 81 | ) 82 | 83 | _validate_field( 84 | event_detail, 85 | "eventSource", 86 | "cognito-idp.amazonaws.com" 87 | ) 88 | _validate_field(event_detail, "eventName", "CreateUserPoolClient") 89 | 90 | _check_missing_field(event_detail, "responseElements") 91 | _check_missing_field(event_detail["responseElements"], "userPoolClient") 92 | 93 | return event_detail["responseElements"]["userPoolClient"] 94 | 95 | 96 | def _configure_logger(): 97 | """Configure python logger""" 98 | level = logging.INFO 99 | verbose = os.environ.get("VERBOSE", "") 100 | if verbose.lower() == "true": 101 | print("Will set the logging output to DEBUG") 102 | level = logging.DEBUG 103 | 104 | if len(logging.getLogger().handlers) > 0: 105 | # The Lambda environment pre-configures a handler logging to stderr. 106 | # If a handler is already configured, `.basicConfig` does not execute. 107 | # Thus we set the level directly. 108 | logging.getLogger().setLevel(level) 109 | else: 110 | logging.basicConfig(level=level) 111 | 112 | 113 | def _send_message_to_sqs(client, queue_url, message_dict): 114 | """Send message to SQS Queue 115 | 116 | :param client: Boto3 client object (SQS) 117 | :param queue_url: String 118 | :param message_dict: Dictionary 119 | 120 | :raises: Exception 121 | """ 122 | LOGGER.info(f"Attempting to send message to: {queue_url}") 123 | resp = client.send_message( 124 | QueueUrl=queue_url, 125 | MessageBody=json.dumps(message_dict) 126 | ) 127 | 128 | _check_missing_field(resp, "ResponseMetadata") 129 | resp_metadata = resp["ResponseMetadata"] 130 | 131 | _check_missing_field(resp_metadata, "HTTPStatusCode") 132 | status_code = resp_metadata["HTTPStatusCode"] 133 | 134 | if status_code == 200: 135 | LOGGER.info("Successfully pushed message") 136 | else: 137 | raise Exception("Unable to push message") 138 | 139 | 140 | def lambda_handler(event, context): 141 | """What executes when the program is run""" 142 | 143 | # configure python logger 144 | _configure_logger() 145 | # silence chatty libraries 146 | _silence_noisy_loggers() 147 | 148 | client_details = _extract_valid_event(event) 149 | LOGGER.info("Extracted user pool details") 150 | 151 | sqs_client = boto3.client("sqs") 152 | rds_ddl_queue_url = os.environ.get(QUEUE_URL_ENV_VAR) 153 | if not rds_ddl_queue_url: 154 | raise MissingEnvironmentVariable( 155 | f"{QUEUE_URL_ENV_VAR} environment variable is required") 156 | 157 | # send message to Triggering Queue 158 | _send_message_to_sqs( 159 | sqs_client, 160 | rds_ddl_queue_url, 161 | client_details) 162 | 163 | sqs_client.close() 164 | -------------------------------------------------------------------------------- /lambda/call-back-url-init/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | USER_POOL_ENV_VAR = "USER_POOL_ID" 11 | APP_CLIENT_ENV_VAR = "APP_CLIENT_ID" 12 | ALB_DNS_ENV_VAR = "ALB_DNS_NAME" 13 | SQS_QUEUE_ENV_VAR = "SQS_QUEUE_URL" 14 | 15 | 16 | class MalformedEvent(Exception): 17 | """Raised if a malformed event received""" 18 | 19 | 20 | class MissingEnvironmentVariable(Exception): 21 | """Raised if a required environment variable is missing""" 22 | 23 | 24 | def _silence_noisy_loggers(): 25 | """Silence chatty libraries for better logging""" 26 | for logger in ['boto3', 'botocore', 27 | 'botocore.vendored.requests.packages.urllib3']: 28 | logging.getLogger(logger).setLevel(logging.WARNING) 29 | 30 | 31 | def _check_missing_field(validation_dict, extraction_key): 32 | """Check if a field exists in a dictionary 33 | 34 | :param validation_dict: Dictionary 35 | :param extraction_key: String 36 | 37 | :raises: MalformedEvent 38 | """ 39 | extracted_value = validation_dict.get(extraction_key) 40 | 41 | if not extracted_value: 42 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 43 | raise MalformedEvent 44 | 45 | 46 | def _validate_field(validation_dict, extraction_key, expected_value): 47 | """Validate the passed in field 48 | 49 | :param validation_dict: Dictionary 50 | :param extraction_key: String 51 | :param expected_value: String 52 | 53 | :raises: ValueError 54 | """ 55 | extracted_value = validation_dict.get(extraction_key) 56 | 57 | _check_missing_field(validation_dict, extraction_key) 58 | 59 | if extracted_value != expected_value: 60 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 61 | raise ValueError 62 | 63 | 64 | def _configure_logger(): 65 | """Configure python logger""" 66 | level = logging.INFO 67 | verbose = os.environ.get("VERBOSE", "") 68 | if verbose.lower() == "true": 69 | print("Will set the logging output to DEBUG") 70 | level = logging.DEBUG 71 | 72 | if len(logging.getLogger().handlers) > 0: 73 | # The Lambda environment pre-configures a handler logging to stderr. 74 | # If a handler is already configured, `.basicConfig` does not execute. 75 | # Thus we set the level directly. 76 | logging.getLogger().setLevel(level) 77 | else: 78 | logging.basicConfig(level=level) 79 | 80 | 81 | def _get_message_body(event): 82 | """Extract message body from the event 83 | 84 | :param event: Dictionary 85 | 86 | :raises: MalformedEvent 87 | 88 | :rtype: Dictionary 89 | """ 90 | body = "" 91 | test_event = event.get("test_event", "") 92 | if test_event.lower() == "true": 93 | LOGGER.info("processing test event (and not from SQS)") 94 | LOGGER.debug("Test body: %s", event) 95 | return event 96 | else: 97 | LOGGER.info("Attempting to extract message body from SQS") 98 | 99 | _check_missing_field(event, "Records") 100 | records = event["Records"] 101 | 102 | first_record = records[0] 103 | 104 | try: 105 | body = first_record.get("body") 106 | except AttributeError: 107 | raise MalformedEvent("First record is not a proper dict") 108 | 109 | if not body: 110 | raise MalformedEvent("Missing 'body' in the record") 111 | 112 | try: 113 | return json.loads(body) 114 | except json.decoder.JSONDecodeError: 115 | raise MalformedEvent("'body' is not valid JSON") 116 | 117 | 118 | def _get_sqs_message_attributes(event): 119 | """Extract receiptHandle from message 120 | 121 | :param event: Dictionary 122 | 123 | :raises: MalformedEvent 124 | 125 | :rtype: Dictionary 126 | """ 127 | LOGGER.info("Attempting to extract receiptHandle from SQS") 128 | records = event.get("Records") 129 | if not records: 130 | LOGGER.warning("No receiptHandle found, probably not an SQS message") 131 | return 132 | try: 133 | first_record = records[0] 134 | except IndexError: 135 | raise MalformedEvent("Records seem to be empty") 136 | 137 | _check_missing_field(first_record, "receiptHandle") 138 | receipt_handle = first_record["receiptHandle"] 139 | 140 | _check_missing_field(first_record, "messageId") 141 | message_id = first_record["messageId"] 142 | 143 | return { 144 | "message_id": message_id, 145 | "receipt_handle": receipt_handle 146 | } 147 | 148 | 149 | def lambda_handler(event, context): 150 | """What executes when the program is run""" 151 | 152 | # configure python logger 153 | _configure_logger() 154 | # silence chatty libraries 155 | _silence_noisy_loggers() 156 | 157 | msg_attr = _get_sqs_message_attributes(event) 158 | 159 | if msg_attr: 160 | 161 | # Because messages remain in the queue 162 | LOGGER.info( 163 | f"Deleting message {msg_attr['message_id']} from sqs") 164 | sqs_client = boto3.client("sqs") 165 | queue_url = os.environ.get(SQS_QUEUE_ENV_VAR) 166 | if not queue_url: 167 | raise MissingEnvironmentVariable( 168 | f"{SQS_QUEUE_ENV_VAR} environment variable is required") 169 | 170 | deletion_resp = sqs_client.delete_message( 171 | QueueUrl=queue_url, 172 | ReceiptHandle=msg_attr["receipt_handle"]) 173 | 174 | sqs_client.close() 175 | 176 | resp_metadata = deletion_resp.get("ResponseMetadata") 177 | if not resp_metadata: 178 | raise Exception( 179 | "No response metadata from deletion call") 180 | status_code = resp_metadata.get("HTTPStatusCode") 181 | 182 | if status_code == 200: 183 | LOGGER.info(f"Successfully deleted message") 184 | else: 185 | raise Exception("Unable to delete message") 186 | 187 | client_details = _get_message_body(event) 188 | LOGGER.info("Extracted user pool details") 189 | 190 | user_pool_id = os.environ.get(USER_POOL_ENV_VAR) 191 | if not user_pool_id: 192 | raise MissingEnvironmentVariable( 193 | f"{USER_POOL_ENV_VAR} environment variable is required") 194 | 195 | app_client_id = os.environ.get(APP_CLIENT_ENV_VAR) 196 | if not app_client_id: 197 | raise MissingEnvironmentVariable( 198 | f"{APP_CLIENT_ENV_VAR} environment variable is required") 199 | 200 | alb_dns = os.environ.get(ALB_DNS_ENV_VAR) 201 | if not alb_dns: 202 | raise MissingEnvironmentVariable( 203 | f"{ALB_DNS_ENV_VAR} environment variable is required") 204 | 205 | _validate_field(client_details, "userPoolId", user_pool_id) 206 | _validate_field(client_details, "clientId", app_client_id) 207 | 208 | expected_callback_url = f"https://{alb_dns}/oauth2/idpresponse" 209 | lowered_callback_url = expected_callback_url.lower() 210 | 211 | _check_missing_field(client_details, "callbackURLs") 212 | callback_urls = client_details["callbackURLs"] 213 | if len(callback_urls) != 1: 214 | LOGGER.warning("Unexpected number of callback URLs") 215 | else: 216 | if callback_urls[0] != expected_callback_url: 217 | LOGGER.warning( 218 | "Looks like the callback URL is not " 219 | "associated with the correct load balancer. Please verify.") 220 | 221 | cog_client = boto3.client("cognito-idp") 222 | 223 | LOGGER.info("Updating the user pool client URL") 224 | resp = cog_client.update_user_pool_client( 225 | UserPoolId=user_pool_id, 226 | ClientId=app_client_id, 227 | ExplicitAuthFlows=client_details["explicitAuthFlows"], 228 | SupportedIdentityProviders=client_details["supportedIdentityProviders"], 229 | CallbackURLs=[lowered_callback_url], 230 | AllowedOAuthFlows=client_details["allowedOAuthFlows"], 231 | AllowedOAuthScopes=client_details["allowedOAuthScopes"], 232 | AllowedOAuthFlowsUserPoolClient=client_details["allowedOAuthFlowsUserPoolClient"], 233 | EnableTokenRevocation=client_details["enableTokenRevocation"], 234 | EnablePropagateAdditionalUserContextData=client_details["enablePropagateAdditionalUserContextData"], 235 | AuthSessionValidity=client_details["authSessionValidity"] 236 | ) 237 | _check_missing_field(resp, "ResponseMetadata") 238 | resp_metadata = resp["ResponseMetadata"] 239 | 240 | _check_missing_field(resp_metadata, "HTTPStatusCode") 241 | status_code = resp_metadata["HTTPStatusCode"] 242 | 243 | if status_code == 200: 244 | LOGGER.info("Successfully updated callback URL") 245 | else: 246 | raise Exception("Unable to update user pool client") 247 | 248 | cog_client.close() 249 | -------------------------------------------------------------------------------- /lambda/call-back-url-update/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import boto3 5 | 6 | 7 | LOGGER = logging.getLogger() 8 | 9 | USER_POOL_ENV_VAR = "USER_POOL_ID" 10 | APP_CLIENT_ENV_VAR = "APP_CLIENT_ID" 11 | ALB_DNS_ENV_VAR = "ALB_DNS_NAME" 12 | 13 | 14 | class MalformedEvent(Exception): 15 | """Raised if a malformed event received""" 16 | 17 | 18 | class MissingEnvironmentVariable(Exception): 19 | """Raised if a required environment variable is missing""" 20 | 21 | 22 | def _silence_noisy_loggers(): 23 | """Silence chatty libraries for better logging""" 24 | for logger in ['boto3', 'botocore', 25 | 'botocore.vendored.requests.packages.urllib3']: 26 | logging.getLogger(logger).setLevel(logging.WARNING) 27 | 28 | 29 | def _check_missing_field(validation_dict, extraction_key): 30 | """Check if a field exists in a dictionary 31 | 32 | :param validation_dict: Dictionary 33 | :param extraction_key: String 34 | 35 | :raises: MalformedEvent 36 | """ 37 | extracted_value = validation_dict.get(extraction_key) 38 | 39 | if not extracted_value: 40 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 41 | raise MalformedEvent 42 | 43 | 44 | def _validate_field(validation_dict, extraction_key, expected_value): 45 | """Validate the passed in field 46 | 47 | :param validation_dict: Dictionary 48 | :param extraction_key: String 49 | :param expected_value: String 50 | 51 | :raises: ValueError 52 | """ 53 | extracted_value = validation_dict.get(extraction_key) 54 | 55 | _check_missing_field(validation_dict, extraction_key) 56 | 57 | if extracted_value != expected_value: 58 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 59 | raise ValueError 60 | 61 | 62 | def _extract_valid_event(event): 63 | """Validate incoming event and extract necessary attributes 64 | 65 | :param event: Dictionary 66 | 67 | :raises: MalformedEvent 68 | :raises: ValueError 69 | 70 | :rtype: Dictionary 71 | """ 72 | 73 | _validate_field(event, "source", "aws.cognito-idp") 74 | 75 | _check_missing_field(event, "detail") 76 | event_detail = event["detail"] 77 | 78 | _validate_field( 79 | event_detail, 80 | "sourceIPAddress", 81 | "cloudformation.amazonaws.com" 82 | ) 83 | 84 | _validate_field( 85 | event_detail, 86 | "eventSource", 87 | "cognito-idp.amazonaws.com" 88 | ) 89 | 90 | _validate_field(event_detail, "eventName", "UpdateUserPoolClient") 91 | 92 | _check_missing_field(event_detail, "responseElements") 93 | _check_missing_field(event_detail["responseElements"], "userPoolClient") 94 | 95 | return event_detail["responseElements"]["userPoolClient"] 96 | 97 | 98 | def _configure_logger(): 99 | """Configure python logger""" 100 | level = logging.INFO 101 | verbose = os.environ.get("VERBOSE", "") 102 | if verbose.lower() == "true": 103 | print("Will set the logging output to DEBUG") 104 | level = logging.DEBUG 105 | 106 | if len(logging.getLogger().handlers) > 0: 107 | # The Lambda environment pre-configures a handler logging to stderr. 108 | # If a handler is already configured, `.basicConfig` does not execute. 109 | # Thus we set the level directly. 110 | logging.getLogger().setLevel(level) 111 | else: 112 | logging.basicConfig(level=level) 113 | 114 | 115 | def lambda_handler(event, context): 116 | """What executes when the program is run""" 117 | 118 | # configure python logger 119 | _configure_logger() 120 | # silence chatty libraries 121 | _silence_noisy_loggers() 122 | 123 | user_pool_id = os.environ.get(USER_POOL_ENV_VAR) 124 | if not user_pool_id: 125 | raise MissingEnvironmentVariable( 126 | f"{USER_POOL_ENV_VAR} environment variable is required") 127 | 128 | app_client_id = os.environ.get(APP_CLIENT_ENV_VAR) 129 | if not app_client_id: 130 | raise MissingEnvironmentVariable( 131 | f"{APP_CLIENT_ENV_VAR} environment variable is required") 132 | 133 | alb_dns = os.environ.get(ALB_DNS_ENV_VAR) 134 | if not alb_dns: 135 | raise MissingEnvironmentVariable( 136 | f"{ALB_DNS_ENV_VAR} environment variable is required") 137 | 138 | client_details = _extract_valid_event(event) 139 | LOGGER.info("Extracted user pool details") 140 | 141 | expected_callback_url = f"https://{alb_dns}/oauth2/idpresponse" 142 | lowered_callback_url = expected_callback_url.lower() 143 | 144 | _validate_field(client_details, "userPoolId", user_pool_id) 145 | _validate_field(client_details, "clientId", app_client_id) 146 | 147 | _check_missing_field(client_details, "callbackURLs") 148 | callback_urls = client_details["callbackURLs"] 149 | if len(callback_urls) != 1: 150 | LOGGER.warning("Unexpected number of callback URLs") 151 | else: 152 | if callback_urls[0] != expected_callback_url: 153 | LOGGER.warning( 154 | "Looks like the callback URL is not " 155 | "associated with the correct load balancer. Please verify.") 156 | 157 | cog_client = boto3.client("cognito-idp") 158 | 159 | LOGGER.info("Updating the user pool client URL") 160 | resp = cog_client.update_user_pool_client( 161 | UserPoolId=user_pool_id, 162 | ClientId=app_client_id, 163 | ExplicitAuthFlows=client_details["explicitAuthFlows"], 164 | SupportedIdentityProviders=client_details["supportedIdentityProviders"], 165 | CallbackURLs=[lowered_callback_url], 166 | AllowedOAuthFlows=client_details["allowedOAuthFlows"], 167 | AllowedOAuthScopes=client_details["allowedOAuthScopes"], 168 | AllowedOAuthFlowsUserPoolClient=client_details["allowedOAuthFlowsUserPoolClient"], 169 | EnableTokenRevocation=client_details["enableTokenRevocation"], 170 | EnablePropagateAdditionalUserContextData=client_details["enablePropagateAdditionalUserContextData"], 171 | AuthSessionValidity=client_details["authSessionValidity"] 172 | ) 173 | _check_missing_field(resp, "ResponseMetadata") 174 | resp_metadata = resp["ResponseMetadata"] 175 | 176 | _check_missing_field(resp_metadata, "HTTPStatusCode") 177 | status_code = resp_metadata["HTTPStatusCode"] 178 | 179 | if status_code == 200: 180 | LOGGER.info("Successfully updated callback URL") 181 | else: 182 | raise Exception("Unable to update user pool client") 183 | 184 | cog_client.close() 185 | -------------------------------------------------------------------------------- /lambda/pdf-processor/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | COPY requirements.txt requirements.txt 4 | 5 | RUN pip3 install -r requirements.txt 6 | 7 | COPY lambda_function.py lambda_function.py 8 | 9 | CMD [ "lambda_function.lambda_handler"] 10 | -------------------------------------------------------------------------------- /lambda/pdf-processor/lambda_function.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from pypdf import PdfReader 7 | from pypdf.errors import PdfReadError 8 | 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | SOURCE_BUCKET_ENV_VAR = "SOURCE_BUCKET_NAME" 13 | DEST_BUCKET_ENV_VAR = "DESTINATION_BUCKET_NAME" 14 | 15 | MALFORMED_EVENT_LOG_MSG = "Malformed event. Skipping this record." 16 | 17 | 18 | class MissingEnvironmentVariable(Exception): 19 | """Raised if a required environment variable is missing""" 20 | 21 | 22 | def _silence_noisy_loggers(): 23 | """Silence chatty libraries for better logging""" 24 | for logger in ['boto3', 'botocore', 25 | 'botocore.vendored.requests.packages.urllib3']: 26 | logging.getLogger(logger).setLevel(logging.WARNING) 27 | 28 | 29 | def _configure_logger(): 30 | """Configure python logger for lambda function""" 31 | default_log_args = { 32 | "level": logging.DEBUG if os.environ.get("VERBOSE", False) else logging.INFO, 33 | "format": "%(asctime)s [%(levelname)s] %(name)s - %(message)s", 34 | "datefmt": "%d-%b-%y %H:%M", 35 | "force": True, 36 | } 37 | logging.basicConfig(**default_log_args) 38 | 39 | 40 | def _check_missing_field(validation_dict, extraction_key): 41 | """Check if a field exists in a dictionary 42 | 43 | :param validation_dict: Dictionary 44 | :param extraction_key: String 45 | 46 | :raises: KeyError 47 | """ 48 | extracted_value = validation_dict.get(extraction_key) 49 | 50 | if not extracted_value: 51 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 52 | raise KeyError 53 | 54 | 55 | def _validate_field(validation_dict, extraction_key, expected_value): 56 | """Validate the passed in field 57 | 58 | :param validation_dict: Dictionary 59 | :param extraction_key: String 60 | :param expected_value: String 61 | 62 | :raises: ValueError 63 | """ 64 | extracted_value = validation_dict.get(extraction_key) 65 | _check_missing_field(validation_dict, extraction_key) 66 | 67 | if extracted_value != expected_value: 68 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 69 | raise ValueError 70 | 71 | 72 | def _record_validation(record, source_bucket_name): 73 | """Validate record 74 | 75 | :param record: Dictionary 76 | :param source_bucket_name: String 77 | 78 | :rtype: Boolean 79 | """ 80 | # validate eventSource 81 | _validate_field(record, "eventSource", "aws:s3") 82 | 83 | # validate eventSource 84 | _check_missing_field(record, "eventName") 85 | if not record["eventName"].startswith("ObjectCreated"): 86 | LOGGER.warning("Found a non ObjectCreated event, ignoring this record") 87 | return False 88 | 89 | # check for 's3' in response elements 90 | _check_missing_field(record, "s3") 91 | 92 | s3_data = record["s3"] 93 | # validate s3 data 94 | _check_missing_field(s3_data, "bucket") 95 | _validate_field(s3_data["bucket"], "name", source_bucket_name) 96 | 97 | # check for object 98 | _check_missing_field(s3_data, "object") 99 | # check for key 100 | _check_missing_field(s3_data["object"], "key") 101 | 102 | return True 103 | 104 | 105 | def _get_source_file_contents(client, bucket, filename): 106 | """Fetch the contents of the file 107 | 108 | :param client: boto3 Client Object (S3) 109 | :param bucket: String 110 | :param filename: String 111 | 112 | :rtype Bytes 113 | """ 114 | resp = client.get_object(Bucket=bucket, Key=filename) 115 | 116 | _check_missing_field(resp, "ResponseMetadata") 117 | 118 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 119 | 120 | _check_missing_field(resp, "Body") 121 | 122 | return resp["Body"].read() 123 | 124 | 125 | def lambda_handler(event, context): 126 | """What executes when the program is run""" 127 | 128 | # configure python logger for Lambda 129 | _configure_logger() 130 | # silence chatty libraries for better logging 131 | _silence_noisy_loggers() 132 | 133 | # check for SOURCE_BUCKET_NAME env var 134 | source_bucket_name = os.environ.get(SOURCE_BUCKET_ENV_VAR) 135 | if not source_bucket_name: 136 | raise MissingEnvironmentVariable(SOURCE_BUCKET_ENV_VAR) 137 | LOGGER.info(f"source bucket: {source_bucket_name}") 138 | 139 | # check for DESTINATION_BUCKET_NAME env var 140 | dest_bucket_name = os.environ.get(DEST_BUCKET_ENV_VAR) 141 | if not dest_bucket_name: 142 | raise MissingEnvironmentVariable(DEST_BUCKET_ENV_VAR) 143 | LOGGER.info(f"destination bucket: {dest_bucket_name}") 144 | 145 | # check for 'records' field in the event 146 | _check_missing_field(event, "Records") 147 | records = event["Records"] 148 | 149 | if not isinstance(records, list): 150 | raise Exception("'Records' is not a list") 151 | LOGGER.info("Extracted 'Records' from the event") 152 | 153 | for record in records: 154 | try: 155 | valid_record = _record_validation(record, source_bucket_name) 156 | except KeyError: 157 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 158 | continue 159 | except ValueError: 160 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 161 | continue 162 | 163 | if not valid_record: 164 | LOGGER.warning("record could not be validated. Skipping this one.") 165 | continue 166 | file_name = record["s3"]["object"]["key"] 167 | LOGGER.info( 168 | f"Valid record found. Will attempt to process the file: {file_name}") 169 | 170 | s3_client = boto3.client("s3") 171 | file_contents = _get_source_file_contents( 172 | s3_client, source_bucket_name, file_name) 173 | 174 | try: 175 | pdf = PdfReader(BytesIO(file_contents)) 176 | except PdfReadError as err: 177 | LOGGER.error(err) 178 | LOGGER.warning( 179 | f"{file_name} is invalid and/or corrupt, skipping.") 180 | continue 181 | 182 | LOGGER.info("Extracting text from pdf..") 183 | text_file_contents = "" 184 | for page in pdf.pages: 185 | text_file_contents = f"{text_file_contents}\n{page.extract_text()}" 186 | 187 | s3_resource = boto3.resource("s3") 188 | LOGGER.debug("Writing file to S3") 189 | s3_resource.Bucket(dest_bucket_name).put_object( 190 | Key=file_name.replace(".pdf", ".txt"), 191 | Body=text_file_contents.encode("utf-8")) 192 | LOGGER.info("Successfully converted pdf to txt, and uploaded to s3") 193 | 194 | LOGGER.debug("Closing s3 boto3 client") 195 | s3_client.close() 196 | -------------------------------------------------------------------------------- /lambda/pdf-processor/requirements.txt: -------------------------------------------------------------------------------- 1 | pypdf 2 | -------------------------------------------------------------------------------- /lambda/pgvector-trigger/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | QUEUE_URL_ENV_VAR = "PGVECTOR_UPDATE_QUEUE" 11 | BUCKET_ENV_VAR = "BUCKET_NAME" 12 | 13 | MALFORMED_EVENT_LOG_MSG = "Malformed event. Skipping this record." 14 | 15 | 16 | class MalformedEvent(Exception): 17 | """Raised if a malformed event received""" 18 | 19 | 20 | class MissingEnvironmentVariable(Exception): 21 | """Raised if a required environment variable is missing""" 22 | 23 | 24 | def _silence_noisy_loggers(): 25 | """Silence chatty libraries for better logging""" 26 | for logger in ['boto3', 'botocore', 27 | 'botocore.vendored.requests.packages.urllib3']: 28 | logging.getLogger(logger).setLevel(logging.WARNING) 29 | 30 | 31 | def _configure_logger(): 32 | """Configure python logger""" 33 | level = logging.INFO 34 | verbose = os.environ.get("VERBOSE", "") 35 | if verbose.lower() == "true": 36 | print("Will set the logging output to DEBUG") 37 | level = logging.DEBUG 38 | 39 | if len(logging.getLogger().handlers) > 0: 40 | # The Lambda environment pre-configures a handler logging to stderr. 41 | # If a handler is already configured, `.basicConfig` does not execute. 42 | # Thus we set the level directly. 43 | logging.getLogger().setLevel(level) 44 | else: 45 | logging.basicConfig(level=level) 46 | 47 | 48 | def _check_missing_field(validation_dict, extraction_key): 49 | """Check if a field exists in a dictionary 50 | 51 | :param validation_dict: Dictionary 52 | :param extraction_key: String 53 | 54 | :raises: MalformedEvent 55 | """ 56 | extracted_value = validation_dict.get(extraction_key) 57 | 58 | if not extracted_value: 59 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 60 | raise MalformedEvent 61 | 62 | 63 | def _validate_field(validation_dict, extraction_key, expected_value): 64 | """Validate the passed in field 65 | 66 | :param validation_dict: Dictionary 67 | :param extraction_key: String 68 | :param expected_value: String 69 | 70 | :raises: ValueError 71 | """ 72 | extracted_value = validation_dict.get(extraction_key) 73 | 74 | _check_missing_field(validation_dict, extraction_key) 75 | 76 | if extracted_value != expected_value: 77 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 78 | raise ValueError 79 | 80 | 81 | def _record_validation(record, source_bucket_name): 82 | """Validate record 83 | 84 | :param record: Dictionary 85 | :param source_bucket_name: String 86 | 87 | :rtype: Boolean 88 | """ 89 | # validate eventSource 90 | _validate_field(record, "eventSource", "aws:s3") 91 | 92 | # validate eventSource 93 | _check_missing_field(record, "eventName") 94 | if not record["eventName"].startswith("ObjectCreated"): 95 | LOGGER.warning( 96 | "Found a non ObjectCreated event, ignoring this record") 97 | return False 98 | 99 | # check for 's3' in response elements 100 | _check_missing_field(record, "s3") 101 | 102 | s3_data = record["s3"] 103 | # validate s3 data 104 | _check_missing_field(s3_data, "bucket") 105 | _validate_field(s3_data["bucket"], "name", source_bucket_name) 106 | 107 | # check for object 108 | _check_missing_field(s3_data, "object") 109 | # check for key 110 | _check_missing_field(s3_data["object"], "key") 111 | 112 | return True 113 | 114 | 115 | def _send_message_to_sqs(client, queue_url, message_dict): 116 | """Send message to SQS Queue 117 | 118 | :param client: Boto3 client object (SQS) 119 | :param queue_url: String 120 | :param message_dict: Dictionary 121 | 122 | :raises: Exception 123 | """ 124 | LOGGER.info(f"Attempting to send message to: {queue_url}") 125 | resp = client.send_message( 126 | QueueUrl=queue_url, 127 | MessageBody=json.dumps(message_dict) 128 | ) 129 | 130 | _check_missing_field(resp, "ResponseMetadata") 131 | resp_metadata = resp["ResponseMetadata"] 132 | 133 | _check_missing_field(resp_metadata, "HTTPStatusCode") 134 | status_code = resp_metadata["HTTPStatusCode"] 135 | 136 | if status_code == 200: 137 | LOGGER.info("Successfully pushed message") 138 | else: 139 | raise Exception("Unable to push message") 140 | 141 | 142 | def lambda_handler(event, context): 143 | """What executes when the program is run""" 144 | 145 | # configure python logger 146 | _configure_logger() 147 | # silence chatty libraries 148 | _silence_noisy_loggers() 149 | 150 | # check for DESTINATION_BUCKET_NAME env var 151 | bucket_name = os.environ.get(BUCKET_ENV_VAR) 152 | if not bucket_name: 153 | raise MissingEnvironmentVariable(BUCKET_ENV_VAR) 154 | LOGGER.info(f"destination bucket: {bucket_name}") 155 | 156 | queue_url = os.environ.get(QUEUE_URL_ENV_VAR) 157 | if not queue_url: 158 | raise MissingEnvironmentVariable( 159 | f"{QUEUE_URL_ENV_VAR} environment variable is required") 160 | 161 | # check for 'records' field in the event 162 | _check_missing_field(event, "Records") 163 | records = event["Records"] 164 | 165 | if not isinstance(records, list): 166 | raise Exception("'Records' is not a list") 167 | LOGGER.info("Extracted 'Records' from the event") 168 | 169 | sqs_client = boto3.client("sqs") 170 | for record in records: 171 | try: 172 | valid_record = _record_validation(record, bucket_name) 173 | except KeyError: 174 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 175 | continue 176 | except ValueError: 177 | LOGGER.warning(MALFORMED_EVENT_LOG_MSG) 178 | continue 179 | if not valid_record: 180 | LOGGER.warning( 181 | "record could not be validated. Skipping this one.") 182 | continue 183 | 184 | _send_message_to_sqs( 185 | sqs_client, 186 | queue_url, 187 | { 188 | "bucket": bucket_name, 189 | "file": record["s3"]["object"]["key"] 190 | } 191 | ) 192 | sqs_client.close() 193 | -------------------------------------------------------------------------------- /lambda/pgvector-update/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | RUN yum update -y \ 4 | && yum install -y postgresql-libs gcc postgresql-devel \ 5 | && pip3 install psycopg2 6 | 7 | RUN yum install -y amazon-linux-extras \ 8 | && yum repolist \ 9 | && PYTHON=python2 amazon-linux-extras install postgresql10 -y \ 10 | && pip3 install psycopg2 \ 11 | && pip3 install psycopg2-binary 12 | 13 | COPY requirements.txt requirements.txt 14 | 15 | RUN pip3 install -r requirements.txt 16 | 17 | COPY app.py app.py 18 | 19 | CMD [ "app.lambda_handler"] 20 | -------------------------------------------------------------------------------- /lambda/pgvector-update/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from botocore.exceptions import ClientError 7 | from langchain.document_loaders import S3FileLoader 8 | from langchain.embeddings.openai import OpenAIEmbeddings 9 | from langchain.vectorstores.pgvector import PGVector 10 | 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | SQS_QUEUE_ENV_VAR = "QUEUE_URL" 15 | COLLECTION_ENV_VAR = "COLLECTION_NAME" 16 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 17 | DB_SECRET_ENV_VAR = "DB_CREDS" 18 | 19 | 20 | class MalformedEvent(Exception): 21 | """Raised if a malformed event received""" 22 | 23 | 24 | class MissingEnvironmentVariable(Exception): 25 | """Raised if a required environment variable is missing""" 26 | 27 | 28 | def _silence_noisy_loggers(): 29 | """Silence chatty libraries for better logging""" 30 | for logger in ['boto3', 'botocore', 31 | 'botocore.vendored.requests.packages.urllib3']: 32 | logging.getLogger(logger).setLevel(logging.WARNING) 33 | 34 | 35 | def _configure_logger(): 36 | """Configure python logger for lambda function""" 37 | default_log_args = { 38 | "level": logging.DEBUG if os.environ.get("VERBOSE", False) else logging.INFO, 39 | "format": "%(asctime)s [%(levelname)s] %(name)s - %(message)s", 40 | "datefmt": "%d-%b-%y %H:%M", 41 | "force": True, 42 | } 43 | logging.basicConfig(**default_log_args) 44 | 45 | 46 | def _check_missing_field(validation_dict, extraction_key): 47 | """Check if a field exists in a dictionary 48 | 49 | :param validation_dict: Dictionary 50 | :param extraction_key: String 51 | 52 | :raises: KeyError 53 | """ 54 | extracted_value = validation_dict.get(extraction_key) 55 | 56 | if not extracted_value: 57 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 58 | raise KeyError 59 | 60 | 61 | def _get_message_body(event): 62 | """Extract message body from the event 63 | 64 | :param event: Dictionary 65 | 66 | :raises: MalformedEvent 67 | 68 | :rtype: Dictionary 69 | """ 70 | body = "" 71 | test_event = event.get("test_event", "") 72 | if test_event.lower() == "true": 73 | LOGGER.info("processing test event (and not from SQS)") 74 | LOGGER.debug("Test body: %s", event) 75 | return event 76 | else: 77 | LOGGER.info("Attempting to extract message body from SQS") 78 | 79 | _check_missing_field(event, "Records") 80 | records = event["Records"] 81 | 82 | first_record = records[0] 83 | 84 | try: 85 | body = first_record.get("body") 86 | except AttributeError: 87 | raise MalformedEvent("First record is not a proper dict") 88 | 89 | if not body: 90 | raise MalformedEvent("Missing 'body' in the record") 91 | 92 | try: 93 | return json.loads(body) 94 | except json.decoder.JSONDecodeError: 95 | raise MalformedEvent("'body' is not valid JSON") 96 | 97 | 98 | def _get_sqs_message_attributes(event): 99 | """Extract receiptHandle from message 100 | 101 | :param event: Dictionary 102 | 103 | :raises: MalformedEvent 104 | 105 | :rtype: Dictionary 106 | """ 107 | LOGGER.info("Attempting to extract receiptHandle from SQS") 108 | records = event.get("Records") 109 | if not records: 110 | LOGGER.warning("No receiptHandle found, probably not an SQS message") 111 | return 112 | try: 113 | first_record = records[0] 114 | except IndexError: 115 | raise MalformedEvent("Records seem to be empty") 116 | 117 | _check_missing_field(first_record, "receiptHandle") 118 | receipt_handle = first_record["receiptHandle"] 119 | 120 | _check_missing_field(first_record, "messageId") 121 | message_id = first_record["messageId"] 122 | 123 | return { 124 | "message_id": message_id, 125 | "receipt_handle": receipt_handle 126 | } 127 | 128 | 129 | def get_secret_from_name(secret_name, kv=True): 130 | """Return secret from secret name 131 | 132 | :param secret_name: String 133 | :param kv: Boolean (weather it is json or not) 134 | 135 | :raises: botocore.exceptions.ClientError 136 | 137 | :rtype: Dictionary 138 | """ 139 | session = boto3.session.Session() 140 | 141 | # Initializing Secret Manager's client 142 | client = session.client( 143 | service_name='secretsmanager', 144 | region_name=os.environ.get("AWS_REGION", session.region_name) 145 | ) 146 | LOGGER.info(f"Attempting to get secret value for: {secret_name}") 147 | try: 148 | get_secret_value_response = client.get_secret_value( 149 | SecretId=secret_name) 150 | except ClientError as e: 151 | # For a list of exceptions thrown, see 152 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 153 | LOGGER.error("Unable to fetch details from Secrets Manager") 154 | raise e 155 | 156 | _check_missing_field( 157 | get_secret_value_response, "SecretString") 158 | 159 | if kv: 160 | return json.loads( 161 | get_secret_value_response["SecretString"]) 162 | else: 163 | return get_secret_value_response["SecretString"] 164 | 165 | 166 | def lambda_handler(event, context): 167 | """What executes when the program is run""" 168 | 169 | # configure python logger for Lambda 170 | _configure_logger() 171 | # silence chatty libraries for better logging 172 | _silence_noisy_loggers() 173 | 174 | msg_attr = _get_sqs_message_attributes(event) 175 | 176 | if msg_attr: 177 | 178 | # Because messages remain in the queue 179 | LOGGER.info( 180 | f"Deleting message {msg_attr['message_id']} from sqs") 181 | sqs_client = boto3.client("sqs") 182 | queue_url = os.environ.get(SQS_QUEUE_ENV_VAR) 183 | if not queue_url: 184 | raise MissingEnvironmentVariable( 185 | f"{SQS_QUEUE_ENV_VAR} environment variable is required") 186 | 187 | deletion_resp = sqs_client.delete_message( 188 | QueueUrl=queue_url, 189 | ReceiptHandle=msg_attr["receipt_handle"]) 190 | 191 | sqs_client.close() 192 | 193 | resp_metadata = deletion_resp.get("ResponseMetadata") 194 | if not resp_metadata: 195 | raise Exception( 196 | "No response metadata from deletion call") 197 | status_code = resp_metadata.get("HTTPStatusCode") 198 | 199 | if status_code == 200: 200 | LOGGER.info(f"Successfully deleted message") 201 | else: 202 | raise Exception("Unable to delete message") 203 | 204 | body = _get_message_body(event) 205 | 206 | _check_missing_field(body, "bucket") 207 | _check_missing_field(body, "file") 208 | 209 | secret_name = os.environ.get(DB_SECRET_ENV_VAR) 210 | if not secret_name: 211 | raise MissingEnvironmentVariable( 212 | f"{DB_SECRET_ENV_VAR} environment variable is required") 213 | 214 | db_secret_dict = get_secret_from_name(secret_name) 215 | conn_string = PGVector.connection_string_from_db_params( 216 | driver=os.environ.get("PGVECTOR_DRIVER", "psycopg2"), 217 | host=db_secret_dict["host"], 218 | port=db_secret_dict["port"], 219 | database=os.environ.get("PGVECTOR_DATABASE", "postgres"), 220 | user=db_secret_dict["username"], 221 | password=db_secret_dict["password"], 222 | ) 223 | collection = os.environ.get(COLLECTION_ENV_VAR) 224 | if not collection: 225 | raise MissingEnvironmentVariable( 226 | f"{COLLECTION_ENV_VAR} environment variable is required") 227 | 228 | openai_secret = os.environ.get(API_KEY_SECRET_ENV_VAR) 229 | if not openai_secret: 230 | raise MissingEnvironmentVariable( 231 | f"{API_KEY_SECRET_ENV_VAR} environment variable is required") 232 | os.environ["OPENAI_API_KEY"] = get_secret_from_name( 233 | openai_secret, kv=False) 234 | LOGGER.info("Fetching OpenAI embeddings") 235 | embeddings = OpenAIEmbeddings() 236 | 237 | LOGGER.info("Initializing vector store connection") 238 | store = PGVector( 239 | collection_name=collection, 240 | connection_string=conn_string, 241 | embedding_function=embeddings, 242 | ) 243 | 244 | LOGGER.info("Initializing S3FileLoader") 245 | loader = S3FileLoader(body['bucket'], body['file']) 246 | 247 | LOGGER.info( 248 | f"Loading document: {body['file']} from bucket: {body['bucket']}") 249 | docs = loader.load() 250 | 251 | LOGGER.info("Adding new document to the vector store") 252 | store.add_documents(docs) 253 | -------------------------------------------------------------------------------- /lambda/pgvector-update/requirements.txt: -------------------------------------------------------------------------------- 1 | pgvector 2 | openai 3 | tiktoken 4 | langchain 5 | unstructured 6 | nltk 7 | -------------------------------------------------------------------------------- /lambda/rds-ddl-change/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | RUN yum update -y \ 4 | && yum install -y postgresql-libs gcc postgresql-devel \ 5 | && pip3 install psycopg2 6 | 7 | RUN yum install -y amazon-linux-extras \ 8 | && yum repolist \ 9 | && PYTHON=python2 amazon-linux-extras install postgresql10 -y \ 10 | # && amazon-linux-extras install postgresql10 \ 11 | && pip3 install psycopg2 \ 12 | && pip3 install psycopg2-binary 13 | 14 | COPY requirements.txt requirements.txt 15 | 16 | RUN pip3 install -r requirements.txt 17 | 18 | COPY app.py app.py 19 | 20 | CMD [ "app.lambda_handler"] 21 | -------------------------------------------------------------------------------- /lambda/rds-ddl-change/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import time 5 | 6 | import boto3 7 | from botocore.exceptions import ClientError 8 | import queries 9 | 10 | 11 | DB_NAME_ENV_VAR = "DB_NAME" 12 | REGION_ENV_VAR = "AWS_REGION" 13 | DDL_SOURCE_BUCKET_ENV_VAR = "DDL_SOURCE_BUCKET" 14 | 15 | LOGGER = logging.getLogger() 16 | 17 | DDL_FILE = "rds-ddl.sql" 18 | 19 | DB_IDENTIFIER_KEY = "dBInstanceIdentifier" 20 | 21 | MALFORMED_EVENT_LOG_MSG = "Malformed event. Skipping this record." 22 | 23 | 24 | class MalformedEvent(Exception): 25 | """Raised if a malformed event received""" 26 | 27 | 28 | class MissingEnvironmentVariable(Exception): 29 | """Raised if a required environment variable is missing""" 30 | 31 | 32 | def _check_missing_field(validation_dict, extraction_key): 33 | """Check if a field exists in a dictionary 34 | 35 | :param validation_dict: Dictionary 36 | :param extraction_key: String 37 | 38 | :raises: MalformedEvent 39 | """ 40 | extracted_value = validation_dict.get(extraction_key) 41 | 42 | if not extracted_value: 43 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 44 | raise MalformedEvent 45 | 46 | 47 | def _validate_field(validation_dict, extraction_key, expected_value): 48 | """Validate the passed in field 49 | 50 | :param validation_dict: Dictionary 51 | :param extraction_key: String 52 | :param expected_value: String 53 | 54 | :raises: ValueError 55 | """ 56 | extracted_value = validation_dict.get(extraction_key) 57 | _check_missing_field(validation_dict, extraction_key) 58 | 59 | if extracted_value != expected_value: 60 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 61 | raise ValueError 62 | 63 | 64 | def _silence_noisy_loggers(): 65 | """Silence chatty libraries for better logging""" 66 | for logger in ['boto3', 'botocore', 67 | 'botocore.vendored.requests.packages.urllib3']: 68 | logging.getLogger(logger).setLevel(logging.WARNING) 69 | 70 | 71 | def _configure_logger(): 72 | """Configure python logger""" 73 | level = logging.INFO 74 | verbose = os.environ.get("VERBOSE", "") 75 | if verbose.lower() == "true": 76 | print("Will set the logging output to DEBUG") 77 | level = logging.DEBUG 78 | 79 | if len(logging.getLogger().handlers) > 0: 80 | # The Lambda environment pre-configures a handler logging to stderr. 81 | # If a handler is already configured, `.basicConfig` does not execute. 82 | # Thus we set the level directly. 83 | logging.getLogger().setLevel(level) 84 | else: 85 | logging.basicConfig(level=level) 86 | 87 | 88 | def _get_ddl_source_file_contents(client, bucket, filename): 89 | """Fetch the contents of the DDL SQL file 90 | 91 | :param client: boto3 Client Object (S3) 92 | :param bucket: String 93 | :param filename: String 94 | 95 | :raises: Exception 96 | 97 | :rtype String 98 | """ 99 | resp = client.get_object(Bucket=bucket, Key=filename) 100 | 101 | _check_missing_field(resp, "ResponseMetadata") 102 | 103 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 104 | 105 | _check_missing_field(resp, "Body") 106 | body_obj = resp["Body"] 107 | 108 | return body_obj.read().decode("utf-8") 109 | 110 | 111 | def get_db_secret_from_secret_name(secret_name): 112 | """Return DB secret from secret name 113 | 114 | :param secret_name: String 115 | 116 | :raises: botocore.exceptions.ClientError 117 | 118 | :rtype: Dictionary 119 | """ 120 | session = boto3.session.Session() 121 | 122 | # Initializing Secret Manager's client 123 | client = session.client( 124 | service_name='secretsmanager', 125 | region_name=os.environ.get("AWS_REGION", session.region_name) 126 | ) 127 | LOGGER.info(f"Attempting to get secret value for: {secret_name}") 128 | try: 129 | get_secret_value_response = client.get_secret_value( 130 | SecretId=secret_name) 131 | except ClientError as e: 132 | # For a list of exceptions thrown, see 133 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 134 | LOGGER.error("Unable to fetch details from Secrets Manager") 135 | raise e 136 | 137 | _check_missing_field( 138 | get_secret_value_response, "SecretString") 139 | 140 | try: 141 | return json.loads( 142 | get_secret_value_response["SecretString"]) 143 | except json.decoder.JSONDecodeError: 144 | LOGGER.warning("Secret value is not a valid dictionary") 145 | return {} 146 | 147 | 148 | def _fetch_secret_for_db(db_identifier): 149 | """Fetch the secret arn, name for the database 150 | 151 | :param db_identifier: String 152 | 153 | :rtype: Dictionary 154 | """ 155 | ret_dict = None 156 | sm_client = boto3.client("secretsmanager") 157 | 158 | resp = sm_client.list_secrets() 159 | 160 | _check_missing_field(resp, "ResponseMetadata") 161 | 162 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 163 | 164 | _check_missing_field(resp, "SecretList") 165 | 166 | for secret in resp["SecretList"]: 167 | _check_missing_field(secret, "Name") 168 | db_secret = get_db_secret_from_secret_name(secret["Name"]) 169 | 170 | db_id = db_secret.get( 171 | # super annoying they didn't name it consistently 172 | DB_IDENTIFIER_KEY.replace("dBInstance", "dbInstance")) 173 | if not db_id: 174 | LOGGER.warning("No database ID fetched from secret name") 175 | continue 176 | 177 | if db_id == db_identifier: 178 | LOGGER.info("Found matching secret for the database") 179 | _check_missing_field(secret, "ARN") 180 | ret_dict = db_secret 181 | break 182 | 183 | sm_client.close() 184 | return ret_dict 185 | 186 | 187 | def lambda_handler(event, context): 188 | """What executes when the program is run""" 189 | 190 | # configure python logger for Lambda 191 | _configure_logger() 192 | # silence chatty libraries for better logging 193 | _silence_noisy_loggers() 194 | 195 | LOGGER.info("Waiting for DDL source to be updated..") 196 | time.sleep(120) 197 | 198 | ddl_source_file = os.environ.get("DDL_SOURCE_FILE_RDS", DDL_FILE) 199 | 200 | source_s3_bucket = os.environ.get(DDL_SOURCE_BUCKET_ENV_VAR) 201 | if not source_s3_bucket: 202 | raise MissingEnvironmentVariable(DDL_SOURCE_BUCKET_ENV_VAR) 203 | 204 | db_name = os.environ.get(DB_NAME_ENV_VAR) 205 | if not db_name: 206 | raise MissingEnvironmentVariable(DDL_SOURCE_BUCKET_ENV_VAR) 207 | 208 | db_arn = "" 209 | db_id = source_s3_bucket.replace("ddl-source-", "") 210 | 211 | rds_client = boto3.client("rds") 212 | LOGGER.info("Attempting to get db arn from RDS") 213 | resp = rds_client.describe_db_instances(DBInstanceIdentifier=db_id) 214 | rds_client.close() 215 | 216 | _check_missing_field(resp, "ResponseMetadata") 217 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 218 | 219 | print(resp) 220 | _check_missing_field(resp, "DBInstances") 221 | try: 222 | db_details = resp["DBInstances"][0] 223 | _check_missing_field(db_details, "DBInstanceArn") 224 | db_arn = db_details["DBInstanceArn"] 225 | except IndexError: 226 | LOGGER.error("No databases returned from the API call") 227 | 228 | if not db_arn: 229 | LOGGER.warning("Unable to fetch database arn from RDS API call." 230 | " Will attempt to infer it.") 231 | region = os.environ.get(REGION_ENV_VAR) 232 | if not region: 233 | raise MissingEnvironmentVariable(REGION_ENV_VAR) 234 | 235 | account_id = boto3.client('sts').get_caller_identity().get('Account') 236 | if not account_id: 237 | LOGGER.warning("Unable to fetch account_id from sts") 238 | else: 239 | db_arn = f"arn:aws:rds:{region}:{account_id}:db:{db_id}" 240 | 241 | if not db_arn: 242 | LOGGER.error("Unable to find a matching db ARN. Exiting.") 243 | raise Exception 244 | 245 | LOGGER.info(f"DB ARN: {db_arn}") 246 | 247 | secret_dict = _fetch_secret_for_db(db_id) 248 | if not secret_dict: 249 | LOGGER.error( 250 | f"No secret found associated with the db: {db_id}. Exiting") 251 | raise Exception 252 | 253 | s3_client = boto3.client("s3") 254 | file_content_string = _get_ddl_source_file_contents( 255 | s3_client, source_s3_bucket, ddl_source_file) 256 | s3_client.close() 257 | 258 | db_session = queries.Session( 259 | queries.uri( 260 | secret_dict["host"], 261 | int(secret_dict["port"]), 262 | db_name, 263 | secret_dict["username"], 264 | secret_dict["password"] 265 | ) 266 | ) 267 | 268 | with db_session as session: 269 | for sql in file_content_string.split(";"): 270 | # get rid of white spaces 271 | eff_sql = sql.strip(" \n\t") 272 | LOGGER.info(f"Executing: {eff_sql}") 273 | if eff_sql: 274 | results = session.query(eff_sql) 275 | print(results) 276 | -------------------------------------------------------------------------------- /lambda/rds-ddl-change/requirements.txt: -------------------------------------------------------------------------------- 1 | queries 2 | -------------------------------------------------------------------------------- /lambda/rds-ddl-init/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | RUN yum update -y \ 4 | && yum install -y postgresql-libs gcc postgresql-devel \ 5 | && pip3 install psycopg2 6 | 7 | RUN yum install -y amazon-linux-extras \ 8 | && yum repolist \ 9 | && PYTHON=python2 amazon-linux-extras install postgresql10 -y \ 10 | # && amazon-linux-extras install postgresql10 \ 11 | && pip3 install psycopg2 \ 12 | && pip3 install psycopg2-binary 13 | 14 | COPY requirements.txt requirements.txt 15 | 16 | RUN pip3 install -r requirements.txt 17 | 18 | COPY app.py app.py 19 | 20 | CMD [ "app.lambda_handler"] 21 | -------------------------------------------------------------------------------- /lambda/rds-ddl-init/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from botocore.exceptions import ClientError 7 | import queries 8 | 9 | 10 | DB_NAME_ENV_VAR = "DB_NAME" 11 | SQS_QUEUE_ENV_VAR = "SQS_QUEUE_URL" 12 | DDL_SOURCE_BUCKET_ENV_VAR = "DDL_SOURCE_BUCKET" 13 | 14 | DB_IDENTIFIER_KEY = "dBInstanceIdentifier" 15 | 16 | LOGGER = logging.getLogger() 17 | 18 | DDL_FILE = "rds-ddl.sql" 19 | 20 | 21 | class MalformedEvent(Exception): 22 | """Raised if a malformed event received""" 23 | 24 | 25 | class MissingEnvironmentVariable(Exception): 26 | """Raised if a required environment variable is missing""" 27 | 28 | 29 | def _silence_noisy_loggers(): 30 | """Silence chatty libraries for better logging""" 31 | for logger in ['boto3', 'botocore', 32 | 'botocore.vendored.requests.packages.urllib3']: 33 | logging.getLogger(logger).setLevel(logging.WARNING) 34 | 35 | 36 | def _configure_logger(): 37 | """Configure python logger""" 38 | level = logging.INFO 39 | verbose = os.environ.get("VERBOSE", "") 40 | if verbose.lower() == "true": 41 | print("Will set the logging output to DEBUG") 42 | level = logging.DEBUG 43 | 44 | if len(logging.getLogger().handlers) > 0: 45 | # The Lambda environment pre-configures a handler logging to stderr. 46 | # If a handler is already configured, `.basicConfig` does not execute. 47 | # Thus we set the level directly. 48 | logging.getLogger().setLevel(level) 49 | else: 50 | logging.basicConfig(level=level) 51 | 52 | 53 | def _check_missing_field(validation_dict, extraction_key): 54 | """Check if a field exists in a dictionary 55 | 56 | :param validation_dict: Dictionary 57 | :param extraction_key: String 58 | 59 | :raises: MalformedEvent 60 | """ 61 | extracted_value = validation_dict.get(extraction_key) 62 | 63 | if not extracted_value: 64 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 65 | raise MalformedEvent 66 | 67 | 68 | def _validate_field(validation_dict, extraction_key, expected_value): 69 | """Validate the passed in field 70 | 71 | :param validation_dict: Dictionary 72 | :param extraction_key: String 73 | :param expected_value: String 74 | 75 | :raises: ValueError 76 | """ 77 | extracted_value = validation_dict.get(extraction_key) 78 | _check_missing_field(validation_dict, extraction_key) 79 | 80 | if extracted_value != expected_value: 81 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 82 | raise ValueError 83 | 84 | 85 | def _get_message_body(event): 86 | """Extract message body from the event 87 | 88 | :param event: Dictionary 89 | 90 | :raises: MalformedEvent 91 | 92 | :rtype: Dictionary 93 | """ 94 | body = "" 95 | test_event = event.get("test_event", "") 96 | if test_event.lower() == "true": 97 | LOGGER.info("processing test event (and not from SQS)") 98 | LOGGER.debug("Test body: %s", event) 99 | return event 100 | else: 101 | LOGGER.info("Attempting to extract message body from SQS") 102 | 103 | _check_missing_field(event, "Records") 104 | records = event["Records"] 105 | 106 | first_record = records[0] 107 | 108 | try: 109 | body = first_record.get("body") 110 | except AttributeError: 111 | raise MalformedEvent("First record is not a proper dict") 112 | 113 | if not body: 114 | raise MalformedEvent("Missing 'body' in the record") 115 | 116 | try: 117 | return json.loads(body) 118 | except json.decoder.JSONDecodeError: 119 | raise MalformedEvent("'body' is not valid JSON") 120 | 121 | 122 | def _get_sqs_message_attributes(event): 123 | """Extract receiptHandle from message 124 | 125 | :param event: Dictionary 126 | 127 | :raises: MalformedEvent 128 | 129 | :rtype: Dictionary 130 | """ 131 | LOGGER.info("Attempting to extract receiptHandle from SQS") 132 | records = event.get("Records") 133 | if not records: 134 | LOGGER.warning("No receiptHandle found, probably not an SQS message") 135 | return 136 | try: 137 | first_record = records[0] 138 | except IndexError: 139 | raise MalformedEvent("Records seem to be empty") 140 | 141 | _check_missing_field(first_record, "receiptHandle") 142 | receipt_handle = first_record["receiptHandle"] 143 | 144 | _check_missing_field(first_record, "messageId") 145 | message_id = first_record["messageId"] 146 | 147 | return { 148 | "message_id": message_id, 149 | "receipt_handle": receipt_handle 150 | } 151 | 152 | 153 | def get_db_secret_from_secret_name(secret_name): 154 | """Return DB secret from secret name 155 | 156 | :param secret_name: String 157 | 158 | :raises: botocore.exceptions.ClientError 159 | 160 | :rtype: Dictionary 161 | """ 162 | session = boto3.session.Session() 163 | 164 | # Initializing Secret Manager's client 165 | client = session.client( 166 | service_name='secretsmanager', 167 | region_name=os.environ.get("AWS_REGION", session.region_name) 168 | ) 169 | LOGGER.info(f"Attempting to get secret value for: {secret_name}") 170 | try: 171 | get_secret_value_response = client.get_secret_value( 172 | SecretId=secret_name) 173 | except ClientError as e: 174 | # For a list of exceptions thrown, see 175 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 176 | LOGGER.error("Unable to fetch details from Secrets Manager") 177 | raise e 178 | 179 | _check_missing_field( 180 | get_secret_value_response, "SecretString") 181 | try: 182 | return json.loads( 183 | get_secret_value_response["SecretString"]) 184 | except json.decoder.JSONDecodeError: 185 | LOGGER.warning("Secret value is not a valid dictionary") 186 | return {} 187 | 188 | 189 | def _fetch_secret_for_db(db_identifier): 190 | """Fetch the secret arn, name for the database 191 | 192 | :param db_identifier: String 193 | 194 | :rtype: Dictionary 195 | """ 196 | ret_dict = None 197 | sm_client = boto3.client("secretsmanager") 198 | 199 | resp = sm_client.list_secrets() 200 | 201 | _check_missing_field(resp, "ResponseMetadata") 202 | 203 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 204 | 205 | _check_missing_field(resp, "SecretList") 206 | 207 | for secret in resp["SecretList"]: 208 | _check_missing_field(secret, "Name") 209 | db_secret = get_db_secret_from_secret_name(secret["Name"]) 210 | 211 | db_id = db_secret.get( 212 | # super annoying they didn't name it consistently 213 | DB_IDENTIFIER_KEY.replace("dBInstance", "dbInstance")) 214 | if not db_id: 215 | LOGGER.warning("No database ID fetched from secret name") 216 | continue 217 | 218 | if db_id == db_identifier: 219 | LOGGER.info("Found matching secret for the database") 220 | _check_missing_field(secret, "ARN") 221 | ret_dict = db_secret 222 | break 223 | 224 | sm_client.close() 225 | return ret_dict 226 | 227 | 228 | def _get_ddl_source_file_contents(client, bucket, filename): 229 | """Fetch the contents of the DDL SQL file 230 | 231 | :param client: boto3 Client Object (S3) 232 | :param bucket: String 233 | :param filename: String 234 | 235 | :raises: Exception 236 | 237 | :rtype String 238 | """ 239 | resp = client.get_object(Bucket=bucket, Key=filename) 240 | 241 | _check_missing_field(resp, "ResponseMetadata") 242 | 243 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 244 | 245 | _check_missing_field(resp, "Body") 246 | body_obj = resp["Body"] 247 | 248 | return body_obj.read().decode("utf-8") 249 | 250 | 251 | def lambda_handler(event, context): 252 | """What executes when the program is run""" 253 | 254 | # configure python logger for Lambda 255 | _configure_logger() 256 | # silence chatty libraries for better logging 257 | _silence_noisy_loggers() 258 | 259 | msg_attr = _get_sqs_message_attributes(event) 260 | 261 | if msg_attr: 262 | 263 | # Because messages remain in the queue 264 | LOGGER.info( 265 | f"Deleting message {msg_attr['message_id']} from sqs") 266 | sqs_client = boto3.client("sqs") 267 | queue_url = os.environ.get(SQS_QUEUE_ENV_VAR) 268 | if not queue_url: 269 | raise MissingEnvironmentVariable( 270 | f"{SQS_QUEUE_ENV_VAR} environment variable is required") 271 | 272 | deletion_resp = sqs_client.delete_message( 273 | QueueUrl=queue_url, 274 | ReceiptHandle=msg_attr["receipt_handle"]) 275 | 276 | sqs_client.close() 277 | 278 | resp_metadata = deletion_resp.get("ResponseMetadata") 279 | if not resp_metadata: 280 | raise Exception( 281 | "No response metadata from deletion call") 282 | status_code = resp_metadata.get("HTTPStatusCode") 283 | 284 | if status_code == 200: 285 | LOGGER.info(f"Successfully deleted message") 286 | else: 287 | raise Exception("Unable to delete message") 288 | 289 | body = _get_message_body(event) 290 | 291 | _check_missing_field(body, DB_IDENTIFIER_KEY) 292 | cluster_id = body[DB_IDENTIFIER_KEY] 293 | LOGGER.info(f"cluster id: {cluster_id}") 294 | 295 | source_s3_bucket = os.environ.get(DDL_SOURCE_BUCKET_ENV_VAR) 296 | if not source_s3_bucket: 297 | raise MissingEnvironmentVariable(DDL_SOURCE_BUCKET_ENV_VAR) 298 | 299 | if cluster_id.lower() not in source_s3_bucket.lower(): 300 | LOGGER.warning( 301 | "DDL Source bucket name does not contain database ID. Exiting.") 302 | return 303 | 304 | # TODO: make more robust 305 | body_dbname = body.get("databaseName") 306 | if not body_dbname: 307 | # Other database engines may have something different 308 | # this will need some more thought to make it more resilient 309 | LOGGER.warning( 310 | "No databaseName found in the CreateDBInstace event body") 311 | body_dbname = "information_schema" 312 | 313 | env_db_name = os.environ.get(DB_NAME_ENV_VAR) 314 | if not env_db_name: 315 | LOGGER.info( 316 | f"{DB_NAME_ENV_VAR} environment variable is not supplied") 317 | db_name = body_dbname 318 | else: 319 | LOGGER.warning( 320 | f"{DB_NAME_ENV_VAR} environment variable will be used as dbname") 321 | db_name = env_db_name 322 | 323 | secret_dict = _fetch_secret_for_db(cluster_id) 324 | if not secret_dict: 325 | LOGGER.error( 326 | f"No secret found associated with the cluster: {cluster_id}. Exiting") 327 | raise Exception 328 | 329 | ddl_source_file = os.environ.get("DDL_SOURCE_FILE_RDS", DDL_FILE) 330 | 331 | s3_client = boto3.client("s3") 332 | file_content_string = _get_ddl_source_file_contents( 333 | s3_client, source_s3_bucket, ddl_source_file) 334 | s3_client.close() 335 | 336 | sql_statements = file_content_string.split(";") 337 | 338 | db_session = queries.Session( 339 | queries.uri( 340 | secret_dict["host"], 341 | int(secret_dict["port"]), 342 | db_name, 343 | secret_dict["username"], 344 | secret_dict["password"] 345 | ) 346 | ) 347 | 348 | with db_session as session: 349 | for sql in sql_statements: 350 | # get rid of white spaces 351 | eff_sql = sql.strip(" \n\t") 352 | LOGGER.info(f"Executing: {eff_sql}") 353 | if eff_sql: 354 | results = session.query(eff_sql) 355 | print(results) 356 | -------------------------------------------------------------------------------- /lambda/rds-ddl-init/requirements.txt: -------------------------------------------------------------------------------- 1 | queries 2 | -------------------------------------------------------------------------------- /lambda/rds-ddl-trigger/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | 7 | 8 | LOGGER = logging.getLogger() 9 | 10 | RDS_DDL_QUEUE_URL_ENV_VAR = "RDS_DDL_QUEUE_URL" 11 | 12 | 13 | class MalformedEvent(Exception): 14 | """Raised if a malformed event received""" 15 | 16 | 17 | class MissingEnvironmentVariable(Exception): 18 | """Raised if a required environment variable is missing""" 19 | 20 | 21 | def _silence_noisy_loggers(): 22 | """Silence chatty libraries for better logging""" 23 | for logger in ['boto3', 'botocore', 24 | 'botocore.vendored.requests.packages.urllib3']: 25 | logging.getLogger(logger).setLevel(logging.WARNING) 26 | 27 | 28 | def _configure_logger(): 29 | """Configure python logger""" 30 | level = logging.INFO 31 | verbose = os.environ.get("VERBOSE", "") 32 | if verbose.lower() == "true": 33 | print("Will set the logging output to DEBUG") 34 | level = logging.DEBUG 35 | 36 | if len(logging.getLogger().handlers) > 0: 37 | # The Lambda environment pre-configures a handler logging to stderr. 38 | # If a handler is already configured, `.basicConfig` does not execute. 39 | # Thus we set the level directly. 40 | logging.getLogger().setLevel(level) 41 | else: 42 | logging.basicConfig(level=level) 43 | 44 | 45 | def _check_missing_field(validation_dict, extraction_key): 46 | """Check if a field exists in a dictionary 47 | 48 | :param validation_dict: Dictionary 49 | :param extraction_key: String 50 | 51 | :raises: MalformedEvent 52 | """ 53 | extracted_value = validation_dict.get(extraction_key) 54 | 55 | if not extracted_value: 56 | LOGGER.error(f"Missing '{extraction_key}' field in the event") 57 | raise MalformedEvent 58 | 59 | 60 | def _validate_field(validation_dict, extraction_key, expected_value): 61 | """Validate the passed in field 62 | 63 | :param validation_dict: Dictionary 64 | :param extraction_key: String 65 | :param expected_value: String 66 | 67 | :raises: ValueError 68 | """ 69 | extracted_value = validation_dict.get(extraction_key) 70 | 71 | _check_missing_field(validation_dict, extraction_key) 72 | 73 | if extracted_value != expected_value: 74 | LOGGER.error(f"Incorrect value found for '{extraction_key}' field") 75 | raise ValueError 76 | 77 | 78 | def _extract_valid_event(event): 79 | """Validate incoming event and extract necessary attributes 80 | 81 | :param event: Dictionary 82 | 83 | :raises: MalformedEvent 84 | :raises: ValueError 85 | 86 | :rtype: Dictionary 87 | """ 88 | valid_event = {} 89 | 90 | _validate_field(event, "source", "aws.rds") 91 | 92 | _check_missing_field(event, "detail") 93 | event_detail = event["detail"] 94 | 95 | _validate_field(event_detail, "eventName", "CreateDBInstance") 96 | 97 | _check_missing_field(event_detail, "responseElements") 98 | 99 | return event_detail["responseElements"] 100 | 101 | 102 | def _send_message_to_sqs(client, queue_url, message_dict): 103 | """Send message to SQS Queue 104 | 105 | :param client: Boto3 client object (SQS) 106 | :param queue_url: String 107 | :param message_dict: Dictionary 108 | 109 | :raises: Exception 110 | """ 111 | LOGGER.info(f"Attempting to send message to: {queue_url}") 112 | resp = client.send_message( 113 | QueueUrl=queue_url, 114 | MessageBody=json.dumps(message_dict) 115 | ) 116 | 117 | _check_missing_field(resp, "ResponseMetadata") 118 | resp_metadata = resp["ResponseMetadata"] 119 | 120 | _check_missing_field(resp_metadata, "HTTPStatusCode") 121 | status_code = resp_metadata["HTTPStatusCode"] 122 | 123 | if status_code == 200: 124 | LOGGER.info("Successfully pushed message") 125 | else: 126 | raise Exception("Unable to push message") 127 | 128 | 129 | def lambda_handler(event, context): 130 | """What executes when the program is run""" 131 | 132 | # configure python logger 133 | _configure_logger() 134 | # silence chatty libraries 135 | _silence_noisy_loggers() 136 | 137 | valid_event = _extract_valid_event(event) 138 | LOGGER.info("Extracted data to send to SQS") 139 | 140 | sqs_client = boto3.client("sqs") 141 | rds_ddl_queue_url = os.environ.get(RDS_DDL_QUEUE_URL_ENV_VAR) 142 | if not rds_ddl_queue_url: 143 | raise MissingEnvironmentVariable( 144 | f"{RDS_DDL_QUEUE_URL_ENV_VAR} environment variable is required") 145 | 146 | # send message to DDL Triggering Queue 147 | _send_message_to_sqs( 148 | sqs_client, 149 | rds_ddl_queue_url, 150 | valid_event) 151 | 152 | sqs_client.close() 153 | -------------------------------------------------------------------------------- /lib/base-infra-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import * as s3 from "aws-cdk-lib/aws-s3"; 5 | import * as iam from "aws-cdk-lib/aws-iam"; 6 | import * as s3deploy from 'aws-cdk-lib/aws-s3-deployment'; 7 | import * as lambda from 'aws-cdk-lib/aws-lambda'; 8 | import * as s3notif from 'aws-cdk-lib/aws-s3-notifications'; 9 | import * as sqs from 'aws-cdk-lib/aws-sqs'; 10 | import * as events from 'aws-cdk-lib/aws-events'; 11 | import * as targets from 'aws-cdk-lib/aws-events-targets'; 12 | import * as cognito from 'aws-cdk-lib/aws-cognito'; 13 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 14 | import * as elbv2 from "aws-cdk-lib/aws-elasticloadbalancingv2"; 15 | import * as elbv2_actions from "aws-cdk-lib/aws-elasticloadbalancingv2-actions"; 16 | import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; 17 | 18 | import path = require("path"); 19 | 20 | export class BaseInfraStack extends cdk.Stack { 21 | readonly vpc: ec2.Vpc; 22 | readonly lambdaSG: ec2.SecurityGroup; 23 | readonly ecsTaskSecGroup: ec2.SecurityGroup; 24 | readonly knowledgeBaseBucket: s3.Bucket; 25 | readonly processedBucket: s3.Bucket; 26 | readonly rdsDdlTriggerQueue: sqs.Queue; 27 | readonly pgvectorQueue: sqs.Queue; 28 | readonly pgvectorCollectionName: string; 29 | readonly apiKeySecret: secretsmanager.Secret; 30 | readonly appTargetGroup: elbv2.ApplicationTargetGroup; 31 | readonly ec2SecGroup: ec2.SecurityGroup; 32 | 33 | constructor(scope: Construct, id: string, props?: cdk.StackProps) { 34 | super(scope, id, props); 35 | 36 | /* 37 | capturing region env var to know which region to deploy this infrastructure 38 | 39 | NOTE - the AWS profile that is used to deploy should have the same default region 40 | */ 41 | let validRegions: string[] = ['us-east-1', 'us-west-2']; 42 | const regionPrefix = process.env.CDK_DEFAULT_REGION || 'us-east-1'; 43 | console.log(`CDK_DEFAULT_REGION: ${regionPrefix}`); 44 | // throw error if unsupported CDK_DEFAULT_REGION specified 45 | if (!(validRegions.includes(regionPrefix))) { 46 | throw new Error('Unsupported CDK_DEFAULT_REGION specified') 47 | }; 48 | 49 | // collection name used by the vector store (used to update and retrieve content) 50 | this.pgvectorCollectionName = `pgvector-collection-${regionPrefix}-${this.account}` 51 | 52 | // create VPC to deploy the infrastructure in 53 | const vpc = new ec2.Vpc(this, "InfraNetwork", { 54 | ipAddresses: ec2.IpAddresses.cidr('10.80.0.0/20'), 55 | availabilityZones: [`${regionPrefix}a`, `${regionPrefix}b`, `${regionPrefix}c`], 56 | subnetConfiguration: [ 57 | { 58 | name: "public", 59 | subnetType: ec2.SubnetType.PUBLIC, 60 | }, 61 | { 62 | name: "private", 63 | subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, 64 | } 65 | ], 66 | }); 67 | this.vpc = vpc; 68 | 69 | // create bucket for knowledgeBase 70 | const docsBucket = new s3.Bucket(this, `knowledgeBase`, {}); 71 | this.knowledgeBaseBucket = docsBucket; 72 | // use s3 bucket deploy to upload documents from local repo to the knowledgebase bucket 73 | new s3deploy.BucketDeployment(this, 'knowledgeBaseBucketDeploy', { 74 | sources: [s3deploy.Source.asset(path.join(__dirname, "../knowledgebase"))], 75 | destinationBucket: docsBucket 76 | }); 77 | 78 | // create bucket for processed text (from PDF to txt) 79 | const processedTextBucket = new s3.Bucket(this, `processedText`, {}); 80 | this.processedBucket = processedTextBucket; 81 | 82 | // capturing architecture for docker container (arm or x86) 83 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 84 | 85 | // Docker assets for lambda function 86 | const dockerfile = path.join(__dirname, "../lambda/pdf-processor/"); 87 | // create a Lambda function to process knowledgebase pdf documents 88 | const lambdaFn = new lambda.Function(this, "pdfProcessorFn", { 89 | code: lambda.Code.fromAssetImage(dockerfile), 90 | handler: lambda.Handler.FROM_IMAGE, 91 | runtime: lambda.Runtime.FROM_IMAGE, 92 | timeout: cdk.Duration.minutes(15), 93 | memorySize: 512, 94 | architecture: dockerPlatform == "arm" ? lambda.Architecture.ARM_64 : lambda.Architecture.X86_64, 95 | environment: { 96 | "SOURCE_BUCKET_NAME": docsBucket.bucketName, 97 | "DESTINATION_BUCKET_NAME": processedTextBucket.bucketName 98 | } 99 | }); 100 | // grant lambda function permissions to read knowledgebase bucket 101 | docsBucket.grantRead(lambdaFn); 102 | // grant lambda function permissions to write to the processed text bucket 103 | processedTextBucket.grantWrite(lambdaFn); 104 | 105 | // create a new S3 notification that triggers the pdf processor lambda function 106 | const kbNotification = new s3notif.LambdaDestination(lambdaFn); 107 | // assign notification for the s3 event type 108 | docsBucket.addEventNotification(s3.EventType.OBJECT_CREATED, kbNotification); 109 | 110 | // Queue for triggering initialization (DDL deployment) of RDS 111 | const rdsDdlDetectionQueue = new sqs.Queue(this, 'rdsDdlDetectionQueue', { 112 | queueName: "RDS_DDL_Detection_Queue", 113 | visibilityTimeout: cdk.Duration.minutes(6) 114 | }); 115 | this.rdsDdlTriggerQueue = rdsDdlDetectionQueue; 116 | 117 | // Function that gets triggered on the creation of an RDS cluster 118 | const rdsDdlTriggerFn = new lambda.Function(this, "rdsDdlTriggerFn", { 119 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/rds-ddl-trigger")), 120 | runtime: lambda.Runtime.PYTHON_3_11, 121 | timeout: cdk.Duration.minutes(2), 122 | handler: "app.lambda_handler", 123 | environment:{ 124 | "RDS_DDL_QUEUE_URL": rdsDdlDetectionQueue.queueUrl, 125 | }, 126 | }); 127 | // give permission to the function to be able to send messages to the queues 128 | rdsDdlDetectionQueue.grantSendMessages(rdsDdlTriggerFn); 129 | 130 | // Trigger an event when there is a RDS CreateDB API call recorded in CloudTrail 131 | const eventBridgeCreateDBRule = new events.Rule(this, 'eventBridgeCreateDBRule', { 132 | eventPattern: { 133 | source: ["aws.rds"], 134 | detail: { 135 | eventSource: ["rds.amazonaws.com"], 136 | eventName: ["CreateDBInstance"] 137 | } 138 | }, 139 | }); 140 | // Invoke the rdsDdlTriggerFn upon a matching event 141 | eventBridgeCreateDBRule.addTarget(new targets.LambdaFunction(rdsDdlTriggerFn)); 142 | 143 | // Create security group for Lambda functions interacting with RDS (not defined in this stack) 144 | const lambdaSecGroupName = "lambda-security-group"; 145 | const lambdaSecurityGroup = new ec2.SecurityGroup(this, lambdaSecGroupName, { 146 | securityGroupName: lambdaSecGroupName, 147 | vpc: vpc, 148 | // for internet access 149 | allowAllOutbound: true 150 | }); 151 | this.lambdaSG = lambdaSecurityGroup; 152 | 153 | // Create security group for test ec2 instance (will be removed later) 154 | const ec2SecGroupName = "ec2-security-group"; 155 | const ec2SecurityGroup = new ec2.SecurityGroup(this, ec2SecGroupName, { 156 | securityGroupName: ec2SecGroupName, 157 | vpc: vpc, 158 | // for internet access 159 | allowAllOutbound: true 160 | }); 161 | this.ec2SecGroup = ec2SecurityGroup; 162 | 163 | // to store the API KEY for OpenAI embeddings 164 | const oaiSecret = 'openAiApiKey'; 165 | const openAiApiKey = new secretsmanager.Secret(this, oaiSecret, { 166 | secretName: oaiSecret 167 | }); 168 | this.apiKeySecret = openAiApiKey; 169 | 170 | // Queue for triggering pgvector update 171 | const pgVectorUpdateQueue = new sqs.Queue(this, 'pgVectorUpdateQueue', { 172 | queueName: "PGVector_Update_Queue", 173 | visibilityTimeout: cdk.Duration.minutes(5) 174 | }); 175 | this.pgvectorQueue = pgVectorUpdateQueue; 176 | 177 | // create a Lambda function to send message to SQS for vector store updates 178 | const pgvectorTriggerFn = new lambda.Function(this, "pgvectorTrigger", { 179 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/pgvector-trigger")), 180 | runtime: lambda.Runtime.PYTHON_3_11, 181 | handler: "app.lambda_handler", 182 | timeout: cdk.Duration.minutes(2), 183 | environment: { 184 | "PGVECTOR_UPDATE_QUEUE": pgVectorUpdateQueue.queueUrl, 185 | "BUCKET_NAME": processedTextBucket.bucketName 186 | } 187 | }); 188 | // create a new S3 notification that triggers the pgvector trigger lambda function 189 | const processedBucketNotif = new s3notif.LambdaDestination(pgvectorTriggerFn); 190 | // assign notification for the s3 event type 191 | processedTextBucket.addEventNotification(s3.EventType.OBJECT_CREATED, processedBucketNotif); 192 | // give permission to the function to be able to send messages to the queues 193 | pgVectorUpdateQueue.grantSendMessages(pgvectorTriggerFn); 194 | 195 | // Security group for ECS tasks 196 | const ragAppSecGroup = new ec2.SecurityGroup(this, "ragAppSecGroup", { 197 | securityGroupName: "ecs-rag-sec-group", 198 | vpc: vpc, 199 | allowAllOutbound: true, 200 | }); 201 | this.ecsTaskSecGroup = ragAppSecGroup; 202 | 203 | // Security group for ALB 204 | const albSecGroup = new ec2.SecurityGroup(this, "albSecGroup", { 205 | securityGroupName: "alb-sec-group", 206 | vpc: vpc, 207 | allowAllOutbound: true, 208 | }); 209 | 210 | // create load balancer 211 | const appLoadBalancer = new elbv2.ApplicationLoadBalancer(this, 'ragAppLb', { 212 | vpc: vpc, 213 | internetFacing: true, 214 | securityGroup: albSecGroup 215 | }); 216 | 217 | const certName = process.env.IAM_SELF_SIGNED_SERVER_CERT_NAME; 218 | // throw error if IAM_SELF_SIGNED_SERVER_CERT_NAME is undefined 219 | if (certName === undefined || certName === '') { 220 | throw new Error('Please specify the "IAM_SELF_SIGNED_SERVER_CERT_NAME" env var') 221 | }; 222 | console.log(`self signed cert name: ${certName}`); 223 | 224 | const cognitoDomain = process.env.COGNITO_DOMAIN_NAME; 225 | // throw error if COGNITO_DOMAIN_NAME is undefined 226 | if (cognitoDomain === undefined || cognitoDomain === '') { 227 | throw new Error('Please specify the "COGNITO_DOMAIN_NAME" env var') 228 | }; 229 | console.log(`cognito domain name: ${cognitoDomain}`); 230 | 231 | // create Target group for ECS service 232 | const ecsTargetGroup = new elbv2.ApplicationTargetGroup(this, 'default', { 233 | vpc: vpc, 234 | protocol: elbv2.ApplicationProtocol.HTTP, 235 | port: 8501 236 | }); 237 | this.appTargetGroup = ecsTargetGroup; 238 | 239 | // Queue for triggering app client creation 240 | const appClientCreationQueue = new sqs.Queue(this, 'appClientCreateQueue', { 241 | queueName: "COG_APP_CLIENT_CREATE_QUEUE", 242 | visibilityTimeout: cdk.Duration.minutes(5) 243 | }); 244 | 245 | // create a Lambda function to send message to SQS for vector store updates 246 | const appClientCreateTriggerFn = new lambda.Function(this, "appClientCreateTrigger", { 247 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/app-client-create-trigger")), 248 | runtime: lambda.Runtime.PYTHON_3_11, 249 | handler: "app.lambda_handler", 250 | timeout: cdk.Duration.minutes(2), 251 | environment: { 252 | "TRIGGER_QUEUE": appClientCreationQueue.queueUrl, 253 | } 254 | }); 255 | // give permission to the function to be able to send messages to the queues 256 | appClientCreationQueue.grantSendMessages(appClientCreateTriggerFn); 257 | 258 | // Trigger an event when there is a Cognito CreateUserPoolClient call recorded in CloudTrail 259 | const appClientCreateRule = new events.Rule(this, 'appClientCreateRule', { 260 | eventPattern: { 261 | source: ["aws.cognito-idp"], 262 | detail: { 263 | eventSource: ["cognito-idp.amazonaws.com"], 264 | eventName: ["CreateUserPoolClient"], 265 | sourceIPAddress: ["cloudformation.amazonaws.com"] 266 | } 267 | }, 268 | }); 269 | appClientCreateRule.node.addDependency(appClientCreationQueue); 270 | // Invoke the callBack update fn upon a matching event 271 | appClientCreateRule.addTarget(new targets.LambdaFunction(appClientCreateTriggerFn)); 272 | 273 | // create cognito user pool 274 | const userPool = new cognito.UserPool(this, "UserPool", { 275 | removalPolicy: cdk.RemovalPolicy.DESTROY, 276 | selfSignUpEnabled: true, 277 | signInAliases: { email: true}, 278 | autoVerify: { email: true } 279 | }); 280 | userPool.node.addDependency(appClientCreateRule); 281 | 282 | // create cognito user pool domain 283 | const userPoolDomain = new cognito.UserPoolDomain(this, 'upDomain', { 284 | userPool, 285 | cognitoDomain: { 286 | domainPrefix: cognitoDomain 287 | } 288 | }); 289 | 290 | // create and add Application Integration for the User Pool 291 | const client = userPool.addClient("WebClient", { 292 | userPoolClientName: "MyAppWebClient", 293 | idTokenValidity: cdk.Duration.days(1), 294 | accessTokenValidity: cdk.Duration.days(1), 295 | generateSecret: true, 296 | authFlows: { 297 | adminUserPassword: true, 298 | userPassword: true, 299 | userSrp: true 300 | }, 301 | oAuth: { 302 | flows: {authorizationCodeGrant: true}, 303 | scopes: [cognito.OAuthScope.OPENID], 304 | callbackUrls: [ `https://${appLoadBalancer.loadBalancerDnsName}/oauth2/idpresponse` ] 305 | }, 306 | supportedIdentityProviders: [cognito.UserPoolClientIdentityProvider.COGNITO] 307 | }); 308 | client.node.addDependency(appClientCreateRule); 309 | 310 | // add https listener to the load balancer 311 | const httpsListener = appLoadBalancer.addListener("httpsListener", { 312 | port: 443, 313 | open: true, 314 | certificates: [ 315 | { 316 | certificateArn: `arn:aws:iam::${this.account}:server-certificate/${certName}` 317 | }, 318 | ], 319 | defaultAction: new elbv2_actions.AuthenticateCognitoAction({ 320 | userPool: userPool, 321 | userPoolClient: client, 322 | userPoolDomain: userPoolDomain, 323 | next: elbv2.ListenerAction.forward([ecsTargetGroup]) 324 | }) 325 | }); 326 | /* 327 | 328 | create lambda function because ALB dns name is not lowercase, 329 | and cognito does not function as intended due to that 330 | 331 | Reference - https://github.com/aws/aws-cdk/issues/11171 332 | 333 | */ 334 | const callBackInitFn = new lambda.Function(this, "callBackInit", { 335 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/call-back-url-init")), 336 | runtime: lambda.Runtime.PYTHON_3_11, 337 | timeout: cdk.Duration.minutes(2), 338 | handler: "app.lambda_handler", 339 | environment:{ 340 | "USER_POOL_ID": userPool.userPoolId, 341 | "APP_CLIENT_ID": client.userPoolClientId, 342 | "ALB_DNS_NAME": appLoadBalancer.loadBalancerDnsName, 343 | "SQS_QUEUE_URL": appClientCreationQueue.queueUrl, 344 | }, 345 | }); 346 | callBackInitFn.role?.addManagedPolicy( 347 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonCognitoPowerUser") 348 | ); 349 | // create SQS event source 350 | const appClientCreateSqsEventSource = new SqsEventSource(appClientCreationQueue); 351 | // trigger Lambda function upon message in SQS queue 352 | callBackInitFn.addEventSource(appClientCreateSqsEventSource); 353 | 354 | const callBackUpdateFn = new lambda.Function(this, "callBackUpdate", { 355 | code: lambda.Code.fromAsset(path.join(__dirname, "../lambda/call-back-url-update")), 356 | runtime: lambda.Runtime.PYTHON_3_11, 357 | timeout: cdk.Duration.minutes(2), 358 | handler: "app.lambda_handler", 359 | environment:{ 360 | "USER_POOL_ID": userPool.userPoolId, 361 | "APP_CLIENT_ID": client.userPoolClientId, 362 | "ALB_DNS_NAME": appLoadBalancer.loadBalancerDnsName 363 | }, 364 | }); 365 | callBackUpdateFn.role?.addManagedPolicy( 366 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonCognitoPowerUser") 367 | ); 368 | 369 | // Trigger an event when there is a Cognito CreateUserPoolClient call recorded in CloudTrail 370 | const appClientUpdateRule = new events.Rule(this, 'appClientUpdateRule', { 371 | eventPattern: { 372 | source: ["aws.cognito-idp"], 373 | detail: { 374 | eventSource: ["cognito-idp.amazonaws.com"], 375 | eventName: ["UpdateUserPoolClient"], 376 | sourceIPAddress: ["cloudformation.amazonaws.com"] 377 | } 378 | }, 379 | }); 380 | // Invoke the callBack update fn upon a matching event 381 | appClientUpdateRule.addTarget(new targets.LambdaFunction(callBackUpdateFn)); 382 | } 383 | } 384 | -------------------------------------------------------------------------------- /lib/ddl-source-rds-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import path = require("path"); 4 | import * as s3 from 'aws-cdk-lib/aws-s3'; 5 | import * as s3deploy from 'aws-cdk-lib/aws-s3-deployment'; 6 | import * as rds from 'aws-cdk-lib/aws-rds'; 7 | 8 | export interface DDLSourceRDSStackProps extends cdk.StackProps { 9 | rdsInstance: rds.DatabaseInstance; 10 | } 11 | 12 | export class DDLSourceRDSStack extends cdk.Stack { 13 | readonly sourceS3Bucket: s3.Bucket; 14 | 15 | constructor(scope: Construct, id: string, props: DDLSourceRDSStackProps) { 16 | super(scope, id, props); 17 | 18 | // create S3 bucket to host DDL file 19 | const ddlSourceBucket = new s3.Bucket(this, `ddlSourceBucket`, { 20 | bucketName: `ddl-source-${props.rdsInstance.instanceIdentifier}` 21 | }); 22 | this.sourceS3Bucket = ddlSourceBucket; 23 | 24 | // create s3 bucket deployment to upload the DDL file 25 | new s3deploy.BucketDeployment(this, 'deployDDLSourceRDS', { 26 | sources: [s3deploy.Source.asset(path.join(__dirname, "../scripts/rds-ddl-sql"))], 27 | destinationBucket: ddlSourceBucket 28 | }); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /lib/pgvector-update-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from "aws-cdk-lib/aws-ec2"; 4 | import * as lambda from "aws-cdk-lib/aws-lambda"; 5 | import * as s3 from "aws-cdk-lib/aws-s3"; 6 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 7 | import * as iam from 'aws-cdk-lib/aws-iam'; 8 | import * as sqs from 'aws-cdk-lib/aws-sqs'; 9 | import * as rds from 'aws-cdk-lib/aws-rds'; 10 | import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; 11 | import path = require("path"); 12 | 13 | 14 | export interface PGVectorUpdateStackProps extends cdk.StackProps { 15 | vpc: ec2.Vpc; 16 | processedBucket: s3.Bucket; 17 | collectionName: string; 18 | apiKeySecret: secretsmanager.Secret; 19 | databaseCreds: string; 20 | triggerQueue: sqs.Queue; 21 | dbInstance: rds.DatabaseInstance; 22 | lambdaSG: ec2.SecurityGroup; 23 | } 24 | 25 | export class PGVectorUpdateStack extends cdk.Stack { 26 | 27 | constructor(scope: Construct, id: string, props: PGVectorUpdateStackProps) { 28 | super(scope, id, props); 29 | 30 | // capturing architecture for docker container (arm or x86) 31 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 32 | 33 | // Docker assets for lambda function 34 | const dockerfilePGVectorUpdate = path.join(__dirname, "../lambda/pgvector-update/"); 35 | 36 | // create a Lambda function to update the vector store everytime a new document is added to the processed bucket 37 | const pgvectorUpdateFn = new lambda.Function(this, "pgvectorUpdate", { 38 | code: lambda.Code.fromAssetImage(dockerfilePGVectorUpdate), 39 | handler: lambda.Handler.FROM_IMAGE, 40 | runtime: lambda.Runtime.FROM_IMAGE, 41 | vpc: props.vpc, 42 | securityGroups: [props.lambdaSG], 43 | vpcSubnets: props.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}), 44 | timeout: cdk.Duration.minutes(3), 45 | memorySize: 512, 46 | architecture: dockerPlatform == "arm" ? lambda.Architecture.ARM_64 : lambda.Architecture.X86_64, 47 | environment: { 48 | "API_KEY_SECRET_NAME": props.apiKeySecret.secretName, 49 | "DB_CREDS": props.databaseCreds, 50 | "COLLECTION_NAME": props.collectionName, 51 | "QUEUE_URL": props.triggerQueue.queueUrl, 52 | // for under the hood stuff 53 | "NLTK_DATA": "/tmp" 54 | } 55 | }); 56 | // grant lambda function permissions to read processed bucket 57 | props.processedBucket.grantRead(pgvectorUpdateFn); 58 | // grant lambda function permissions to ready the api key secret 59 | props.apiKeySecret.grantRead(pgvectorUpdateFn); 60 | // grant Connection permission to the function 61 | props.dbInstance.grantConnect(pgvectorUpdateFn); 62 | // create SQS event source 63 | const eventSource = new SqsEventSource(props.triggerQueue); 64 | // trigger Lambda function upon message in SQS queue 65 | pgvectorUpdateFn.addEventSource(eventSource); 66 | 67 | // for giving permissions to lambda to be able to extract database creds from Secrets Manager 68 | const smPolicyStatementDBCreds = new iam.PolicyStatement({ 69 | effect: iam.Effect.ALLOW, 70 | actions: [ 71 | "secretsmanager:GetResourcePolicy", 72 | "secretsmanager:GetSecretValue", 73 | "secretsmanager:DescribeSecret", 74 | "secretsmanager:ListSecretVersionIds" 75 | ], 76 | resources: [props.databaseCreds], 77 | }); 78 | const smPolicyDBCreds = new iam.Policy(this, "dbCredsSecretsManagerPolicy", { 79 | statements : [smPolicyStatementDBCreds] 80 | }); 81 | pgvectorUpdateFn.role?.attachInlinePolicy(smPolicyDBCreds); 82 | 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /lib/rag-app-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from "aws-cdk-lib/aws-ec2"; 4 | import * as ecs from "aws-cdk-lib/aws-ecs"; 5 | import * as iam from "aws-cdk-lib/aws-iam"; 6 | import * as logs from "aws-cdk-lib/aws-logs"; 7 | import * as rds from 'aws-cdk-lib/aws-rds'; 8 | import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; 9 | import * as ecr_assets from "aws-cdk-lib/aws-ecr-assets"; 10 | import * as elbv2 from "aws-cdk-lib/aws-elasticloadbalancingv2"; 11 | import path = require("path"); 12 | 13 | export interface RagAppStackProps extends cdk.StackProps { 14 | vpc: ec2.Vpc; 15 | databaseCreds: string; 16 | collectionName: string; 17 | apiKeySecret: secretsmanager.Secret; 18 | dbInstance: rds.DatabaseInstance; 19 | taskSecGroup: ec2.SecurityGroup; 20 | elbTargetGroup: elbv2.ApplicationTargetGroup; 21 | } 22 | 23 | export class RagAppStack extends cdk.Stack { 24 | 25 | constructor(scope: Construct, id: string, props: RagAppStackProps) { 26 | super(scope, id, props); 27 | 28 | // This is the ECS cluster that we use for running tasks at. 29 | const cluster = new ecs.Cluster(this, "ecsClusterRAG", { 30 | vpc: props.vpc, 31 | containerInsights: true, 32 | executeCommandConfiguration: { 33 | logging: ecs.ExecuteCommandLogging.DEFAULT, 34 | }, 35 | }); 36 | 37 | // This IAM Role is used by tasks 38 | const ragTaskRole = new iam.Role(this, "RagTaskRole", { 39 | assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 40 | inlinePolicies: { 41 | dbCredsPolicy: new iam.PolicyDocument({ 42 | statements: [ 43 | new iam.PolicyStatement({ 44 | effect: iam.Effect.ALLOW, 45 | resources: [props.databaseCreds], 46 | actions: [ 47 | "secretsmanager:GetResourcePolicy", 48 | "secretsmanager:GetSecretValue", 49 | "secretsmanager:DescribeSecret", 50 | "secretsmanager:ListSecretVersionIds" 51 | ], 52 | }), 53 | ], 54 | }), 55 | bedrockPolicy: new iam.PolicyDocument({ 56 | statements: [ 57 | new iam.PolicyStatement({ 58 | effect: iam.Effect.ALLOW, 59 | resources: ["*"], 60 | actions: [ 61 | "bedrock:InvokeModel", 62 | ], 63 | }), 64 | ], 65 | }), 66 | }, 67 | }); 68 | // grant permissions to ready the api key secret 69 | props.apiKeySecret.grantRead(ragTaskRole); 70 | // grant Connection permission to the role 71 | props.dbInstance.grantConnect(ragTaskRole); 72 | 73 | // This IAM role is used to execute the tasks. It is used by task definition. 74 | const taskExecRole = new iam.Role(this, "TaskExecRole", { 75 | assumedBy: new iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 76 | managedPolicies: [ 77 | iam.ManagedPolicy.fromAwsManagedPolicyName( 78 | "service-role/AmazonECSTaskExecutionRolePolicy" 79 | ), 80 | ], 81 | }); 82 | 83 | // We create Log Group in CloudWatch to follow task logs 84 | const taskLogGroup = new logs.LogGroup(this, "TaskLogGroup", { 85 | logGroupName: "/ragapp/", 86 | removalPolicy: cdk.RemovalPolicy.DESTROY, 87 | retention: logs.RetentionDays.THREE_DAYS, 88 | }); 89 | 90 | // We create a log driver for ecs 91 | const ragTaskLogDriver = new ecs.AwsLogDriver({ 92 | streamPrefix: "rag-app", 93 | logGroup: taskLogGroup, 94 | }); 95 | 96 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 97 | 98 | // We create the task definition. Task definition is used to create tasks by ECS. 99 | const ragTaskDef = new ecs.FargateTaskDefinition(this, "RagTaskDef", { 100 | family: "rag-app", 101 | memoryLimitMiB: 512, 102 | cpu: 256, 103 | taskRole: ragTaskRole, 104 | executionRole: taskExecRole, 105 | runtimePlatform: { 106 | operatingSystemFamily: ecs.OperatingSystemFamily.LINUX, 107 | cpuArchitecture: dockerPlatform == "arm" ? ecs.CpuArchitecture.ARM64 : ecs.CpuArchitecture.X86_64 108 | } 109 | }); 110 | 111 | // We create a container image to be run by the tasks. 112 | const ragContainerImage = new ecs.AssetImage( path.join(__dirname, '../rag-app'), { 113 | platform: dockerPlatform == "arm" ? ecr_assets.Platform.LINUX_ARM64 : ecr_assets.Platform.LINUX_AMD64 114 | }); 115 | const containerName = "ragAppPostgresVec"; 116 | // We add this container image to our task definition that we created earlier. 117 | const ragContainer = ragTaskDef.addContainer("rag-container", { 118 | containerName: containerName, 119 | image: ragContainerImage, 120 | logging: ragTaskLogDriver, 121 | environment: { 122 | "AWS_REGION": `${this.region}`, 123 | "DB_CREDS": props.databaseCreds, 124 | "COLLECTION_NAME": props.collectionName, 125 | "API_KEY_SECRET_NAME": props.apiKeySecret.secretName, 126 | }, 127 | portMappings: [ 128 | { 129 | containerPort: 8501, 130 | hostPort: 8501, 131 | protocol: ecs.Protocol.TCP 132 | }, 133 | ] 134 | }); 135 | 136 | // define ECS fargate service to run the RAG app 137 | const ragAppService = new ecs.FargateService(this, "rag-app-service", { 138 | cluster, 139 | taskDefinition: ragTaskDef, 140 | desiredCount: 1,//vpc.availabilityZones.length, 141 | securityGroups: [props.taskSecGroup], 142 | minHealthyPercent: 0, 143 | }); 144 | // add fargate service as a target to the target group 145 | props.elbTargetGroup.addTarget(ragAppService); 146 | 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /lib/rds-ddl-automation-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as sqs from 'aws-cdk-lib/aws-sqs'; 4 | import * as s3 from 'aws-cdk-lib/aws-s3'; 5 | import * as lambda from "aws-cdk-lib/aws-lambda"; 6 | import path = require("path"); 7 | import * as rds from 'aws-cdk-lib/aws-rds'; 8 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 9 | import * as iam from 'aws-cdk-lib/aws-iam'; 10 | import { SqsEventSource } from 'aws-cdk-lib/aws-lambda-event-sources'; 11 | import * as events from 'aws-cdk-lib/aws-events'; 12 | import * as targets from 'aws-cdk-lib/aws-events-targets'; 13 | 14 | 15 | export interface RdsDdlAutomationStackProps extends cdk.StackProps { 16 | ddlTriggerQueue: sqs.Queue; 17 | rdsInstance: rds.DatabaseInstance; 18 | dbName: string; 19 | ddlSourceS3Bucket: s3.Bucket; 20 | vpc: ec2.Vpc; 21 | lambdaSG: ec2.SecurityGroup; 22 | ddlSourceStackName: string; 23 | } 24 | 25 | export class RdsDdlAutomationStack extends cdk.Stack { 26 | constructor(scope: Construct, id: string, props: RdsDdlAutomationStackProps) { 27 | super(scope, id, props); 28 | 29 | // setting some constants 30 | const ddlTriggerQueue = props.ddlTriggerQueue; 31 | const rdsInstance = props.rdsInstance; 32 | const dbName = props.dbName; 33 | const sourceS3Bucket = props.ddlSourceS3Bucket; 34 | const ddlSourceStackName = props.ddlSourceStackName; 35 | 36 | // private subnets 37 | const privSubnets = props.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}); 38 | 39 | // capturing architecture for docker container (arm or x86) 40 | const dockerPlatform = process.env["DOCKER_CONTAINER_PLATFORM_ARCH"] 41 | 42 | // Docker assets for init lambda function 43 | const deployDockerfile = path.join(__dirname, "../lambda/rds-ddl-init/"); 44 | 45 | // lambda function to deploy DDL on RDS (when it is first created) 46 | const ddlInitDeployFn = new lambda.Function(this, "ddlDeployFn", { 47 | code: lambda.Code.fromAssetImage(deployDockerfile), 48 | handler: lambda.Handler.FROM_IMAGE, 49 | runtime: lambda.Runtime.FROM_IMAGE, 50 | timeout: cdk.Duration.minutes(3), 51 | architecture: dockerPlatform == "arm" ? lambda.Architecture.ARM_64 : lambda.Architecture.X86_64, 52 | vpc: props.vpc, 53 | securityGroups: [props.lambdaSG], 54 | vpcSubnets: privSubnets, 55 | environment:{ 56 | "DB_NAME": dbName, 57 | "SQS_QUEUE_URL": ddlTriggerQueue.queueUrl, 58 | "DDL_SOURCE_BUCKET": sourceS3Bucket.bucketName 59 | }, 60 | }); 61 | // grant Connection property to the ddl init deploy function 62 | rdsInstance.grantConnect(ddlInitDeployFn); 63 | // create SQS event source 64 | const ddlEventSource = new SqsEventSource(ddlTriggerQueue); 65 | // trigger Lambda function upon message in SQS queue 66 | ddlInitDeployFn.addEventSource(ddlEventSource); 67 | // give S3 permissions 68 | sourceS3Bucket.grantRead(ddlInitDeployFn); 69 | // to be able to list secrets 70 | ddlInitDeployFn.role?.addManagedPolicy( 71 | iam.ManagedPolicy.fromAwsManagedPolicyName("SecretsManagerReadWrite") 72 | ); 73 | 74 | // Docker assets for change lambda function 75 | const changeDockerfile = path.join(__dirname, "../lambda/rds-ddl-change/"); 76 | 77 | // lambda function to deploy DDL on RDS (when there is a change to the DDL SQL File) 78 | const ddlChangeFn = new lambda.Function(this, "ddlChangeFn", { 79 | code: lambda.Code.fromAssetImage(changeDockerfile), 80 | handler: lambda.Handler.FROM_IMAGE, 81 | runtime: lambda.Runtime.FROM_IMAGE, 82 | timeout: cdk.Duration.minutes(10), 83 | architecture: dockerPlatform == "arm" ? lambda.Architecture.ARM_64 : lambda.Architecture.X86_64, 84 | vpc: props.vpc, 85 | securityGroups: [props.lambdaSG], 86 | vpcSubnets: privSubnets, 87 | environment:{ 88 | "DB_NAME": dbName, 89 | "SQS_QUEUE_URL": ddlTriggerQueue.queueUrl, 90 | "DDL_SOURCE_BUCKET": sourceS3Bucket.bucketName 91 | }, 92 | }); 93 | // give S3 permissions 94 | sourceS3Bucket.grantRead(ddlChangeFn); 95 | // to be able to list secrets 96 | ddlChangeFn.role?.addManagedPolicy( 97 | iam.ManagedPolicy.fromAwsManagedPolicyName("SecretsManagerReadWrite") 98 | ); 99 | // grant Connection property to the ddl init deploy function 100 | rdsInstance.grantConnect(ddlInitDeployFn); 101 | // to be able to describe cluster on RDS 102 | ddlChangeFn.role?.addManagedPolicy( 103 | iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonRDSReadOnlyAccess") 104 | ); 105 | 106 | const cfnChangesetRule = new events.Rule(this, 'cfnChangesetRule', { 107 | eventPattern: { 108 | "source": ["aws.cloudformation"], 109 | "detail": { 110 | "eventSource": ["cloudformation.amazonaws.com"], 111 | "eventName": ["ExecuteChangeSet"], 112 | "requestParameters": { 113 | "stackName": [ddlSourceStackName] 114 | } 115 | } 116 | }, 117 | }); 118 | // Invoke the ddlChangeFn upon a matching event 119 | cfnChangesetRule.addTarget(new targets.LambdaFunction(ddlChangeFn)); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /lib/rds-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from "aws-cdk-lib/aws-ec2"; 4 | import * as rds from "aws-cdk-lib/aws-rds"; 5 | 6 | 7 | export interface RDSStackProps extends cdk.StackProps { 8 | vpc: ec2.Vpc; 9 | sgLambda: ec2.SecurityGroup; 10 | sgEc2: ec2.SecurityGroup; 11 | ecsSecGroup: ec2.SecurityGroup; 12 | } 13 | 14 | export class RDSStack extends cdk.Stack { 15 | readonly rdsDBName: string; 16 | readonly dbInstance: rds.DatabaseInstance; 17 | 18 | constructor(scope: Construct, id: string, props: RDSStackProps) { 19 | super(scope, id, props); 20 | 21 | // passed in as property 22 | const vpc = props.vpc; 23 | 24 | // create RDS bits (security group and serverless instance) 25 | const dbName = "postgres"; 26 | const rdsSecGroupName = "rds-security-group"; 27 | 28 | const rdsSecurityGroup = new ec2.SecurityGroup(this, rdsSecGroupName, { 29 | securityGroupName: rdsSecGroupName, 30 | vpc: vpc, 31 | allowAllOutbound: false, 32 | }); 33 | // allow connection from lambda 34 | rdsSecurityGroup.connections.allowFrom(props.sgLambda, ec2.Port.tcp(5432)); 35 | // allow connection from test ec2 instance (will be deleted) 36 | rdsSecurityGroup.connections.allowFrom(props.sgEc2, ec2.Port.tcp(5432)); 37 | // allow connection from ecs Task Security Group 38 | rdsSecurityGroup.connections.allowFrom(props.ecsSecGroup, ec2.Port.tcp(5432)); 39 | 40 | const rdsInstance = new rds.DatabaseInstance(this, 'rdsInstance', { 41 | engine: rds.DatabaseInstanceEngine.POSTGRES, 42 | credentials: rds.Credentials.fromGeneratedSecret('postgres'), 43 | vpc: vpc, 44 | securityGroups: [rdsSecurityGroup], 45 | }); 46 | this.dbInstance = rdsInstance; 47 | 48 | this.rdsDBName = dbName; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /lib/test-compute-stack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 4 | import * as iam from 'aws-cdk-lib/aws-iam'; 5 | 6 | 7 | export interface TestComputeStackProps extends cdk.StackProps { 8 | vpc: ec2.Vpc; 9 | ec2SG: ec2.SecurityGroup; 10 | } 11 | 12 | 13 | export class TestComputeStack extends cdk.Stack { 14 | readonly jumpHostSG: ec2.SecurityGroup; 15 | 16 | constructor(scope: Construct, id: string, props: TestComputeStackProps) { 17 | super(scope, id, props); 18 | 19 | 20 | const userData = ec2.UserData.forLinux(); 21 | userData.addCommands( 22 | 'apt-get update -y', 23 | 'apt-get install -y git awscli ec2-instance-connect', 24 | 'apt install -y fish' 25 | ); 26 | 27 | const machineImage = ec2.MachineImage.fromSsmParameter( 28 | '/aws/service/canonical/ubuntu/server/focal/stable/current/amd64/hvm/ebs-gp2/ami-id', 29 | ); 30 | 31 | const jumpHostRole = new iam.Role(this, 'jumpHostRole', { 32 | assumedBy: new iam.CompositePrincipal( 33 | new iam.ServicePrincipal('ec2.amazonaws.com'), 34 | ), 35 | managedPolicies: [ 36 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSSMManagedInstanceCore'), 37 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonS3FullAccess'), 38 | iam.ManagedPolicy.fromAwsManagedPolicyName('AdministratorAccess'), 39 | ] 40 | }); 41 | 42 | const instanceProf = new iam.CfnInstanceProfile(this, 'jumpHostInstanceProf', { 43 | roles: [jumpHostRole.roleName] 44 | }); 45 | 46 | const ec2JumpHost = new ec2.Instance(this, 'ec2JumpHost', { 47 | vpc: props.vpc, 48 | instanceType: ec2.InstanceType.of(ec2.InstanceClass.T2, ec2.InstanceSize.MICRO), 49 | machineImage: machineImage, 50 | securityGroup: props.ec2SG, 51 | userData: userData, 52 | role: jumpHostRole, 53 | requireImdsv2: true, 54 | // for public access testing 55 | vpcSubnets: {subnetType: ec2.SubnetType.PUBLIC}, 56 | // for public access testing 57 | associatePublicIpAddress: true, 58 | blockDevices: [ 59 | { 60 | deviceName: '/dev/sda1', 61 | mappingEnabled: true, 62 | volume: ec2.BlockDeviceVolume.ebs(128, { 63 | deleteOnTermination: true, 64 | encrypted: true, 65 | volumeType: ec2.EbsDeviceVolumeType.GP2 66 | }) 67 | } 68 | ] 69 | }); 70 | 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "rag-with-amazon-bedrock-and-rds-pgvector", 3 | "version": "0.1.0", 4 | "bin": { 5 | "rag-with-amazon-bedrock-and-rds-pgvector": "bin/rag-with-amazon-bedrock-and-rds-pgvector.js" 6 | }, 7 | "scripts": { 8 | "build": "tsc", 9 | "watch": "tsc -w", 10 | "test": "jest", 11 | "cdk": "cdk" 12 | }, 13 | "devDependencies": { 14 | "@types/jest": "^29.5.4", 15 | "@types/node": "20.5.9", 16 | "aws-cdk": "2.96.2", 17 | "jest": "^29.6.4", 18 | "ts-jest": "^29.1.1", 19 | "ts-node": "^10.9.1", 20 | "typescript": "~5.2.2" 21 | }, 22 | "dependencies": { 23 | "aws-cdk-lib": "2.96.2", 24 | "constructs": "^10.0.0", 25 | "source-map-support": "^0.5.21" 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /rag-app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.2 2 | 3 | WORKDIR python-docker 4 | 5 | COPY requirements.txt requirements.txt 6 | 7 | RUN pip3 install -r requirements.txt && mkdir /root/.streamlit 8 | 9 | COPY . . 10 | 11 | CMD [ "./run_app.sh" ] 12 | -------------------------------------------------------------------------------- /rag-app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/rag-app/__init__.py -------------------------------------------------------------------------------- /rag-app/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import uuid 5 | 6 | import streamlit as st 7 | 8 | import pgvec_chat_bedrock as bedrock_claude 9 | 10 | 11 | # TODO: clean up the way this app is written 12 | USER_ICON = "images/user-icon.png" 13 | AI_ICON = "images/ai-icon.png" 14 | MAX_HISTORY_LENGTH = 5 15 | 16 | COLLECTION_ENV_VAR = "COLLECTION_NAME" 17 | DB_SECRET_ENV_VAR = "DB_CREDS" 18 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 19 | 20 | DEFAULT_LOG_LEVEL = logging.INFO 21 | LOGGER = logging.getLogger(__name__) 22 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 23 | "[%(name)s]:[%(threadName)s] " \ 24 | "%(message)s" 25 | 26 | 27 | class MissingEnvironmentVariable(Exception): 28 | """Raised if a required environment variable is missing""" 29 | 30 | 31 | # logging configuration 32 | log_level = DEFAULT_LOG_LEVEL 33 | if os.environ.get("VERBOSE", "").lower() == "true": 34 | log_level = logging.DEBUG 35 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 36 | 37 | os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] 38 | 39 | # get collection name 40 | collection = os.environ.get(COLLECTION_ENV_VAR) 41 | if not collection: 42 | raise MissingEnvironmentVariable(f"{COLLECTION_ENV_VAR} environment variable is required") 43 | 44 | #function to read a properties file and create environment variables 45 | def read_properties_file(filename): 46 | import os 47 | import re 48 | with open(filename, 'r') as f: 49 | for line in f: 50 | m = re.match(r'^\s*(\w+)\s*=\s*(.*)\s*$', line) 51 | if m: 52 | os.environ[m.group(1)] = m.group(2) 53 | 54 | 55 | # Check if the user ID is already stored in the session state 56 | if 'user_id' in st.session_state: 57 | user_id = st.session_state['user_id'] 58 | 59 | # If the user ID is not yet stored in the session state, generate a random UUID 60 | else: 61 | user_id = str(uuid.uuid4()) 62 | st.session_state['user_id'] = user_id 63 | 64 | 65 | if 'llm_chain' not in st.session_state: 66 | if (len(sys.argv) > 1): 67 | if (sys.argv[1] == 'bedrock_claude'): 68 | st.session_state['llm_app'] = bedrock_claude 69 | st.session_state['llm_chain'] = bedrock_claude.build_chain( 70 | { 71 | "host": st.secrets["host"], 72 | "port": int(st.secrets["port"]), 73 | "username": st.secrets["username"], 74 | "password": st.secrets["password"] 75 | }, 76 | collection 77 | ) 78 | else: 79 | raise Exception("Unsupported LLM: ", sys.argv[1]) 80 | else: 81 | raise Exception("Usage: streamlit run app.py bedrock_claude") 82 | 83 | if 'chat_history' not in st.session_state: 84 | st.session_state['chat_history'] = [] 85 | 86 | if "chats" not in st.session_state: 87 | st.session_state.chats = [ 88 | { 89 | 'id': 0, 90 | 'question': '', 91 | 'answer': '' 92 | } 93 | ] 94 | 95 | if "questions" not in st.session_state: 96 | st.session_state.questions = [] 97 | 98 | if "answers" not in st.session_state: 99 | st.session_state.answers = [] 100 | 101 | if "input" not in st.session_state: 102 | st.session_state.input = "" 103 | 104 | 105 | st.markdown(""" 106 | 121 | """, unsafe_allow_html=True) 122 | 123 | 124 | def write_logo(): 125 | col1, col2, col3 = st.columns([5, 1, 5]) 126 | with col2: 127 | st.image(AI_ICON, use_column_width='always') 128 | 129 | 130 | def write_top_bar(): 131 | col1, col2, col3 = st.columns([1,10,2]) 132 | with col1: 133 | st.image(AI_ICON, use_column_width='always') 134 | with col2: 135 | selected_provider = sys.argv[1] 136 | provider = selected_provider.capitalize() 137 | header = f"An AI App powered by PGVector (on Amazon RDS) and {provider}!" 138 | st.write(f"

{header}

", unsafe_allow_html=True) 139 | with col3: 140 | clear = st.button("Clear Chat") 141 | return clear 142 | 143 | 144 | clear = write_top_bar() 145 | 146 | if clear: 147 | st.session_state.questions = [] 148 | st.session_state.answers = [] 149 | st.session_state.input = "" 150 | st.session_state["chat_history"] = [] 151 | 152 | 153 | def handle_input(): 154 | input = st.session_state.input 155 | question_with_id = { 156 | 'question': input, 157 | 'id': len(st.session_state.questions) 158 | } 159 | st.session_state.questions.append(question_with_id) 160 | 161 | chat_history = st.session_state["chat_history"] 162 | if len(chat_history) == MAX_HISTORY_LENGTH: 163 | chat_history = chat_history[:-1] 164 | 165 | llm_chain = st.session_state['llm_chain'] 166 | chain = st.session_state['llm_app'] 167 | result = chain.run_chain(llm_chain, input, chat_history) 168 | answer = result['answer'] 169 | chat_history.append((input, answer)) 170 | 171 | document_list = [] 172 | if 'source_documents' in result: 173 | for d in result['source_documents']: 174 | if not (d.metadata['source'] in document_list): 175 | document_list.append((d.metadata['source'])) 176 | 177 | st.session_state.answers.append({ 178 | 'answer': result, 179 | 'sources': document_list, 180 | 'id': len(st.session_state.questions) 181 | }) 182 | st.session_state.input = "" 183 | 184 | 185 | def write_user_message(md): 186 | col1, col2 = st.columns([1,12]) 187 | 188 | with col1: 189 | st.image(USER_ICON, use_column_width='always') 190 | with col2: 191 | st.warning(md['question']) 192 | 193 | 194 | def render_result(result): 195 | answer, sources = st.tabs(['Answer', 'Sources']) 196 | with answer: 197 | render_answer(result['answer']) 198 | with sources: 199 | if 'source_documents' in result: 200 | render_sources(result['source_documents']) 201 | else: 202 | render_sources([]) 203 | 204 | 205 | def render_answer(answer): 206 | col1, col2 = st.columns([1,12]) 207 | with col1: 208 | st.image(AI_ICON, use_column_width='always') 209 | with col2: 210 | st.info(answer['answer']) 211 | 212 | 213 | def render_sources(sources): 214 | col1, col2 = st.columns([1,12]) 215 | with col2: 216 | with st.expander("Sources"): 217 | for s in sources: 218 | st.write(s) 219 | 220 | 221 | #Each answer will have context of the question asked in order to associate the provided feedback with the respective question 222 | def write_chat_message(md, q): 223 | chat = st.container() 224 | with chat: 225 | render_answer(md['answer']) 226 | render_sources(md['sources']) 227 | 228 | 229 | with st.container(): 230 | for (q, a) in zip(st.session_state.questions, st.session_state.answers): 231 | write_user_message(q) 232 | write_chat_message(a, q) 233 | 234 | 235 | st.markdown('---') 236 | input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) -------------------------------------------------------------------------------- /rag-app/app_init.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import toml 5 | 6 | import helper_functions as hfn 7 | 8 | 9 | class MissingEnvironmentVariable(Exception): 10 | """Raised if a required environment variable is missing""" 11 | 12 | DB_SECRET_ENV_VAR = "DB_CREDS" 13 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 14 | 15 | DEFAULT_LOG_LEVEL = logging.INFO 16 | LOGGER = logging.getLogger(__name__) 17 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 18 | "[%(name)s]:[%(threadName)s] " \ 19 | "%(message)s" 20 | 21 | 22 | if __name__ == "__main__": 23 | 24 | streamlit_secrets = {} 25 | 26 | # logging configuration 27 | log_level = DEFAULT_LOG_LEVEL 28 | if os.environ.get("VERBOSE", "").lower() == "true": 29 | log_level = logging.DEBUG 30 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 31 | 32 | # vector db secret fetch 33 | secret_name = os.environ.get(DB_SECRET_ENV_VAR) 34 | if not secret_name: 35 | raise MissingEnvironmentVariable(f"{DB_SECRET_ENV_VAR} environment variable is required") 36 | streamlit_secrets.update(hfn.get_secret_from_name(secret_name)) 37 | 38 | # open ai api key fetch 39 | openai_secret = os.environ.get(API_KEY_SECRET_ENV_VAR) 40 | if not openai_secret: 41 | raise MissingEnvironmentVariable(f"{API_KEY_SECRET_ENV_VAR} environment variable is required") 42 | streamlit_secrets["OPENAI_API_KEY"] = hfn.get_secret_from_name(openai_secret, kv=False) 43 | 44 | LOGGER.info("Writing streamlit secrets") 45 | with open("/root/.streamlit/secrets.toml", "w") as file: 46 | toml.dump(streamlit_secrets, file) 47 | -------------------------------------------------------------------------------- /rag-app/helper_functions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import boto3 6 | from botocore.exceptions import ClientError 7 | 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | def _check_missing_field(validation_dict, extraction_key): 13 | """Check if a field exists in a dictionary 14 | 15 | :param validation_dict: Dictionary 16 | :param extraction_key: String 17 | 18 | :raises: KeyError 19 | """ 20 | extracted_value = validation_dict.get(extraction_key) 21 | 22 | if not extracted_value: 23 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 24 | raise KeyError 25 | 26 | 27 | def get_secret_from_name(secret_name, kv=True): 28 | """Return secret from secret name 29 | 30 | :param secret_name: String 31 | :param kv: Boolean (weather it is json or not) 32 | 33 | :raises: botocore.exceptions.ClientError 34 | 35 | :rtype: Dictionary 36 | """ 37 | session = boto3.session.Session() 38 | 39 | # Initializing Secret Manager's client 40 | client = session.client( 41 | service_name='secretsmanager', 42 | region_name=os.environ.get("AWS_REGION", session.region_name) 43 | ) 44 | LOGGER.info(f"Attempting to get secret value for: {secret_name}") 45 | try: 46 | get_secret_value_response = client.get_secret_value( 47 | SecretId=secret_name) 48 | except ClientError as e: 49 | # For a list of exceptions thrown, see 50 | # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html 51 | LOGGER.error("Unable to fetch details from Secrets Manager") 52 | raise e 53 | 54 | _check_missing_field( 55 | get_secret_value_response, "SecretString") 56 | 57 | if kv: 58 | return json.loads( 59 | get_secret_value_response["SecretString"]) 60 | else: 61 | return get_secret_value_response["SecretString"] 62 | -------------------------------------------------------------------------------- /rag-app/images/ai-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/rag-app/images/ai-icon.png -------------------------------------------------------------------------------- /rag-app/images/user-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/rag-app/images/user-icon.png -------------------------------------------------------------------------------- /rag-app/pgvec_chat_bedrock.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from langchain.chains import ConversationalRetrievalChain 6 | from langchain.embeddings.openai import OpenAIEmbeddings 7 | from langchain.llms.bedrock import Bedrock 8 | from langchain.prompts import PromptTemplate 9 | from langchain.vectorstores.pgvector import PGVector 10 | 11 | import helper_functions as hfn 12 | 13 | 14 | class MissingEnvironmentVariable(Exception): 15 | """Raised if a required environment variable is missing""" 16 | 17 | 18 | class bcolors: 19 | HEADER = '\033[95m' 20 | OKBLUE = '\033[94m' 21 | OKCYAN = '\033[96m' 22 | OKGREEN = '\033[92m' 23 | WARNING = '\033[93m' 24 | FAIL = '\033[91m' 25 | ENDC = '\033[0m' 26 | BOLD = '\033[1m' 27 | UNDERLINE = '\033[4m' 28 | 29 | 30 | MAX_HISTORY_LENGTH = 5 31 | 32 | COLLECTION_ENV_VAR = "COLLECTION_NAME" 33 | DB_SECRET_ENV_VAR = "DB_CREDS" 34 | API_KEY_SECRET_ENV_VAR = "API_KEY_SECRET_NAME" 35 | 36 | DEFAULT_LOG_LEVEL = logging.INFO 37 | LOGGER = logging.getLogger(__name__) 38 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 39 | "[%(name)s]:[%(threadName)s] " \ 40 | "%(message)s" 41 | 42 | 43 | def build_chain(db_creds, collection): 44 | """Build conversational retrieval chain 45 | 46 | :param db_creds: Dictionary 47 | :param collection: String 48 | 49 | :rtype: ConversationalRetrievalChain 50 | """ 51 | region = os.environ["AWS_REGION"] 52 | 53 | llm = Bedrock( 54 | # credentials_profile_name=credentials_profile_name, 55 | region_name = region, 56 | model_kwargs={"max_tokens_to_sample":300,"temperature":1,"top_k":250,"top_p":0.999,"anthropic_version":"bedrock-2023-05-31"}, 57 | model_id=os.environ.get("FOUNDATION_MODEL_ID", "anthropic.claude-instant-v1") 58 | ) 59 | conn_str = PGVector.connection_string_from_db_params( 60 | driver=os.environ.get("PGVECTOR_DRIVER", "psycopg2"), 61 | host=db_creds["host"], 62 | port=db_creds["port"], 63 | database=os.environ.get("PGVECTOR_DATABASE", "postgres"), 64 | user=db_creds["username"], 65 | password=db_creds["password"], 66 | ) 67 | embeddings = OpenAIEmbeddings() 68 | store = PGVector( 69 | collection_name=collection, 70 | connection_string=conn_str, 71 | embedding_function=embeddings, 72 | ) 73 | retriever = store.as_retriever() 74 | 75 | prompt_template = """Human: This is a friendly conversation between a human and an AI. 76 | The AI is talkative and provides specific details from its context but limits it to 240 tokens. 77 | If the AI does not know the answer to a question, it truthfully says it 78 | does not know. 79 | 80 | Assistant: OK, got it, I'll be a talkative truthful AI assistant. 81 | 82 | Human: Here are a few documents in tags: 83 | 84 | {context} 85 | 86 | Based on the above documents, provide a detailed answer for, {question} 87 | Answer "don't know" if not present in the document. 88 | 89 | Assistant: 90 | """ 91 | PROMPT = PromptTemplate( 92 | template=prompt_template, input_variables=["context", "question"] 93 | ) 94 | 95 | condense_qa_template = """{chat_history} 96 | Human: 97 | Given the previous conversation and a follow up question below, rephrase the follow up question 98 | to be a standalone question. 99 | 100 | Follow Up Question: {question} 101 | Standalone Question: 102 | 103 | Assistant:""" 104 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) 105 | 106 | return ConversationalRetrievalChain.from_llm( 107 | llm=llm, 108 | retriever=retriever, 109 | condense_question_prompt=standalone_question_prompt, 110 | return_source_documents=True, 111 | combine_docs_chain_kwargs={"prompt":PROMPT}, 112 | verbose=True) 113 | 114 | 115 | def run_chain(chain, prompt: str, history=[]): 116 | return chain({"question": prompt, "chat_history": history}) 117 | 118 | 119 | if __name__ == "__main__": 120 | 121 | # logging configuration 122 | log_level = DEFAULT_LOG_LEVEL 123 | if os.environ.get("VERBOSE", "").lower() == "true": 124 | log_level = logging.DEBUG 125 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 126 | 127 | # vector db secret fetch 128 | secret_name = os.environ.get(DB_SECRET_ENV_VAR) 129 | if not secret_name: 130 | raise MissingEnvironmentVariable(f"{DB_SECRET_ENV_VAR} environment variable is required") 131 | 132 | # open ai api key fetch 133 | openai_secret = os.environ.get(API_KEY_SECRET_ENV_VAR) 134 | if not openai_secret: 135 | raise MissingEnvironmentVariable(f"{API_KEY_SECRET_ENV_VAR} environment variable is required") 136 | os.environ["OPENAI_API_KEY"] = hfn.get_secret_from_name(openai_secret, kv=False) 137 | 138 | # get collection name 139 | collection = os.environ.get(COLLECTION_ENV_VAR) 140 | if not collection: 141 | raise MissingEnvironmentVariable(f"{COLLECTION_ENV_VAR} environment variable is required") 142 | 143 | LOGGER.info("starting conversational retrieval chain now..") 144 | 145 | # langchain stuff 146 | chat_history = [] 147 | qa = build_chain( 148 | hfn.get_secret_from_name(secret_name), 149 | collection 150 | ) 151 | 152 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) 153 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 154 | print(">", end=" ", flush=True) 155 | 156 | for query in sys.stdin: 157 | if (query.strip().lower().startswith("new search:")): 158 | query = query.strip().lower().replace("new search:","") 159 | chat_history = [] 160 | elif (len(chat_history) == MAX_HISTORY_LENGTH): 161 | chat_history.pop(0) 162 | 163 | result = run_chain(qa, query, chat_history) 164 | 165 | chat_history.append((query, result["answer"])) 166 | 167 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) 168 | if 'source_documents' in result: 169 | print(bcolors.OKGREEN + 'Sources:') 170 | for d in result['source_documents']: 171 | print(d.metadata['source']) 172 | print(bcolors.ENDC) 173 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 174 | print(">", end=" ", flush=True) 175 | 176 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) 177 | -------------------------------------------------------------------------------- /rag-app/requirements.txt: -------------------------------------------------------------------------------- 1 | # langchain==0.0.308 2 | # langchain==0.0.353 3 | langchain==0.1.11 4 | boto3>=1.28.27 5 | openai 6 | anthropic 7 | streamlit 8 | pgvector 9 | psycopg2-binary 10 | tiktoken 11 | toml 12 | -------------------------------------------------------------------------------- /rag-app/run_app.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "initializing app.." 5 | python3 app_init.py 6 | 7 | echo "starting streamlit app" 8 | streamlit run app.py bedrock_claude 9 | -------------------------------------------------------------------------------- /screenshots/app_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/screenshots/app_screenshot.png -------------------------------------------------------------------------------- /screenshots/cog_login_page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/screenshots/cog_login_page.png -------------------------------------------------------------------------------- /screenshots/invalid_cert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-bedrock-and-pgvector/93c5fc9ee9378d8d99a114299c49507a89d9cae2/screenshots/invalid_cert.png -------------------------------------------------------------------------------- /scripts/api-key-secret-manager-upload/api-key-secret-manager-upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import getpass 3 | import logging 4 | import time 5 | 6 | import boto3 7 | 8 | 9 | DEFAULT_LOG_LEVEL = logging.INFO 10 | LOGGER = logging.getLogger(__name__) 11 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 12 | "[%(name)s]:[%(threadName)s] " \ 13 | "%(message)s" 14 | 15 | AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" 16 | AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" 17 | AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" 18 | 19 | 20 | def _check_missing_field(validation_dict, extraction_key): 21 | """Check if a field exists in a dictionary 22 | 23 | :param validation_dict: Dictionary 24 | :param extraction_key: String 25 | 26 | :raises: Exception 27 | """ 28 | extracted_value = validation_dict.get(extraction_key) 29 | 30 | if not extracted_value: 31 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 32 | raise Exception 33 | 34 | 35 | def _validate_field(validation_dict, extraction_key, expected_value): 36 | """Validate the passed in field 37 | 38 | :param validation_dict: Dictionary 39 | :param extraction_key: String 40 | :param expected_value: String 41 | 42 | :raises: ValueError 43 | """ 44 | extracted_value = validation_dict.get(extraction_key) 45 | _check_missing_field(validation_dict, extraction_key) 46 | 47 | if extracted_value != expected_value: 48 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 49 | raise ValueError 50 | 51 | 52 | def _cli_args(): 53 | """Parse CLI Args 54 | 55 | :rtype: argparse.Namespace 56 | """ 57 | parser = argparse.ArgumentParser(description="api-key-secret-manager-upload") 58 | 59 | parser.add_argument("-s", 60 | "--secret-name", 61 | type=str, 62 | help="Secret Name", 63 | required=True 64 | ) 65 | parser.add_argument("-p", 66 | "--aws-profile", 67 | type=str, 68 | default="default", 69 | help="AWS profile to be used for the API calls") 70 | parser.add_argument("-v", 71 | "--verbose", 72 | action="store_true", 73 | help="debug log output") 74 | parser.add_argument("-e", 75 | "--env", 76 | action="store_true", 77 | help="Use environment variables for AWS credentials") 78 | return parser.parse_args() 79 | 80 | 81 | def _silence_noisy_loggers(): 82 | """Silence chatty libraries for better logging""" 83 | for logger in ['boto3', 'botocore', 84 | 'botocore.vendored.requests.packages.urllib3']: 85 | logging.getLogger(logger).setLevel(logging.WARNING) 86 | 87 | 88 | def main(): 89 | """What executes when the script is run""" 90 | start = time.time() # to capture elapsed time 91 | 92 | args = _cli_args() 93 | 94 | # logging configuration 95 | log_level = DEFAULT_LOG_LEVEL 96 | if args.verbose: 97 | log_level = logging.DEBUG 98 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 99 | # silence chatty libraries 100 | _silence_noisy_loggers() 101 | 102 | if args.env: 103 | LOGGER.info( 104 | "Attempting to fetch AWS credentials via environment variables") 105 | aws_access_key_id = os.environ.get(AWS_ACCESS_KEY_ID) 106 | aws_secret_access_key = os.environ.get(AWS_SECRET_ACCESS_KEY) 107 | aws_session_token = os.environ.get(AWS_SESSION_TOKEN) 108 | if not aws_secret_access_key or not aws_access_key_id or not aws_session_token: 109 | raise Exception( 110 | f"Missing one or more environment variables - " 111 | f"'{AWS_ACCESS_KEY_ID}', '{AWS_SECRET_ACCESS_KEY}', " 112 | f"'{AWS_SESSION_TOKEN}'" 113 | ) 114 | else: 115 | LOGGER.info(f"AWS Profile being used: {args.aws_profile}") 116 | boto3.setup_default_session(profile_name=args.aws_profile) 117 | 118 | sm_client = boto3.client("secretsmanager") 119 | 120 | LOGGER.info(f"Updating Secret: {args.secret_name}") 121 | 122 | resp = sm_client.update_secret( 123 | SecretId=args.secret_name, 124 | SecretString=getpass.getpass("Please enter the API Key: ") 125 | ) 126 | _check_missing_field(resp, "ResponseMetadata") 127 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 128 | LOGGER.info("Successfully updated secret value") 129 | 130 | LOGGER.debug("Closing secretsmanager boto3 client") 131 | sm_client.close() 132 | 133 | LOGGER.info(f"Total time elapsed: {time.time() - start} seconds") 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /scripts/api-key-secret-manager-upload/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 -------------------------------------------------------------------------------- /scripts/rds-ddl-sql/rds-ddl.sql: -------------------------------------------------------------------------------- 1 | -- Create pgvector extension 2 | CREATE EXTENSION IF NOT EXISTS vector; 3 | -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/.gitignore: -------------------------------------------------------------------------------- 1 | .ssl 2 | *env 3 | -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/default_cert_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "country": "US", 3 | "state": "Pennsylvania", 4 | "locality": "Hatfield", 5 | "organization": "MyCo", 6 | "organizational_unit": "Infra", 7 | "email_address": "infra at myco dot com" 8 | } -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | pyopenssl 3 | validators 4 | -------------------------------------------------------------------------------- /scripts/self-signed-cert-utility/self-signed-cert-utility.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pathlib 6 | import time 7 | 8 | import boto3 9 | from OpenSSL import crypto 10 | import validators 11 | from validators import ValidationError 12 | 13 | 14 | DEFAULT_LOG_LEVEL = logging.INFO 15 | LOGGER = logging.getLogger(__name__) 16 | LOGGING_FORMAT = "%(asctime)s %(levelname)-5.5s " \ 17 | "[%(name)s]:[%(threadName)s] " \ 18 | "%(message)s" 19 | 20 | CERT_FILE = "cert.pem" 21 | KEY_FILE = "key.pem" 22 | 23 | AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" 24 | AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" 25 | AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" 26 | 27 | DEFAULT_APP_DOMAIN = "pgvec.rag" 28 | 29 | 30 | def generate_ssl_keys(key_config): 31 | """Generate ssl keys 32 | 33 | Solution taken from: 34 | https://stackoverflow.com/questions/27164354/create-a-self-signed-x509-certificate-in-python 35 | 36 | :param data_dir_obj: DataDir 37 | :param key_config: Dictionary 38 | 39 | :rtype: Dictionary 40 | """ 41 | ssl_dir = f"{pathlib.Path.cwd()}/.ssl" 42 | if not os.path.isdir(ssl_dir): 43 | os.mkdir(ssl_dir) 44 | LOGGER.info("Created ssl dir: %s", ssl_dir) 45 | else: 46 | LOGGER.info(f"ssl dir: {ssl_dir} already exists") 47 | 48 | ssl_cert = f"{ssl_dir}/{os.environ.get('CERT_FILE', CERT_FILE)}" 49 | ssl_pem = f"{ssl_dir}/{os.environ.get('KEY_FILE', KEY_FILE)}" 50 | 51 | LOGGER.debug("Creating a key pair") 52 | ssl_key = crypto.PKey() 53 | ssl_key.generate_key(crypto.TYPE_RSA, 2048) 54 | 55 | LOGGER.debug("Creating a self-signed cert") 56 | cert = crypto.X509() 57 | cert.get_subject().countryName = key_config["country"] 58 | cert.get_subject().stateOrProvinceName = key_config["state"] 59 | cert.get_subject().localityName = key_config["locality"] 60 | cert.get_subject().organizationName = key_config["organization"] 61 | cert.get_subject().organizationalUnitName = \ 62 | key_config["organizational_unit"] 63 | domain = os.environ.get("APP_DOMAIN", DEFAULT_APP_DOMAIN) 64 | if not domain: 65 | raise Exception("Missing 'APP_DOMAIN' environement variable'") 66 | try: 67 | valid_domain = validators.domain(domain) 68 | if not valid_domain: 69 | LOGGER.error(f"'{domain}' could not be validated as a domain") 70 | raise Exception 71 | except ValidationError as ve: 72 | LOGGER.error(f"'{domain}' could not be validated as a domain") 73 | raise Exception 74 | 75 | cert.get_subject().commonName = domain 76 | 77 | cert.set_serial_number(1000) 78 | cert.gmtime_adj_notBefore(0) 79 | cert.gmtime_adj_notAfter(10*365*24*60*60) 80 | cert.set_issuer(cert.get_subject()) 81 | cert.set_pubkey(ssl_key) 82 | cert.sign(ssl_key, "sha1") 83 | 84 | cert_body = crypto.dump_certificate( 85 | crypto.FILETYPE_PEM, cert).decode("utf-8") 86 | LOGGER.info("Writing self-signed cert to: %s", ssl_cert) 87 | with open(ssl_cert, "wt") as cert_writer: 88 | cert_writer.write(cert_body) 89 | 90 | key_body = crypto.dump_privatekey( 91 | crypto.FILETYPE_PEM, ssl_key).decode("utf-8") 92 | LOGGER.info("Writing self-signed cert to: %s", ssl_pem) 93 | with open(ssl_pem, "wt") as key_writer: 94 | key_writer.write(key_body) 95 | return { 96 | "key": key_body, 97 | "cert": cert_body 98 | } 99 | 100 | 101 | def _validate_config_file_path(file_path): 102 | """Checks if passed in file path is valid or not 103 | 104 | :file_path: String 105 | 106 | :raises: FileNotFoundError 107 | """ 108 | LOGGER.info(f"Config file path: {file_path}") 109 | if not os.path.isfile(file_path): 110 | LOGGER.error("Config file provided is not found") 111 | raise FileNotFoundError 112 | else: 113 | LOGGER.debug("File path is valid") 114 | return 115 | 116 | 117 | def _parse_key_details_file(args): 118 | """Parse json file containing key details 119 | 120 | :param args: argparse.Namespace (CLI args) 121 | 122 | :rtype: List 123 | """ 124 | absolute_file_path = pathlib.Path(args.config_file).resolve() 125 | _validate_config_file_path(absolute_file_path) 126 | 127 | with open(absolute_file_path) as config_file: 128 | try: 129 | return json.load(config_file) 130 | except json.JSONDecodeError as e: 131 | LOGGER.error("Configuration file is not valid JSON") 132 | raise TypeError 133 | 134 | 135 | def _check_missing_field(validation_dict, extraction_key): 136 | """Check if a field exists in a dictionary 137 | 138 | :param validation_dict: Dictionary 139 | :param extraction_key: String 140 | 141 | :raises: KeyError 142 | """ 143 | extracted_value = validation_dict.get(extraction_key) 144 | 145 | if not extracted_value: 146 | LOGGER.error(f"Missing '{extraction_key}' key in the dict") 147 | raise KeyError 148 | 149 | 150 | def _validate_field(validation_dict, extraction_key, expected_value): 151 | """Validate the passed in field 152 | 153 | :param validation_dict: Dictionary 154 | :param extraction_key: String 155 | :param expected_value: String 156 | 157 | :raises: ValueError 158 | """ 159 | extracted_value = validation_dict.get(extraction_key) 160 | _check_missing_field(validation_dict, extraction_key) 161 | 162 | if extracted_value != expected_value: 163 | LOGGER.error(f"Incorrect value found for '{extraction_key}' key") 164 | raise ValueError 165 | 166 | 167 | def _cli_args(): 168 | """Parse CLI Args 169 | 170 | :rtype: argparse.Namespace 171 | """ 172 | parser = argparse.ArgumentParser(description="self-signed-cert-utility") 173 | parser.add_argument("-p", 174 | "--aws-profile", 175 | type=str, 176 | default="default", 177 | help="AWS profile to be used for the API calls") 178 | parser.add_argument("-f", 179 | "--config-file", 180 | type=str, 181 | default="default_cert_params.json", 182 | help="path to configuration file") 183 | parser.add_argument("-v", 184 | "--verbose", 185 | action="store_true", 186 | help="debug log output") 187 | parser.add_argument("-e", 188 | "--env", 189 | action="store_true", 190 | help="Use environment variables for AWS credentials") 191 | return parser.parse_args() 192 | 193 | 194 | def _silence_noisy_loggers(): 195 | """Silence chatty libraries for better logging""" 196 | for logger in ['boto3', 'botocore', 197 | 'botocore.vendored.requests.packages.urllib3']: 198 | logging.getLogger(logger).setLevel(logging.WARNING) 199 | 200 | 201 | def main(): 202 | """What executes when the script is run""" 203 | start = time.time() # to capture elapsed time 204 | 205 | args = _cli_args() 206 | 207 | # logging configuration 208 | log_level = DEFAULT_LOG_LEVEL 209 | if args.verbose: 210 | log_level = logging.DEBUG 211 | logging.basicConfig(level=log_level, format=LOGGING_FORMAT) 212 | # silence chatty libraries 213 | _silence_noisy_loggers() 214 | 215 | if args.env: 216 | LOGGER.info( 217 | "Attempting to fetch AWS credentials via environment variables") 218 | aws_access_key_id = os.environ.get(AWS_ACCESS_KEY_ID) 219 | aws_secret_access_key = os.environ.get(AWS_SECRET_ACCESS_KEY) 220 | aws_session_token = os.environ.get(AWS_SESSION_TOKEN) 221 | if not aws_secret_access_key or not aws_access_key_id or not aws_session_token: 222 | raise Exception( 223 | f"Missing one or more environment variables - " 224 | f"'{AWS_ACCESS_KEY_ID}', '{AWS_SECRET_ACCESS_KEY}', " 225 | f"'{AWS_SESSION_TOKEN}'" 226 | ) 227 | else: 228 | LOGGER.info(f"AWS Profile being used: {args.aws_profile}") 229 | boto3.setup_default_session(profile_name=args.aws_profile) 230 | 231 | cert_name = os.environ.get("IAM_SELF_SIGNED_SERVER_CERT_NAME") 232 | if not cert_name: 233 | LOGGER.error( 234 | "Need to export the 'IAM_SELF_SIGNED_SERVER_CERT_NAME' env var") 235 | raise Exception 236 | 237 | key_config = _parse_key_details_file(args) 238 | 239 | cert_files = generate_ssl_keys(key_config) 240 | 241 | iam_client = boto3.client("iam") 242 | 243 | resp = iam_client.upload_server_certificate( 244 | ServerCertificateName=cert_name, 245 | CertificateBody=cert_files["cert"], 246 | PrivateKey=cert_files["key"] 247 | ) 248 | _check_missing_field(resp, "ResponseMetadata") 249 | _validate_field(resp["ResponseMetadata"], "HTTPStatusCode", 200) 250 | print(resp) 251 | 252 | LOGGER.debug("Closing iam client") 253 | iam_client.close() 254 | 255 | LOGGER.info(f"Total time elapsed: {time.time() - start} seconds") 256 | 257 | 258 | if __name__ == "__main__": 259 | main() -------------------------------------------------------------------------------- /test/rag-with-amazon-bedrock-and-rds-pgvector.test.ts: -------------------------------------------------------------------------------- 1 | // import * as cdk from 'aws-cdk-lib'; 2 | // import { Template } from 'aws-cdk-lib/assertions'; 3 | // import * as RagWithAmazonBedrockAndRdsPgvector from '../lib/rag-with-amazon-bedrock-and-rds-pgvector-stack'; 4 | 5 | // example test. To run these tests, uncomment this file along with the 6 | // example resource in lib/rag-with-amazon-bedrock-and-rds-pgvector-stack.ts 7 | test('SQS Queue Created', () => { 8 | // const app = new cdk.App(); 9 | // // WHEN 10 | // const stack = new RagWithAmazonBedrockAndRdsPgvector.RagWithAmazonBedrockAndRdsPgvectorStack(app, 'MyTestStack'); 11 | // // THEN 12 | // const template = Template.fromStack(stack); 13 | 14 | // template.hasResourceProperties('AWS::SQS::Queue', { 15 | // VisibilityTimeout: 300 16 | // }); 17 | }); 18 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "module": "commonjs", 5 | "lib": [ 6 | "es2020", 7 | "dom" 8 | ], 9 | "declaration": true, 10 | "strict": true, 11 | "noImplicitAny": true, 12 | "strictNullChecks": true, 13 | "noImplicitThis": true, 14 | "alwaysStrict": true, 15 | "noUnusedLocals": false, 16 | "noUnusedParameters": false, 17 | "noImplicitReturns": true, 18 | "noFallthroughCasesInSwitch": false, 19 | "inlineSourceMap": true, 20 | "inlineSources": true, 21 | "experimentalDecorators": true, 22 | "strictPropertyInitialization": false, 23 | "typeRoots": [ 24 | "./node_modules/@types" 25 | ] 26 | }, 27 | "exclude": [ 28 | "node_modules", 29 | "cdk.out" 30 | ] 31 | } 32 | --------------------------------------------------------------------------------