├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app
├── Dockerfile
├── README.md
├── app.py
├── env_vars.sh
├── images
│ ├── ai-icon.png
│ └── user-icon.png
├── opensearch_chat_flan_xl.py
├── opensearch_chat_llama2.py
├── opensearch_load_qa_chain_flan_xl.py
├── opensearch_load_qa_chain_llama2.py
├── opensearch_retriever_flan_xl.py
├── opensearch_retriever_llama2.py
├── qa-with-llm-and-rag.png
└── requirements.txt
├── cdk_stacks
├── README.md
├── app.py
├── cdk.context.json
├── cdk.json
├── rag_with_aos
│ ├── __init__.py
│ ├── ecs_streamlit_app.py
│ ├── ops.py
│ ├── sm_custom_embedding_endpoint.py
│ ├── sm_jumpstart_llm_endpoint.py
│ ├── sm_studio.py
│ └── vpc.py
├── rag_with_opensearch_arch.svg
├── requirements.txt
└── source.bat
└── data_ingestion_to_vectordb
├── container
├── Dockerfile
├── credentials.py
├── load_data_into_opensearch.py
└── sm_helper.py
├── data_ingestion_to_opensearch.ipynb
└── scripts
└── get_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 | Untitled*.ipynb
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | .DS_Store
108 | .idea/
109 | bin/
110 | lib64
111 | pyvenv.cfg
112 | *.bak
113 | share/
114 | cdk.out/
115 | cdk.context.json*
116 | zap/
117 |
118 | */.gitignore
119 | */setup.py
120 | */source.bat
121 |
122 | */*/.gitignore
123 | */*/setup.py
124 | */*/source.bat
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QA with LLM and RAG (Retrieval Augmented Generation)
2 |
3 | This project is a Question Answering application with Large Language Models (LLMs) and Amazon OpenSearch Service. An application using the RAG(Retrieval Augmented Generation) approach retrieves information most relevant to the user’s request from the enterprise knowledge base or content, bundles it as context along with the user’s request as a prompt, and then sends it to the LLM to get a GenAI response.
4 |
5 | LLMs have limitations around the maximum word count for the input prompt, therefore choosing the right passages among thousands or millions of documents in the enterprise, has a direct impact on the LLM’s accuracy.
6 |
7 | In this project, Amazon OpenSearch Service is used for knowledge base.
8 |
9 | The overall architecture is like this:
10 |
11 | 
12 |
13 | ### Overall Workflow
14 |
15 | 1. Deploy the cdk stacks (For more information, see [here](./cdk_stacks/README.md)).
16 | - A SageMaker Endpoint for text generation.
17 | - A SageMaker Endpoint for generating embeddings.
18 | - An Amazon OpenSearch cluster for storing embeddings.
19 | - Opensearch cluster's access credentials (username and password) stored in AWS Secrets Mananger as a name such as `OpenSearchMasterUserSecret1-xxxxxxxxxxxx`.
20 | 2. Open JupyterLab in SageMaker Studio and then open a new terminal.
21 | 3. Run the following commands on the terminal to clone the code repository for this project:
22 | ```
23 | git clone --depth=1 https://github.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker.git
24 | ```
25 | 4. Open `data_ingestion_to_opensearch` notebook and Run it. (For more information, see [here](./data_ingestion_to_vectordb/data_ingestion_to_opensearch.ipynb))
26 | 5. Run Streamlit application. (For more information, see [here](./app/README.md))
27 |
28 | ### References
29 |
30 | * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain (2023-05-25)](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/)
31 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/)
32 | * [Build Streamlit apps in Amazon SageMaker Studio (2023-04-11)](https://aws.amazon.com/blogs/machine-learning/build-streamlit-apps-in-amazon-sagemaker-studio/)
33 | * [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models (2023-05-03)](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)
34 | * [(github) Amazon Kendra Retriver Samples](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples)
35 | * [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart (2023-05-02)](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/)
36 | * [Amazon OpenSearch Service’s vector database capabilities explained](https://aws.amazon.com/blogs/big-data/amazon-opensearch-services-vector-database-capabilities-explained/)
37 | * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models.
38 | * [Streamlit](https://streamlit.io/) - A faster way to build and share data apps
39 | * [Improve search relevance with ML in Amazon OpenSearch Service Workshop](https://catalog.workshops.aws/semantic-search/en-US) - Module 7. Retrieval Augmented Generation
40 | * [rag-with-amazon-kendra](https://github.com/ksmin23/rag-with-amazon-kendra) - Question Answering application with Large Language Models (LLMs) and Amazon Kendra
41 | * [rag-with-postgresql-pgvector](https://github.com/ksmin23/rag-with-postgresql-pgvector) - Question Answering application with Large Language Models (LLMs) and Amazon Aurora Postgresql
42 |
43 | ## Security
44 |
45 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
46 |
47 | ## License
48 |
49 | This library is licensed under the MIT-0 License. See the LICENSE file.
50 |
51 |
--------------------------------------------------------------------------------
/app/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM --platform=linux/amd64 python:3.10.13-slim
2 |
3 | WORKDIR /app
4 |
5 | COPY requirements.txt .
6 | COPY *.py .
7 | COPY images/*.png images/
8 |
9 | RUN pip --no-cache-dir install -Uq pip
10 | RUN pip --no-cache-dir install -Uq -r requirements.txt
11 |
12 | # Set some environment variables. PYTHONUNBUFFERED keeps Python from buffering our standard
13 | # output stream, which means that logs can be delivered to the user quickly. PYTHONDONTWRITEBYTECODE
14 | # keeps Python from writing the .pyc files which are unnecessary in this case. We also update
15 | # PATH so that the train and serve programs are found when the container is invoked.
16 | ENV PYTHONUNBUFFERED=TRUE
17 | ENV PYTHONDONTWRITEBYTECODE=TRUE
18 |
19 | EXPOSE 8501
20 | CMD streamlit run app.py --server.address=0.0.0.0
21 | # ENTRYPOINT ["streamlit", "run", "app.py", "--server.address=0.0.0.0"]
22 |
--------------------------------------------------------------------------------
/app/README.md:
--------------------------------------------------------------------------------
1 | ## Run the Streamlit application in Studio
2 |
3 | Now we’re ready to run the Streamlit web application for our question answering bot.
4 |
5 | SageMaker Studio provides a convenient platform to host the Streamlit web application. The following steps describe how to run the Streamlit app on SageMaker Studio. Alternatively, you could also follow the same procedure to run the app on Amazon EC2 instance, Cloud9 in your AWS Account, or deploy as a container service to AWS ECS Fargate.
6 |
7 | 1. Open JupyterLab and then open a **Terminal**.
8 | 2. Run the following commands on the terminal to clone the code repository for this post and install the Python packages needed by the application:
9 | ```
10 | git clone --depth=1 https://github.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker.git
11 | cd rag-with-amazon-opensearch-and-sagemaker/app
12 | python -m venv .env
13 | source .env/bin/activate
14 | pip install -r requirements.txt
15 | ```
16 | 3. In the shell, set the following environment variables with the values that are available from the CloudFormation stack output.
17 | ```
18 | export AWS_REGION=us-east-1
19 | export OPENSEARCH_SECRET="your-opensearch-secret"
20 | export OPENSEARCH_DOMAIN_ENDPOINT="your-opensearch-url"
21 | export OPENSEARCH_INDEX="llm_rag_embeddings"
22 | export EMBEDDING_ENDPOINT_NAME="your-sagemakr-endpoint-for-embedding-model"
23 | export TEXT2TEXT_ENDPOINT_NAME="your-sagemaner-endpoint-for-text-generation-model"
24 | ```
25 | 4. When the application runs successfully, you’ll see an output similar to the following (the IP addresses you will see will be different from the ones shown in this example). Note the port number (typically `8501`) from the output to use as part of the URL for app in the next step.
26 | ```
27 | sagemaker-user@studio$ streamlit run app.py
28 |
29 | Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
30 |
31 | You can now view your Streamlit app in your browser.
32 |
33 | Network URL: http://169.255.255.2:8501
34 | External URL: http://52.4.240.77:8501
35 | ```
36 | 5. You can access the app in a new browser tab using a URL that is similar to your Studio domain URL. For example, if your Studio URL is `https://d-randomidentifier.studio.us-east-1.sagemaker.aws/jupyter/default/lab?` then the URL for your Streamlit app will be `https://d-randomidentifier.studio.us-east-1.sagemaker.aws/jupyter/default/proxy/8501/app` (notice that `lab` is replaced with `proxy/8501/app`). If the port number noted in the previous step is different from `8501` then use that instead of `8501` in the URL for the Streamlit app.
37 |
38 | The following screenshot shows the app with a couple of user questions. (e.g., `What are the versions of XGBoost supported by Amazon SageMaker?`)
39 |
40 | 
41 |
42 |
43 | ## Deploy Streamlit application on Amazon ECS Fargate with AWS CDK
44 |
45 | To deploy the Streamlit application on Amazon ECS Fargate using AWS CDK, follow these steps:
46 |
47 | 1. Ensure you have the AWS CDK and docker or finch installed and configured.
48 | 2. Deploy the ECS stack from `cdk_stacks/` using the command `cdk deploy --require-approval never StreamlitAppStack`.
49 | 3. Access the Streamlit application using the public URL provided by the provisioned load balancer.
50 | 1. You can find this value under the export named `{stack-name}-StreamlitEndpoint`
51 | 4. Consider adding a security group ingress rule that scopes access to the application from your network by a prefix list or a CIDR block
52 | 5. Also consider enabling HTTPS by uploading a ssl certificate from IAM or ACM to the loadbalancer and add a listener on port 443
53 |
54 |
55 | ## References
56 |
57 | * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain (2023-05-25)](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/)
58 | * [Build Streamlit apps in Amazon SageMaker Studio (2023-04-11)](https://aws.amazon.com/blogs/machine-learning/build-streamlit-apps-in-amazon-sagemaker-studio/)
59 | * [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models (2023-05-03)](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)
60 | * [(github) Amazon Kendra Retriver Samples](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples)
61 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/)
62 | * [sagemaker-huggingface-inference-toolkit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) - SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker.
63 | * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models.
64 | * [Streamlit](https://streamlit.io/) - A faster way to build and share data apps
65 |
--------------------------------------------------------------------------------
/app/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import streamlit as st
7 | import uuid
8 |
9 | import opensearch_chat_flan_xl as flanxl
10 | import opensearch_chat_llama2 as llama2
11 |
12 | PROVIDER_NAME = os.environ.get('PROVIDER_NAME', 'llama2')
13 |
14 | USER_ICON = f"{os.path.dirname(__file__)}/images/user-icon.png"
15 | AI_ICON = f"{os.path.dirname(__file__)}/images/ai-icon.png"
16 | MAX_HISTORY_LENGTH = 5
17 | PROVIDER_MAP = {
18 | 'flanxl': 'Flan XL',
19 | 'llama2': 'Llama2 7B',
20 | }
21 |
22 | # Check if the user ID is already stored in the session state
23 | if 'user_id' in st.session_state:
24 | user_id = st.session_state['user_id']
25 |
26 | # If the user ID is not yet stored in the session state, generate a random UUID
27 | else:
28 | user_id = str(uuid.uuid4())
29 | st.session_state['user_id'] = user_id
30 |
31 |
32 | if 'llm_chain' not in st.session_state:
33 | llm_app = llama2 if PROVIDER_NAME == 'llama2' else flanxl
34 | st.session_state['llm_app'] = llm_app
35 | st.session_state['llm_chain'] = llm_app.build_chain()
36 |
37 | if 'chat_history' not in st.session_state:
38 | st.session_state['chat_history'] = []
39 |
40 | if "chats" not in st.session_state:
41 | st.session_state.chats = [
42 | {
43 | 'id': 0,
44 | 'question': '',
45 | 'answer': ''
46 | }
47 | ]
48 |
49 | if "questions" not in st.session_state:
50 | st.session_state.questions = []
51 |
52 | if "answers" not in st.session_state:
53 | st.session_state.answers = []
54 |
55 | if "input" not in st.session_state:
56 | st.session_state.input = ""
57 |
58 |
59 | st.markdown("""
60 |
75 | """, unsafe_allow_html=True)
76 |
77 |
78 | def write_logo():
79 | col1, col2, col3 = st.columns([5, 1, 5])
80 | with col2:
81 | st.image(AI_ICON, use_column_width='always')
82 |
83 |
84 | def write_top_bar():
85 | col1, col2, col3 = st.columns([1,10,2])
86 | with col1:
87 | st.image(AI_ICON, use_column_width='always')
88 | with col2:
89 | selected_provider = PROVIDER_NAME
90 | if selected_provider in PROVIDER_MAP:
91 | provider = PROVIDER_MAP[selected_provider]
92 | else:
93 | provider = selected_provider.capitalize()
94 | header = f"An AI App powered by Amazon OpenSearch and {provider}!"
95 | st.write(f"
{header}
", unsafe_allow_html=True)
96 | with col3:
97 | clear = st.button("Clear Chat")
98 | return clear
99 |
100 |
101 | clear = write_top_bar()
102 |
103 | if clear:
104 | st.session_state.questions = []
105 | st.session_state.answers = []
106 | st.session_state.input = ""
107 | st.session_state["chat_history"] = []
108 |
109 |
110 | def handle_input():
111 | input = st.session_state.input
112 | question_with_id = {
113 | 'question': input,
114 | 'id': len(st.session_state.questions)
115 | }
116 | st.session_state.questions.append(question_with_id)
117 |
118 | chat_history = st.session_state["chat_history"]
119 | if len(chat_history) == MAX_HISTORY_LENGTH:
120 | chat_history = chat_history[:-1]
121 |
122 | llm_chain = st.session_state['llm_chain']
123 | chain = st.session_state['llm_app']
124 | with st.spinner():
125 | result = chain.run_chain(llm_chain, input, chat_history)
126 | answer = result['answer']
127 | chat_history.append((input, answer))
128 |
129 | document_list = []
130 | if 'source_documents' in result:
131 | for d in result['source_documents']:
132 | if not (d.metadata['source'] in document_list):
133 | document_list.append((d.metadata['source']))
134 |
135 | st.session_state.answers.append({
136 | 'answer': result,
137 | 'sources': document_list,
138 | 'id': len(st.session_state.questions)
139 | })
140 | st.session_state.input = ""
141 |
142 |
143 | def write_user_message(md):
144 | col1, col2 = st.columns([1,12])
145 |
146 | with col1:
147 | st.image(USER_ICON, use_column_width='always')
148 | with col2:
149 | st.warning(md['question'])
150 |
151 |
152 | def render_result(result):
153 | answer, sources = st.tabs(['Answer', 'Sources'])
154 | with answer:
155 | render_answer(result['answer'])
156 | with sources:
157 | if 'source_documents' in result:
158 | render_sources(result['source_documents'])
159 | else:
160 | render_sources([])
161 |
162 |
163 | def render_answer(answer):
164 | col1, col2 = st.columns([1,12])
165 | with col1:
166 | st.image(AI_ICON, use_column_width='always')
167 | with col2:
168 | st.info(answer['answer'])
169 |
170 |
171 | def render_sources(sources):
172 | col1, col2 = st.columns([1,12])
173 | with col2:
174 | with st.expander("Sources"):
175 | for s in sources:
176 | st.write(s)
177 |
178 |
179 | # Each answer will have context of the question asked in order to associate the provided feedback with the respective question
180 | def write_chat_message(md, q):
181 | chat = st.container()
182 | with chat:
183 | render_answer(md['answer'])
184 | render_sources(md['sources'])
185 |
186 |
187 | with st.container():
188 | for (q, a) in zip(st.session_state.questions, st.session_state.answers):
189 | write_user_message(q)
190 | write_chat_message(a, q)
191 |
192 | st.markdown('---')
193 | input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)
194 |
--------------------------------------------------------------------------------
/app/env_vars.sh:
--------------------------------------------------------------------------------
1 | export AWS_REGION="your-aws-region"
2 | export OPENSEARCH_SECRET="your-opensearch-secret"
3 | export OPENSEARCH_DOMAIN_ENDPOINT="your-opensearch-url"
4 | export OPENSEARCH_INDEX="llm_rag_embeddings"
5 | export EMBEDDING_ENDPOINT_NAME="your-sagemaker-endpoint-for-embedding-model"
6 | export TEXT2TEXT_ENDPOINT_NAME="your-sagemaker-endpoint-for-text-generation-model"
--------------------------------------------------------------------------------
/app/images/ai-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker/5d43e78f0af9a69c0306b4f1181f3cd4971b78f9/app/images/ai-icon.png
--------------------------------------------------------------------------------
/app/images/user-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker/5d43e78f0af9a69c0306b4f1181f3cd4971b78f9/app/images/user-icon.png
--------------------------------------------------------------------------------
/app/opensearch_chat_flan_xl.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List, Callable
10 | from urllib.parse import urlparse
11 |
12 | import boto3
13 |
14 | from langchain_community.vectorstores import OpenSearchVectorSearch
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains import ConversationalRetrievalChain
23 |
24 | logger = logging.getLogger()
25 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
26 |
27 |
28 | class bcolors:
29 | HEADER = '\033[95m'
30 | OKBLUE = '\033[94m'
31 | OKCYAN = '\033[96m'
32 | OKGREEN = '\033[92m'
33 | WARNING = '\033[93m'
34 | FAIL = '\033[91m'
35 | ENDC = '\033[0m'
36 | BOLD = '\033[1m'
37 | UNDERLINE = '\033[4m'
38 |
39 |
40 | MAX_HISTORY_LENGTH = 5
41 |
42 |
43 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
44 | def embed_documents(
45 | self, texts: List[str], chunk_size: int = 5
46 | ) -> List[List[float]]:
47 | """Compute doc embeddings using a SageMaker Inference Endpoint.
48 |
49 | Args:
50 | texts: The list of texts to embed.
51 | chunk_size: The chunk size defines how many input texts will
52 | be grouped together as request. If None, will use the
53 | chunk size specified by the class.
54 |
55 | Returns:
56 | List of embeddings, one for each text.
57 | """
58 | results = []
59 |
60 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
61 | for i in range(0, len(texts), _chunk_size):
62 | response = self._embedding_func(texts[i : i + _chunk_size])
63 | results.extend(response)
64 | return results
65 |
66 |
67 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
68 |
69 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
70 | """
71 | encode input string as utf-8 bytes, read the embeddings
72 | from the output
73 | """
74 | content_type = "application/json"
75 | accepts = "application/json"
76 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
77 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
78 | return input_str.encode('utf-8')
79 |
80 | def transform_output(self, output: bytes) -> str:
81 | response_json = json.loads(output.read().decode("utf-8"))
82 | embeddings = response_json["embedding"]
83 | if len(embeddings) == 1:
84 | return [embeddings[0]]
85 | return embeddings
86 |
87 | # create a content handler object which knows how to serialize
88 | # and deserialize communication with the model endpoint
89 | content_handler = ContentHandlerForEmbeddings()
90 |
91 | # read to create the Sagemaker embeddings, we are providing
92 | # the Sagemaker endpoint that will be used for generating the
93 | # embeddings to the class
94 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
95 | endpoint_name=endpoint_name,
96 | region_name=region,
97 | content_handler=content_handler
98 | )
99 | logger.info(f"embeddings type={type(embeddings)}")
100 |
101 | return embeddings
102 |
103 |
104 | def _get_credentials(secret_id: str, region_name: str) -> str:
105 | client = boto3.client('secretsmanager', region_name=region_name)
106 | response = client.get_secret_value(SecretId=secret_id)
107 | secrets_value = json.loads(response['SecretString'])
108 | return secrets_value
109 |
110 |
111 | def build_chain():
112 | region = os.environ["AWS_REGION"]
113 | opensearch_secret = os.environ["OPENSEARCH_SECRET"]
114 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"]
115 | opensearch_index = os.environ["OPENSEARCH_INDEX"]
116 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
117 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
118 |
119 | class ContentHandler(LLMContentHandler):
120 | content_type = "application/json"
121 | accepts = "application/json"
122 |
123 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
124 | input_str = json.dumps({"inputs": prompt, **model_kwargs})
125 | return input_str.encode('utf-8')
126 |
127 | def transform_output(self, output: bytes) -> str:
128 | response_json = json.loads(output.read().decode("utf-8"))
129 | return response_json[0]["generated_texts"]
130 |
131 | content_handler = ContentHandler()
132 |
133 | model_kwargs = {
134 | "max_length": 500,
135 | "num_return_sequences": 1,
136 | "top_k": 250,
137 | "top_p": 0.95,
138 | "do_sample": False,
139 | "temperature": 1
140 | }
141 |
142 | llm = SagemakerEndpoint(
143 | endpoint_name=text2text_model_endpoint,
144 | region_name=region,
145 | model_kwargs=model_kwargs,
146 | content_handler=content_handler
147 | )
148 |
149 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint
150 |
151 | creds = _get_credentials(opensearch_secret, region)
152 | http_auth = (creds['username'], creds['password'])
153 |
154 | opensearch_vector_search = OpenSearchVectorSearch(
155 | opensearch_url=opensearch_url,
156 | index_name=opensearch_index,
157 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region),
158 | http_auth=http_auth
159 | )
160 |
161 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3})
162 |
163 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
164 |
165 | PROMPT = PromptTemplate(
166 | template=prompt_template, input_variables=["context", "question"]
167 | )
168 |
169 | condense_qa_template = """
170 | Given the following conversation and a follow up question, rephrase the follow up question
171 | to be a standalone question.
172 |
173 | Chat History:
174 | {chat_history}
175 | Follow Up Input: {question}
176 | Standalone question:"""
177 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
178 |
179 | qa = ConversationalRetrievalChain.from_llm(
180 | llm=llm,
181 | retriever=retriever,
182 | condense_question_prompt=standalone_question_prompt,
183 | return_source_documents=True,
184 | combine_docs_chain_kwargs={"prompt":PROMPT}
185 | )
186 |
187 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n")
188 | return qa
189 |
190 |
191 | def run_chain(chain, prompt: str, history=[]):
192 | return chain.invoke({"question": prompt, "chat_history": history})
193 |
194 |
195 | if __name__ == "__main__":
196 | chat_history = []
197 | qa = build_chain()
198 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
199 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
200 | print(">", end=" ", flush=True)
201 | for query in sys.stdin:
202 | if (query.strip().lower().startswith("new search:")):
203 | query = query.strip().lower().replace("new search:","")
204 | chat_history = []
205 | elif (len(chat_history) == MAX_HISTORY_LENGTH):
206 | chat_history.pop(0)
207 | result = run_chain(qa, query, chat_history)
208 | chat_history.append((query, result["answer"]))
209 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
210 | if 'source_documents' in result:
211 | print(bcolors.OKGREEN + '\nSources:')
212 | for d in result['source_documents']:
213 | print(d.metadata['source'])
214 | print(bcolors.ENDC)
215 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
216 | print(">", end=" ", flush=True)
217 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
--------------------------------------------------------------------------------
/app/opensearch_chat_llama2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List
10 | from urllib.parse import urlparse
11 |
12 | import boto3
13 |
14 | from langchain_community.vectorstores import OpenSearchVectorSearch
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains import ConversationalRetrievalChain
23 |
24 | logger = logging.getLogger()
25 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
26 |
27 |
28 | class bcolors:
29 | HEADER = '\033[95m'
30 | OKBLUE = '\033[94m'
31 | OKCYAN = '\033[96m'
32 | OKGREEN = '\033[92m'
33 | WARNING = '\033[93m'
34 | FAIL = '\033[91m'
35 | ENDC = '\033[0m'
36 | BOLD = '\033[1m'
37 | UNDERLINE = '\033[4m'
38 |
39 |
40 | MAX_HISTORY_LENGTH = 5
41 |
42 |
43 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
44 | def embed_documents(
45 | self, texts: List[str], chunk_size: int = 5
46 | ) -> List[List[float]]:
47 | """Compute doc embeddings using a SageMaker Inference Endpoint.
48 |
49 | Args:
50 | texts: The list of texts to embed.
51 | chunk_size: The chunk size defines how many input texts will
52 | be grouped together as request. If None, will use the
53 | chunk size specified by the class.
54 |
55 | Returns:
56 | List of embeddings, one for each text.
57 | """
58 | results = []
59 |
60 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
61 | for i in range(0, len(texts), _chunk_size):
62 | response = self._embedding_func(texts[i : i + _chunk_size])
63 | results.extend(response)
64 | return results
65 |
66 |
67 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
68 |
69 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
70 | """
71 | encode input string as utf-8 bytes, read the embeddings
72 | from the output
73 | """
74 | content_type = "application/json"
75 | accepts = "application/json"
76 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
77 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
78 | return input_str.encode('utf-8')
79 |
80 | def transform_output(self, output: bytes) -> str:
81 | response_json = json.loads(output.read().decode("utf-8"))
82 | embeddings = response_json["embedding"]
83 | if len(embeddings) == 1:
84 | return [embeddings[0]]
85 | return embeddings
86 |
87 | # create a content handler object which knows how to serialize
88 | # and deserialize communication with the model endpoint
89 | content_handler = ContentHandlerForEmbeddings()
90 |
91 | # read to create the Sagemaker embeddings, we are providing
92 | # the Sagemaker endpoint that will be used for generating the
93 | # embeddings to the class
94 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
95 | endpoint_name=endpoint_name,
96 | region_name=region,
97 | content_handler=content_handler
98 | )
99 | logger.info(f"embeddings type={type(embeddings)}")
100 |
101 | return embeddings
102 |
103 |
104 | def _get_credentials(secret_id: str, region_name: str) -> str:
105 | client = boto3.client('secretsmanager', region_name=region_name)
106 | response = client.get_secret_value(SecretId=secret_id)
107 | secrets_value = json.loads(response['SecretString'])
108 | return secrets_value
109 |
110 |
111 | def build_chain():
112 | region = os.environ["AWS_REGION"]
113 | opensearch_secret = os.environ["OPENSEARCH_SECRET"]
114 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"]
115 | opensearch_index = os.environ["OPENSEARCH_INDEX"]
116 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
117 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
118 |
119 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-chat-completion.ipynb
120 | class ContentHandler(LLMContentHandler):
121 | content_type = "application/json"
122 | accepts = "application/json"
123 |
124 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
125 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \
126 | " If you don't know the answer to a question, say I don't know the answer"
127 |
128 | payload = {
129 | "inputs": [
130 | [
131 | {"role": "system", "content": system_prompt},
132 | {"role": "user", "content": prompt},
133 | ],
134 | ],
135 | "parameters": model_kwargs,
136 | }
137 | input_str = json.dumps(payload)
138 | return input_str.encode("utf-8")
139 |
140 | def transform_output(self, output: bytes) -> str:
141 | response_json = json.loads(output.read().decode("utf-8"))
142 | content = response_json[0]["generation"]["content"]
143 | return content
144 |
145 | content_handler = ContentHandler()
146 |
147 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-text-completion.ipynb
148 | model_kwargs = {
149 | "max_new_tokens": 256,
150 | "top_p": 0.9,
151 | "temperature": 0.6,
152 | "return_full_text": False,
153 | }
154 |
155 | llm = SagemakerEndpoint(
156 | endpoint_name=text2text_model_endpoint,
157 | region_name=region,
158 | model_kwargs=model_kwargs,
159 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
160 | content_handler=content_handler
161 | )
162 |
163 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint
164 |
165 | creds = _get_credentials(opensearch_secret, region)
166 | http_auth = (creds['username'], creds['password'])
167 |
168 | opensearch_vector_search = OpenSearchVectorSearch(
169 | opensearch_url=opensearch_url,
170 | index_name=opensearch_index,
171 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region),
172 | http_auth=http_auth
173 | )
174 |
175 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3})
176 |
177 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
178 |
179 | PROMPT = PromptTemplate(
180 | template=prompt_template, input_variables=["context", "question"]
181 | )
182 |
183 | condense_qa_template = """
184 | Given the following conversation and a follow up question, rephrase the follow up question
185 | to be a standalone question.
186 |
187 | Chat History:
188 | {chat_history}
189 | Follow Up Input: {question}
190 | Standalone question:"""
191 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
192 |
193 | qa = ConversationalRetrievalChain.from_llm(
194 | llm=llm,
195 | retriever=retriever,
196 | condense_question_prompt=standalone_question_prompt,
197 | return_source_documents=True,
198 | combine_docs_chain_kwargs={"prompt":PROMPT},
199 | verbose=False
200 | )
201 |
202 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n")
203 | return qa
204 |
205 |
206 | def run_chain(chain, prompt: str, history=[]):
207 | return chain.invoke({"question": prompt, "chat_history": history})
208 |
209 |
210 | if __name__ == "__main__":
211 | chat_history = []
212 | qa = build_chain()
213 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
214 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
215 | print(">", end=" ", flush=True)
216 | for query in sys.stdin:
217 | if (query.strip().lower().startswith("new search:")):
218 | query = query.strip().lower().replace("new search:","")
219 | chat_history = []
220 | elif (len(chat_history) == MAX_HISTORY_LENGTH):
221 | chat_history.pop(0)
222 | result = run_chain(qa, query, chat_history)
223 | chat_history.append((query, result["answer"]))
224 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
225 | if 'source_documents' in result:
226 | print(bcolors.OKGREEN + '\nSources:')
227 | for d in result['source_documents']:
228 | print(d.metadata['source'])
229 | print(bcolors.ENDC)
230 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
231 | print(">", end=" ", flush=True)
232 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
--------------------------------------------------------------------------------
/app/opensearch_load_qa_chain_flan_xl.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List, Callable
10 | from urllib.parse import urlparse
11 |
12 | import boto3
13 |
14 | from langchain_community.vectorstores import OpenSearchVectorSearch
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains.question_answering import load_qa_chain
23 |
24 |
25 | logger = logging.getLogger()
26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
27 |
28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
29 | def embed_documents(
30 | self, texts: List[str], chunk_size: int = 5
31 | ) -> List[List[float]]:
32 | """Compute doc embeddings using a SageMaker Inference Endpoint.
33 |
34 | Args:
35 | texts: The list of texts to embed.
36 | chunk_size: The chunk size defines how many input texts will
37 | be grouped together as request. If None, will use the
38 | chunk size specified by the class.
39 |
40 | Returns:
41 | List of embeddings, one for each text.
42 | """
43 | results = []
44 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
45 |
46 | for i in range(0, len(texts), _chunk_size):
47 | response = self._embedding_func(texts[i : i + _chunk_size])
48 | results.extend(response)
49 | return results
50 |
51 |
52 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
53 | """
54 | encode input string as utf-8 bytes, read the embeddings
55 | from the output
56 | """
57 | content_type = "application/json"
58 | accepts = "application/json"
59 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
60 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
61 | return input_str.encode('utf-8')
62 |
63 | def transform_output(self, output: bytes) -> str:
64 | response_json = json.loads(output.read().decode("utf-8"))
65 | embeddings = response_json["embedding"]
66 | if len(embeddings) == 1:
67 | return [embeddings[0]]
68 | return embeddings
69 |
70 |
71 | class ContentHandlerForTextGeneration(LLMContentHandler):
72 | content_type = "application/json"
73 | accepts = "application/json"
74 |
75 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
76 | input_str = json.dumps({"inputs": prompt, **model_kwargs})
77 | return input_str.encode('utf-8')
78 |
79 | def transform_output(self, output: bytes) -> str:
80 | response_json = json.loads(output.read().decode("utf-8"))
81 | return response_json[0]["generated_texts"]
82 |
83 |
84 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
85 | # create a content handler object which knows how to serialize
86 | # and deserialize communication with the model endpoint
87 | content_handler = ContentHandlerForEmbeddings()
88 |
89 | # read to create the Sagemaker embeddings, we are providing
90 | # the Sagemaker endpoint that will be used for generating the
91 | # embeddings to the class
92 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
93 | endpoint_name=endpoint_name,
94 | region_name=region,
95 | content_handler=content_handler
96 | )
97 | logger.info(f"embeddings type={type(embeddings)}")
98 |
99 | return embeddings
100 |
101 |
102 | def _get_credentials(secret_id: str, region_name: str) -> str:
103 |
104 | client = boto3.client('secretsmanager', region_name=region_name)
105 | response = client.get_secret_value(SecretId=secret_id)
106 | secrets_value = json.loads(response['SecretString'])
107 | return secrets_value
108 |
109 |
110 | def load_vector_db_opensearch(secret_id: str,
111 | region: str,
112 | opensearch_domain_endpoint: str,
113 | opensearch_index: str,
114 | embeddings_model_endpoint: str) -> OpenSearchVectorSearch:
115 | logger.info(f"load_vector_db_opensearch, secret_id={secret_id}, region={region}, "
116 | f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}, "
117 | f"embeddings_model_endpoint={embeddings_model_endpoint}")
118 |
119 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint
120 | logger.info(f"embeddings_model_endpoint={embeddings_model_endpoint}, opensearch_url={opensearch_url}")
121 |
122 | creds = _get_credentials(secret_id, region)
123 | http_auth = (creds['username'], creds['password'])
124 | vector_db = OpenSearchVectorSearch(index_name=opensearch_index,
125 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region),
126 | opensearch_url=opensearch_url,
127 | http_auth=http_auth)
128 |
129 | logger.info(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}")
130 | return vector_db
131 |
132 |
133 | def setup_sagemaker_endpoint_for_text_generation(endpoint_name, region: str = "us-east-1") -> Callable:
134 | parameters = {
135 | "max_length": 500,
136 | "num_return_sequences": 1,
137 | "top_k": 250,
138 | "top_p": 0.95,
139 | "do_sample": False,
140 | "temperature": 1
141 | }
142 |
143 | content_handler = ContentHandlerForTextGeneration()
144 | sm_llm = SagemakerEndpoint(
145 | endpoint_name=endpoint_name,
146 | region_name=region,
147 | model_kwargs=parameters,
148 | content_handler=content_handler)
149 | return sm_llm
150 |
151 |
152 | def main():
153 | region = os.environ["AWS_REGION"]
154 | opensearch_secret = os.environ["OPENSEARCH_SECRET"]
155 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"]
156 | opensearch_index = os.environ["OPENSEARCH_INDEX"]
157 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
158 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
159 |
160 | # initialize vector db and Sagemaker Endpoint
161 | os_creds_secretid_in_secrets_manager = opensearch_secret
162 | _vector_db = load_vector_db_opensearch(os_creds_secretid_in_secrets_manager,
163 | region,
164 | opensearch_domain_endpoint,
165 | opensearch_index,
166 | embeddings_model_endpoint)
167 |
168 | _sm_llm = setup_sagemaker_endpoint_for_text_generation(text2text_model_endpoint, region)
169 |
170 | # Use the vector db to find similar documents to the query
171 | # the vector db call would automatically convert the query text
172 | # into embeddings
173 | query = 'What is SageMaker model monitor? Write your answer in a nicely formatted way.'
174 | max_matching_docs = 3
175 | docs = _vector_db.similarity_search(query, k=max_matching_docs)
176 | logger.info(f"here are the {max_matching_docs} closest matching docs to the query=\"{query}\"")
177 | for d in docs:
178 | logger.info(f"---------")
179 | logger.info(d)
180 | logger.info(f"---------")
181 |
182 | # now that we have the matching docs, lets pack them as a context
183 | # into the prompt and ask the LLM to generate a response
184 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
185 |
186 | prompt = PromptTemplate(
187 | template=prompt_template, input_variables=["context", "question"]
188 | )
189 | logger.info(f"prompt sent to llm = \"{prompt}\"")
190 |
191 | chain = load_qa_chain(llm=_sm_llm, prompt=prompt, verbose=True)
192 | logger.info(f"\ntype('chain'): \"{type(chain)}\"\n")
193 |
194 | answer = chain.invoke({"input_documents": docs, "question": query}, return_only_outputs=True)['output_text']
195 |
196 | logger.info(f"answer received from llm,\nquestion: \"{query}\"\nanswer: \"{answer}\"")
197 |
198 | resp = {'question': query, 'answer': answer}
199 | resp['docs'] = docs
200 | print(resp)
201 |
202 |
203 | if __name__ == "__main__":
204 | main()
205 |
206 |
--------------------------------------------------------------------------------
/app/opensearch_load_qa_chain_llama2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List, Callable
10 | from urllib.parse import urlparse
11 |
12 | import boto3
13 |
14 | from langchain_community.vectorstores import OpenSearchVectorSearch
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains.question_answering import load_qa_chain
23 |
24 |
25 | logger = logging.getLogger()
26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
27 |
28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
29 | def embed_documents(
30 | self, texts: List[str], chunk_size: int = 5
31 | ) -> List[List[float]]:
32 | """Compute doc embeddings using a SageMaker Inference Endpoint.
33 |
34 | Args:
35 | texts: The list of texts to embed.
36 | chunk_size: The chunk size defines how many input texts will
37 | be grouped together as request. If None, will use the
38 | chunk size specified by the class.
39 |
40 | Returns:
41 | List of embeddings, one for each text.
42 | """
43 | results = []
44 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
45 | for i in range(0, len(texts), _chunk_size):
46 | response = self._embedding_func(texts[i : i + _chunk_size])
47 | #print(response)
48 | results.extend(response)
49 | return results
50 |
51 |
52 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
53 | """
54 | encode input string as utf-8 bytes, read the embeddings
55 | from the output
56 | """
57 | content_type = "application/json"
58 | accepts = "application/json"
59 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
60 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
61 | return input_str.encode('utf-8')
62 |
63 | def transform_output(self, output: bytes) -> str:
64 | response_json = json.loads(output.read().decode("utf-8"))
65 | embeddings = response_json["embedding"]
66 | if len(embeddings) == 1:
67 | return [embeddings[0]]
68 | return embeddings
69 |
70 |
71 | class ContentHandlerForTextGeneration(LLMContentHandler):
72 | content_type = "application/json"
73 | accepts = "application/json"
74 |
75 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-chat-completion.ipynb
76 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
77 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \
78 | " If you don't know the answer to a question, say I don't know the answer"
79 |
80 | payload = {
81 | "inputs": [
82 | [
83 | {"role": "system", "content": system_prompt},
84 | {"role": "user", "content": prompt},
85 | ],
86 | ],
87 | "parameters": model_kwargs,
88 | }
89 | input_str = json.dumps(payload)
90 | return input_str.encode("utf-8")
91 |
92 | def transform_output(self, output: bytes) -> str:
93 | response_json = json.loads(output.read().decode("utf-8"))
94 | content = response_json[0]["generation"]["content"]
95 | return content
96 |
97 |
98 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
99 | # create a content handler object which knows how to serialize
100 | # and deserialize communication with the model endpoint
101 | content_handler = ContentHandlerForEmbeddings()
102 |
103 | # read to create the Sagemaker embeddings, we are providing
104 | # the Sagemaker endpoint that will be used for generating the
105 | # embeddings to the class
106 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
107 | endpoint_name=endpoint_name,
108 | region_name=region,
109 | content_handler=content_handler
110 | )
111 | logger.info(f"embeddings type={type(embeddings)}")
112 |
113 | return embeddings
114 |
115 |
116 | def _get_credentials(secret_id: str, region_name: str) -> str:
117 |
118 | client = boto3.client('secretsmanager', region_name=region_name)
119 | response = client.get_secret_value(SecretId=secret_id)
120 | secrets_value = json.loads(response['SecretString'])
121 | return secrets_value
122 |
123 |
124 | def load_vector_db_opensearch(secret_id: str,
125 | region: str,
126 | opensearch_domain_endpoint: str,
127 | opensearch_index: str,
128 | embeddings_model_endpoint: str) -> OpenSearchVectorSearch:
129 | logger.info(f"load_vector_db_opensearch, secret_id={secret_id}, region={region}, "
130 | f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}, "
131 | f"embeddings_model_endpoint={embeddings_model_endpoint}")
132 |
133 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint
134 | logger.info(f"embeddings_model_endpoint={embeddings_model_endpoint}, opensearch_url={opensearch_url}")
135 |
136 | creds = _get_credentials(secret_id, region)
137 | http_auth = (creds['username'], creds['password'])
138 | vector_db = OpenSearchVectorSearch(index_name=opensearch_index,
139 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region),
140 | opensearch_url=opensearch_url,
141 | http_auth=http_auth)
142 |
143 | logger.info(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}")
144 | return vector_db
145 |
146 |
147 | def setup_sagemaker_endpoint_for_text_generation(endpoint_name, region: str = "us-east-1") -> Callable:
148 | parameters = {
149 | "max_new_tokens": 256,
150 | "top_p": 0.9,
151 | "temperature": 0.6,
152 | "return_full_text": False,
153 | }
154 |
155 | content_handler = ContentHandlerForTextGeneration()
156 | sm_llm = SagemakerEndpoint(
157 | endpoint_name=endpoint_name,
158 | region_name=region,
159 | model_kwargs=parameters,
160 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
161 | content_handler=content_handler)
162 | return sm_llm
163 |
164 |
165 | def main():
166 | region = os.environ["AWS_REGION"]
167 | opensearch_secret = os.environ["OPENSEARCH_SECRET"]
168 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"]
169 | opensearch_index = os.environ["OPENSEARCH_INDEX"]
170 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
171 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
172 |
173 | # initialize vector db and Sagemaker Endpoint
174 | os_creds_secretid_in_secrets_manager = opensearch_secret
175 | _vector_db = load_vector_db_opensearch(os_creds_secretid_in_secrets_manager,
176 | region,
177 | opensearch_domain_endpoint,
178 | opensearch_index,
179 | embeddings_model_endpoint)
180 |
181 | _sm_llm = setup_sagemaker_endpoint_for_text_generation(text2text_model_endpoint, region)
182 |
183 | # Use the vector db to find similar documents to the query
184 | # the vector db call would automatically convert the query text
185 | # into embeddings
186 | query = 'What is SageMaker model monitor? Write your answer in a nicely formatted way.'
187 | max_matching_docs = 3
188 | docs = _vector_db.similarity_search(query, k=max_matching_docs)
189 | logger.info(f"here are the {max_matching_docs} closest matching docs to the query=\"{query}\"")
190 | for d in docs:
191 | logger.info(f"---------")
192 | logger.info(d)
193 | logger.info(f"---------")
194 |
195 | # now that we have the matching docs, lets pack them as a context
196 | # into the prompt and ask the LLM to generate a response
197 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
198 |
199 | prompt = PromptTemplate(
200 | template=prompt_template, input_variables=["context", "question"]
201 | )
202 | logger.info(f"prompt sent to llm = \"{prompt}\"")
203 |
204 | chain = load_qa_chain(llm=_sm_llm, prompt=prompt, verbose=True)
205 | logger.info(f"\ntype('chain'): \"{type(chain)}\"\n")
206 |
207 | answer = chain.invoke({"input_documents": docs, "question": query}, return_only_outputs=True)['output_text']
208 |
209 | logger.info(f"answer received from llm,\nquestion: \"{query}\"\nanswer: \"{answer}\"")
210 |
211 | resp = {'question': query, 'answer': answer}
212 | resp['docs'] = docs
213 | print(resp)
214 |
215 |
216 | if __name__ == "__main__":
217 | main()
218 |
219 |
--------------------------------------------------------------------------------
/app/opensearch_retriever_flan_xl.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List
10 | from urllib.parse import urlparse
11 |
12 | import boto3
13 |
14 | from langchain_community.vectorstores import OpenSearchVectorSearch
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains import RetrievalQA
23 |
24 |
25 | logger = logging.getLogger()
26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
27 |
28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
29 | def embed_documents(
30 | self, texts: List[str], chunk_size: int = 5
31 | ) -> List[List[float]]:
32 | """Compute doc embeddings using a SageMaker Inference Endpoint.
33 |
34 | Args:
35 | texts: The list of texts to embed.
36 | chunk_size: The chunk size defines how many input texts will
37 | be grouped together as request. If None, will use the
38 | chunk size specified by the class.
39 |
40 | Returns:
41 | List of embeddings, one for each text.
42 | """
43 | results = []
44 |
45 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
46 | for i in range(0, len(texts), _chunk_size):
47 | response = self._embedding_func(texts[i : i + _chunk_size])
48 | results.extend(response)
49 | return results
50 |
51 |
52 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
53 |
54 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
55 | """
56 | encode input string as utf-8 bytes, read the embeddings
57 | from the output
58 | """
59 | content_type = "application/json"
60 | accepts = "application/json"
61 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
62 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
63 | return input_str.encode('utf-8')
64 |
65 | def transform_output(self, output: bytes) -> str:
66 | response_json = json.loads(output.read().decode("utf-8"))
67 | embeddings = response_json["embedding"]
68 | if len(embeddings) == 1:
69 | return [embeddings[0]]
70 | return embeddings
71 |
72 | # create a content handler object which knows how to serialize
73 | # and deserialize communication with the model endpoint
74 | content_handler = ContentHandlerForEmbeddings()
75 |
76 | # read to create the Sagemaker embeddings, we are providing
77 | # the Sagemaker endpoint that will be used for generating the
78 | # embeddings to the class
79 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
80 | endpoint_name=endpoint_name,
81 | region_name=region,
82 | content_handler=content_handler
83 | )
84 | logger.info(f"embeddings type={type(embeddings)}")
85 |
86 | return embeddings
87 |
88 |
89 | def _get_credentials(secret_id: str, region_name: str) -> str:
90 | client = boto3.client('secretsmanager', region_name=region_name)
91 | response = client.get_secret_value(SecretId=secret_id)
92 | secrets_value = json.loads(response['SecretString'])
93 | return secrets_value
94 |
95 |
96 | def build_chain():
97 | region = os.environ["AWS_REGION"] # us-east-1
98 | opensearch_secret = os.environ["OPENSEARCH_SECRET"]
99 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"]
100 | opensearch_index = os.environ["OPENSEARCH_INDEX"]
101 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
102 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
103 |
104 | class ContentHandler(LLMContentHandler):
105 | content_type = "application/json"
106 | accepts = "application/json"
107 |
108 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
109 | input_str = json.dumps({"inputs": prompt, **model_kwargs})
110 | return input_str.encode('utf-8')
111 |
112 | def transform_output(self, output: bytes) -> str:
113 | response_json = json.loads(output.read().decode("utf-8"))
114 | return response_json[0]["generated_texts"]
115 |
116 | content_handler = ContentHandler()
117 |
118 | model_kwargs = {
119 | "max_length": 500,
120 | "num_return_sequences": 1,
121 | "top_k": 250,
122 | "top_p": 0.95,
123 | "do_sample": False,
124 | "temperature": 1
125 | }
126 |
127 | llm = SagemakerEndpoint(
128 | endpoint_name=text2text_model_endpoint,
129 | region_name=region,
130 | model_kwargs=model_kwargs,
131 | content_handler=content_handler
132 | )
133 |
134 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint
135 |
136 | creds = _get_credentials(opensearch_secret, region)
137 | http_auth = (creds['username'], creds['password'])
138 |
139 | opensearch_vector_search = OpenSearchVectorSearch(
140 | opensearch_url=opensearch_url,
141 | index_name=opensearch_index,
142 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region),
143 | http_auth=http_auth
144 | )
145 |
146 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3})
147 |
148 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
149 |
150 | PROMPT = PromptTemplate(
151 | template=prompt_template, input_variables=["context", "question"]
152 | )
153 |
154 | chain_type_kwargs = {"prompt": PROMPT, "verbose": True}
155 | qa = RetrievalQA.from_chain_type(
156 | llm,
157 | chain_type="stuff",
158 | retriever=retriever,
159 | chain_type_kwargs=chain_type_kwargs,
160 | return_source_documents=True,
161 | verbose=True, #DEBUG
162 | )
163 |
164 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n")
165 | return qa
166 |
167 |
168 | def run_chain(chain, prompt: str, history=[]):
169 | result = chain(prompt, include_run_info=True)
170 | # To make it compatible with chat samples
171 | return {
172 | "answer": result['result'],
173 | "source_documents": result['source_documents']
174 | }
175 |
176 |
177 | if __name__ == "__main__":
178 | chain = build_chain()
179 | result = run_chain(chain, "What is SageMaker model monitor? Write your answer in a nicely formatted way.")
180 | print(result['answer'])
181 | if 'source_documents' in result:
182 | print('Sources:')
183 | for d in result['source_documents']:
184 | print(d.metadata['source'])
--------------------------------------------------------------------------------
/app/opensearch_retriever_llama2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab
4 |
5 | import os
6 | import json
7 | import logging
8 | import sys
9 | from typing import List
10 | from urllib.parse import urlparse
11 |
12 | import boto3
13 |
14 | from langchain_community.vectorstores import OpenSearchVectorSearch
15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
17 |
18 | from langchain_community.llms import SagemakerEndpoint
19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
20 |
21 | from langchain.prompts import PromptTemplate
22 | from langchain.chains import RetrievalQA
23 |
24 |
25 | logger = logging.getLogger()
26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
27 |
28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
29 | def embed_documents(
30 | self, texts: List[str], chunk_size: int = 5
31 | ) -> List[List[float]]:
32 | """Compute doc embeddings using a SageMaker Inference Endpoint.
33 |
34 | Args:
35 | texts: The list of texts to embed.
36 | chunk_size: The chunk size defines how many input texts will
37 | be grouped together as request. If None, will use the
38 | chunk size specified by the class.
39 |
40 | Returns:
41 | List of embeddings, one for each text.
42 | """
43 | results = []
44 |
45 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
46 | for i in range(0, len(texts), _chunk_size):
47 | response = self._embedding_func(texts[i : i + _chunk_size])
48 | results.extend(response)
49 | return results
50 |
51 |
52 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart:
53 |
54 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler):
55 | """
56 | encode input string as utf-8 bytes, read the embeddings
57 | from the output
58 | """
59 | content_type = "application/json"
60 | accepts = "application/json"
61 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
62 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
63 | return input_str.encode('utf-8')
64 |
65 | def transform_output(self, output: bytes) -> str:
66 | response_json = json.loads(output.read().decode("utf-8"))
67 | embeddings = response_json["embedding"]
68 | if len(embeddings) == 1:
69 | return [embeddings[0]]
70 | return embeddings
71 |
72 | # create a content handler object which knows how to serialize
73 | # and deserialize communication with the model endpoint
74 | content_handler = ContentHandlerForEmbeddings()
75 |
76 | # read to create the Sagemaker embeddings, we are providing
77 | # the Sagemaker endpoint that will be used for generating the
78 | # embeddings to the class
79 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
80 | endpoint_name=endpoint_name,
81 | region_name=region,
82 | content_handler=content_handler
83 | )
84 | logger.info(f"embeddings type={type(embeddings)}")
85 |
86 | return embeddings
87 |
88 |
89 | def _get_credentials(secret_id: str, region_name: str) -> str:
90 | client = boto3.client('secretsmanager', region_name=region_name)
91 | response = client.get_secret_value(SecretId=secret_id)
92 | secrets_value = json.loads(response['SecretString'])
93 | return secrets_value
94 |
95 |
96 | def build_chain():
97 | region = os.environ["AWS_REGION"] # us-east-1
98 | opensearch_secret = os.environ["OPENSEARCH_SECRET"]
99 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"]
100 | opensearch_index = os.environ["OPENSEARCH_INDEX"]
101 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"]
102 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"]
103 |
104 | class ContentHandler(LLMContentHandler):
105 | content_type = "application/json"
106 | accepts = "application/json"
107 |
108 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
109 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \
110 | " If you don't know the answer to a question, say I don't know the answer"
111 |
112 | payload = {
113 | "inputs": [
114 | [
115 | {"role": "system", "content": system_prompt},
116 | {"role": "user", "content": prompt},
117 | ],
118 | ],
119 | "parameters": model_kwargs,
120 | }
121 | input_str = json.dumps(payload)
122 | return input_str.encode("utf-8")
123 |
124 | def transform_output(self, output: bytes) -> str:
125 | response_json = json.loads(output.read().decode("utf-8"))
126 | content = response_json[0]["generation"]["content"]
127 | return content
128 |
129 | content_handler = ContentHandler()
130 |
131 | model_kwargs = {
132 | "max_new_tokens": 256,
133 | "top_p": 0.9,
134 | "temperature": 0.6,
135 | "return_full_text": False,
136 | }
137 |
138 | llm = SagemakerEndpoint(
139 | endpoint_name=text2text_model_endpoint,
140 | region_name=region,
141 | model_kwargs=model_kwargs,
142 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
143 | content_handler=content_handler
144 | )
145 |
146 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint
147 |
148 | creds = _get_credentials(opensearch_secret, region)
149 | http_auth = (creds['username'], creds['password'])
150 |
151 | opensearch_vector_search = OpenSearchVectorSearch(
152 | opensearch_url=opensearch_url,
153 | index_name=opensearch_index,
154 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region),
155 | http_auth=http_auth
156 | )
157 |
158 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3})
159 |
160 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
161 |
162 | PROMPT = PromptTemplate(
163 | template=prompt_template, input_variables=["context", "question"]
164 | )
165 |
166 | chain_type_kwargs = {"prompt": PROMPT, "verbose": True}
167 | qa = RetrievalQA.from_chain_type(
168 | llm,
169 | chain_type="stuff",
170 | retriever=retriever,
171 | chain_type_kwargs=chain_type_kwargs,
172 | return_source_documents=True,
173 | verbose=True, #DEBUG
174 | )
175 |
176 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n")
177 | return qa
178 |
179 |
180 | def run_chain(chain, prompt: str, history=[]):
181 | result = chain(prompt, include_run_info=True)
182 | # To make it compatible with chat samples
183 | return {
184 | "answer": result['result'],
185 | "source_documents": result['source_documents']
186 | }
187 |
188 |
189 | if __name__ == "__main__":
190 | chain = build_chain()
191 | result = run_chain(chain, "What is SageMaker model monitor? Write your answer in a nicely formatted way.")
192 | print(result['answer'])
193 | if 'source_documents' in result:
194 | print('Sources:')
195 | for d in result['source_documents']:
196 | print(d.metadata['source'])
--------------------------------------------------------------------------------
/app/qa-with-llm-and-rag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker/5d43e78f0af9a69c0306b4f1181f3cd4971b78f9/app/qa-with-llm-and-rag.png
--------------------------------------------------------------------------------
/app/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3>=1.26.159
2 | langchain>=0.3,<0.4
3 | langchain-aws>=0.2,<0.3
4 | langchain-community>=0.3,<0.4
5 | SQLAlchemy==2.0.28
6 | opensearch-py==2.2.0
7 | streamlit==1.37.0
8 |
--------------------------------------------------------------------------------
/cdk_stacks/README.md:
--------------------------------------------------------------------------------
1 |
2 | # RAG Application CDK Python project!
3 |
4 | 
5 |
6 | This is an QA application with LLMs and RAG project for CDK development with Python.
7 |
8 | The `cdk.json` file tells the CDK Toolkit how to execute your app.
9 |
10 | This project is set up like a standard Python project. The initialization
11 | process also creates a virtualenv within this project, stored under the `.venv`
12 | directory. To create the virtualenv it assumes that there is a `python3`
13 | (or `python` for Windows) executable in your path with access to the `venv`
14 | package. If for any reason the automatic creation of the virtualenv fails,
15 | you can create the virtualenv manually.
16 |
17 | To manually create a virtualenv on MacOS and Linux:
18 |
19 | ```
20 | $ python3 -m venv .venv
21 | ```
22 |
23 | After the init process completes and the virtualenv is created, you can use the following
24 | step to activate your virtualenv.
25 |
26 | ```
27 | $ source .venv/bin/activate
28 | ```
29 |
30 | If you are a Windows platform, you would activate the virtualenv like this:
31 |
32 | ```
33 | % .venv\Scripts\activate.bat
34 | ```
35 |
36 | Once the virtualenv is activated, you can install the required dependencies.
37 |
38 | ```
39 | (.venv) $ pip install -r requirements.txt
40 | ```
41 |
42 | To add additional dependencies, for example other CDK libraries, just add
43 | them to your `setup.py` file and rerun the `pip install -r requirements.txt`
44 | command.
45 |
46 | Before synthesizing the CloudFormation, you should set approperly the cdk context configuration file, `cdk.context.json`.
47 |
48 | For example:
49 |
50 |
51 | {
52 | "opensearch_domain_name": "open-search-domain-name",
53 | "jumpstart_model_info": {
54 | "model_id": "meta-textgeneration-llama-2-7b-f",
55 | "version": "2.0.1"
56 | }
57 | }
58 |
59 |
60 | :information_source: The `model_id`, and `version` provided by SageMaker JumpStart can be found in [**SageMaker Built-in Algorithms with pre-trained Model Table**](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html).
61 |
62 | > :warning: **Important**: Make sure you need to make sure `docker daemon` is running.
63 | > Otherwise you will encounter the following errors:
64 |
65 | ```
66 | ERROR: Cannot connect to the Docker daemon at unix://$HOME/.docker/run/docker.sock. Is the docker daemon running?
67 | jsii.errors.JavaScriptError:
68 | Error: docker exited with status 1
69 | ```
70 |
71 | Now this point you can now synthesize the CloudFormation template for this code.
72 |
73 | ```
74 | (.venv) $ export CDK_DEFAULT_ACCOUNT=$(aws sts get-caller-identity --query Account --output text)
75 | (.venv) $ export CDK_DEFAULT_REGION=$(aws configure get region)
76 | (.venv) $ cdk synth --all
77 | ```
78 |
79 | Now we will be able to deploy all the CDK stacks at once like this:
80 |
81 | ```
82 | (.venv) $ cdk deploy --require-approval never --all
83 | ```
84 |
85 | Or, we can provision each CDK stack one at a time like this:
86 |
87 | #### Step 1: List all CDK Stacks
88 |
89 | ```
90 | (.venv) $ cdk list
91 | RAGVpcStack
92 | RAGOpenSearchStack
93 | RAGSageMakerStudioStack
94 | EmbeddingEndpointStack
95 | LLMEndpointStack
96 | StreamlitAppStack
97 | ```
98 |
99 | #### Step 2: Create OpenSearch cluster
100 |
101 | ```
102 | (.venv) $ cdk deploy --require-approval never RAGVpcStack RAGOpenSearchStack
103 | ```
104 |
105 | #### Step 3: Create SageMaker Studio
106 |
107 | ```
108 | (.venv) $ cdk deploy --require-approval never RAGSageMakerStudioStack
109 | ```
110 |
111 | #### Step 4: Deploy LLM Embedding Endpoint
112 |
113 | ```
114 | (.venv) $ cdk deploy --require-approval never EmbeddingEndpointStack
115 | ```
116 |
117 | #### Step 5: Deploy Text Generation LLM Endpoint
118 |
119 | ```
120 | (.venv) $ cdk deploy --require-approval never LLMEndpointStack
121 | ```
122 |
123 | #### Step 6 (Optional): Deploy the Streamlit app on ECS Fargate
124 |
125 | :warning: Before deploy the following CDK stack, make sure Docker is runing on your machine.
126 |
127 | ```
128 | (.venv) $ cdk deploy --require-approval never StreamlitAppStack
129 | ```
130 |
131 | **Once all CDK stacks have been successfully created, proceed with the remaining steps of the [overall workflow](../README.md#overall-workflow).**
132 |
133 |
134 | ## Clean Up
135 |
136 | Delete the CloudFormation stacks by running the below command.
137 |
138 | ```
139 | (.venv) $ cdk destroy --all
140 | ```
141 |
142 | ## Useful commands
143 |
144 | * `cdk ls` list all stacks in the app
145 | * `cdk synth` emits the synthesized CloudFormation template
146 | * `cdk deploy` deploy this stack to your default AWS account/region
147 | * `cdk diff` compare deployed stack with current state
148 | * `cdk docs` open CDK documentation
149 |
150 | Enjoy!
151 |
152 | ## References
153 |
154 | * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain (2023-05-25)](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/)
155 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/)
156 | * [SageMaker Built-in Algorithms with pre-trained Model Table](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html)
157 | * [AWS Deep Learning Containers Images](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/deep-learning-containers-images.html)
158 | * [OpenSearch Popular APIs](https://opensearch.org/docs/latest/opensearch/popular-api/)
159 | * [Using the Amazon SageMaker Studio Image Build CLI to build container images from your Studio notebooks (2020-09-14)](https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/)
160 |
--------------------------------------------------------------------------------
/cdk_stacks/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | import os
6 |
7 | from aws_cdk import App, Environment, Stack
8 |
9 | from rag_with_aos import (
10 | VpcStack,
11 | OpenSearchStack,
12 | SageMakerStudioStack,
13 | EmbeddingEndpointStack,
14 | LLMEndpointStack,
15 | StreamlitAppStack
16 | )
17 |
18 | APP_ENV = Environment(
19 | account=os.environ["CDK_DEFAULT_ACCOUNT"],
20 | region=os.environ["CDK_DEFAULT_REGION"]
21 | )
22 |
23 | app = App()
24 |
25 | vpc_stack = VpcStack(app, 'RAGVpcStack',
26 | env=APP_ENV)
27 |
28 | ops_stack = OpenSearchStack(app, 'RAGOpenSearchStack',
29 | vpc_stack.vpc,
30 | env=APP_ENV
31 | )
32 | ops_stack.add_dependency(vpc_stack)
33 |
34 | sm_studio_stack = SageMakerStudioStack(app, 'RAGSageMakerStudioStack',
35 | vpc_stack.vpc,
36 | env=APP_ENV
37 | )
38 | sm_studio_stack.add_dependency(ops_stack)
39 |
40 | sm_embedding_endpoint = EmbeddingEndpointStack(app, 'EmbeddingEndpointStack',
41 | env=APP_ENV
42 | )
43 | sm_embedding_endpoint.add_dependency(sm_studio_stack)
44 |
45 | sm_llm_endpoint = LLMEndpointStack(app, 'LLMEndpointStack',
46 | env=APP_ENV
47 | )
48 | sm_llm_endpoint.add_dependency(sm_embedding_endpoint)
49 |
50 | ecs_app = StreamlitAppStack(app, "StreamlitAppStack",
51 | vpc_stack.vpc,
52 | ops_stack.master_user_secret,
53 | ops_stack.opensearch_domain,
54 | sm_llm_endpoint.llm_endpoint,
55 | sm_embedding_endpoint.embedding_endpoint,
56 | env=APP_ENV
57 | )
58 | ecs_app.add_dependency(sm_llm_endpoint)
59 |
60 | app.synth()
61 |
--------------------------------------------------------------------------------
/cdk_stacks/cdk.context.json:
--------------------------------------------------------------------------------
1 | {
2 | "opensearch_domain_name": "rag-vectordb",
3 | "jumpstart_model_info": {
4 | "model_id": "meta-textgeneration-llama-2-7b-f",
5 | "version": "2.0.1"
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/cdk_stacks/cdk.json:
--------------------------------------------------------------------------------
1 | {
2 | "app": "python3 app.py",
3 | "watch": {
4 | "include": [
5 | "**"
6 | ],
7 | "exclude": [
8 | "README.md",
9 | "cdk*.json",
10 | "requirements*.txt",
11 | "source.bat",
12 | "**/__init__.py",
13 | "python/__pycache__",
14 | "tests"
15 | ]
16 | },
17 | "context": {
18 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true,
19 | "@aws-cdk/core:checkSecretUsage": true,
20 | "@aws-cdk/core:target-partitions": [
21 | "aws",
22 | "aws-cn"
23 | ],
24 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true,
25 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true,
26 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true,
27 | "@aws-cdk/aws-iam:minimizePolicies": true,
28 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true,
29 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true,
30 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true,
31 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true,
32 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true,
33 | "@aws-cdk/core:enablePartitionLiterals": true,
34 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true,
35 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true,
36 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true,
37 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true,
38 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true,
39 | "@aws-cdk/aws-route53-patters:useCertificate": true,
40 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false,
41 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true,
42 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true,
43 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true,
44 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true,
45 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true,
46 | "@aws-cdk/aws-redshift:columnId": true,
47 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true,
48 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true,
49 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true,
50 | "@aws-cdk/aws-kms:aliasNameRef": true,
51 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/__init__.py:
--------------------------------------------------------------------------------
1 | from .vpc import VpcStack
2 | from .ops import OpenSearchStack
3 | from .sm_studio import SageMakerStudioStack
4 | from .sm_custom_embedding_endpoint import SageMakerEmbeddingEndpointStack as EmbeddingEndpointStack
5 | from .sm_jumpstart_llm_endpoint import SageMakerJumpStartLLMEndpointStack as LLMEndpointStack
6 | from .ecs_streamlit_app import StreamlitAppStack
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/ecs_streamlit_app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | from aws_cdk import (
6 | CfnOutput,
7 | Stack,
8 | aws_ec2 as ec2,
9 | aws_ecs as ecs,
10 | aws_iam as iam,
11 | aws_ecs_patterns as ecs_patterns,
12 | )
13 | from constructs import Construct
14 |
15 |
16 | class StreamlitAppStack(Stack):
17 | def __init__(
18 | self,
19 | scope: Construct,
20 | id: str,
21 | vpc,
22 | opensearch_master_user_secret,
23 | opensearch_domain,
24 | llm_endpoint,
25 | embedding_endpoint,
26 | **kwargs) -> None:
27 |
28 | super().__init__(scope, id, **kwargs)
29 |
30 | container_port = self.node.try_get_context("streamlit_container_port") or 8501
31 |
32 | # Create an IAM service role for ECS task execution
33 | execution_role = iam.Role(
34 | self, "ECSExecutionRole",
35 | assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"),
36 | managed_policies=[
37 | iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AmazonECSTaskExecutionRolePolicy")
38 | ]
39 | )
40 |
41 | # Create an IAM role for the Fargate task
42 | task_role = iam.Role(
43 | self, "ECSTaskRole",
44 | assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"),
45 | description="Role for ECS Tasks to access Secrets Manager and SageMaker"
46 | )
47 |
48 | # Policy to access Secrets Manager secrets
49 | secrets_policy = iam.PolicyStatement(
50 | actions=["secretsmanager:GetSecretValue"],
51 | resources=[
52 | opensearch_master_user_secret.secret_arn
53 | ],
54 | effect=iam.Effect.ALLOW
55 | )
56 |
57 | # Policy to invoke SageMaker endpoints
58 | sagemaker_policy = iam.PolicyStatement(
59 | actions=["sagemaker:InvokeEndpoint"],
60 | resources=[
61 | embedding_endpoint.endpoint_arn,
62 | llm_endpoint.endpoint_arn
63 | ],
64 | effect=iam.Effect.ALLOW
65 | )
66 |
67 | # Attach policies to the role
68 | task_role.add_to_policy(secrets_policy)
69 | task_role.add_to_policy(sagemaker_policy)
70 |
71 | # Set up ECS cluster and networking
72 | cluster = ecs.Cluster(
73 | self, "Cluster",
74 | vpc=vpc
75 | )
76 |
77 | security_group = ec2.SecurityGroup(
78 | self, "SecurityGroup",
79 | vpc=vpc,
80 | description="Allow traffic on container port",
81 | allow_all_outbound=True
82 | )
83 |
84 | security_group.add_ingress_rule(
85 | ec2.Peer.any_ipv4(),
86 | ec2.Port.tcp(container_port),
87 | "Allow inbound traffic"
88 | )
89 |
90 | # Set up Fargate task definition and service
91 | task_definition = ecs.FargateTaskDefinition(
92 | self, "TaskDef",
93 | memory_limit_mib=512,
94 | execution_role=execution_role,
95 | task_role=task_role,
96 | cpu=256
97 | )
98 |
99 | container = task_definition.add_container(
100 | "WebContainer",
101 | image=ecs.ContainerImage.from_asset("../app/"),
102 | logging=ecs.LogDrivers.aws_logs(stream_prefix="streamlitapp"),
103 | environment={
104 | "AWS_REGION": self.region,
105 | "OPENSEARCH_SECRET": opensearch_master_user_secret.secret_name,
106 | "OPENSEARCH_DOMAIN_ENDPOINT": f"https://{opensearch_domain.domain_endpoint}",
107 | "OPENSEARCH_INDEX": "llm_rag_embeddings",
108 | "EMBEDDING_ENDPOINT_NAME": embedding_endpoint.cfn_endpoint.endpoint_name,
109 | "TEXT2TEXT_ENDPOINT_NAME": llm_endpoint.cfn_endpoint.endpoint_name
110 | }
111 | )
112 | container.add_port_mappings(ecs.PortMapping(container_port=container_port))
113 |
114 | fargate_service = ecs_patterns.ApplicationLoadBalancedFargateService(
115 | self, "FargateService",
116 | cluster=cluster,
117 | task_definition=task_definition,
118 | public_load_balancer=True,
119 | assign_public_ip=True
120 | )
121 |
122 |
123 | CfnOutput(self, 'StreamlitEndpoint',
124 | value=fargate_service.load_balancer.load_balancer_dns_name,
125 | export_name=f'{self.stack_name}-StreamlitEndpoint'
126 | )
127 |
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | import json
6 | import random
7 | import re
8 | import string
9 |
10 | import aws_cdk as cdk
11 |
12 | from aws_cdk import (
13 | Stack,
14 | aws_ec2,
15 | aws_opensearchservice,
16 | aws_secretsmanager
17 | )
18 | from constructs import Construct
19 |
20 | random.seed(47)
21 |
22 |
23 | class OpenSearchStack(Stack):
24 |
25 | def __init__(self, scope: Construct, construct_id: str, vpc, **kwargs) -> None:
26 | super().__init__(scope, construct_id, **kwargs)
27 |
28 | #XXX: Amazon OpenSearch Service Domain naming restrictions
29 | # https://docs.aws.amazon.com/opensearch-service/latest/developerguide/createupdatedomains.html#createdomains
30 | OPENSEARCH_DEFAULT_DOMAIN_NAME = 'opensearch-{}'.format(''.join(random.sample((string.ascii_letters), k=5)))
31 | opensearch_domain_name = self.node.try_get_context('opensearch_domain_name') or OPENSEARCH_DEFAULT_DOMAIN_NAME
32 | assert re.fullmatch(r'([a-z][a-z0-9\-]+){3,28}?', opensearch_domain_name), 'Invalid domain name'
33 |
34 | self.master_user_secret = aws_secretsmanager.Secret(self, "OpenSearchMasterUserSecret",
35 | generate_secret_string=aws_secretsmanager.SecretStringGenerator(
36 | secret_string_template=json.dumps({"username": "admin"}),
37 | generate_string_key="password",
38 | # Master password must be at least 8 characters long and contain at least one uppercase letter,
39 | # one lowercase letter, one number, and one special character.
40 | password_length=8
41 | )
42 | )
43 |
44 | #XXX: aws cdk elastsearch example - https://github.com/aws/aws-cdk/issues/2873
45 | # You should camelCase the property names instead of PascalCase
46 | self.opensearch_domain = aws_opensearchservice.Domain(self, "OpenSearch",
47 | domain_name=opensearch_domain_name,
48 | #XXX: Supported versions of OpenSearch and Elasticsearch
49 | # https://docs.aws.amazon.com/opensearch-service/latest/developerguide/what-is.html#choosing-version
50 | version=aws_opensearchservice.EngineVersion.OPENSEARCH_2_11,
51 | #XXX: Amazon OpenSearch Service - Current generation instance types
52 | # https://docs.aws.amazon.com/opensearch-service/latest/developerguide/supported-instance-types.html#latest-gen
53 | # - The OR1 instance types require OpenSearch 2.11 or later.
54 | # - OR1 instances are only compatible with other Graviton instance types master nodes (C6g, M6g, R6g)
55 | capacity={
56 | "master_nodes": 3,
57 | "master_node_instance_type": "m6g.large.search",
58 | "data_nodes": 3,
59 | "data_node_instance_type": "or1.large.search"
60 | },
61 | ebs={
62 | # Volume size must be between 20 and 1536 for or1.large.search instance type and version OpenSearch_2.11
63 | "volume_size": 20,
64 | "volume_type": aws_ec2.EbsDeviceVolumeType.GP3
65 | },
66 | #XXX: az_count must be equal to vpc subnets count.
67 | zone_awareness={
68 | "availability_zone_count": 3
69 | },
70 | logging={
71 | "slow_search_log_enabled": True,
72 | "app_log_enabled": True,
73 | "slow_index_log_enabled": True
74 | },
75 | fine_grained_access_control=aws_opensearchservice.AdvancedSecurityOptions(
76 | master_user_name=self.master_user_secret.secret_value_from_json("username").unsafe_unwrap(),
77 | master_user_password=self.master_user_secret.secret_value_from_json("password")
78 | ),
79 | # Enforce HTTPS is required when fine-grained access control is enabled.
80 | enforce_https=True,
81 | # Node-to-node encryption is required when fine-grained access control is enabled
82 | node_to_node_encryption=True,
83 | # Encryption-at-rest is required when fine-grained access control is enabled.
84 | encryption_at_rest={
85 | "enabled": True
86 | },
87 | use_unsigned_basic_auth=True,
88 | removal_policy=cdk.RemovalPolicy.DESTROY # default: cdk.RemovalPolicy.RETAIN
89 | )
90 |
91 | cdk.Tags.of(self.opensearch_domain).add('Name', opensearch_domain_name)
92 |
93 | cdk.CfnOutput(self, 'OpenSourceDomainArn',
94 | value=self.opensearch_domain.domain_arn,
95 | export_name=f'{self.stack_name}-OpenSourceDomainArn')
96 | cdk.CfnOutput(self, 'OpenSearchDomainEndpoint',
97 | value=f"https://{self.opensearch_domain.domain_endpoint}",
98 | export_name=f'{self.stack_name}-OpenSearchDomainEndpoint')
99 | cdk.CfnOutput(self, 'OpenSearchDashboardsURL',
100 | value=f"https://{self.opensearch_domain.domain_endpoint}/_dashboards/",
101 | export_name=f'{self.stack_name}-OpenSearchDashboardsURL')
102 | cdk.CfnOutput(self, 'OpenSearchSecret',
103 | value=self.master_user_secret.secret_name,
104 | export_name=f'{self.stack_name}-MasterUserSecretId')
105 |
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/sm_custom_embedding_endpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | import random
6 | import string
7 |
8 | import aws_cdk as cdk
9 |
10 | from aws_cdk import (
11 | Stack
12 | )
13 | from constructs import Construct
14 |
15 | from cdklabs.generative_ai_cdk_constructs import (
16 | CustomSageMakerEndpoint,
17 | DeepLearningContainerImage,
18 | SageMakerInstanceType,
19 | )
20 |
21 | random.seed(47)
22 |
23 |
24 | class SageMakerEmbeddingEndpointStack(Stack):
25 |
26 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
27 | super().__init__(scope, construct_id, **kwargs)
28 |
29 | bucket_name = f'jumpstart-cache-prod-{cdk.Aws.REGION}'
30 | key_name = 'huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-textembedding-gpt-j-6b-fp16.tar.gz'
31 |
32 | RANDOM_GUID = ''.join(random.sample(string.digits, k=7))
33 | endpoint_name = f"gpt-j-6b-fp16-endpoint-{RANDOM_GUID}"
34 |
35 | #XXX: https://github.com/awslabs/generative-ai-cdk-constructs/blob/main/src/patterns/gen-ai/aws-model-deployment-sagemaker/README_custom_sagemaker_endpoint.md
36 | self.embedding_endpoint = CustomSageMakerEndpoint(self, 'EmbeddingEndpoint',
37 | model_id='gpt-j-6b-fp16',
38 | instance_type=SageMakerInstanceType.ML_G5_2_XLARGE,
39 | container=DeepLearningContainerImage.from_deep_learning_container_image(
40 | 'pytorch-inference',
41 | '1.12.0-gpu-py38'
42 | ),
43 | model_data_url=f's3://{bucket_name}/{key_name}',
44 | endpoint_name=endpoint_name,
45 | instance_count=1,
46 | # volume_size_in_gb=100
47 | )
48 |
49 | cdk.CfnOutput(self, 'EmbeddingEndpointName',
50 | value=self.embedding_endpoint.cfn_endpoint.endpoint_name,
51 | export_name=f'{self.stack_name}-EmbeddingEndpointName')
52 | cdk.CfnOutput(self, 'EmbeddingEndpointArn',
53 | value=self.embedding_endpoint.endpoint_arn,
54 | export_name=f'{self.stack_name}-EmbeddingEndpointArn')
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/sm_jumpstart_llm_endpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | import random
6 | import string
7 |
8 | import aws_cdk as cdk
9 |
10 | from aws_cdk import (
11 | Stack
12 | )
13 | from constructs import Construct
14 |
15 | from cdklabs.generative_ai_cdk_constructs import (
16 | JumpStartSageMakerEndpoint,
17 | JumpStartModel,
18 | SageMakerInstanceType
19 | )
20 |
21 | random.seed(47)
22 |
23 |
24 | def name_from_base(base, max_length=63):
25 | unique = ''.join(random.sample(string.digits, k=7))
26 | max_length = 63
27 | trimmed_base = base[: max_length - len(unique) - 1]
28 | return "{}-{}".format(trimmed_base, unique)
29 |
30 |
31 | class SageMakerJumpStartLLMEndpointStack(Stack):
32 |
33 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
34 | super().__init__(scope, construct_id, **kwargs)
35 |
36 | jumpstart_model = self.node.try_get_context('jumpstart_model_info')
37 | model_id, model_version = jumpstart_model.get('model_id', 'meta-textgeneration-llama-2-7b-f'), jumpstart_model.get('version', '2.0.1')
38 | model_name = f"{model_id.upper().replace('-', '_')}_{model_version.replace('.', '_')}"
39 |
40 | llm_endpoint_name = name_from_base(model_id.replace('/', '-').replace('.', '-'))
41 |
42 | #XXX: Available JumStart Model List
43 | # https://github.com/awslabs/generative-ai-cdk-constructs/blob/main/src/patterns/gen-ai/aws-model-deployment-sagemaker/jumpstart-model.ts
44 | self.llm_endpoint = JumpStartSageMakerEndpoint(self, 'LLMEndpoint',
45 | model=JumpStartModel.of(model_name),
46 | accept_eula=True,
47 | instance_type=SageMakerInstanceType.ML_G5_2_XLARGE,
48 | endpoint_name=llm_endpoint_name
49 | )
50 |
51 | cdk.CfnOutput(self, 'LLMEndpointName',
52 | value=self.llm_endpoint.cfn_endpoint.endpoint_name,
53 | export_name=f'{self.stack_name}-LLMEndpointName')
54 | cdk.CfnOutput(self, 'LLMEndpointArn',
55 | value=self.llm_endpoint.endpoint_arn,
56 | export_name=f'{self.stack_name}-LLMEndpointArn')
57 |
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/sm_studio.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | import aws_cdk as cdk
6 |
7 | from aws_cdk import (
8 | Stack,
9 | aws_ec2,
10 | aws_iam,
11 | aws_sagemaker
12 | )
13 | from constructs import Construct
14 |
15 |
16 | class SageMakerStudioStack(Stack):
17 |
18 | def __init__(self, scope: Construct, construct_id: str, vpc, **kwargs) -> None:
19 | super().__init__(scope, construct_id, **kwargs)
20 |
21 | sagemaker_execution_policy_doc = aws_iam.PolicyDocument()
22 | sagemaker_execution_policy_doc.add_statements(aws_iam.PolicyStatement(**{
23 | "effect": aws_iam.Effect.ALLOW,
24 | "resources": ["arn:aws:s3:::*"],
25 | "actions": [
26 | "s3:GetObject",
27 | "s3:PutObject",
28 | "s3:DeleteObject",
29 | "s3:ListBucket"
30 | ]
31 | }))
32 |
33 | sagemaker_custom_access_policy_doc = aws_iam.PolicyDocument()
34 | sagemaker_custom_access_policy_doc.add_statements(aws_iam.PolicyStatement(**{
35 | "effect": aws_iam.Effect.ALLOW,
36 | "resources": [f"arn:aws:es:{cdk.Aws.REGION}:{cdk.Aws.ACCOUNT_ID}:domain/*"],
37 | "actions": ["es:ESHttp*"],
38 | "sid": "ReadFromOpenSearch"
39 | }))
40 |
41 | sagemaker_custom_access_policy_doc.add_statements(aws_iam.PolicyStatement(**{
42 | "effect": aws_iam.Effect.ALLOW,
43 | "resources": [f"arn:aws:secretsmanager:{cdk.Aws.REGION}:{cdk.Aws.ACCOUNT_ID}:secret:*"],
44 | "actions": ["secretsmanager:GetSecretValue"],
45 | "sid": "ReadSecretFromSecretsManager"
46 | }))
47 |
48 | sagemaker_docker_build_policy_doc = aws_iam.PolicyDocument()
49 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
50 | "effect": aws_iam.Effect.ALLOW,
51 | "resources": ["arn:aws:codebuild:*:*:project/sagemaker-studio*"],
52 | "actions": [
53 | "codebuild:DeleteProject",
54 | "codebuild:CreateProject",
55 | "codebuild:BatchGetBuilds",
56 | "codebuild:StartBuild"
57 | ]
58 | }))
59 |
60 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
61 | "effect": aws_iam.Effect.ALLOW,
62 | "resources": ["arn:aws:logs:*:*:log-group:/aws/codebuild/sagemaker-studio*"],
63 | "actions": ["logs:CreateLogStream"],
64 | }))
65 |
66 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
67 | "effect": aws_iam.Effect.ALLOW,
68 | "resources": ["arn:aws:logs:*:*:log-group:/aws/codebuild/sagemaker-studio*:log-stream:*"],
69 | "actions": [
70 | "logs:GetLogEvents",
71 | "logs:PutLogEvents"
72 | ]
73 | }))
74 |
75 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
76 | "effect": aws_iam.Effect.ALLOW,
77 | "resources": ["*"],
78 | "actions": ["logs:CreateLogGroup"]
79 | }))
80 |
81 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
82 | "effect": aws_iam.Effect.ALLOW,
83 | "resources": ["*"],
84 | "actions": [
85 | "ecr:BatchGetImage",
86 | "ecr:BatchCheckLayerAvailability",
87 | "ecr:CompleteLayerUpload",
88 | "ecr:DescribeImages",
89 | "ecr:DescribeRepositories",
90 | "ecr:GetDownloadUrlForLayer",
91 | "ecr:InitiateLayerUpload",
92 | "ecr:ListImages",
93 | "ecr:PutImage",
94 | "ecr:UploadLayerPart",
95 | "ecr:CreateRepository",
96 | "ecr:GetAuthorizationToken",
97 | "ec2:DescribeAvailabilityZones"
98 | ],
99 | "sid": "ReadWriteFromECR"
100 | }))
101 |
102 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
103 | "effect": aws_iam.Effect.ALLOW,
104 | "resources": ["*"],
105 | "actions": ["ecr:GetAuthorizationToken"]
106 | }))
107 |
108 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
109 | "effect": aws_iam.Effect.ALLOW,
110 | "resources": ["arn:aws:s3:::sagemaker-*/*"],
111 | "actions": [
112 | "s3:GetObject",
113 | "s3:DeleteObject",
114 | "s3:PutObject"
115 | ]
116 | }))
117 |
118 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
119 | "effect": aws_iam.Effect.ALLOW,
120 | "resources": ["arn:aws:s3:::sagemaker*"],
121 | "actions": ["s3:CreateBucket"],
122 | }))
123 |
124 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
125 | "effect": aws_iam.Effect.ALLOW,
126 | "resources": ["*"],
127 | "actions": [
128 | "iam:GetRole",
129 | "iam:ListRoles"
130 | ]
131 | }))
132 |
133 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{
134 | "effect": aws_iam.Effect.ALLOW,
135 | "resources": ["arn:aws:iam::*:role/*"],
136 | "conditions": {
137 | "StringLikeIfExists": {
138 | "iam:PassedToService": [
139 | "codebuild.amazonaws.com"
140 | ]
141 | }
142 | },
143 | "actions": ["iam:PassRole"]
144 | }))
145 |
146 | sagemaker_execution_role = aws_iam.Role(self, 'SageMakerExecutionRole',
147 | role_name=f'AmazonSageMakerStudioExecutionRole-{self.stack_name.lower()}',
148 | assumed_by=aws_iam.ServicePrincipal('sagemaker.amazonaws.com'),
149 | path='/',
150 | inline_policies={
151 | 'sagemaker-execution-policy': sagemaker_execution_policy_doc,
152 | 'sagemaker-custom-access-policy': sagemaker_custom_access_policy_doc,
153 | 'sagemaker-docker-build-policy': sagemaker_docker_build_policy_doc,
154 | },
155 | managed_policies=[
156 | aws_iam.ManagedPolicy.from_aws_managed_policy_name('AmazonSageMakerFullAccess'),
157 | aws_iam.ManagedPolicy.from_aws_managed_policy_name('AmazonSageMakerCanvasFullAccess'),
158 | aws_iam.ManagedPolicy.from_aws_managed_policy_name('AWSCloudFormationReadOnlyAccess'),
159 | ]
160 | )
161 |
162 | #XXX: To use the sm-docker CLI, the Amazon SageMaker execution role used by the Studio notebook
163 | # environment should have a trust policy with CodeBuild
164 | sagemaker_execution_role.assume_role_policy.add_statements(aws_iam.PolicyStatement(**{
165 | "effect": aws_iam.Effect.ALLOW,
166 | "principals": [aws_iam.ServicePrincipal('codebuild.amazonaws.com')],
167 | "actions": ["sts:AssumeRole"]
168 | }))
169 |
170 | sm_studio_user_settings = aws_sagemaker.CfnDomain.UserSettingsProperty(
171 | execution_role=sagemaker_execution_role.role_arn
172 | )
173 |
174 | sagemaker_studio_domain = aws_sagemaker.CfnDomain(self, 'SageMakerStudioDomain',
175 | auth_mode='IAM', # [SSO | IAM]
176 | default_user_settings=sm_studio_user_settings,
177 | domain_name='llm-app-rag-workshop',
178 | subnet_ids=vpc.select_subnets(subnet_type=aws_ec2.SubnetType.PUBLIC).subnet_ids,
179 | vpc_id=vpc.vpc_id,
180 | app_network_access_type='PublicInternetOnly' # [PublicInternetOnly | VpcOnly]
181 | )
182 |
183 | #XXX: https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jl.html#studio-jl-set
184 | sagmaker_jupyerlab_arn = self.node.try_get_context('sagmaker_jupyterlab_arn')
185 |
186 | default_user_settings = aws_sagemaker.CfnUserProfile.UserSettingsProperty(
187 | jupyter_server_app_settings=aws_sagemaker.CfnUserProfile.JupyterServerAppSettingsProperty(
188 | default_resource_spec=aws_sagemaker.CfnUserProfile.ResourceSpecProperty(
189 | #XXX: JupyterServer apps only support the system value.
190 | instance_type="system",
191 | sage_maker_image_arn=sagmaker_jupyerlab_arn
192 | )
193 | )
194 | )
195 |
196 | sagemaker_user_profile = aws_sagemaker.CfnUserProfile(self, 'SageMakerStudioUserProfile',
197 | domain_id=sagemaker_studio_domain.attr_domain_id,
198 | user_profile_name='default-user',
199 | user_settings=default_user_settings
200 | )
201 |
202 | self.sm_execution_role_arn = sagemaker_execution_role.role_arn
203 |
204 | cdk.CfnOutput(self, 'DomainUrl',
205 | value=sagemaker_studio_domain.attr_url,
206 | export_name=f'{self.stack_name}-DomainUrl')
207 | cdk.CfnOutput(self, 'DomainId',
208 | value=sagemaker_user_profile.domain_id,
209 | export_name=f'{self.stack_name}-DomainId')
210 | cdk.CfnOutput(self, 'UserProfileName',
211 | value=sagemaker_user_profile.user_profile_name,
212 | export_name=f'{self.stack_name}-UserProfileName')
213 |
--------------------------------------------------------------------------------
/cdk_stacks/rag_with_aos/vpc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab
4 |
5 | import os
6 | import aws_cdk as cdk
7 |
8 | from aws_cdk import (
9 | Stack,
10 | aws_ec2,
11 | )
12 | from constructs import Construct
13 |
14 |
15 | class VpcStack(Stack):
16 |
17 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
18 | super().__init__(scope, construct_id, **kwargs)
19 |
20 | #XXX: For creating this CDK Stack in the existing VPC,
21 | # remove comments from the below codes and
22 | # comments out vpc = aws_ec2.Vpc(..) codes,
23 | # then pass -c vpc_name=your-existing-vpc to cdk command
24 | # for example,
25 | # cdk -c vpc_name=your-existing-vpc syth
26 | #
27 | if str(os.environ.get('USE_DEFAULT_VPC', 'false')).lower() == 'true':
28 | vpc_name = self.node.try_get_context('vpc_name') or 'default'
29 | self.vpc = aws_ec2.Vpc.from_lookup(self, 'ExistingVPC',
30 | is_default=True,
31 | vpc_name=vpc_name
32 | )
33 | else:
34 | #XXX: To use more than 2 AZs, be sure to specify the account and region on your stack.
35 | #XXX: https://docs.aws.amazon.com/cdk/api/latest/python/aws_cdk.aws_ec2/Vpc.html
36 | self.vpc = aws_ec2.Vpc(self, 'RAGAppVPC',
37 | ip_addresses=aws_ec2.IpAddresses.cidr("10.0.0.0/16"),
38 | max_azs=3,
39 |
40 | # 'subnetConfiguration' specifies the "subnet groups" to create.
41 | # Every subnet group will have a subnet for each AZ, so this
42 | # configuration will create `2 groups × 3 AZs = 6` subnets.
43 | subnet_configuration=[
44 | {
45 | "cidrMask": 20,
46 | "name": "Public",
47 | "subnetType": aws_ec2.SubnetType.PUBLIC,
48 | },
49 | {
50 | "cidrMask": 20,
51 | "name": "Private",
52 | "subnetType": aws_ec2.SubnetType.PRIVATE_WITH_EGRESS
53 | }
54 | ],
55 | gateway_endpoints={
56 | "S3": aws_ec2.GatewayVpcEndpointOptions(
57 | service=aws_ec2.GatewayVpcEndpointAwsService.S3
58 | )
59 | }
60 | )
61 |
62 |
63 | cdk.CfnOutput(self, 'VPCID', value=self.vpc.vpc_id,
64 | export_name=f'{self.stack_name}-VPCID')
65 |
--------------------------------------------------------------------------------
/cdk_stacks/requirements.txt:
--------------------------------------------------------------------------------
1 | aws-cdk-lib==2.171.1
2 | constructs>=10.0.0,<11.0.0
3 | cdklabs.generative-ai-cdk-constructs==0.1.286
4 |
--------------------------------------------------------------------------------
/cdk_stacks/source.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | rem The sole purpose of this script is to make the command
4 | rem
5 | rem source .venv/bin/activate
6 | rem
7 | rem (which activates a Python virtualenv on Linux or Mac OS X) work on Windows.
8 | rem On Windows, this command just runs this batch file (the argument is ignored).
9 | rem
10 | rem Now we don't need to document a Windows command for activating a virtualenv.
11 |
12 | echo Executing .venv\Scripts\activate.bat for you
13 | .venv\Scripts\activate.bat
14 |
--------------------------------------------------------------------------------
/data_ingestion_to_vectordb/container/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10.13-slim
2 |
3 | # pip leaves the install caches populated which uses a
4 | # significant amount of space. These optimizations save a fair
5 | # amount of space in the image, which reduces start up time.
6 | RUN pip --no-cache-dir install -U pip
7 | RUN pip --no-cache-dir install boto3==1.33.9 \
8 | langchain==0.1.11 \
9 | langchain-community==0.0.29 \
10 | SQLAlchemy==2.0.2 \
11 | opensearch-py==2.2.0 \
12 | beautifulsoup4==4.12.3
13 |
14 | # Include python script for retrieving credentials
15 | # from AWS SecretsManager and Sagemaker helper classes
16 | ADD credentials.py /code/
17 | ADD sm_helper.py /code/
18 |
19 | # Set some environment variables. PYTHONUNBUFFERED keeps Python from buffering our standard
20 | # output stream, which means that logs can be delivered to the user quickly. PYTHONDONTWRITEBYTECODE
21 | # keeps Python from writing the .pyc files which are unnecessary in this case. We also update
22 | # PATH so that the train and serve programs are found when the container is invoked.
23 | ENV PYTHONUNBUFFERED=TRUE
24 | ENV PYTHONDONTWRITEBYTECODE=TRUE
25 |
--------------------------------------------------------------------------------
/data_ingestion_to_vectordb/container/credentials.py:
--------------------------------------------------------------------------------
1 | """
2 | Retrieve credentials password for given username from AWS SecretsManager
3 | """
4 | import json
5 | import boto3
6 |
7 | def get_credentials(secret_id: str, region_name: str) -> str:
8 |
9 | client = boto3.client('secretsmanager', region_name=region_name)
10 | response = client.get_secret_value(SecretId=secret_id)
11 | secrets_value = json.loads(response['SecretString'])
12 |
13 | return secrets_value
--------------------------------------------------------------------------------
/data_ingestion_to_vectordb/container/load_data_into_opensearch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # this is needed because the credentials.py and sm_helper.py
5 | # are in /code directory of the custom container we are going
6 | # to create for Sagemaker Processing Job
7 | sys.path.insert(1, '/code')
8 |
9 | import glob
10 | import time
11 | import logging
12 | import argparse
13 | import multiprocessing as mp
14 | from itertools import repeat
15 | from functools import partial
16 | from typing import (
17 | List,
18 | Tuple
19 | )
20 |
21 | import numpy as np
22 |
23 | from langchain_community.document_loaders import ReadTheDocsLoader
24 | from langchain_community.vectorstores import OpenSearchVectorSearch
25 | from langchain.text_splitter import RecursiveCharacterTextSplitter
26 |
27 | from opensearchpy import (
28 | OpenSearch,
29 | RequestsHttpConnection,
30 | AWSV4SignerAuth
31 | )
32 |
33 | from credentials import get_credentials
34 | from sm_helper import create_sagemaker_embeddings_from_js_model
35 |
36 |
37 | # global constants
38 | MAX_OS_DOCS_PER_PUT = 500
39 | TOTAL_INDEX_CREATION_WAIT_TIME = 60
40 | PER_ITER_SLEEP_TIME = 5
41 |
42 | logger = logging.getLogger()
43 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
44 |
45 |
46 | def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch:
47 | #update the region if you're working other than us-east-1
48 |
49 | aos_client = OpenSearch(
50 | hosts = [{'host': host.rstrip('/').replace("https://", ""), 'port': 443}],
51 | http_auth = http_auth,
52 | use_ssl = True,
53 | verify_certs = True,
54 | connection_class = RequestsHttpConnection
55 | )
56 | exists = aos_client.indices.exists(index_name)
57 | logger.info(f"index_name={index_name}, exists={exists}")
58 | return exists
59 |
60 |
61 | def process_shard(shard, embeddings_model_endpoint_name, aws_region, os_index_name, os_domain_ep, os_http_auth) -> int:
62 | logger.info(f'Starting process_shard of {len(shard)} chunks.')
63 | st = time.time()
64 | embeddings = create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name, aws_region)
65 | docsearch = OpenSearchVectorSearch(index_name=os_index_name,
66 | embedding_function=embeddings,
67 | opensearch_url=os_domain_ep,
68 | http_auth=os_http_auth)
69 | docsearch.add_documents(documents=shard)
70 | et = time.time() - st
71 | logger.info(f'Shard completed in {et} seconds.')
72 | return 0
73 |
74 |
75 | if __name__ == "__main__":
76 | parser = argparse.ArgumentParser()
77 |
78 | parser.add_argument("--opensearch-cluster-domain", type=str, default=None)
79 | parser.add_argument("--opensearch-secretid", type=str, default=None)
80 | parser.add_argument("--opensearch-index-name", type=str, default=None)
81 | parser.add_argument("--aws-region", type=str, default="us-east-1")
82 | parser.add_argument("--embeddings-model-endpoint-name", type=str, default=None)
83 | parser.add_argument("--chunk-size-for-doc-split", type=int, default=500)
84 | parser.add_argument("--chunk-overlap-for-doc-split", type=int, default=30)
85 | parser.add_argument("--input-data-dir", type=str, default="/opt/ml/processing/input_data")
86 | parser.add_argument("--process-count", type=int, default=1)
87 | parser.add_argument("--create-index-hint-file", type=str, default="_create_index_hint")
88 | args, _ = parser.parse_known_args()
89 |
90 | logger.info("Received arguments {}".format(args))
91 | # list all the files
92 | files = glob.glob(os.path.join(args.input_data_dir, "*.*"))
93 | logger.info(f"there are {len(files)} files to process in the {args.input_data_dir} folder")
94 |
95 | # retrieve secret to talk to opensearch
96 | creds = get_credentials(args.opensearch_secretid, args.aws_region)
97 | http_auth = (creds['username'], creds['password'])
98 |
99 | loader = ReadTheDocsLoader(args.input_data_dir)
100 | text_splitter = RecursiveCharacterTextSplitter(
101 | # Set a really small chunk size, just to show.
102 | chunk_size=args.chunk_size_for_doc_split,
103 | chunk_overlap=args.chunk_overlap_for_doc_split,
104 | length_function=len,
105 | )
106 |
107 | # Stage one: read all the docs, split them into chunks.
108 | st = time.time()
109 | logger.info('Loading documents ...')
110 | docs = loader.load()
111 |
112 | # add a custom metadata field, such as timestamp
113 | for doc in docs:
114 | doc.metadata['timestamp'] = time.time()
115 | doc.metadata['embeddings_model'] = args.embeddings_model_endpoint_name
116 | chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])
117 | et = time.time() - st
118 | logger.info(f'Time taken: {et} seconds. {len(chunks)} chunks generated')
119 |
120 |
121 | db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1
122 | print(f'Loading chunks into vector store ... using {db_shards} shards')
123 | st = time.time()
124 | shards = np.array_split(chunks, db_shards)
125 |
126 | t1 = time.time()
127 |
128 | # first check if index exists, if it does then call the add_documents function
129 | # otherwise call the from_documents function which would first create the index
130 | # and then do a bulk add. Both add_documents and from_documents do a bulk add
131 | # but it is important to call from_documents first so that the index is created
132 | # correctly for K-NN
133 | index_exists = check_if_index_exists(args.opensearch_index_name,
134 | args.aws_region,
135 | args.opensearch_cluster_domain,
136 | http_auth)
137 |
138 | embeddings = create_sagemaker_embeddings_from_js_model(args.embeddings_model_endpoint_name, args.aws_region)
139 |
140 | if index_exists is False:
141 | # create an index if the create index hint file exists
142 | path = os.path.join(args.input_data_dir, args.create_index_hint_file)
143 | if os.path.isfile(path) is True:
144 | logger.info(f"index {args.opensearch_index_name} does not exist but {path} file is present so will create index")
145 | # by default langchain would create a k-NN index and the embeddings would be ingested as a k-NN vector type
146 | docsearch = OpenSearchVectorSearch.from_documents(index_name=args.opensearch_index_name,
147 | documents=shards[0],
148 | embedding=embeddings,
149 | opensearch_url=args.opensearch_cluster_domain,
150 | http_auth=http_auth)
151 | # we now need to start the loop below for the second shard
152 | shard_start_index = 1
153 | else:
154 | logger.info(f"index {args.opensearch_index_name} does not exist and {path} file is not present, "
155 | f"will wait for some other node to create the index")
156 | shard_start_index = 0
157 | # start a loop to wait for index creation by another node
158 | time_slept = 0
159 | while True:
160 | logger.info(f"index {args.opensearch_index_name} still does not exist, sleeping...")
161 | time.sleep(PER_ITER_SLEEP_TIME)
162 | index_exists = check_if_index_exists(args.opensearch_index_name,
163 | args.aws_region,
164 | args.opensearch_cluster_domain,
165 | http_auth)
166 | if index_exists is True:
167 | logger.info(f"index {args.opensearch_index_name} now exists")
168 | break
169 | time_slept += PER_ITER_SLEEP_TIME
170 | if time_slept >= TOTAL_INDEX_CREATION_WAIT_TIME:
171 | logger.error(f"time_slept={time_slept} >= {TOTAL_INDEX_CREATION_WAIT_TIME}, not waiting anymore for index creation")
172 | break
173 |
174 | else:
175 | logger.info(f"index={args.opensearch_index_name} does exists, going to call add_documents")
176 | shard_start_index = 0
177 |
178 | with mp.Pool(processes = args.process_count) as pool:
179 | results = pool.map(partial(process_shard,
180 | embeddings_model_endpoint_name=args.embeddings_model_endpoint_name,
181 | aws_region=args.aws_region,
182 | os_index_name=args.opensearch_index_name,
183 | os_domain_ep=args.opensearch_cluster_domain,
184 | os_http_auth=http_auth),
185 | shards[shard_start_index:])
186 |
187 | t2 = time.time()
188 | logger.info(f'run time in seconds: {t2-t1:.2f}')
189 | logger.info("all done")
190 |
--------------------------------------------------------------------------------
/data_ingestion_to_vectordb/container/sm_helper.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper functions for using Samgemaker Endpoint via langchain
3 | """
4 | import time
5 | import json
6 | import logging
7 | from typing import List
8 |
9 | from langchain_community.embeddings import SagemakerEndpointEmbeddings
10 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | # extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding function
15 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
16 | def embed_documents(
17 | self, texts: List[str], chunk_size: int = 5
18 | ) -> List[List[float]]:
19 | """Compute doc embeddings using a SageMaker Inference Endpoint.
20 |
21 | Args:
22 | texts: The list of texts to embed.
23 | chunk_size: The chunk size defines how many input texts will
24 | be grouped together as request. If None, will use the
25 | chunk size specified by the class.
26 |
27 | Returns:
28 | List of embeddings, one for each text.
29 | """
30 | results = []
31 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
32 | st = time.time()
33 | for i in range(0, len(texts), _chunk_size):
34 | response = self._embedding_func(texts[i:i + _chunk_size])
35 | results.extend(response)
36 | time_taken = time.time() - st
37 | logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}")
38 | return results
39 |
40 |
41 | # class for serializing/deserializing requests/responses to/from the embeddings model
42 | class ContentHandler(EmbeddingsContentHandler):
43 | content_type = "application/json"
44 | accepts = "application/json"
45 |
46 | def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
47 |
48 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
49 | return input_str.encode('utf-8')
50 |
51 | def transform_output(self, output: bytes) -> str:
52 |
53 | response_json = json.loads(output.read().decode("utf-8"))
54 | embeddings = response_json["embedding"]
55 | if len(embeddings) == 1:
56 | return [embeddings[0]]
57 | return embeddings
58 |
59 |
60 | def create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name: str, aws_region: str) -> SagemakerEndpointEmbeddingsJumpStart:
61 | # all set to create the objects for the ContentHandler and
62 | # SagemakerEndpointEmbeddingsJumpStart classes
63 | content_handler = ContentHandler()
64 |
65 | # note the name of the LLM Sagemaker endpoint, this is the model that we would
66 | # be using for generating the embeddings
67 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
68 | endpoint_name=embeddings_model_endpoint_name,
69 | region_name=aws_region,
70 | content_handler=content_handler
71 | )
72 | return embeddings
73 |
74 |
--------------------------------------------------------------------------------
/data_ingestion_to_vectordb/data_ingestion_to_opensearch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "93179240-9c5f-4ba6-a1c7-3a981624f794",
6 | "metadata": {},
7 | "source": [
8 | "# Ingest massive amounts of data to a Vector DB (Amazon OpenSearch)\n",
9 | "**_Use of Amazon OpenSearch as a vector database for storing embeddings_**\n",
10 | "\n",
11 | "This notebook works well on `ml.t3.xlarge` instance with `Python3` kernel from **JupyterLab** or `Data Science 3.0` kernel from **SageMaker Studio Class**.\n",
12 | "\n",
13 | "Here is a list of packages that are used in this notebook.\n",
14 | "\n",
15 | "```\n",
16 | "!pip list | grep -E -w \"sagemaker|ipywidgets|langchain|opensearch-py|faiss-cpu|numpy|sh|SQLAlchemy\"\n",
17 | "---------------------------------------------------------------------------------------------------\n",
18 | "faiss-cpu 1.7.4\n",
19 | "ipywidgets 7.6.5\n",
20 | "langchain 0.1.11\n",
21 | "langchain-community 0.0.29\n",
22 | "langchain-core 0.1.23\n",
23 | "numpy 1.26.2\n",
24 | "opensearch-py 2.2.0\n",
25 | "sagemaker 2.199.0\n",
26 | "sagemaker-studio-image-build 0.6.0\n",
27 | "sh 2.0.4\n",
28 | "SQLAlchemy 2.0.2\n",
29 | "```"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "id": "79aae52c-cd7a-4637-a07d-9c0131dc7d0a",
35 | "metadata": {},
36 | "source": [
37 | "## Step 1: Setup\n",
38 | "Install the required packages."
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "87e64f84-b7ac-427d-b5a8-cf98b430be9b",
45 | "metadata": {
46 | "tags": []
47 | },
48 | "outputs": [],
49 | "source": [
50 | "%%capture --no-stderr\n",
51 | "\n",
52 | "!pip install -U pip\n",
53 | "!pip install -U langchain==0.1.11\n",
54 | "!pip install -U langchain-community==0.0.29\n",
55 | "!pip install -U SQLAlchemy==2.0.2\n",
56 | "!pip install -U opensearch-py==2.2.0\n",
57 | "!pip install -U faiss_cpu==1.7.4\n",
58 | "!pip install -U sh==2.0.4\n",
59 | "!pip install -U sagemaker-studio-image-build==0.6.0"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "d88757ba-7ae1-4efb-9c02-ab17ec22e79a",
66 | "metadata": {
67 | "tags": []
68 | },
69 | "outputs": [],
70 | "source": [
71 | "!pip list | grep -E -w \"sagemaker|ipywidgets|langchain|opensearch-py|faiss-cpu|numpy|sh|SQLAlchemy\""
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "id": "c017bc3f-e507-4f0c-b640-ea774c5ea9c8",
77 | "metadata": {},
78 | "source": [
79 | "## Step 2: Download the data from the web and upload to S3\n",
80 | "\n",
81 | "In this step we use `wget` to crawl a Python documentation style website data. All files other than `html`, `txt` and `md` are removed. **This data download would take a few minutes**."
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "id": "5c2b8c14-0ffc-4090-adf1-c2a8a1bdebaa",
88 | "metadata": {
89 | "tags": []
90 | },
91 | "outputs": [],
92 | "source": [
93 | "WEBSITE = \"https://sagemaker.readthedocs.io/en/stable/\"\n",
94 | "DOMAIN = \"sagemaker.readthedocs.io\"\n",
95 | "DATA_DIR = \"docs\""
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "0eb232ee-6b62-4718-9104-345fe7978703",
102 | "metadata": {
103 | "tags": []
104 | },
105 | "outputs": [],
106 | "source": [
107 | "!python ./scripts/get_data.py --website {WEBSITE} --domain {DOMAIN} --output-dir {DATA_DIR}"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "8ee1fbb8-583a-4c41-a831-715e4250ff3c",
114 | "metadata": {
115 | "tags": []
116 | },
117 | "outputs": [],
118 | "source": [
119 | "import boto3\n",
120 | "import sagemaker\n",
121 | "\n",
122 | "sagemaker_session = sagemaker.session.Session()\n",
123 | "aws_region = boto3.Session().region_name\n",
124 | "bucket = sagemaker_session.default_bucket()"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "id": "6c127969-4abc-4a31-8829-c00bee321a95",
131 | "metadata": {
132 | "tags": []
133 | },
134 | "outputs": [],
135 | "source": [
136 | "CREATE_OS_INDEX_HINT_FILE = \"_create_index_hint\"\n",
137 | "app_name = 'llm-app-rag'"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "id": "25217f27-4995-4da5-8fc4-b1b9533185b5",
144 | "metadata": {
145 | "tags": []
146 | },
147 | "outputs": [],
148 | "source": [
149 | "# create a dummy file called _create_index to provide a hint for opensearch index creation\n",
150 | "# this is needed for Sagemaker Processing Job when there are multiple instance nodes\n",
151 | "# all running the same code for data ingestion but only one node needs to create the index\n",
152 | "!touch {DATA_DIR}/{CREATE_OS_INDEX_HINT_FILE}\n",
153 | "\n",
154 | "# upload this data to S3, to be used when we run the Sagemaker Processing Job\n",
155 | "!aws s3 cp --recursive {DATA_DIR}/ s3://{bucket}/{app_name}/{DOMAIN}"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "af7dc134",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "from typing import List\n",
166 | "\n",
167 | "\n",
168 | "def get_cfn_outputs(stackname: str, region_name: str) -> List:\n",
169 | " cfn = boto3.client('cloudformation', region_name=region_name)\n",
170 | " outputs = {}\n",
171 | " for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:\n",
172 | " outputs[output['OutputKey']] = output['OutputValue']\n",
173 | " return outputs"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "id": "b616e2c8",
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "CFN_STACK_NAME = 'EmbeddingEndpointStack'\n",
184 | "\n",
185 | "cfn_stack_outputs = get_cfn_outputs(CFN_STACK_NAME, aws_region)\n",
186 | "embeddings_model_endpoint_name = cfn_stack_outputs['EmbeddingEndpointName']"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "id": "cbef08c3",
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "CFN_STACK_NAME = \"RAGOpenSearchStack\"\n",
197 | "cfn_stack_outputs = get_cfn_outputs(CFN_STACK_NAME, aws_region)\n",
198 | "\n",
199 | "opensearch_domain_endpoint = cfn_stack_outputs['OpenSearchDomainEndpoint']\n",
200 | "opensearch_secretid = cfn_stack_outputs['OpenSearchSecret']\n",
201 | "opensearch_index = 'llm_rag_embeddings'"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "id": "fa041f3f",
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "CHUNK_SIZE_FOR_DOC_SPLIT = 600\n",
212 | "CHUNK_OVERLAP_FOR_DOC_SPLIT = 20"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "id": "15667be7-43b5-4954-95e0-885c9173f82c",
218 | "metadata": {
219 | "tags": []
220 | },
221 | "source": [
222 | "## Step 3: Load data into OpenSearch\n",
223 | "\n",
224 | "- Option 1) Parallel loading data with SageMaker Processing Job\n",
225 | "- Option 2) Sequential loading data with Document Loader"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "id": "406c6dd1",
231 | "metadata": {},
232 | "source": [
233 | "### Option 1) Parallel loading data with SageMaker Processing Job\n",
234 | "\n",
235 | "We now have a working script that is able to ingest data into an OpenSearch index. But for this to work for massive amounts of data we need to scale up the processing by running this code in a distributed fashion. We will do this using Sagemkaer Processing Job. This involves the following steps:\n",
236 | "\n",
237 | "1. Create a custom container in which we will install the `langchain` and `opensearch-py` packges and then upload this container image to Amazon Elastic Container Registry (ECR).\n",
238 | "2. Use the Sagemaker `ScriptProcessor` class to create a Sagemaker Processing job that will run on multiple nodes.\n",
239 | " - The data files available in S3 are automatically distributed across in the Sagemaker Processing Job instances by setting `s3_data_distribution_type='ShardedByS3Key'` as part of the `ProcessingInput` provided to the processing job.\n",
240 | " - Each node processes a subset of the files and this brings down the overall time required to ingest the data into Opensearch.\n",
241 | " - Each node also uses Python `multiprocessing` to internally also parallelize the file processing. Thus, **there are two levels of parallelization happening, one at the cluster level where individual nodes are distributing the work (files) amongst themselves and another at the node level where the files in a node are also split between multiple processes running on the node**."
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "id": "b45d2938-994f-432d-8c50-92269f22f4b8",
247 | "metadata": {},
248 | "source": [
249 | "### Create custom container\n",
250 | "\n",
251 | "We will now create a container locally and push the container image to ECR. **The container creation process takes about 1 minute**.\n",
252 | "\n",
253 | "1. The container include all the Python packages we need i.e. `langchain`, `opensearch-py`, `sagemaker` and `beautifulsoup4`.\n",
254 | "1. The container also includes the `credentials.py` script for retrieving credentials from Secrets Manager and `sm_helper.py` for helping to create SageMaker endpoint classes that langchain uses."
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "id": "01848619-8b49-48da-8cbf-c9cbbd8d1e40",
261 | "metadata": {
262 | "tags": []
263 | },
264 | "outputs": [],
265 | "source": [
266 | "DOCKER_IMAGE = \"load-data-opensearch-custom\"\n",
267 | "DOCKER_IMAGE_TAG = \"latest\""
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "id": "97bd70a7-8f61-477c-a8ef-82c0b5cd7821",
274 | "metadata": {
275 | "tags": []
276 | },
277 | "outputs": [],
278 | "source": [
279 | "!cd ./container && sm-docker build . --repository {DOCKER_IMAGE}:{DOCKER_IMAGE_TAG}"
280 | ]
281 | },
282 | {
283 | "cell_type": "markdown",
284 | "id": "1089064e-2099-4226-82c1-c7c406203c49",
285 | "metadata": {},
286 | "source": [
287 | "### Create and run the Sagemaker Processing Job\n",
288 | "\n",
289 | "Now we will run the Sagemaker Processing Job to ingest the data into OpenSearch."
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "id": "afb7373c-e80f-4d1a-a8dc-0dc79fb28e8a",
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "import sys\n",
300 | "import time\n",
301 | "import logging\n",
302 | "\n",
303 | "from sagemaker.processing import (\n",
304 | " ProcessingInput,\n",
305 | " ScriptProcessor,\n",
306 | ")\n",
307 | "\n",
308 | "logger = logging.getLogger()\n",
309 | "logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)\n",
310 | "\n",
311 | "\n",
312 | "# setup the parameters for the job\n",
313 | "base_job_name = f\"{app_name}-job\"\n",
314 | "tags = [{\"Key\": \"data\", \"Value\": \"embeddings-for-llm-apps\"}]\n",
315 | "\n",
316 | "account_id = boto3.client(\"sts\").get_caller_identity()[\"Account\"]\n",
317 | "aws_role = sagemaker_session.get_caller_identity_arn()\n",
318 | "\n",
319 | "# use the custom container we just created\n",
320 | "image_uri = f\"{account_id}.dkr.ecr.{aws_region}.amazonaws.com/{DOCKER_IMAGE}:{DOCKER_IMAGE_TAG}\"\n",
321 | "\n",
322 | "# instance type and count determined via trial and error: how much overall processing time\n",
323 | "# and what compute cost works best for your use-case\n",
324 | "instance_type = \"ml.m5.xlarge\"\n",
325 | "instance_count = 3\n",
326 | "logger.info(f\"base_job_name={base_job_name}, tags={tags}, image_uri={image_uri}, instance_type={instance_type}, instance_count={instance_count}\")\n",
327 | "\n",
328 | "# setup the ScriptProcessor with the above parameters\n",
329 | "processor = ScriptProcessor(base_job_name=base_job_name,\n",
330 | " image_uri=image_uri,\n",
331 | " role=aws_role,\n",
332 | " instance_type=instance_type,\n",
333 | " instance_count=instance_count,\n",
334 | " command=[\"python3\"],\n",
335 | " tags=tags)\n",
336 | "\n",
337 | "# setup input from S3, note the ShardedByS3Key, this ensures that\n",
338 | "# each instance gets a random and equal subset of the files in S3.\n",
339 | "inputs = [ProcessingInput(source=f\"s3://{bucket}/{app_name}/{DOMAIN}\",\n",
340 | " destination='/opt/ml/processing/input_data',\n",
341 | " s3_data_distribution_type='ShardedByS3Key',\n",
342 | " s3_data_type='S3Prefix')]\n",
343 | "\n",
344 | "\n",
345 | "logger.info(f\"creating an opensearch index with name={opensearch_index}\")\n",
346 | "\n",
347 | "# ready to run the processing job\n",
348 | "st = time.time()\n",
349 | "processor.run(code=\"container/load_data_into_opensearch.py\",\n",
350 | " inputs=inputs,\n",
351 | " outputs=[],\n",
352 | " arguments=[\"--opensearch-cluster-domain\", opensearch_domain_endpoint,\n",
353 | " \"--opensearch-secretid\", opensearch_secretid,\n",
354 | " \"--opensearch-index-name\", opensearch_index,\n",
355 | " \"--aws-region\", aws_region,\n",
356 | " \"--embeddings-model-endpoint-name\", embeddings_model_endpoint_name,\n",
357 | " \"--chunk-size-for-doc-split\", str(CHUNK_SIZE_FOR_DOC_SPLIT),\n",
358 | " \"--chunk-overlap-for-doc-split\", str(CHUNK_OVERLAP_FOR_DOC_SPLIT),\n",
359 | " \"--input-data-dir\", \"/opt/ml/processing/input_data\",\n",
360 | " \"--create-index-hint-file\", CREATE_OS_INDEX_HINT_FILE,\n",
361 | " \"--process-count\", \"2\"])\n",
362 | "\n",
363 | "time_taken = time.time() - st\n",
364 | "logger.info(f\"processing job completed, total time taken={time_taken}s\")\n",
365 | "\n",
366 | "preprocessing_job_description = processor.jobs[-1].describe()\n",
367 | "logger.info(preprocessing_job_description)"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "id": "bc4657c0",
373 | "metadata": {},
374 | "source": [
375 | "### Option 2) Sequential loading data with Document Loader"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": null,
381 | "id": "c9ef2a96",
382 | "metadata": {},
383 | "outputs": [],
384 | "source": [
385 | "%%capture --no-stderr\n",
386 | "\n",
387 | "!pip install -Uq beautifulsoup4==4.12.3"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "id": "00b3c8e3",
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "%%time\n",
398 | "\n",
399 | "from langchain_community.document_loaders import ReadTheDocsLoader\n",
400 | "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
401 | "import time\n",
402 | "\n",
403 | "\n",
404 | "loader = ReadTheDocsLoader(DATA_DIR)\n",
405 | "text_splitter = RecursiveCharacterTextSplitter(\n",
406 | " chunk_size=CHUNK_SIZE_FOR_DOC_SPLIT,\n",
407 | " chunk_overlap=CHUNK_OVERLAP_FOR_DOC_SPLIT,\n",
408 | " length_function=len,\n",
409 | ")\n",
410 | "\n",
411 | "\n",
412 | "docs = loader.load()\n",
413 | "\n",
414 | "# add a custom metadata field, such as timestamp\n",
415 | "for doc in docs:\n",
416 | " doc.metadata['timestamp'] = time.time()\n",
417 | " doc.metadata['embeddings_model'] = embeddings_model_endpoint_name"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "id": "deb8e52e",
424 | "metadata": {},
425 | "outputs": [],
426 | "source": [
427 | "chunks = text_splitter.create_documents(\n",
428 | " [doc.page_content for doc in docs],\n",
429 | " metadatas=[doc.metadata for doc in docs]\n",
430 | ")"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": null,
436 | "id": "a88606e0",
437 | "metadata": {},
438 | "outputs": [],
439 | "source": [
440 | "from container.sm_helper import create_sagemaker_embeddings_from_js_model\n",
441 | "\n",
442 | "\n",
443 | "embeddings = create_sagemaker_embeddings_from_js_model(\n",
444 | " embeddings_model_endpoint_name,\n",
445 | " aws_region\n",
446 | ")"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": null,
452 | "id": "aa92edfd",
453 | "metadata": {},
454 | "outputs": [],
455 | "source": [
456 | "import numpy as np\n",
457 | "\n",
458 | "\n",
459 | "MAX_OS_DOCS_PER_PUT = 500\n",
460 | "\n",
461 | "db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1\n",
462 | "shards = np.array_split(chunks, db_shards)\n",
463 | "\n",
464 | "print(f'Loading chunks into vector store ... using {len(shards)} shards')"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": null,
470 | "id": "ed247419",
471 | "metadata": {},
472 | "outputs": [],
473 | "source": [
474 | "from container.credentials import get_credentials\n",
475 | "\n",
476 | "\n",
477 | "creds = get_credentials(opensearch_secretid, aws_region)\n",
478 | "http_auth = (creds['username'], creds['password'])\n",
479 | "http_auth"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": null,
485 | "id": "316802fd",
486 | "metadata": {},
487 | "outputs": [],
488 | "source": [
489 | "from langchain_community.vectorstores import OpenSearchVectorSearch\n",
490 | "\n",
491 | "\n",
492 | "docsearch = OpenSearchVectorSearch.from_documents(\n",
493 | " index_name=opensearch_index,\n",
494 | " documents=shards[0],\n",
495 | " embedding=embeddings,\n",
496 | " opensearch_url=opensearch_domain_endpoint,\n",
497 | " http_auth=http_auth\n",
498 | ")"
499 | ]
500 | },
501 | {
502 | "cell_type": "code",
503 | "execution_count": null,
504 | "id": "ac1e0ea0",
505 | "metadata": {},
506 | "outputs": [],
507 | "source": [
508 | "%%time\n",
509 | "\n",
510 | "for i, shard in enumerate(shards[1:]):\n",
511 | " docsearch.add_documents(documents=shard)\n",
512 | " print(f\"[{i+1}] shard is added.\")"
513 | ]
514 | },
515 | {
516 | "cell_type": "markdown",
517 | "id": "1e444161-262e-44e5-ad31-e490a763be4e",
518 | "metadata": {},
519 | "source": [
520 | "## Step 4: Do a similarity search for user input to documents (embeddings) in OpenSearch"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": null,
526 | "id": "294a7292-8bdb-4d11-a23d-130a4a039cd2",
527 | "metadata": {
528 | "tags": []
529 | },
530 | "outputs": [],
531 | "source": [
532 | "from container.credentials import get_credentials\n",
533 | "from langchain.vectorstores import OpenSearchVectorSearch\n",
534 | "from container.sm_helper import create_sagemaker_embeddings_from_js_model\n",
535 | "\n",
536 | "creds = get_credentials(opensearch_secretid, aws_region)\n",
537 | "http_auth = (creds['username'], creds['password'])\n",
538 | "\n",
539 | "docsearch = OpenSearchVectorSearch(index_name=opensearch_index,\n",
540 | " embedding_function=create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name,\n",
541 | " aws_region),\n",
542 | " opensearch_url=opensearch_domain_endpoint,\n",
543 | " http_auth=http_auth)\n",
544 | "\n",
545 | "q = \"Which XGBoost versions does SageMaker support?\"\n",
546 | "docs = docsearch.similarity_search(q, k=3) #, search_type=\"script_scoring\", space_type=\"cosinesimil\"\n",
547 | "for doc in docs:\n",
548 | " logger.info(\"----------\")\n",
549 | " logger.info(f\"content=\\\"{doc.page_content}\\\",\\nmetadata=\\\"{doc.metadata}\\\"\")"
550 | ]
551 | },
552 | {
553 | "cell_type": "markdown",
554 | "id": "6e29eae5-c463-4153-9167-e4628c74d13c",
555 | "metadata": {
556 | "tags": []
557 | },
558 | "source": [
559 | "## Cleanup\n",
560 | "\n",
561 | "To avoid incurring future charges, delete the resources. You can do this by deleting the CloudFormation template used to create the IAM role and SageMaker notebook."
562 | ]
563 | },
564 | {
565 | "cell_type": "markdown",
566 | "id": "59ce3fe8-bb71-4e22-a551-2475eb2d16b7",
567 | "metadata": {},
568 | "source": [
569 | "---\n",
570 | "\n",
571 | "## Conclusion\n",
572 | "In this notebook we were able to see how to use LLMs deployed on a SageMaker Endpoint to generate embeddings and then ingest those embeddings into OpenSearch and finally do a similarity search for user input to the documents (embeddings) stored in OpenSearch. We used langchain as an abstraction layer to talk to both the SageMaker Endpoint as well as OpenSearch."
573 | ]
574 | },
575 | {
576 | "cell_type": "markdown",
577 | "id": "53386268-0cf9-4a37-b3d0-711fba1e5585",
578 | "metadata": {},
579 | "source": [
580 | "---\n",
581 | "\n",
582 | "## Appendix"
583 | ]
584 | },
585 | {
586 | "cell_type": "code",
587 | "execution_count": null,
588 | "id": "7332274c-586d-4cf9-838f-ea7e0cbe6c0f",
589 | "metadata": {
590 | "tags": []
591 | },
592 | "outputs": [],
593 | "source": [
594 | "import numpy as np\n",
595 | "from container.sm_helper import create_sagemaker_embeddings_from_js_model\n",
596 | "\n",
597 | "CFN_STACK_NAME = 'EmbeddingEndpointStack'\n",
598 | "cfn_stack_outputs = get_cfn_outputs(CFN_STACK_NAME, aws_region)\n",
599 | "embeddings_model_endpoint_name = cfn_stack_outputs['EmbeddingEndpointName']\n",
600 | "\n",
601 | "embeddings = create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name, aws_region)\n",
602 | "\n",
603 | "text = \"This is a sample query.\"\n",
604 | "query_result = embeddings.embed_query(text)\n",
605 | "\n",
606 | "print(np.array(query_result))\n",
607 | "print(f\"length: {len(query_result)}\")"
608 | ]
609 | },
610 | {
611 | "cell_type": "markdown",
612 | "id": "dd881bab",
613 | "metadata": {},
614 | "source": [
615 | "## References\n",
616 | "\n",
617 | " * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/)\n",
618 | " * [Using the Amazon SageMaker Studio Image Build CLI to build container images from your Studio notebooks](https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/)\n",
619 | " * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models."
620 | ]
621 | }
622 | ],
623 | "metadata": {
624 | "availableInstances": [
625 | {
626 | "_defaultOrder": 0,
627 | "_isFastLaunch": true,
628 | "category": "General purpose",
629 | "gpuNum": 0,
630 | "hideHardwareSpecs": false,
631 | "memoryGiB": 4,
632 | "name": "ml.t3.medium",
633 | "vcpuNum": 2
634 | },
635 | {
636 | "_defaultOrder": 1,
637 | "_isFastLaunch": false,
638 | "category": "General purpose",
639 | "gpuNum": 0,
640 | "hideHardwareSpecs": false,
641 | "memoryGiB": 8,
642 | "name": "ml.t3.large",
643 | "vcpuNum": 2
644 | },
645 | {
646 | "_defaultOrder": 2,
647 | "_isFastLaunch": false,
648 | "category": "General purpose",
649 | "gpuNum": 0,
650 | "hideHardwareSpecs": false,
651 | "memoryGiB": 16,
652 | "name": "ml.t3.xlarge",
653 | "vcpuNum": 4
654 | },
655 | {
656 | "_defaultOrder": 3,
657 | "_isFastLaunch": false,
658 | "category": "General purpose",
659 | "gpuNum": 0,
660 | "hideHardwareSpecs": false,
661 | "memoryGiB": 32,
662 | "name": "ml.t3.2xlarge",
663 | "vcpuNum": 8
664 | },
665 | {
666 | "_defaultOrder": 4,
667 | "_isFastLaunch": true,
668 | "category": "General purpose",
669 | "gpuNum": 0,
670 | "hideHardwareSpecs": false,
671 | "memoryGiB": 8,
672 | "name": "ml.m5.large",
673 | "vcpuNum": 2
674 | },
675 | {
676 | "_defaultOrder": 5,
677 | "_isFastLaunch": false,
678 | "category": "General purpose",
679 | "gpuNum": 0,
680 | "hideHardwareSpecs": false,
681 | "memoryGiB": 16,
682 | "name": "ml.m5.xlarge",
683 | "vcpuNum": 4
684 | },
685 | {
686 | "_defaultOrder": 6,
687 | "_isFastLaunch": false,
688 | "category": "General purpose",
689 | "gpuNum": 0,
690 | "hideHardwareSpecs": false,
691 | "memoryGiB": 32,
692 | "name": "ml.m5.2xlarge",
693 | "vcpuNum": 8
694 | },
695 | {
696 | "_defaultOrder": 7,
697 | "_isFastLaunch": false,
698 | "category": "General purpose",
699 | "gpuNum": 0,
700 | "hideHardwareSpecs": false,
701 | "memoryGiB": 64,
702 | "name": "ml.m5.4xlarge",
703 | "vcpuNum": 16
704 | },
705 | {
706 | "_defaultOrder": 8,
707 | "_isFastLaunch": false,
708 | "category": "General purpose",
709 | "gpuNum": 0,
710 | "hideHardwareSpecs": false,
711 | "memoryGiB": 128,
712 | "name": "ml.m5.8xlarge",
713 | "vcpuNum": 32
714 | },
715 | {
716 | "_defaultOrder": 9,
717 | "_isFastLaunch": false,
718 | "category": "General purpose",
719 | "gpuNum": 0,
720 | "hideHardwareSpecs": false,
721 | "memoryGiB": 192,
722 | "name": "ml.m5.12xlarge",
723 | "vcpuNum": 48
724 | },
725 | {
726 | "_defaultOrder": 10,
727 | "_isFastLaunch": false,
728 | "category": "General purpose",
729 | "gpuNum": 0,
730 | "hideHardwareSpecs": false,
731 | "memoryGiB": 256,
732 | "name": "ml.m5.16xlarge",
733 | "vcpuNum": 64
734 | },
735 | {
736 | "_defaultOrder": 11,
737 | "_isFastLaunch": false,
738 | "category": "General purpose",
739 | "gpuNum": 0,
740 | "hideHardwareSpecs": false,
741 | "memoryGiB": 384,
742 | "name": "ml.m5.24xlarge",
743 | "vcpuNum": 96
744 | },
745 | {
746 | "_defaultOrder": 12,
747 | "_isFastLaunch": false,
748 | "category": "General purpose",
749 | "gpuNum": 0,
750 | "hideHardwareSpecs": false,
751 | "memoryGiB": 8,
752 | "name": "ml.m5d.large",
753 | "vcpuNum": 2
754 | },
755 | {
756 | "_defaultOrder": 13,
757 | "_isFastLaunch": false,
758 | "category": "General purpose",
759 | "gpuNum": 0,
760 | "hideHardwareSpecs": false,
761 | "memoryGiB": 16,
762 | "name": "ml.m5d.xlarge",
763 | "vcpuNum": 4
764 | },
765 | {
766 | "_defaultOrder": 14,
767 | "_isFastLaunch": false,
768 | "category": "General purpose",
769 | "gpuNum": 0,
770 | "hideHardwareSpecs": false,
771 | "memoryGiB": 32,
772 | "name": "ml.m5d.2xlarge",
773 | "vcpuNum": 8
774 | },
775 | {
776 | "_defaultOrder": 15,
777 | "_isFastLaunch": false,
778 | "category": "General purpose",
779 | "gpuNum": 0,
780 | "hideHardwareSpecs": false,
781 | "memoryGiB": 64,
782 | "name": "ml.m5d.4xlarge",
783 | "vcpuNum": 16
784 | },
785 | {
786 | "_defaultOrder": 16,
787 | "_isFastLaunch": false,
788 | "category": "General purpose",
789 | "gpuNum": 0,
790 | "hideHardwareSpecs": false,
791 | "memoryGiB": 128,
792 | "name": "ml.m5d.8xlarge",
793 | "vcpuNum": 32
794 | },
795 | {
796 | "_defaultOrder": 17,
797 | "_isFastLaunch": false,
798 | "category": "General purpose",
799 | "gpuNum": 0,
800 | "hideHardwareSpecs": false,
801 | "memoryGiB": 192,
802 | "name": "ml.m5d.12xlarge",
803 | "vcpuNum": 48
804 | },
805 | {
806 | "_defaultOrder": 18,
807 | "_isFastLaunch": false,
808 | "category": "General purpose",
809 | "gpuNum": 0,
810 | "hideHardwareSpecs": false,
811 | "memoryGiB": 256,
812 | "name": "ml.m5d.16xlarge",
813 | "vcpuNum": 64
814 | },
815 | {
816 | "_defaultOrder": 19,
817 | "_isFastLaunch": false,
818 | "category": "General purpose",
819 | "gpuNum": 0,
820 | "hideHardwareSpecs": false,
821 | "memoryGiB": 384,
822 | "name": "ml.m5d.24xlarge",
823 | "vcpuNum": 96
824 | },
825 | {
826 | "_defaultOrder": 20,
827 | "_isFastLaunch": false,
828 | "category": "General purpose",
829 | "gpuNum": 0,
830 | "hideHardwareSpecs": true,
831 | "memoryGiB": 0,
832 | "name": "ml.geospatial.interactive",
833 | "supportedImageNames": [
834 | "sagemaker-geospatial-v1-0"
835 | ],
836 | "vcpuNum": 0
837 | },
838 | {
839 | "_defaultOrder": 21,
840 | "_isFastLaunch": true,
841 | "category": "Compute optimized",
842 | "gpuNum": 0,
843 | "hideHardwareSpecs": false,
844 | "memoryGiB": 4,
845 | "name": "ml.c5.large",
846 | "vcpuNum": 2
847 | },
848 | {
849 | "_defaultOrder": 22,
850 | "_isFastLaunch": false,
851 | "category": "Compute optimized",
852 | "gpuNum": 0,
853 | "hideHardwareSpecs": false,
854 | "memoryGiB": 8,
855 | "name": "ml.c5.xlarge",
856 | "vcpuNum": 4
857 | },
858 | {
859 | "_defaultOrder": 23,
860 | "_isFastLaunch": false,
861 | "category": "Compute optimized",
862 | "gpuNum": 0,
863 | "hideHardwareSpecs": false,
864 | "memoryGiB": 16,
865 | "name": "ml.c5.2xlarge",
866 | "vcpuNum": 8
867 | },
868 | {
869 | "_defaultOrder": 24,
870 | "_isFastLaunch": false,
871 | "category": "Compute optimized",
872 | "gpuNum": 0,
873 | "hideHardwareSpecs": false,
874 | "memoryGiB": 32,
875 | "name": "ml.c5.4xlarge",
876 | "vcpuNum": 16
877 | },
878 | {
879 | "_defaultOrder": 25,
880 | "_isFastLaunch": false,
881 | "category": "Compute optimized",
882 | "gpuNum": 0,
883 | "hideHardwareSpecs": false,
884 | "memoryGiB": 72,
885 | "name": "ml.c5.9xlarge",
886 | "vcpuNum": 36
887 | },
888 | {
889 | "_defaultOrder": 26,
890 | "_isFastLaunch": false,
891 | "category": "Compute optimized",
892 | "gpuNum": 0,
893 | "hideHardwareSpecs": false,
894 | "memoryGiB": 96,
895 | "name": "ml.c5.12xlarge",
896 | "vcpuNum": 48
897 | },
898 | {
899 | "_defaultOrder": 27,
900 | "_isFastLaunch": false,
901 | "category": "Compute optimized",
902 | "gpuNum": 0,
903 | "hideHardwareSpecs": false,
904 | "memoryGiB": 144,
905 | "name": "ml.c5.18xlarge",
906 | "vcpuNum": 72
907 | },
908 | {
909 | "_defaultOrder": 28,
910 | "_isFastLaunch": false,
911 | "category": "Compute optimized",
912 | "gpuNum": 0,
913 | "hideHardwareSpecs": false,
914 | "memoryGiB": 192,
915 | "name": "ml.c5.24xlarge",
916 | "vcpuNum": 96
917 | },
918 | {
919 | "_defaultOrder": 29,
920 | "_isFastLaunch": true,
921 | "category": "Accelerated computing",
922 | "gpuNum": 1,
923 | "hideHardwareSpecs": false,
924 | "memoryGiB": 16,
925 | "name": "ml.g4dn.xlarge",
926 | "vcpuNum": 4
927 | },
928 | {
929 | "_defaultOrder": 30,
930 | "_isFastLaunch": false,
931 | "category": "Accelerated computing",
932 | "gpuNum": 1,
933 | "hideHardwareSpecs": false,
934 | "memoryGiB": 32,
935 | "name": "ml.g4dn.2xlarge",
936 | "vcpuNum": 8
937 | },
938 | {
939 | "_defaultOrder": 31,
940 | "_isFastLaunch": false,
941 | "category": "Accelerated computing",
942 | "gpuNum": 1,
943 | "hideHardwareSpecs": false,
944 | "memoryGiB": 64,
945 | "name": "ml.g4dn.4xlarge",
946 | "vcpuNum": 16
947 | },
948 | {
949 | "_defaultOrder": 32,
950 | "_isFastLaunch": false,
951 | "category": "Accelerated computing",
952 | "gpuNum": 1,
953 | "hideHardwareSpecs": false,
954 | "memoryGiB": 128,
955 | "name": "ml.g4dn.8xlarge",
956 | "vcpuNum": 32
957 | },
958 | {
959 | "_defaultOrder": 33,
960 | "_isFastLaunch": false,
961 | "category": "Accelerated computing",
962 | "gpuNum": 4,
963 | "hideHardwareSpecs": false,
964 | "memoryGiB": 192,
965 | "name": "ml.g4dn.12xlarge",
966 | "vcpuNum": 48
967 | },
968 | {
969 | "_defaultOrder": 34,
970 | "_isFastLaunch": false,
971 | "category": "Accelerated computing",
972 | "gpuNum": 1,
973 | "hideHardwareSpecs": false,
974 | "memoryGiB": 256,
975 | "name": "ml.g4dn.16xlarge",
976 | "vcpuNum": 64
977 | },
978 | {
979 | "_defaultOrder": 35,
980 | "_isFastLaunch": false,
981 | "category": "Accelerated computing",
982 | "gpuNum": 1,
983 | "hideHardwareSpecs": false,
984 | "memoryGiB": 61,
985 | "name": "ml.p3.2xlarge",
986 | "vcpuNum": 8
987 | },
988 | {
989 | "_defaultOrder": 36,
990 | "_isFastLaunch": false,
991 | "category": "Accelerated computing",
992 | "gpuNum": 4,
993 | "hideHardwareSpecs": false,
994 | "memoryGiB": 244,
995 | "name": "ml.p3.8xlarge",
996 | "vcpuNum": 32
997 | },
998 | {
999 | "_defaultOrder": 37,
1000 | "_isFastLaunch": false,
1001 | "category": "Accelerated computing",
1002 | "gpuNum": 8,
1003 | "hideHardwareSpecs": false,
1004 | "memoryGiB": 488,
1005 | "name": "ml.p3.16xlarge",
1006 | "vcpuNum": 64
1007 | },
1008 | {
1009 | "_defaultOrder": 38,
1010 | "_isFastLaunch": false,
1011 | "category": "Accelerated computing",
1012 | "gpuNum": 8,
1013 | "hideHardwareSpecs": false,
1014 | "memoryGiB": 768,
1015 | "name": "ml.p3dn.24xlarge",
1016 | "vcpuNum": 96
1017 | },
1018 | {
1019 | "_defaultOrder": 39,
1020 | "_isFastLaunch": false,
1021 | "category": "Memory Optimized",
1022 | "gpuNum": 0,
1023 | "hideHardwareSpecs": false,
1024 | "memoryGiB": 16,
1025 | "name": "ml.r5.large",
1026 | "vcpuNum": 2
1027 | },
1028 | {
1029 | "_defaultOrder": 40,
1030 | "_isFastLaunch": false,
1031 | "category": "Memory Optimized",
1032 | "gpuNum": 0,
1033 | "hideHardwareSpecs": false,
1034 | "memoryGiB": 32,
1035 | "name": "ml.r5.xlarge",
1036 | "vcpuNum": 4
1037 | },
1038 | {
1039 | "_defaultOrder": 41,
1040 | "_isFastLaunch": false,
1041 | "category": "Memory Optimized",
1042 | "gpuNum": 0,
1043 | "hideHardwareSpecs": false,
1044 | "memoryGiB": 64,
1045 | "name": "ml.r5.2xlarge",
1046 | "vcpuNum": 8
1047 | },
1048 | {
1049 | "_defaultOrder": 42,
1050 | "_isFastLaunch": false,
1051 | "category": "Memory Optimized",
1052 | "gpuNum": 0,
1053 | "hideHardwareSpecs": false,
1054 | "memoryGiB": 128,
1055 | "name": "ml.r5.4xlarge",
1056 | "vcpuNum": 16
1057 | },
1058 | {
1059 | "_defaultOrder": 43,
1060 | "_isFastLaunch": false,
1061 | "category": "Memory Optimized",
1062 | "gpuNum": 0,
1063 | "hideHardwareSpecs": false,
1064 | "memoryGiB": 256,
1065 | "name": "ml.r5.8xlarge",
1066 | "vcpuNum": 32
1067 | },
1068 | {
1069 | "_defaultOrder": 44,
1070 | "_isFastLaunch": false,
1071 | "category": "Memory Optimized",
1072 | "gpuNum": 0,
1073 | "hideHardwareSpecs": false,
1074 | "memoryGiB": 384,
1075 | "name": "ml.r5.12xlarge",
1076 | "vcpuNum": 48
1077 | },
1078 | {
1079 | "_defaultOrder": 45,
1080 | "_isFastLaunch": false,
1081 | "category": "Memory Optimized",
1082 | "gpuNum": 0,
1083 | "hideHardwareSpecs": false,
1084 | "memoryGiB": 512,
1085 | "name": "ml.r5.16xlarge",
1086 | "vcpuNum": 64
1087 | },
1088 | {
1089 | "_defaultOrder": 46,
1090 | "_isFastLaunch": false,
1091 | "category": "Memory Optimized",
1092 | "gpuNum": 0,
1093 | "hideHardwareSpecs": false,
1094 | "memoryGiB": 768,
1095 | "name": "ml.r5.24xlarge",
1096 | "vcpuNum": 96
1097 | },
1098 | {
1099 | "_defaultOrder": 47,
1100 | "_isFastLaunch": false,
1101 | "category": "Accelerated computing",
1102 | "gpuNum": 1,
1103 | "hideHardwareSpecs": false,
1104 | "memoryGiB": 16,
1105 | "name": "ml.g5.xlarge",
1106 | "vcpuNum": 4
1107 | },
1108 | {
1109 | "_defaultOrder": 48,
1110 | "_isFastLaunch": false,
1111 | "category": "Accelerated computing",
1112 | "gpuNum": 1,
1113 | "hideHardwareSpecs": false,
1114 | "memoryGiB": 32,
1115 | "name": "ml.g5.2xlarge",
1116 | "vcpuNum": 8
1117 | },
1118 | {
1119 | "_defaultOrder": 49,
1120 | "_isFastLaunch": false,
1121 | "category": "Accelerated computing",
1122 | "gpuNum": 1,
1123 | "hideHardwareSpecs": false,
1124 | "memoryGiB": 64,
1125 | "name": "ml.g5.4xlarge",
1126 | "vcpuNum": 16
1127 | },
1128 | {
1129 | "_defaultOrder": 50,
1130 | "_isFastLaunch": false,
1131 | "category": "Accelerated computing",
1132 | "gpuNum": 1,
1133 | "hideHardwareSpecs": false,
1134 | "memoryGiB": 128,
1135 | "name": "ml.g5.8xlarge",
1136 | "vcpuNum": 32
1137 | },
1138 | {
1139 | "_defaultOrder": 51,
1140 | "_isFastLaunch": false,
1141 | "category": "Accelerated computing",
1142 | "gpuNum": 1,
1143 | "hideHardwareSpecs": false,
1144 | "memoryGiB": 256,
1145 | "name": "ml.g5.16xlarge",
1146 | "vcpuNum": 64
1147 | },
1148 | {
1149 | "_defaultOrder": 52,
1150 | "_isFastLaunch": false,
1151 | "category": "Accelerated computing",
1152 | "gpuNum": 4,
1153 | "hideHardwareSpecs": false,
1154 | "memoryGiB": 192,
1155 | "name": "ml.g5.12xlarge",
1156 | "vcpuNum": 48
1157 | },
1158 | {
1159 | "_defaultOrder": 53,
1160 | "_isFastLaunch": false,
1161 | "category": "Accelerated computing",
1162 | "gpuNum": 4,
1163 | "hideHardwareSpecs": false,
1164 | "memoryGiB": 384,
1165 | "name": "ml.g5.24xlarge",
1166 | "vcpuNum": 96
1167 | },
1168 | {
1169 | "_defaultOrder": 54,
1170 | "_isFastLaunch": false,
1171 | "category": "Accelerated computing",
1172 | "gpuNum": 8,
1173 | "hideHardwareSpecs": false,
1174 | "memoryGiB": 768,
1175 | "name": "ml.g5.48xlarge",
1176 | "vcpuNum": 192
1177 | },
1178 | {
1179 | "_defaultOrder": 55,
1180 | "_isFastLaunch": false,
1181 | "category": "Accelerated computing",
1182 | "gpuNum": 8,
1183 | "hideHardwareSpecs": false,
1184 | "memoryGiB": 1152,
1185 | "name": "ml.p4d.24xlarge",
1186 | "vcpuNum": 96
1187 | },
1188 | {
1189 | "_defaultOrder": 56,
1190 | "_isFastLaunch": false,
1191 | "category": "Accelerated computing",
1192 | "gpuNum": 8,
1193 | "hideHardwareSpecs": false,
1194 | "memoryGiB": 1152,
1195 | "name": "ml.p4de.24xlarge",
1196 | "vcpuNum": 96
1197 | }
1198 | ],
1199 | "instance_type": "ml.t3.medium",
1200 | "kernelspec": {
1201 | "display_name": "Python 3 (Data Science 2.0)",
1202 | "language": "python",
1203 | "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:12345678012:image/sagemaker-data-science-38"
1204 | },
1205 | "language_info": {
1206 | "codemirror_mode": {
1207 | "name": "ipython",
1208 | "version": 3
1209 | },
1210 | "file_extension": ".py",
1211 | "mimetype": "text/x-python",
1212 | "name": "python",
1213 | "nbconvert_exporter": "python",
1214 | "pygments_lexer": "ipython3",
1215 | "version": "3.8.13"
1216 | }
1217 | },
1218 | "nbformat": 4,
1219 | "nbformat_minor": 5
1220 | }
1221 |
--------------------------------------------------------------------------------
/data_ingestion_to_vectordb/scripts/get_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import traceback
4 |
5 | from sh import cp, find, mkdir, wget
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 |
11 | parser.add_argument("--domain", type=str, default="sagemaker.readthedocs.io")
12 | parser.add_argument("--website", type=str, default="https://sagemaker.readthedocs.io/en/stable/")
13 | parser.add_argument("--output-dir", type=str, default="docs")
14 | parser.add_argument("--dryrun", action='store_true')
15 | args, _ = parser.parse_known_args()
16 |
17 | WEBSITE, DOMAIN, KB_DIR = (args.website, args.domain, args.output_dir)
18 |
19 | if args.dryrun:
20 | print(f"WEBSITE={WEBSITE}, DOMAIN={DOMAIN}, OUTPUT_DIR={KB_DIR}", file=sys.stderr)
21 | sys.exit(0)
22 |
23 | mkdir('-p', KB_DIR)
24 |
25 | try:
26 | WGET_ARGUMENTS = f"-e robots=off --recursive --no-clobber --page-requisites --html-extension --convert-links --restrict-file-names=windows --domains {DOMAIN} --no-parent {WEBSITE}"
27 | wget_argument_list = WGET_ARGUMENTS.split()
28 | wget(*wget_argument_list)
29 | except Exception as ex:
30 | traceback.print_exc()
31 |
32 | results = find(DOMAIN, '-name', '*.html')
33 | html_files = results.strip('\n').split('\n')
34 | for each in html_files:
35 | flat_i = each.replace('/', '-')
36 | cp(each, f"{KB_DIR}/{flat_i}")
37 |
38 | print(f"There are {len(html_files)} files in {KB_DIR} directory", file=sys.stderr)
39 |
40 |
41 | if __name__ == "__main__":
42 | main()
--------------------------------------------------------------------------------