├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app
├── README.md
├── app.py
├── env_vars.sh
├── images
│ ├── ai-icon.png
│ └── user-icon.png
├── pgvector_chat_flan_xl.py
├── pgvector_chat_llama2.py
├── qa-with-llm-and-rag.png
└── requirements.txt
├── cdk_stacks
├── .gitignore
├── README.md
├── app.py
├── cdk.context.json
├── cdk.json
├── rag_with_pgvector
│ ├── __init__.py
│ ├── aurora_postgresql.py
│ ├── sm_embedding_endpoint.py
│ ├── sm_llm_endpoint.py
│ ├── sm_studio.py
│ └── vpc.py
├── rag_with_pgvector_arch.svg
├── requirements.txt
└── source.bat
└── data_ingestion_to_vectordb
├── container
├── Dockerfile
├── credentials.py
├── load_data_into_pgvector.py
└── sm_helper.py
├── data_ingestion_to_pgvector.ipynb
└── scripts
└── get_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 | Untitled*.ipynb
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | .DS_Store
108 | .idea/
109 | bin/
110 | lib64
111 | pyvenv.cfg
112 | *.bak
113 | share/
114 | cdk.out/
115 | cdk.context.json*
116 | zap/
117 |
118 | */.gitignore
119 | */setup.py
120 | */source.bat
121 |
122 | */*/.gitignore
123 | */*/setup.py
124 | */*/source.bat
125 |
126 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QA with LLM and RAG (Retrieval Augmented Generation)
2 |
3 | This project is a Question Answering application with Large Language Models (LLMs) and Amazon Aurora Postgresql using [pgvector](https://github.com/pgvector/pgvector). An application using the RAG(Retrieval Augmented Generation) approach retrieves information most relevant to the user’s request from the enterprise knowledge base or content, bundles it as context along with the user’s request as a prompt, and then sends it to the LLM to get a GenAI response.
4 |
5 | LLMs have limitations around the maximum word count for the input prompt, therefore choosing the right passages among thousands or millions of documents in the enterprise, has a direct impact on the LLM’s accuracy.
6 |
7 | In this project, Amazon Aurora Postgresql with pgvector is used for knowledge base.
8 |
9 | The overall architecture is like this:
10 |
11 | 
12 |
13 | ### Overall Workflow
14 |
15 | 1. Deploy the cdk stacks (For more information, see [here](./cdk_stacks/README.md)).
16 | - A SageMaker Studio in a private VPC.
17 | - A SageMaker Endpoint for text generation.
18 | - A SageMaker Endpoint for generating embeddings.
19 | - An Amazon Aurora Postgresql cluster for storing embeddings.
20 | - Aurora Postgresql cluster's access credentials (username and password) stored in AWS Secrets Mananger as a name such as `RAGPgVectorStackAuroraPostg-xxxxxxxxxxxx`.
21 | 2. Open JupyterLab in SageMaker Studio and then open a new terminal.
22 | 3. Run the following commands on the terminal to clone the code repository for this project:
23 | ```
24 | git clone --depth=1 https://github.com/aws-samples/rag-with-amazon-postgresql-using-pgvector.git
25 | ```
26 | 4. Open `data_ingestion_to_pgvector.ipynb` notebook and Run it. (For more information, see [here](./data_ingestion_to_vectordb/data_ingestion_to_pgvector.ipynb))
27 | 5. Run Streamlit application. (For more information, see [here](./app/README.md))
28 |
29 | ### References
30 |
31 | * [Leverage pgvector and Amazon Aurora PostgreSQL for Natural Language Processing, Chatbots and Sentiment Analysis (2023-07-13)](https://aws.amazon.com/blogs/database/leverage-pgvector-and-amazon-aurora-postgresql-for-natural-language-processing-chatbots-and-sentiment-analysis/)
32 | * [Accelerate HNSW indexing and searching with pgvector on Amazon Aurora PostgreSQL-compatible edition and Amazon RDS for PostgreSQL (2023-11-06)](https://aws.amazon.com/blogs/database/accelerate-hnsw-indexing-and-searching-with-pgvector-on-amazon-aurora-postgresql-compatible-edition-and-amazon-rds-for-postgresql/)
33 | * [Optimize generative AI applications with pgvector indexing: A deep dive into IVFFlat and HNSW techniques (2024-03-15)](https://aws.amazon.com/blogs/database/optimize-generative-ai-applications-with-pgvector-indexing-a-deep-dive-into-ivfflat-and-hnsw-techniques/)
34 | * [Improve the performance of generative AI workloads on Amazon Aurora with Optimized Reads and pgvector (2024-02-09)](https://aws.amazon.com/blogs/database/accelerate-generative-ai-workloads-on-amazon-aurora-with-optimized-reads-and-pgvector/)
35 | * [Building AI-powered search in PostgreSQL using Amazon SageMaker and pgvector (2023-05-03)](https://aws.amazon.com/blogs/database/building-ai-powered-search-in-postgresql-using-amazon-sagemaker-and-pgvector/)
36 | * [Build Streamlit apps in Amazon SageMaker Studio (2023-04-11)](https://aws.amazon.com/blogs/machine-learning/build-streamlit-apps-in-amazon-sagemaker-studio/)
37 | * [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models (2023-05-03)](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)
38 | * [(github) Amazon Kendra Retriver Samples](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples)
39 | * [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart (2023-05-02)](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/)
40 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/)
41 | * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models.
42 | * [Streamlit](https://streamlit.io/) - A faster way to build and share data apps
43 | * [rag-with-amazon-kendra-and-sagemaker](https://github.com/aws-samples/aws-kr-startup-samples/tree/main/gen-ai/rag-with-amazon-kendra-and-sagemaker) - Question Answering application with Large Language Models (LLMs) and Amazon Kendra
44 | * [rag-with-amazon-opensearch-and-sagemaker](https://github.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker) - Question Answering application with Large Language Models (LLMs) and Amazon OpenSearch Service
45 | * [rag-with-amazon-opensearch-serverless](https://github.com/aws-samples/rag-with-amazon-opensearch-serverless) - Question Answering application with Large Language Models (LLMs) and Amazon OpenSearch Serverless Service
46 | * [Pgvector changelog - v0.4.0 (2023-01-11)](https://github.com/pgvector/pgvector/blob/master/CHANGELOG.md#040-2023-01-11)
47 | > Increased max dimensions for vector from `1024` to `16000`
48 | > Increased max dimensions for index from `1024` to `2000`
49 |
50 | ## Security
51 |
52 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
53 |
54 | ## License
55 |
56 | This library is licensed under the MIT-0 License. See the LICENSE file.
57 |
--------------------------------------------------------------------------------
/app/README.md:
--------------------------------------------------------------------------------
1 | ## Run the Streamlit application in Studio
2 |
3 | Now we’re ready to run the Streamlit web application for our question answering bot.
4 |
5 | SageMaker Studio provides a convenient platform to host the Streamlit web application. The following steps describe how to run the Streamlit app on SageMaker Studio. Alternatively, you could also follow the same procedure to run the app on Amazon EC2 instance or Cloud9 in your AWS Account.
6 |
7 | 1. Open JupyterLab and then open a new **Terminal**.
8 | 2. Run the following commands on the terminal to clone the code repository for this post and install the Python packages needed by the application:
9 | ```
10 | git clone --depth=1 https://github.com/aws-samples/rag-with-amazon-postgresql-using-pgvector-and-sagemaker.git
11 | cd rag-with-amazon-postgresql-using-pgvector-and-sagemaker/app
12 | python -m venv .env
13 | source .env/bin/activate
14 | pip install -r requirements.txt
15 | ```
16 | 3. In the shell, set the following environment variables with the values that are available from the CloudFormation stack output.
17 | ```
18 | export AWS_REGION=us-east-1
19 | export PGVECTOR_SECRET_ID="your-postgresql-secret-id"
20 | export COLLECTION_NAME="llm_rag_embeddings"
21 | export EMBEDDING_ENDPOINT_NAME="your-sagemakr-endpoint-for-embedding-model"
22 | export TEXT2TEXT_ENDPOINT_NAME="your-sagemaner-endpoint-for-text-generation-model"
23 | ```
24 | :information_source: `COLLECTION_NAME` can be found in [data ingestion to vectordb](../data_ingestion_to_vectordb/data_ingestion_to_pgvector.ipynb) step.
25 | 4. When the application runs successfully, you’ll see an output similar to the following (the IP addresses you will see will be different from the ones shown in this example). Note the port number (typically `8501`) from the output to use as part of the URL for app in the next step.
26 | ```
27 | sagemaker-user@studio$ streamlit run app.py
28 |
29 | Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
30 |
31 | You can now view your Streamlit app in your browser.
32 |
33 | Network URL: http://169.255.255.2:8501
34 | External URL: http://52.4.240.77:8501
35 | ```
36 | 5. You can access the app in a new browser tab using a URL that is similar to your Studio domain URL. For example, if your Studio URL is `https://d-randomidentifier.studio.us-east-1.sagemaker.aws/jupyter/default/lab?` then the URL for your Streamlit app will be `https://d-randomidentifier.studio.us-east-1.sagemaker.aws/jupyter/default/proxy/8501/app` (notice that `lab` is replaced with `proxy/8501/app`). If the port number noted in the previous step is different from `8501` then use that instead of `8501` in the URL for the Streamlit app.
37 |
38 | The following screenshot shows the app with a couple of user questions. (e.g., `What are the versions of XGBoost supported by Amazon SageMaker?`)
39 |
40 | 
41 |
42 | ## References
43 |
44 | * [Leverage pgvector and Amazon Aurora PostgreSQL for Natural Language Processing, Chatbots and Sentiment Analysis (2023-07-13)](https://aws.amazon.com/blogs/database/leverage-pgvector-and-amazon-aurora-postgresql-for-natural-language-processing-chatbots-and-sentiment-analysis/)
45 | * [Building AI-powered search in PostgreSQL using Amazon SageMaker and pgvector (2023-05-03)](https://aws.amazon.com/blogs/database/building-ai-powered-search-in-postgresql-using-amazon-sagemaker-and-pgvector/)
46 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/)
47 | * [Build Streamlit apps in Amazon SageMaker Studio (2023-04-11)](https://aws.amazon.com/blogs/machine-learning/build-streamlit-apps-in-amazon-sagemaker-studio/)
48 | * [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models (2023-05-02)](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)
49 | * [sagemaker-huggingface-inference-toolkit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) - SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker.
50 | * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models.
51 | * [Streamlit](https://streamlit.io/) - A faster way to build and share data apps
52 |
--------------------------------------------------------------------------------
/app/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import streamlit as st
7 | import uuid
8 |
9 | import pgvector_chat_flan_xl as flanxl
10 | import pgvector_chat_llama2 as llama2
11 |
12 | PROVIDER_NAME = os.environ.get('PROVIDER_NAME', 'llama2')
13 |
14 | USER_ICON = "images/user-icon.png"
15 | AI_ICON = "images/ai-icon.png"
16 | MAX_HISTORY_LENGTH = 5
17 | PROVIDER_MAP = {
18 | 'flanxl': 'Flan XL',
19 | 'llama2': 'Llama2 7B',
20 | }
21 |
22 | # Check if the user ID is already stored in the session state
23 | if 'user_id' in st.session_state:
24 | user_id = st.session_state['user_id']
25 |
26 | # If the user ID is not yet stored in the session state, generate a random UUID
27 | else:
28 | user_id = str(uuid.uuid4())
29 | st.session_state['user_id'] = user_id
30 |
31 |
32 | if 'llm_chain' not in st.session_state:
33 | llm_app = llama2 if PROVIDER_NAME == 'llama2' else flanxl
34 | st.session_state['llm_app'] = llm_app
35 | st.session_state['llm_chain'] = llm_app.build_chain()
36 |
37 | if 'chat_history' not in st.session_state:
38 | st.session_state['chat_history'] = []
39 |
40 | if "chats" not in st.session_state:
41 | st.session_state.chats = [
42 | {
43 | 'id': 0,
44 | 'question': '',
45 | 'answer': ''
46 | }
47 | ]
48 |
49 | if "questions" not in st.session_state:
50 | st.session_state.questions = []
51 |
52 | if "answers" not in st.session_state:
53 | st.session_state.answers = []
54 |
55 | if "input" not in st.session_state:
56 | st.session_state.input = ""
57 |
58 |
59 | st.markdown("""
60 |
75 | """, unsafe_allow_html=True)
76 |
77 |
78 | def write_logo():
79 | col1, col2, col3 = st.columns([5, 1, 5])
80 | with col2:
81 | st.image(AI_ICON, use_column_width='always')
82 |
83 |
84 | def write_top_bar():
85 | col1, col2, col3 = st.columns([1,10,2])
86 | with col1:
87 | st.image(AI_ICON, use_column_width='always')
88 | with col2:
89 | selected_provider = PROVIDER_NAME
90 | if selected_provider in PROVIDER_MAP:
91 | provider = PROVIDER_MAP[selected_provider]
92 | else:
93 | provider = selected_provider.capitalize()
94 | header = f"An AI App powered by Amazon Aurora Postgresql with pgvector and {provider}!"
95 | st.write(f"
{header}
", unsafe_allow_html=True)
96 | with col3:
97 | clear = st.button("Clear Chat")
98 | return clear
99 |
100 |
101 | clear = write_top_bar()
102 |
103 | if clear:
104 | st.session_state.questions = []
105 | st.session_state.answers = []
106 | st.session_state.input = ""
107 | st.session_state["chat_history"] = []
108 |
109 |
110 | def handle_input():
111 | input = st.session_state.input
112 | question_with_id = {
113 | 'question': input,
114 | 'id': len(st.session_state.questions)
115 | }
116 | st.session_state.questions.append(question_with_id)
117 |
118 | chat_history = st.session_state["chat_history"]
119 | if len(chat_history) == MAX_HISTORY_LENGTH:
120 | chat_history = chat_history[:-1]
121 |
122 | llm_chain = st.session_state['llm_chain']
123 | chain = st.session_state['llm_app']
124 | result = chain.run_chain(llm_chain, input, chat_history)
125 | answer = result['answer']
126 | chat_history.append((input, answer))
127 |
128 | document_list = []
129 | if 'source_documents' in result:
130 | for d in result['source_documents']:
131 | if not (d.metadata['source'] in document_list):
132 | document_list.append((d.metadata['source']))
133 |
134 | st.session_state.answers.append({
135 | 'answer': result,
136 | 'sources': document_list,
137 | 'id': len(st.session_state.questions)
138 | })
139 | st.session_state.input = ""
140 |
141 |
142 | def write_user_message(md):
143 | col1, col2 = st.columns([1,12])
144 |
145 | with col1:
146 | st.image(USER_ICON, use_column_width='always')
147 | with col2:
148 | st.warning(md['question'])
149 |
150 |
151 | def render_result(result):
152 | answer, sources = st.tabs(['Answer', 'Sources'])
153 | with answer:
154 | render_answer(result['answer'])
155 | with sources:
156 | if 'source_documents' in result:
157 | render_sources(result['source_documents'])
158 | else:
159 | render_sources([])
160 |
161 |
162 | def render_answer(answer):
163 | col1, col2 = st.columns([1,12])
164 | with col1:
165 | st.image(AI_ICON, use_column_width='always')
166 | with col2:
167 | st.info(answer['answer'])
168 |
169 |
170 | def render_sources(sources):
171 | col1, col2 = st.columns([1,12])
172 | with col2:
173 | with st.expander("Sources"):
174 | for s in sources:
175 | st.write(s)
176 |
177 |
178 | # Each answer will have context of the question asked in order to associate the provided feedback with the respective question
179 | def write_chat_message(md, q):
180 | chat = st.container()
181 | with chat:
182 | render_answer(md['answer'])
183 | render_sources(md['sources'])
184 |
185 |
186 | with st.container():
187 | for (q, a) in zip(st.session_state.questions, st.session_state.answers):
188 | write_user_message(q)
189 | write_chat_message(a, q)
190 |
191 | st.markdown('---')
192 | input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)
193 |
--------------------------------------------------------------------------------
/app/env_vars.sh:
--------------------------------------------------------------------------------
1 | export AWS_REGION="your-aws-region"
2 | export PGVECTOR_SECRET_ID="your-postgresql-secret"
3 | export COLLECTION_NAME="llm_rag_embeddings"
4 | export EMBEDDING_ENDPOINT_NAME="your-sagemaker-endpoint-for-embedding-model"
5 | export TEXT2TEXT_ENDPOINT_NAME="your-sagemaker-endpoint-for-text-generation-model"
--------------------------------------------------------------------------------
/app/images/ai-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/rag-with-amazon-postgresql-using-pgvector-and-sagemaker/1b5ca45eff14b162e8be28cb179338e1ad4d7bbd/app/images/ai-icon.png
--------------------------------------------------------------------------------
/app/images/user-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/rag-with-amazon-postgresql-using-pgvector-and-sagemaker/1b5ca45eff14b162e8be28cb179338e1ad4d7bbd/app/images/user-icon.png
--------------------------------------------------------------------------------
/app/pgvector_chat_flan_xl.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List
10 | import urllib
11 |
12 | import boto3
13 |
14 | from langchain_postgres import PGVector
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains import ConversationalRetrievalChain
23 |
24 | logger = logging.getLogger()
25 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
26 |
27 |
28 | class bcolors:
29 | HEADER = '\033[95m'
30 | OKBLUE = '\033[94m'
31 | OKCYAN = '\033[96m'
32 | OKGREEN = '\033[92m'
33 | WARNING = '\033[93m'
34 | FAIL = '\033[91m'
35 | ENDC = '\033[0m'
36 | BOLD = '\033[1m'
37 | UNDERLINE = '\033[4m'
38 |
39 |
40 | MAX_HISTORY_LENGTH = 5
41 |
42 |
43 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddings:
44 |
45 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
46 | """
47 | encode input string as utf-8 bytes, read the embeddings
48 | from the output
49 | """
50 |
51 | content_type = "application/json"
52 | accepts = "application/json"
53 |
54 | def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
55 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
56 | return input_str.encode('utf-8')
57 |
58 | def transform_output(self, output: bytes) -> str:
59 | response_json = json.loads(output.read().decode("utf-8"))
60 | embeddings = response_json["embedding"]
61 | if len(embeddings) == 1:
62 | return [embeddings[0]]
63 | return embeddings
64 |
65 | # create a content handler object which knows how to serialize
66 | # and deserialize communication with the model endpoint
67 | content_handler = ContentHandlerForEmbeddings()
68 |
69 | # read to create the Sagemaker embeddings, we are providing
70 | # the Sagemaker endpoint that will be used for generating the
71 | # embeddings to the class
72 | #
73 | embeddings = SagemakerEndpointEmbeddings(
74 | endpoint_name=endpoint_name,
75 | region_name=region,
76 | content_handler=content_handler
77 | )
78 | logger.info(f"embeddings type={type(embeddings)}")
79 |
80 | return embeddings
81 |
82 |
83 | def _get_credentials(secret_id: str, region_name: str = 'us-east-1') -> str:
84 | client = boto3.client('secretsmanager', region_name=region_name)
85 | response = client.get_secret_value(SecretId=secret_id)
86 | secrets_value = json.loads(response['SecretString'])
87 | return secrets_value
88 |
89 |
90 | def build_chain():
91 | region = os.environ["AWS_REGION"]
92 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
93 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
94 |
95 | pgvector_secret_id = os.environ["PGVECTOR_SECRET_ID"]
96 | secret = _get_credentials(pgvector_secret_id, region)
97 | db_username = secret['username']
98 | db_password = urllib.parse.quote_plus(secret['password'])
99 | db_port = secret['port']
100 | db_host = secret['host']
101 |
102 | CONNECTION_STRING = PGVector.connection_string_from_db_params(
103 | driver = 'psycopg',
104 | user = db_username,
105 | password = db_password,
106 | host = db_host,
107 | port = db_port,
108 | database = ''
109 | )
110 |
111 | collection_name = os.environ["COLLECTION_NAME"]
112 |
113 | class ContentHandler(LLMContentHandler):
114 | content_type = "application/json"
115 | accepts = "application/json"
116 |
117 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
118 | input_str = json.dumps({"inputs": prompt, **model_kwargs})
119 | return input_str.encode('utf-8')
120 |
121 | def transform_output(self, output: bytes) -> str:
122 | response_json = json.loads(output.read().decode("utf-8"))
123 | return response_json[0]["generated_text"]
124 |
125 | content_handler = ContentHandler()
126 |
127 | model_kwargs = {
128 | "max_length": 500,
129 | "num_return_sequences": 1,
130 | "top_k": 250,
131 | "top_p": 0.95,
132 | "do_sample": False,
133 | "temperature": 1
134 | }
135 |
136 | llm = SagemakerEndpoint(
137 | endpoint_name=text2text_model_endpoint,
138 | region_name=region,
139 | model_kwargs=model_kwargs,
140 | content_handler=content_handler
141 | )
142 |
143 | vectorstore = PGVector(
144 | collection_name=collection_name,
145 | connection=CONNECTION_STRING,
146 | embeddings=_create_sagemaker_embeddings(embeddings_model_endpoint, region)
147 | )
148 | retriever = vectorstore.as_retriever()
149 |
150 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
151 |
152 | PROMPT = PromptTemplate(
153 | template=prompt_template, input_variables=["context", "question"]
154 | )
155 |
156 | condense_qa_template = """
157 | Given the following conversation and a follow up question, rephrase the follow up question
158 | to be a standalone question.
159 |
160 | Chat History:
161 | {chat_history}
162 | Follow Up Input: {question}
163 | Standalone question:"""
164 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
165 |
166 | qa = ConversationalRetrievalChain.from_llm(
167 | llm=llm,
168 | retriever=retriever,
169 | condense_question_prompt=standalone_question_prompt,
170 | return_source_documents=True,
171 | combine_docs_chain_kwargs={"prompt":PROMPT}
172 | )
173 |
174 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n")
175 | return qa
176 |
177 |
178 | def run_chain(chain, prompt: str, history=[]):
179 | return chain.invoke({"question": prompt, "chat_history": history})
180 |
181 |
182 | if __name__ == "__main__":
183 | chat_history = []
184 | qa = build_chain()
185 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
186 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
187 | print(">", end=" ", flush=True)
188 | for query in sys.stdin:
189 | if (query.strip().lower().startswith("new search:")):
190 | query = query.strip().lower().replace("new search:","")
191 | chat_history = []
192 | elif (len(chat_history) == MAX_HISTORY_LENGTH):
193 | chat_history.pop(0)
194 | result = run_chain(qa, query, chat_history)
195 | chat_history.append((query, result["answer"]))
196 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
197 | if 'source_documents' in result:
198 | print(bcolors.OKGREEN + 'Sources:')
199 | for d in result['source_documents']:
200 | print(d.metadata['source'])
201 | print(bcolors.ENDC)
202 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
203 | print(">", end=" ", flush=True)
204 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
--------------------------------------------------------------------------------
/app/pgvector_chat_llama2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List
10 | import urllib
11 |
12 | import boto3
13 |
14 | from langchain_postgres import PGVector
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains import ConversationalRetrievalChain
23 |
24 | logger = logging.getLogger()
25 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
26 |
27 |
28 | class bcolors:
29 | HEADER = '\033[95m'
30 | OKBLUE = '\033[94m'
31 | OKCYAN = '\033[96m'
32 | OKGREEN = '\033[92m'
33 | WARNING = '\033[93m'
34 | FAIL = '\033[91m'
35 | ENDC = '\033[0m'
36 | BOLD = '\033[1m'
37 | UNDERLINE = '\033[4m'
38 |
39 |
40 | MAX_HISTORY_LENGTH = 5
41 |
42 |
43 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
44 | def embed_documents(
45 | self, texts: List[str], chunk_size: int = 5
46 | ) -> List[List[float]]:
47 | """Compute doc embeddings using a SageMaker Inference Endpoint.
48 |
49 | Args:
50 | texts: The list of texts to embed.
51 | chunk_size: The chunk size defines how many input texts will
52 | be grouped together as request. If None, will use the
53 | chunk size specified by the class.
54 |
55 | Returns:
56 | List of embeddings, one for each text.
57 | """
58 | results = []
59 |
60 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
61 | for i in range(0, len(texts), _chunk_size):
62 | response = self._embedding_func(texts[i : i + _chunk_size])
63 | results.extend(response)
64 | return results
65 |
66 |
67 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
68 |
69 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
70 | """
71 | encode input string as utf-8 bytes, read the embeddings
72 | from the output
73 | """
74 | content_type = "application/json"
75 | accepts = "application/json"
76 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
77 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
78 | return input_str.encode('utf-8')
79 |
80 | def transform_output(self, output: bytes) -> str:
81 | response_json = json.loads(output.read().decode("utf-8"))
82 | embeddings = response_json["embedding"]
83 | if len(embeddings) == 1:
84 | return [embeddings[0]]
85 | return embeddings
86 |
87 | # create a content handler object which knows how to serialize
88 | # and deserialize communication with the model endpoint
89 | content_handler = ContentHandlerForEmbeddings()
90 |
91 | # read to create the Sagemaker embeddings, we are providing
92 | # the Sagemaker endpoint that will be used for generating the
93 | # embeddings to the class
94 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
95 | endpoint_name=endpoint_name,
96 | region_name=region,
97 | content_handler=content_handler
98 | )
99 | logger.info(f"embeddings type={type(embeddings)}")
100 |
101 | return embeddings
102 |
103 |
104 | def _get_credentials(secret_id: str, region_name: str) -> str:
105 | client = boto3.client('secretsmanager', region_name=region_name)
106 | response = client.get_secret_value(SecretId=secret_id)
107 | secrets_value = json.loads(response['SecretString'])
108 | return secrets_value
109 |
110 |
111 | def build_chain():
112 | region = os.environ["AWS_REGION"]
113 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
114 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
115 |
116 | pgvector_secret_id = os.environ["PGVECTOR_SECRET_ID"]
117 | secret = _get_credentials(pgvector_secret_id, region)
118 | db_username = secret['username']
119 | db_password = urllib.parse.quote_plus(secret['password'])
120 | db_port = secret['port']
121 | db_host = secret['host']
122 |
123 | CONNECTION_STRING = PGVector.connection_string_from_db_params(
124 | driver = 'psycopg',
125 | user = db_username,
126 | password = db_password,
127 | host = db_host,
128 | port = db_port,
129 | database = ''
130 | )
131 |
132 | collection_name = os.environ["COLLECTION_NAME"]
133 |
134 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-chat-completion.ipynb
135 | class ContentHandler(LLMContentHandler):
136 | content_type = "application/json"
137 | accepts = "application/json"
138 |
139 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
140 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \
141 | " If you don't know the answer to a question, say I don't know the answer"
142 |
143 | payload = {
144 | "inputs": [
145 | [
146 | {"role": "system", "content": system_prompt},
147 | {"role": "user", "content": prompt},
148 | ],
149 | ],
150 | "parameters": model_kwargs,
151 | }
152 | input_str = json.dumps(payload)
153 | return input_str.encode("utf-8")
154 |
155 | def transform_output(self, output: bytes) -> str:
156 | response_json = json.loads(output.read().decode("utf-8"))
157 | content = response_json[0]["generation"]["content"]
158 | return content
159 |
160 | content_handler = ContentHandler()
161 |
162 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-text-completion.ipynb
163 | model_kwargs = {
164 | "max_new_tokens": 256,
165 | "top_p": 0.9,
166 | "temperature": 0.6,
167 | "return_full_text": False,
168 | }
169 |
170 | llm = SagemakerEndpoint(
171 | endpoint_name=text2text_model_endpoint,
172 | region_name=region,
173 | model_kwargs=model_kwargs,
174 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
175 | content_handler=content_handler
176 | )
177 |
178 | vectorstore = PGVector(
179 | collection_name=collection_name,
180 | connection=CONNECTION_STRING,
181 | embeddings=_create_sagemaker_embeddings(embeddings_model_endpoint, region)
182 | )
183 | retriever = vectorstore.as_retriever()
184 |
185 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
186 |
187 | PROMPT = PromptTemplate(
188 | template=prompt_template, input_variables=["context", "question"]
189 | )
190 |
191 | condense_qa_template = """
192 | Given the following conversation and a follow up question, rephrase the follow up question
193 | to be a standalone question.
194 |
195 | Chat History:
196 | {chat_history}
197 | Follow Up Input: {question}
198 | Standalone question:"""
199 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
200 |
201 | qa = ConversationalRetrievalChain.from_llm(
202 | llm=llm,
203 | retriever=retriever,
204 | condense_question_prompt=standalone_question_prompt,
205 | return_source_documents=True,
206 | combine_docs_chain_kwargs={"prompt":PROMPT},
207 | verbose=False
208 | )
209 |
210 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n")
211 | return qa
212 |
213 |
214 | def run_chain(chain, prompt: str, history=[]):
215 | return chain.invoke({"question": prompt, "chat_history": history})
216 |
217 |
218 | if __name__ == "__main__":
219 | chat_history = []
220 | qa = build_chain()
221 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
222 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
223 | print(">", end=" ", flush=True)
224 | for query in sys.stdin:
225 | if (query.strip().lower().startswith("new search:")):
226 | query = query.strip().lower().replace("new search:","")
227 | chat_history = []
228 | elif (len(chat_history) == MAX_HISTORY_LENGTH):
229 | chat_history.pop(0)
230 | result = run_chain(qa, query, chat_history)
231 | chat_history.append((query, result["answer"]))
232 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
233 | if 'source_documents' in result:
234 | print(bcolors.OKGREEN + '\nSources:')
235 | for d in result['source_documents']:
236 | print(d.metadata['source'])
237 | print(bcolors.ENDC)
238 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
239 | print(">", end=" ", flush=True)
240 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
--------------------------------------------------------------------------------
/app/qa-with-llm-and-rag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/rag-with-amazon-postgresql-using-pgvector-and-sagemaker/1b5ca45eff14b162e8be28cb179338e1ad4d7bbd/app/qa-with-llm-and-rag.png
--------------------------------------------------------------------------------
/app/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3>=1.26.159
2 | langchain>=0.3,<0.4
3 | langchain-community>=0.3,<0.4
4 | pgvector==0.2.5
5 | psycopg[binary]==3.1.19
6 | SQLAlchemy==2.0.28
7 | streamlit==1.37.0
8 |
--------------------------------------------------------------------------------
/cdk_stacks/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | package-lock.json
3 | __pycache__
4 | .pytest_cache
5 | .venv
6 | *.egg-info
7 |
8 | # CDK asset staging directory
9 | .cdk.staging
10 | cdk.out
11 |
--------------------------------------------------------------------------------
/cdk_stacks/README.md:
--------------------------------------------------------------------------------
1 |
2 | # RAG Application CDK Python project!
3 |
4 | 
5 |
6 | This is an QA application with LLMs and RAG project for CDK development with Python.
7 |
8 | The `cdk.json` file tells the CDK Toolkit how to execute your app.
9 |
10 | This project is set up like a standard Python project. The initialization
11 | process also creates a virtualenv within this project, stored under the `.venv`
12 | directory. To create the virtualenv it assumes that there is a `python3`
13 | (or `python` for Windows) executable in your path with access to the `venv`
14 | package. If for any reason the automatic creation of the virtualenv fails,
15 | you can create the virtualenv manually.
16 |
17 | To manually create a virtualenv on MacOS and Linux:
18 |
19 | ```
20 | $ python3 -m venv .venv
21 | ```
22 |
23 | After the init process completes and the virtualenv is created, you can use the following
24 | step to activate your virtualenv.
25 |
26 | ```
27 | $ source .venv/bin/activate
28 | ```
29 |
30 | If you are a Windows platform, you would activate the virtualenv like this:
31 |
32 | ```
33 | % .venv\Scripts\activate.bat
34 | ```
35 |
36 | Once the virtualenv is activated, you can install the required dependencies.
37 |
38 | ```
39 | (.venv) $ pip install -r requirements.txt
40 | ```
41 |
42 | To add additional dependencies, for example other CDK libraries, just add
43 | them to your `setup.py` file and rerun the `pip install -r requirements.txt`
44 | command.
45 |
46 | Before synthesizing the CloudFormation, you should set approperly the cdk context configuration file, `cdk.context.json`.
47 |
48 | For example:
49 |
50 |