├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── app ├── Dockerfile ├── README.md ├── app.py ├── env_vars.sh ├── images │ ├── ai-icon.png │ └── user-icon.png ├── opensearch_chat_flan_xl.py ├── opensearch_chat_llama2.py ├── opensearch_load_qa_chain_flan_xl.py ├── opensearch_load_qa_chain_llama2.py ├── opensearch_retriever_flan_xl.py ├── opensearch_retriever_llama2.py ├── qa-with-llm-and-rag.png └── requirements.txt ├── cdk_stacks ├── README.md ├── app.py ├── cdk.context.json ├── cdk.json ├── rag_with_aos │ ├── __init__.py │ ├── ecs_streamlit_app.py │ ├── ops.py │ ├── sm_custom_embedding_endpoint.py │ ├── sm_jumpstart_llm_endpoint.py │ ├── sm_studio.py │ └── vpc.py ├── rag_with_opensearch_arch.svg ├── requirements.txt └── source.bat └── data_ingestion_to_vectordb ├── container ├── Dockerfile ├── credentials.py ├── load_data_into_opensearch.py └── sm_helper.py ├── data_ingestion_to_opensearch.ipynb └── scripts └── get_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | Untitled*.ipynb 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | .DS_Store 108 | .idea/ 109 | bin/ 110 | lib64 111 | pyvenv.cfg 112 | *.bak 113 | share/ 114 | cdk.out/ 115 | cdk.context.json* 116 | zap/ 117 | 118 | */.gitignore 119 | */setup.py 120 | */source.bat 121 | 122 | */*/.gitignore 123 | */*/setup.py 124 | */*/source.bat -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QA with LLM and RAG (Retrieval Augmented Generation) 2 | 3 | This project is a Question Answering application with Large Language Models (LLMs) and Amazon OpenSearch Service. An application using the RAG(Retrieval Augmented Generation) approach retrieves information most relevant to the user’s request from the enterprise knowledge base or content, bundles it as context along with the user’s request as a prompt, and then sends it to the LLM to get a GenAI response. 4 | 5 | LLMs have limitations around the maximum word count for the input prompt, therefore choosing the right passages among thousands or millions of documents in the enterprise, has a direct impact on the LLM’s accuracy. 6 | 7 | In this project, Amazon OpenSearch Service is used for knowledge base. 8 | 9 | The overall architecture is like this: 10 | 11 | ![rag_with_opensearch_arch](./cdk_stacks/rag_with_opensearch_arch.svg) 12 | 13 | ### Overall Workflow 14 | 15 | 1. Deploy the cdk stacks (For more information, see [here](./cdk_stacks/README.md)). 16 | - A SageMaker Endpoint for text generation. 17 | - A SageMaker Endpoint for generating embeddings. 18 | - An Amazon OpenSearch cluster for storing embeddings. 19 | - Opensearch cluster's access credentials (username and password) stored in AWS Secrets Mananger as a name such as `OpenSearchMasterUserSecret1-xxxxxxxxxxxx`. 20 | 2. Open JupyterLab in SageMaker Studio and then open a new terminal. 21 | 3. Run the following commands on the terminal to clone the code repository for this project: 22 | ``` 23 | git clone --depth=1 https://github.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker.git 24 | ``` 25 | 4. Open `data_ingestion_to_opensearch` notebook and Run it. (For more information, see [here](./data_ingestion_to_vectordb/data_ingestion_to_opensearch.ipynb)) 26 | 5. Run Streamlit application. (For more information, see [here](./app/README.md)) 27 | 28 | ### References 29 | 30 | * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain (2023-05-25)](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/) 31 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/) 32 | * [Build Streamlit apps in Amazon SageMaker Studio (2023-04-11)](https://aws.amazon.com/blogs/machine-learning/build-streamlit-apps-in-amazon-sagemaker-studio/) 33 | * [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models (2023-05-03)](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/) 34 | * [(github) Amazon Kendra Retriver Samples](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples) 35 | * [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart (2023-05-02)](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/) 36 | * [Amazon OpenSearch Service’s vector database capabilities explained](https://aws.amazon.com/blogs/big-data/amazon-opensearch-services-vector-database-capabilities-explained/) 37 | * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models. 38 | * [Streamlit](https://streamlit.io/) - A faster way to build and share data apps 39 | * [Improve search relevance with ML in Amazon OpenSearch Service Workshop](https://catalog.workshops.aws/semantic-search/en-US) - Module 7. Retrieval Augmented Generation 40 | * [rag-with-amazon-kendra](https://github.com/ksmin23/rag-with-amazon-kendra) - Question Answering application with Large Language Models (LLMs) and Amazon Kendra 41 | * [rag-with-postgresql-pgvector](https://github.com/ksmin23/rag-with-postgresql-pgvector) - Question Answering application with Large Language Models (LLMs) and Amazon Aurora Postgresql 42 | 43 | ## Security 44 | 45 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 46 | 47 | ## License 48 | 49 | This library is licensed under the MIT-0 License. See the LICENSE file. 50 | 51 | -------------------------------------------------------------------------------- /app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/amd64 python:3.10.13-slim 2 | 3 | WORKDIR /app 4 | 5 | COPY requirements.txt . 6 | COPY *.py . 7 | COPY images/*.png images/ 8 | 9 | RUN pip --no-cache-dir install -Uq pip 10 | RUN pip --no-cache-dir install -Uq -r requirements.txt 11 | 12 | # Set some environment variables. PYTHONUNBUFFERED keeps Python from buffering our standard 13 | # output stream, which means that logs can be delivered to the user quickly. PYTHONDONTWRITEBYTECODE 14 | # keeps Python from writing the .pyc files which are unnecessary in this case. We also update 15 | # PATH so that the train and serve programs are found when the container is invoked. 16 | ENV PYTHONUNBUFFERED=TRUE 17 | ENV PYTHONDONTWRITEBYTECODE=TRUE 18 | 19 | EXPOSE 8501 20 | CMD streamlit run app.py --server.address=0.0.0.0 21 | # ENTRYPOINT ["streamlit", "run", "app.py", "--server.address=0.0.0.0"] 22 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | ## Run the Streamlit application in Studio 2 | 3 | Now we’re ready to run the Streamlit web application for our question answering bot. 4 | 5 | SageMaker Studio provides a convenient platform to host the Streamlit web application. The following steps describe how to run the Streamlit app on SageMaker Studio. Alternatively, you could also follow the same procedure to run the app on Amazon EC2 instance, Cloud9 in your AWS Account, or deploy as a container service to AWS ECS Fargate. 6 | 7 | 1. Open JupyterLab and then open a **Terminal**. 8 | 2. Run the following commands on the terminal to clone the code repository for this post and install the Python packages needed by the application: 9 | ``` 10 | git clone --depth=1 https://github.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker.git 11 | cd rag-with-amazon-opensearch-and-sagemaker/app 12 | python -m venv .env 13 | source .env/bin/activate 14 | pip install -r requirements.txt 15 | ``` 16 | 3. In the shell, set the following environment variables with the values that are available from the CloudFormation stack output. 17 | ``` 18 | export AWS_REGION=us-east-1 19 | export OPENSEARCH_SECRET="your-opensearch-secret" 20 | export OPENSEARCH_DOMAIN_ENDPOINT="your-opensearch-url" 21 | export OPENSEARCH_INDEX="llm_rag_embeddings" 22 | export EMBEDDING_ENDPOINT_NAME="your-sagemakr-endpoint-for-embedding-model" 23 | export TEXT2TEXT_ENDPOINT_NAME="your-sagemaner-endpoint-for-text-generation-model" 24 | ``` 25 | 4. When the application runs successfully, you’ll see an output similar to the following (the IP addresses you will see will be different from the ones shown in this example). Note the port number (typically `8501`) from the output to use as part of the URL for app in the next step. 26 | ``` 27 | sagemaker-user@studio$ streamlit run app.py 28 | 29 | Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False. 30 | 31 | You can now view your Streamlit app in your browser. 32 | 33 | Network URL: http://169.255.255.2:8501 34 | External URL: http://52.4.240.77:8501 35 | ``` 36 | 5. You can access the app in a new browser tab using a URL that is similar to your Studio domain URL. For example, if your Studio URL is `https://d-randomidentifier.studio.us-east-1.sagemaker.aws/jupyter/default/lab?` then the URL for your Streamlit app will be `https://d-randomidentifier.studio.us-east-1.sagemaker.aws/jupyter/default/proxy/8501/app` (notice that `lab` is replaced with `proxy/8501/app`). If the port number noted in the previous step is different from `8501` then use that instead of `8501` in the URL for the Streamlit app. 37 | 38 | The following screenshot shows the app with a couple of user questions. (e.g., `What are the versions of XGBoost supported by Amazon SageMaker?`) 39 | 40 | ![qa-with-llm-and-rag](./qa-with-llm-and-rag.png) 41 | 42 | 43 | ## Deploy Streamlit application on Amazon ECS Fargate with AWS CDK 44 | 45 | To deploy the Streamlit application on Amazon ECS Fargate using AWS CDK, follow these steps: 46 | 47 | 1. Ensure you have the AWS CDK and docker or finch installed and configured. 48 | 2. Deploy the ECS stack from `cdk_stacks/` using the command `cdk deploy --require-approval never StreamlitAppStack`. 49 | 3. Access the Streamlit application using the public URL provided by the provisioned load balancer. 50 | 1. You can find this value under the export named `{stack-name}-StreamlitEndpoint` 51 | 4. Consider adding a security group ingress rule that scopes access to the application from your network by a prefix list or a CIDR block 52 | 5. Also consider enabling HTTPS by uploading a ssl certificate from IAM or ACM to the loadbalancer and add a listener on port 443 53 | 54 | 55 | ## References 56 | 57 | * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain (2023-05-25)](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/) 58 | * [Build Streamlit apps in Amazon SageMaker Studio (2023-04-11)](https://aws.amazon.com/blogs/machine-learning/build-streamlit-apps-in-amazon-sagemaker-studio/) 59 | * [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models (2023-05-03)](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/) 60 | * [(github) Amazon Kendra Retriver Samples](https://github.com/aws-samples/amazon-kendra-langchain-extensions/tree/main/kendra_retriever_samples) 61 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/) 62 | * [sagemaker-huggingface-inference-toolkit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) - SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker. 63 | * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models. 64 | * [Streamlit](https://streamlit.io/) - A faster way to build and share data apps 65 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import streamlit as st 7 | import uuid 8 | 9 | import opensearch_chat_flan_xl as flanxl 10 | import opensearch_chat_llama2 as llama2 11 | 12 | PROVIDER_NAME = os.environ.get('PROVIDER_NAME', 'llama2') 13 | 14 | USER_ICON = f"{os.path.dirname(__file__)}/images/user-icon.png" 15 | AI_ICON = f"{os.path.dirname(__file__)}/images/ai-icon.png" 16 | MAX_HISTORY_LENGTH = 5 17 | PROVIDER_MAP = { 18 | 'flanxl': 'Flan XL', 19 | 'llama2': 'Llama2 7B', 20 | } 21 | 22 | # Check if the user ID is already stored in the session state 23 | if 'user_id' in st.session_state: 24 | user_id = st.session_state['user_id'] 25 | 26 | # If the user ID is not yet stored in the session state, generate a random UUID 27 | else: 28 | user_id = str(uuid.uuid4()) 29 | st.session_state['user_id'] = user_id 30 | 31 | 32 | if 'llm_chain' not in st.session_state: 33 | llm_app = llama2 if PROVIDER_NAME == 'llama2' else flanxl 34 | st.session_state['llm_app'] = llm_app 35 | st.session_state['llm_chain'] = llm_app.build_chain() 36 | 37 | if 'chat_history' not in st.session_state: 38 | st.session_state['chat_history'] = [] 39 | 40 | if "chats" not in st.session_state: 41 | st.session_state.chats = [ 42 | { 43 | 'id': 0, 44 | 'question': '', 45 | 'answer': '' 46 | } 47 | ] 48 | 49 | if "questions" not in st.session_state: 50 | st.session_state.questions = [] 51 | 52 | if "answers" not in st.session_state: 53 | st.session_state.answers = [] 54 | 55 | if "input" not in st.session_state: 56 | st.session_state.input = "" 57 | 58 | 59 | st.markdown(""" 60 | 75 | """, unsafe_allow_html=True) 76 | 77 | 78 | def write_logo(): 79 | col1, col2, col3 = st.columns([5, 1, 5]) 80 | with col2: 81 | st.image(AI_ICON, use_column_width='always') 82 | 83 | 84 | def write_top_bar(): 85 | col1, col2, col3 = st.columns([1,10,2]) 86 | with col1: 87 | st.image(AI_ICON, use_column_width='always') 88 | with col2: 89 | selected_provider = PROVIDER_NAME 90 | if selected_provider in PROVIDER_MAP: 91 | provider = PROVIDER_MAP[selected_provider] 92 | else: 93 | provider = selected_provider.capitalize() 94 | header = f"An AI App powered by Amazon OpenSearch and {provider}!" 95 | st.write(f"

{header}

", unsafe_allow_html=True) 96 | with col3: 97 | clear = st.button("Clear Chat") 98 | return clear 99 | 100 | 101 | clear = write_top_bar() 102 | 103 | if clear: 104 | st.session_state.questions = [] 105 | st.session_state.answers = [] 106 | st.session_state.input = "" 107 | st.session_state["chat_history"] = [] 108 | 109 | 110 | def handle_input(): 111 | input = st.session_state.input 112 | question_with_id = { 113 | 'question': input, 114 | 'id': len(st.session_state.questions) 115 | } 116 | st.session_state.questions.append(question_with_id) 117 | 118 | chat_history = st.session_state["chat_history"] 119 | if len(chat_history) == MAX_HISTORY_LENGTH: 120 | chat_history = chat_history[:-1] 121 | 122 | llm_chain = st.session_state['llm_chain'] 123 | chain = st.session_state['llm_app'] 124 | with st.spinner(): 125 | result = chain.run_chain(llm_chain, input, chat_history) 126 | answer = result['answer'] 127 | chat_history.append((input, answer)) 128 | 129 | document_list = [] 130 | if 'source_documents' in result: 131 | for d in result['source_documents']: 132 | if not (d.metadata['source'] in document_list): 133 | document_list.append((d.metadata['source'])) 134 | 135 | st.session_state.answers.append({ 136 | 'answer': result, 137 | 'sources': document_list, 138 | 'id': len(st.session_state.questions) 139 | }) 140 | st.session_state.input = "" 141 | 142 | 143 | def write_user_message(md): 144 | col1, col2 = st.columns([1,12]) 145 | 146 | with col1: 147 | st.image(USER_ICON, use_column_width='always') 148 | with col2: 149 | st.warning(md['question']) 150 | 151 | 152 | def render_result(result): 153 | answer, sources = st.tabs(['Answer', 'Sources']) 154 | with answer: 155 | render_answer(result['answer']) 156 | with sources: 157 | if 'source_documents' in result: 158 | render_sources(result['source_documents']) 159 | else: 160 | render_sources([]) 161 | 162 | 163 | def render_answer(answer): 164 | col1, col2 = st.columns([1,12]) 165 | with col1: 166 | st.image(AI_ICON, use_column_width='always') 167 | with col2: 168 | st.info(answer['answer']) 169 | 170 | 171 | def render_sources(sources): 172 | col1, col2 = st.columns([1,12]) 173 | with col2: 174 | with st.expander("Sources"): 175 | for s in sources: 176 | st.write(s) 177 | 178 | 179 | # Each answer will have context of the question asked in order to associate the provided feedback with the respective question 180 | def write_chat_message(md, q): 181 | chat = st.container() 182 | with chat: 183 | render_answer(md['answer']) 184 | render_sources(md['sources']) 185 | 186 | 187 | with st.container(): 188 | for (q, a) in zip(st.session_state.questions, st.session_state.answers): 189 | write_user_message(q) 190 | write_chat_message(a, q) 191 | 192 | st.markdown('---') 193 | input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) 194 | -------------------------------------------------------------------------------- /app/env_vars.sh: -------------------------------------------------------------------------------- 1 | export AWS_REGION="your-aws-region" 2 | export OPENSEARCH_SECRET="your-opensearch-secret" 3 | export OPENSEARCH_DOMAIN_ENDPOINT="your-opensearch-url" 4 | export OPENSEARCH_INDEX="llm_rag_embeddings" 5 | export EMBEDDING_ENDPOINT_NAME="your-sagemaker-endpoint-for-embedding-model" 6 | export TEXT2TEXT_ENDPOINT_NAME="your-sagemaker-endpoint-for-text-generation-model" -------------------------------------------------------------------------------- /app/images/ai-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker/5d43e78f0af9a69c0306b4f1181f3cd4971b78f9/app/images/ai-icon.png -------------------------------------------------------------------------------- /app/images/user-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker/5d43e78f0af9a69c0306b4f1181f3cd4971b78f9/app/images/user-icon.png -------------------------------------------------------------------------------- /app/opensearch_chat_flan_xl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import json 7 | import logging 8 | import sys 9 | from typing import List, Callable 10 | from urllib.parse import urlparse 11 | 12 | import boto3 13 | 14 | from langchain_community.vectorstores import OpenSearchVectorSearch 15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 17 | 18 | from langchain_community.llms import SagemakerEndpoint 19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler 20 | 21 | from langchain.prompts import PromptTemplate 22 | from langchain.chains import ConversationalRetrievalChain 23 | 24 | logger = logging.getLogger() 25 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 26 | 27 | 28 | class bcolors: 29 | HEADER = '\033[95m' 30 | OKBLUE = '\033[94m' 31 | OKCYAN = '\033[96m' 32 | OKGREEN = '\033[92m' 33 | WARNING = '\033[93m' 34 | FAIL = '\033[91m' 35 | ENDC = '\033[0m' 36 | BOLD = '\033[1m' 37 | UNDERLINE = '\033[4m' 38 | 39 | 40 | MAX_HISTORY_LENGTH = 5 41 | 42 | 43 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 44 | def embed_documents( 45 | self, texts: List[str], chunk_size: int = 5 46 | ) -> List[List[float]]: 47 | """Compute doc embeddings using a SageMaker Inference Endpoint. 48 | 49 | Args: 50 | texts: The list of texts to embed. 51 | chunk_size: The chunk size defines how many input texts will 52 | be grouped together as request. If None, will use the 53 | chunk size specified by the class. 54 | 55 | Returns: 56 | List of embeddings, one for each text. 57 | """ 58 | results = [] 59 | 60 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 61 | for i in range(0, len(texts), _chunk_size): 62 | response = self._embedding_func(texts[i : i + _chunk_size]) 63 | results.extend(response) 64 | return results 65 | 66 | 67 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart: 68 | 69 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler): 70 | """ 71 | encode input string as utf-8 bytes, read the embeddings 72 | from the output 73 | """ 74 | content_type = "application/json" 75 | accepts = "application/json" 76 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 77 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 78 | return input_str.encode('utf-8') 79 | 80 | def transform_output(self, output: bytes) -> str: 81 | response_json = json.loads(output.read().decode("utf-8")) 82 | embeddings = response_json["embedding"] 83 | if len(embeddings) == 1: 84 | return [embeddings[0]] 85 | return embeddings 86 | 87 | # create a content handler object which knows how to serialize 88 | # and deserialize communication with the model endpoint 89 | content_handler = ContentHandlerForEmbeddings() 90 | 91 | # read to create the Sagemaker embeddings, we are providing 92 | # the Sagemaker endpoint that will be used for generating the 93 | # embeddings to the class 94 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 95 | endpoint_name=endpoint_name, 96 | region_name=region, 97 | content_handler=content_handler 98 | ) 99 | logger.info(f"embeddings type={type(embeddings)}") 100 | 101 | return embeddings 102 | 103 | 104 | def _get_credentials(secret_id: str, region_name: str) -> str: 105 | client = boto3.client('secretsmanager', region_name=region_name) 106 | response = client.get_secret_value(SecretId=secret_id) 107 | secrets_value = json.loads(response['SecretString']) 108 | return secrets_value 109 | 110 | 111 | def build_chain(): 112 | region = os.environ["AWS_REGION"] 113 | opensearch_secret = os.environ["OPENSEARCH_SECRET"] 114 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"] 115 | opensearch_index = os.environ["OPENSEARCH_INDEX"] 116 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"] 117 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"] 118 | 119 | class ContentHandler(LLMContentHandler): 120 | content_type = "application/json" 121 | accepts = "application/json" 122 | 123 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: 124 | input_str = json.dumps({"inputs": prompt, **model_kwargs}) 125 | return input_str.encode('utf-8') 126 | 127 | def transform_output(self, output: bytes) -> str: 128 | response_json = json.loads(output.read().decode("utf-8")) 129 | return response_json[0]["generated_texts"] 130 | 131 | content_handler = ContentHandler() 132 | 133 | model_kwargs = { 134 | "max_length": 500, 135 | "num_return_sequences": 1, 136 | "top_k": 250, 137 | "top_p": 0.95, 138 | "do_sample": False, 139 | "temperature": 1 140 | } 141 | 142 | llm = SagemakerEndpoint( 143 | endpoint_name=text2text_model_endpoint, 144 | region_name=region, 145 | model_kwargs=model_kwargs, 146 | content_handler=content_handler 147 | ) 148 | 149 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint 150 | 151 | creds = _get_credentials(opensearch_secret, region) 152 | http_auth = (creds['username'], creds['password']) 153 | 154 | opensearch_vector_search = OpenSearchVectorSearch( 155 | opensearch_url=opensearch_url, 156 | index_name=opensearch_index, 157 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region), 158 | http_auth=http_auth 159 | ) 160 | 161 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3}) 162 | 163 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}""" 164 | 165 | PROMPT = PromptTemplate( 166 | template=prompt_template, input_variables=["context", "question"] 167 | ) 168 | 169 | condense_qa_template = """ 170 | Given the following conversation and a follow up question, rephrase the follow up question 171 | to be a standalone question. 172 | 173 | Chat History: 174 | {chat_history} 175 | Follow Up Input: {question} 176 | Standalone question:""" 177 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) 178 | 179 | qa = ConversationalRetrievalChain.from_llm( 180 | llm=llm, 181 | retriever=retriever, 182 | condense_question_prompt=standalone_question_prompt, 183 | return_source_documents=True, 184 | combine_docs_chain_kwargs={"prompt":PROMPT} 185 | ) 186 | 187 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n") 188 | return qa 189 | 190 | 191 | def run_chain(chain, prompt: str, history=[]): 192 | return chain.invoke({"question": prompt, "chat_history": history}) 193 | 194 | 195 | if __name__ == "__main__": 196 | chat_history = [] 197 | qa = build_chain() 198 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) 199 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 200 | print(">", end=" ", flush=True) 201 | for query in sys.stdin: 202 | if (query.strip().lower().startswith("new search:")): 203 | query = query.strip().lower().replace("new search:","") 204 | chat_history = [] 205 | elif (len(chat_history) == MAX_HISTORY_LENGTH): 206 | chat_history.pop(0) 207 | result = run_chain(qa, query, chat_history) 208 | chat_history.append((query, result["answer"])) 209 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) 210 | if 'source_documents' in result: 211 | print(bcolors.OKGREEN + '\nSources:') 212 | for d in result['source_documents']: 213 | print(d.metadata['source']) 214 | print(bcolors.ENDC) 215 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 216 | print(">", end=" ", flush=True) 217 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) -------------------------------------------------------------------------------- /app/opensearch_chat_llama2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import json 7 | import logging 8 | import sys 9 | from typing import List 10 | from urllib.parse import urlparse 11 | 12 | import boto3 13 | 14 | from langchain_community.vectorstores import OpenSearchVectorSearch 15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 17 | 18 | from langchain_community.llms import SagemakerEndpoint 19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler 20 | 21 | from langchain.prompts import PromptTemplate 22 | from langchain.chains import ConversationalRetrievalChain 23 | 24 | logger = logging.getLogger() 25 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 26 | 27 | 28 | class bcolors: 29 | HEADER = '\033[95m' 30 | OKBLUE = '\033[94m' 31 | OKCYAN = '\033[96m' 32 | OKGREEN = '\033[92m' 33 | WARNING = '\033[93m' 34 | FAIL = '\033[91m' 35 | ENDC = '\033[0m' 36 | BOLD = '\033[1m' 37 | UNDERLINE = '\033[4m' 38 | 39 | 40 | MAX_HISTORY_LENGTH = 5 41 | 42 | 43 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 44 | def embed_documents( 45 | self, texts: List[str], chunk_size: int = 5 46 | ) -> List[List[float]]: 47 | """Compute doc embeddings using a SageMaker Inference Endpoint. 48 | 49 | Args: 50 | texts: The list of texts to embed. 51 | chunk_size: The chunk size defines how many input texts will 52 | be grouped together as request. If None, will use the 53 | chunk size specified by the class. 54 | 55 | Returns: 56 | List of embeddings, one for each text. 57 | """ 58 | results = [] 59 | 60 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 61 | for i in range(0, len(texts), _chunk_size): 62 | response = self._embedding_func(texts[i : i + _chunk_size]) 63 | results.extend(response) 64 | return results 65 | 66 | 67 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart: 68 | 69 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler): 70 | """ 71 | encode input string as utf-8 bytes, read the embeddings 72 | from the output 73 | """ 74 | content_type = "application/json" 75 | accepts = "application/json" 76 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 77 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 78 | return input_str.encode('utf-8') 79 | 80 | def transform_output(self, output: bytes) -> str: 81 | response_json = json.loads(output.read().decode("utf-8")) 82 | embeddings = response_json["embedding"] 83 | if len(embeddings) == 1: 84 | return [embeddings[0]] 85 | return embeddings 86 | 87 | # create a content handler object which knows how to serialize 88 | # and deserialize communication with the model endpoint 89 | content_handler = ContentHandlerForEmbeddings() 90 | 91 | # read to create the Sagemaker embeddings, we are providing 92 | # the Sagemaker endpoint that will be used for generating the 93 | # embeddings to the class 94 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 95 | endpoint_name=endpoint_name, 96 | region_name=region, 97 | content_handler=content_handler 98 | ) 99 | logger.info(f"embeddings type={type(embeddings)}") 100 | 101 | return embeddings 102 | 103 | 104 | def _get_credentials(secret_id: str, region_name: str) -> str: 105 | client = boto3.client('secretsmanager', region_name=region_name) 106 | response = client.get_secret_value(SecretId=secret_id) 107 | secrets_value = json.loads(response['SecretString']) 108 | return secrets_value 109 | 110 | 111 | def build_chain(): 112 | region = os.environ["AWS_REGION"] 113 | opensearch_secret = os.environ["OPENSEARCH_SECRET"] 114 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"] 115 | opensearch_index = os.environ["OPENSEARCH_INDEX"] 116 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"] 117 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"] 118 | 119 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-chat-completion.ipynb 120 | class ContentHandler(LLMContentHandler): 121 | content_type = "application/json" 122 | accepts = "application/json" 123 | 124 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: 125 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \ 126 | " If you don't know the answer to a question, say I don't know the answer" 127 | 128 | payload = { 129 | "inputs": [ 130 | [ 131 | {"role": "system", "content": system_prompt}, 132 | {"role": "user", "content": prompt}, 133 | ], 134 | ], 135 | "parameters": model_kwargs, 136 | } 137 | input_str = json.dumps(payload) 138 | return input_str.encode("utf-8") 139 | 140 | def transform_output(self, output: bytes) -> str: 141 | response_json = json.loads(output.read().decode("utf-8")) 142 | content = response_json[0]["generation"]["content"] 143 | return content 144 | 145 | content_handler = ContentHandler() 146 | 147 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-text-completion.ipynb 148 | model_kwargs = { 149 | "max_new_tokens": 256, 150 | "top_p": 0.9, 151 | "temperature": 0.6, 152 | "return_full_text": False, 153 | } 154 | 155 | llm = SagemakerEndpoint( 156 | endpoint_name=text2text_model_endpoint, 157 | region_name=region, 158 | model_kwargs=model_kwargs, 159 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"}, 160 | content_handler=content_handler 161 | ) 162 | 163 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint 164 | 165 | creds = _get_credentials(opensearch_secret, region) 166 | http_auth = (creds['username'], creds['password']) 167 | 168 | opensearch_vector_search = OpenSearchVectorSearch( 169 | opensearch_url=opensearch_url, 170 | index_name=opensearch_index, 171 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region), 172 | http_auth=http_auth 173 | ) 174 | 175 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3}) 176 | 177 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}""" 178 | 179 | PROMPT = PromptTemplate( 180 | template=prompt_template, input_variables=["context", "question"] 181 | ) 182 | 183 | condense_qa_template = """ 184 | Given the following conversation and a follow up question, rephrase the follow up question 185 | to be a standalone question. 186 | 187 | Chat History: 188 | {chat_history} 189 | Follow Up Input: {question} 190 | Standalone question:""" 191 | standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) 192 | 193 | qa = ConversationalRetrievalChain.from_llm( 194 | llm=llm, 195 | retriever=retriever, 196 | condense_question_prompt=standalone_question_prompt, 197 | return_source_documents=True, 198 | combine_docs_chain_kwargs={"prompt":PROMPT}, 199 | verbose=False 200 | ) 201 | 202 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n") 203 | return qa 204 | 205 | 206 | def run_chain(chain, prompt: str, history=[]): 207 | return chain.invoke({"question": prompt, "chat_history": history}) 208 | 209 | 210 | if __name__ == "__main__": 211 | chat_history = [] 212 | qa = build_chain() 213 | print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) 214 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 215 | print(">", end=" ", flush=True) 216 | for query in sys.stdin: 217 | if (query.strip().lower().startswith("new search:")): 218 | query = query.strip().lower().replace("new search:","") 219 | chat_history = [] 220 | elif (len(chat_history) == MAX_HISTORY_LENGTH): 221 | chat_history.pop(0) 222 | result = run_chain(qa, query, chat_history) 223 | chat_history.append((query, result["answer"])) 224 | print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) 225 | if 'source_documents' in result: 226 | print(bcolors.OKGREEN + '\nSources:') 227 | for d in result['source_documents']: 228 | print(d.metadata['source']) 229 | print(bcolors.ENDC) 230 | print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) 231 | print(">", end=" ", flush=True) 232 | print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) -------------------------------------------------------------------------------- /app/opensearch_load_qa_chain_flan_xl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import json 7 | import logging 8 | import sys 9 | from typing import List, Callable 10 | from urllib.parse import urlparse 11 | 12 | import boto3 13 | 14 | from langchain_community.vectorstores import OpenSearchVectorSearch 15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 17 | 18 | from langchain_community.llms import SagemakerEndpoint 19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler 20 | 21 | from langchain.prompts import PromptTemplate 22 | from langchain.chains.question_answering import load_qa_chain 23 | 24 | 25 | logger = logging.getLogger() 26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 27 | 28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 29 | def embed_documents( 30 | self, texts: List[str], chunk_size: int = 5 31 | ) -> List[List[float]]: 32 | """Compute doc embeddings using a SageMaker Inference Endpoint. 33 | 34 | Args: 35 | texts: The list of texts to embed. 36 | chunk_size: The chunk size defines how many input texts will 37 | be grouped together as request. If None, will use the 38 | chunk size specified by the class. 39 | 40 | Returns: 41 | List of embeddings, one for each text. 42 | """ 43 | results = [] 44 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 45 | 46 | for i in range(0, len(texts), _chunk_size): 47 | response = self._embedding_func(texts[i : i + _chunk_size]) 48 | results.extend(response) 49 | return results 50 | 51 | 52 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler): 53 | """ 54 | encode input string as utf-8 bytes, read the embeddings 55 | from the output 56 | """ 57 | content_type = "application/json" 58 | accepts = "application/json" 59 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 60 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 61 | return input_str.encode('utf-8') 62 | 63 | def transform_output(self, output: bytes) -> str: 64 | response_json = json.loads(output.read().decode("utf-8")) 65 | embeddings = response_json["embedding"] 66 | if len(embeddings) == 1: 67 | return [embeddings[0]] 68 | return embeddings 69 | 70 | 71 | class ContentHandlerForTextGeneration(LLMContentHandler): 72 | content_type = "application/json" 73 | accepts = "application/json" 74 | 75 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 76 | input_str = json.dumps({"inputs": prompt, **model_kwargs}) 77 | return input_str.encode('utf-8') 78 | 79 | def transform_output(self, output: bytes) -> str: 80 | response_json = json.loads(output.read().decode("utf-8")) 81 | return response_json[0]["generated_texts"] 82 | 83 | 84 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart: 85 | # create a content handler object which knows how to serialize 86 | # and deserialize communication with the model endpoint 87 | content_handler = ContentHandlerForEmbeddings() 88 | 89 | # read to create the Sagemaker embeddings, we are providing 90 | # the Sagemaker endpoint that will be used for generating the 91 | # embeddings to the class 92 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 93 | endpoint_name=endpoint_name, 94 | region_name=region, 95 | content_handler=content_handler 96 | ) 97 | logger.info(f"embeddings type={type(embeddings)}") 98 | 99 | return embeddings 100 | 101 | 102 | def _get_credentials(secret_id: str, region_name: str) -> str: 103 | 104 | client = boto3.client('secretsmanager', region_name=region_name) 105 | response = client.get_secret_value(SecretId=secret_id) 106 | secrets_value = json.loads(response['SecretString']) 107 | return secrets_value 108 | 109 | 110 | def load_vector_db_opensearch(secret_id: str, 111 | region: str, 112 | opensearch_domain_endpoint: str, 113 | opensearch_index: str, 114 | embeddings_model_endpoint: str) -> OpenSearchVectorSearch: 115 | logger.info(f"load_vector_db_opensearch, secret_id={secret_id}, region={region}, " 116 | f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}, " 117 | f"embeddings_model_endpoint={embeddings_model_endpoint}") 118 | 119 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint 120 | logger.info(f"embeddings_model_endpoint={embeddings_model_endpoint}, opensearch_url={opensearch_url}") 121 | 122 | creds = _get_credentials(secret_id, region) 123 | http_auth = (creds['username'], creds['password']) 124 | vector_db = OpenSearchVectorSearch(index_name=opensearch_index, 125 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region), 126 | opensearch_url=opensearch_url, 127 | http_auth=http_auth) 128 | 129 | logger.info(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}") 130 | return vector_db 131 | 132 | 133 | def setup_sagemaker_endpoint_for_text_generation(endpoint_name, region: str = "us-east-1") -> Callable: 134 | parameters = { 135 | "max_length": 500, 136 | "num_return_sequences": 1, 137 | "top_k": 250, 138 | "top_p": 0.95, 139 | "do_sample": False, 140 | "temperature": 1 141 | } 142 | 143 | content_handler = ContentHandlerForTextGeneration() 144 | sm_llm = SagemakerEndpoint( 145 | endpoint_name=endpoint_name, 146 | region_name=region, 147 | model_kwargs=parameters, 148 | content_handler=content_handler) 149 | return sm_llm 150 | 151 | 152 | def main(): 153 | region = os.environ["AWS_REGION"] 154 | opensearch_secret = os.environ["OPENSEARCH_SECRET"] 155 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"] 156 | opensearch_index = os.environ["OPENSEARCH_INDEX"] 157 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"] 158 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"] 159 | 160 | # initialize vector db and Sagemaker Endpoint 161 | os_creds_secretid_in_secrets_manager = opensearch_secret 162 | _vector_db = load_vector_db_opensearch(os_creds_secretid_in_secrets_manager, 163 | region, 164 | opensearch_domain_endpoint, 165 | opensearch_index, 166 | embeddings_model_endpoint) 167 | 168 | _sm_llm = setup_sagemaker_endpoint_for_text_generation(text2text_model_endpoint, region) 169 | 170 | # Use the vector db to find similar documents to the query 171 | # the vector db call would automatically convert the query text 172 | # into embeddings 173 | query = 'What is SageMaker model monitor? Write your answer in a nicely formatted way.' 174 | max_matching_docs = 3 175 | docs = _vector_db.similarity_search(query, k=max_matching_docs) 176 | logger.info(f"here are the {max_matching_docs} closest matching docs to the query=\"{query}\"") 177 | for d in docs: 178 | logger.info(f"---------") 179 | logger.info(d) 180 | logger.info(f"---------") 181 | 182 | # now that we have the matching docs, lets pack them as a context 183 | # into the prompt and ask the LLM to generate a response 184 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}""" 185 | 186 | prompt = PromptTemplate( 187 | template=prompt_template, input_variables=["context", "question"] 188 | ) 189 | logger.info(f"prompt sent to llm = \"{prompt}\"") 190 | 191 | chain = load_qa_chain(llm=_sm_llm, prompt=prompt, verbose=True) 192 | logger.info(f"\ntype('chain'): \"{type(chain)}\"\n") 193 | 194 | answer = chain.invoke({"input_documents": docs, "question": query}, return_only_outputs=True)['output_text'] 195 | 196 | logger.info(f"answer received from llm,\nquestion: \"{query}\"\nanswer: \"{answer}\"") 197 | 198 | resp = {'question': query, 'answer': answer} 199 | resp['docs'] = docs 200 | print(resp) 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | 206 | -------------------------------------------------------------------------------- /app/opensearch_load_qa_chain_llama2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import json 7 | import logging 8 | import sys 9 | from typing import List, Callable 10 | from urllib.parse import urlparse 11 | 12 | import boto3 13 | 14 | from langchain_community.vectorstores import OpenSearchVectorSearch 15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 17 | 18 | from langchain_community.llms import SagemakerEndpoint 19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler 20 | 21 | from langchain.prompts import PromptTemplate 22 | from langchain.chains.question_answering import load_qa_chain 23 | 24 | 25 | logger = logging.getLogger() 26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 27 | 28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 29 | def embed_documents( 30 | self, texts: List[str], chunk_size: int = 5 31 | ) -> List[List[float]]: 32 | """Compute doc embeddings using a SageMaker Inference Endpoint. 33 | 34 | Args: 35 | texts: The list of texts to embed. 36 | chunk_size: The chunk size defines how many input texts will 37 | be grouped together as request. If None, will use the 38 | chunk size specified by the class. 39 | 40 | Returns: 41 | List of embeddings, one for each text. 42 | """ 43 | results = [] 44 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 45 | for i in range(0, len(texts), _chunk_size): 46 | response = self._embedding_func(texts[i : i + _chunk_size]) 47 | #print(response) 48 | results.extend(response) 49 | return results 50 | 51 | 52 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler): 53 | """ 54 | encode input string as utf-8 bytes, read the embeddings 55 | from the output 56 | """ 57 | content_type = "application/json" 58 | accepts = "application/json" 59 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 60 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 61 | return input_str.encode('utf-8') 62 | 63 | def transform_output(self, output: bytes) -> str: 64 | response_json = json.loads(output.read().decode("utf-8")) 65 | embeddings = response_json["embedding"] 66 | if len(embeddings) == 1: 67 | return [embeddings[0]] 68 | return embeddings 69 | 70 | 71 | class ContentHandlerForTextGeneration(LLMContentHandler): 72 | content_type = "application/json" 73 | accepts = "application/json" 74 | 75 | # https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-chat-completion.ipynb 76 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: 77 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \ 78 | " If you don't know the answer to a question, say I don't know the answer" 79 | 80 | payload = { 81 | "inputs": [ 82 | [ 83 | {"role": "system", "content": system_prompt}, 84 | {"role": "user", "content": prompt}, 85 | ], 86 | ], 87 | "parameters": model_kwargs, 88 | } 89 | input_str = json.dumps(payload) 90 | return input_str.encode("utf-8") 91 | 92 | def transform_output(self, output: bytes) -> str: 93 | response_json = json.loads(output.read().decode("utf-8")) 94 | content = response_json[0]["generation"]["content"] 95 | return content 96 | 97 | 98 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart: 99 | # create a content handler object which knows how to serialize 100 | # and deserialize communication with the model endpoint 101 | content_handler = ContentHandlerForEmbeddings() 102 | 103 | # read to create the Sagemaker embeddings, we are providing 104 | # the Sagemaker endpoint that will be used for generating the 105 | # embeddings to the class 106 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 107 | endpoint_name=endpoint_name, 108 | region_name=region, 109 | content_handler=content_handler 110 | ) 111 | logger.info(f"embeddings type={type(embeddings)}") 112 | 113 | return embeddings 114 | 115 | 116 | def _get_credentials(secret_id: str, region_name: str) -> str: 117 | 118 | client = boto3.client('secretsmanager', region_name=region_name) 119 | response = client.get_secret_value(SecretId=secret_id) 120 | secrets_value = json.loads(response['SecretString']) 121 | return secrets_value 122 | 123 | 124 | def load_vector_db_opensearch(secret_id: str, 125 | region: str, 126 | opensearch_domain_endpoint: str, 127 | opensearch_index: str, 128 | embeddings_model_endpoint: str) -> OpenSearchVectorSearch: 129 | logger.info(f"load_vector_db_opensearch, secret_id={secret_id}, region={region}, " 130 | f"opensearch_domain_endpoint={opensearch_domain_endpoint}, opensearch_index={opensearch_index}, " 131 | f"embeddings_model_endpoint={embeddings_model_endpoint}") 132 | 133 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint 134 | logger.info(f"embeddings_model_endpoint={embeddings_model_endpoint}, opensearch_url={opensearch_url}") 135 | 136 | creds = _get_credentials(secret_id, region) 137 | http_auth = (creds['username'], creds['password']) 138 | vector_db = OpenSearchVectorSearch(index_name=opensearch_index, 139 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region), 140 | opensearch_url=opensearch_url, 141 | http_auth=http_auth) 142 | 143 | logger.info(f"returning handle to OpenSearchVectorSearch, vector_db={vector_db}") 144 | return vector_db 145 | 146 | 147 | def setup_sagemaker_endpoint_for_text_generation(endpoint_name, region: str = "us-east-1") -> Callable: 148 | parameters = { 149 | "max_new_tokens": 256, 150 | "top_p": 0.9, 151 | "temperature": 0.6, 152 | "return_full_text": False, 153 | } 154 | 155 | content_handler = ContentHandlerForTextGeneration() 156 | sm_llm = SagemakerEndpoint( 157 | endpoint_name=endpoint_name, 158 | region_name=region, 159 | model_kwargs=parameters, 160 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"}, 161 | content_handler=content_handler) 162 | return sm_llm 163 | 164 | 165 | def main(): 166 | region = os.environ["AWS_REGION"] 167 | opensearch_secret = os.environ["OPENSEARCH_SECRET"] 168 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"] 169 | opensearch_index = os.environ["OPENSEARCH_INDEX"] 170 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"] 171 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"] 172 | 173 | # initialize vector db and Sagemaker Endpoint 174 | os_creds_secretid_in_secrets_manager = opensearch_secret 175 | _vector_db = load_vector_db_opensearch(os_creds_secretid_in_secrets_manager, 176 | region, 177 | opensearch_domain_endpoint, 178 | opensearch_index, 179 | embeddings_model_endpoint) 180 | 181 | _sm_llm = setup_sagemaker_endpoint_for_text_generation(text2text_model_endpoint, region) 182 | 183 | # Use the vector db to find similar documents to the query 184 | # the vector db call would automatically convert the query text 185 | # into embeddings 186 | query = 'What is SageMaker model monitor? Write your answer in a nicely formatted way.' 187 | max_matching_docs = 3 188 | docs = _vector_db.similarity_search(query, k=max_matching_docs) 189 | logger.info(f"here are the {max_matching_docs} closest matching docs to the query=\"{query}\"") 190 | for d in docs: 191 | logger.info(f"---------") 192 | logger.info(d) 193 | logger.info(f"---------") 194 | 195 | # now that we have the matching docs, lets pack them as a context 196 | # into the prompt and ask the LLM to generate a response 197 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}""" 198 | 199 | prompt = PromptTemplate( 200 | template=prompt_template, input_variables=["context", "question"] 201 | ) 202 | logger.info(f"prompt sent to llm = \"{prompt}\"") 203 | 204 | chain = load_qa_chain(llm=_sm_llm, prompt=prompt, verbose=True) 205 | logger.info(f"\ntype('chain'): \"{type(chain)}\"\n") 206 | 207 | answer = chain.invoke({"input_documents": docs, "question": query}, return_only_outputs=True)['output_text'] 208 | 209 | logger.info(f"answer received from llm,\nquestion: \"{query}\"\nanswer: \"{answer}\"") 210 | 211 | resp = {'question': query, 'answer': answer} 212 | resp['docs'] = docs 213 | print(resp) 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | 219 | -------------------------------------------------------------------------------- /app/opensearch_retriever_flan_xl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import json 7 | import logging 8 | import sys 9 | from typing import List 10 | from urllib.parse import urlparse 11 | 12 | import boto3 13 | 14 | from langchain_community.vectorstores import OpenSearchVectorSearch 15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 17 | 18 | from langchain_community.llms import SagemakerEndpoint 19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler 20 | 21 | from langchain.prompts import PromptTemplate 22 | from langchain.chains import RetrievalQA 23 | 24 | 25 | logger = logging.getLogger() 26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 27 | 28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 29 | def embed_documents( 30 | self, texts: List[str], chunk_size: int = 5 31 | ) -> List[List[float]]: 32 | """Compute doc embeddings using a SageMaker Inference Endpoint. 33 | 34 | Args: 35 | texts: The list of texts to embed. 36 | chunk_size: The chunk size defines how many input texts will 37 | be grouped together as request. If None, will use the 38 | chunk size specified by the class. 39 | 40 | Returns: 41 | List of embeddings, one for each text. 42 | """ 43 | results = [] 44 | 45 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 46 | for i in range(0, len(texts), _chunk_size): 47 | response = self._embedding_func(texts[i : i + _chunk_size]) 48 | results.extend(response) 49 | return results 50 | 51 | 52 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart: 53 | 54 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler): 55 | """ 56 | encode input string as utf-8 bytes, read the embeddings 57 | from the output 58 | """ 59 | content_type = "application/json" 60 | accepts = "application/json" 61 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 62 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 63 | return input_str.encode('utf-8') 64 | 65 | def transform_output(self, output: bytes) -> str: 66 | response_json = json.loads(output.read().decode("utf-8")) 67 | embeddings = response_json["embedding"] 68 | if len(embeddings) == 1: 69 | return [embeddings[0]] 70 | return embeddings 71 | 72 | # create a content handler object which knows how to serialize 73 | # and deserialize communication with the model endpoint 74 | content_handler = ContentHandlerForEmbeddings() 75 | 76 | # read to create the Sagemaker embeddings, we are providing 77 | # the Sagemaker endpoint that will be used for generating the 78 | # embeddings to the class 79 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 80 | endpoint_name=endpoint_name, 81 | region_name=region, 82 | content_handler=content_handler 83 | ) 84 | logger.info(f"embeddings type={type(embeddings)}") 85 | 86 | return embeddings 87 | 88 | 89 | def _get_credentials(secret_id: str, region_name: str) -> str: 90 | client = boto3.client('secretsmanager', region_name=region_name) 91 | response = client.get_secret_value(SecretId=secret_id) 92 | secrets_value = json.loads(response['SecretString']) 93 | return secrets_value 94 | 95 | 96 | def build_chain(): 97 | region = os.environ["AWS_REGION"] # us-east-1 98 | opensearch_secret = os.environ["OPENSEARCH_SECRET"] 99 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"] 100 | opensearch_index = os.environ["OPENSEARCH_INDEX"] 101 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"] 102 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"] 103 | 104 | class ContentHandler(LLMContentHandler): 105 | content_type = "application/json" 106 | accepts = "application/json" 107 | 108 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: 109 | input_str = json.dumps({"inputs": prompt, **model_kwargs}) 110 | return input_str.encode('utf-8') 111 | 112 | def transform_output(self, output: bytes) -> str: 113 | response_json = json.loads(output.read().decode("utf-8")) 114 | return response_json[0]["generated_texts"] 115 | 116 | content_handler = ContentHandler() 117 | 118 | model_kwargs = { 119 | "max_length": 500, 120 | "num_return_sequences": 1, 121 | "top_k": 250, 122 | "top_p": 0.95, 123 | "do_sample": False, 124 | "temperature": 1 125 | } 126 | 127 | llm = SagemakerEndpoint( 128 | endpoint_name=text2text_model_endpoint, 129 | region_name=region, 130 | model_kwargs=model_kwargs, 131 | content_handler=content_handler 132 | ) 133 | 134 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint 135 | 136 | creds = _get_credentials(opensearch_secret, region) 137 | http_auth = (creds['username'], creds['password']) 138 | 139 | opensearch_vector_search = OpenSearchVectorSearch( 140 | opensearch_url=opensearch_url, 141 | index_name=opensearch_index, 142 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region), 143 | http_auth=http_auth 144 | ) 145 | 146 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3}) 147 | 148 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}""" 149 | 150 | PROMPT = PromptTemplate( 151 | template=prompt_template, input_variables=["context", "question"] 152 | ) 153 | 154 | chain_type_kwargs = {"prompt": PROMPT, "verbose": True} 155 | qa = RetrievalQA.from_chain_type( 156 | llm, 157 | chain_type="stuff", 158 | retriever=retriever, 159 | chain_type_kwargs=chain_type_kwargs, 160 | return_source_documents=True, 161 | verbose=True, #DEBUG 162 | ) 163 | 164 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n") 165 | return qa 166 | 167 | 168 | def run_chain(chain, prompt: str, history=[]): 169 | result = chain(prompt, include_run_info=True) 170 | # To make it compatible with chat samples 171 | return { 172 | "answer": result['result'], 173 | "source_documents": result['source_documents'] 174 | } 175 | 176 | 177 | if __name__ == "__main__": 178 | chain = build_chain() 179 | result = run_chain(chain, "What is SageMaker model monitor? Write your answer in a nicely formatted way.") 180 | print(result['answer']) 181 | if 'source_documents' in result: 182 | print('Sources:') 183 | for d in result['source_documents']: 184 | print(d.metadata['source']) -------------------------------------------------------------------------------- /app/opensearch_retriever_llama2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=4 shiftwidth=4 softtabstop=4 expandtab 4 | 5 | import os 6 | import json 7 | import logging 8 | import sys 9 | from typing import List 10 | from urllib.parse import urlparse 11 | 12 | import boto3 13 | 14 | from langchain_community.vectorstores import OpenSearchVectorSearch 15 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 16 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 17 | 18 | from langchain_community.llms import SagemakerEndpoint 19 | from langchain_community.llms.sagemaker_endpoint import LLMContentHandler 20 | 21 | from langchain.prompts import PromptTemplate 22 | from langchain.chains import RetrievalQA 23 | 24 | 25 | logger = logging.getLogger() 26 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 27 | 28 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 29 | def embed_documents( 30 | self, texts: List[str], chunk_size: int = 5 31 | ) -> List[List[float]]: 32 | """Compute doc embeddings using a SageMaker Inference Endpoint. 33 | 34 | Args: 35 | texts: The list of texts to embed. 36 | chunk_size: The chunk size defines how many input texts will 37 | be grouped together as request. If None, will use the 38 | chunk size specified by the class. 39 | 40 | Returns: 41 | List of embeddings, one for each text. 42 | """ 43 | results = [] 44 | 45 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 46 | for i in range(0, len(texts), _chunk_size): 47 | response = self._embedding_func(texts[i : i + _chunk_size]) 48 | results.extend(response) 49 | return results 50 | 51 | 52 | def _create_sagemaker_embeddings(endpoint_name: str, region: str = "us-east-1") -> SagemakerEndpointEmbeddingsJumpStart: 53 | 54 | class ContentHandlerForEmbeddings(EmbeddingsContentHandler): 55 | """ 56 | encode input string as utf-8 bytes, read the embeddings 57 | from the output 58 | """ 59 | content_type = "application/json" 60 | accepts = "application/json" 61 | def transform_input(self, prompt: str, model_kwargs = {}) -> bytes: 62 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 63 | return input_str.encode('utf-8') 64 | 65 | def transform_output(self, output: bytes) -> str: 66 | response_json = json.loads(output.read().decode("utf-8")) 67 | embeddings = response_json["embedding"] 68 | if len(embeddings) == 1: 69 | return [embeddings[0]] 70 | return embeddings 71 | 72 | # create a content handler object which knows how to serialize 73 | # and deserialize communication with the model endpoint 74 | content_handler = ContentHandlerForEmbeddings() 75 | 76 | # read to create the Sagemaker embeddings, we are providing 77 | # the Sagemaker endpoint that will be used for generating the 78 | # embeddings to the class 79 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 80 | endpoint_name=endpoint_name, 81 | region_name=region, 82 | content_handler=content_handler 83 | ) 84 | logger.info(f"embeddings type={type(embeddings)}") 85 | 86 | return embeddings 87 | 88 | 89 | def _get_credentials(secret_id: str, region_name: str) -> str: 90 | client = boto3.client('secretsmanager', region_name=region_name) 91 | response = client.get_secret_value(SecretId=secret_id) 92 | secrets_value = json.loads(response['SecretString']) 93 | return secrets_value 94 | 95 | 96 | def build_chain(): 97 | region = os.environ["AWS_REGION"] # us-east-1 98 | opensearch_secret = os.environ["OPENSEARCH_SECRET"] 99 | opensearch_domain_endpoint = os.environ["OPENSEARCH_DOMAIN_ENDPOINT"] 100 | opensearch_index = os.environ["OPENSEARCH_INDEX"] 101 | embeddings_model_endpoint = os.environ["EMBEDDING_ENDPOINT_NAME"] 102 | text2text_model_endpoint = os.environ["TEXT2TEXT_ENDPOINT_NAME"] 103 | 104 | class ContentHandler(LLMContentHandler): 105 | content_type = "application/json" 106 | accepts = "application/json" 107 | 108 | def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: 109 | system_prompt = "You are a helpful assistant. Always answer to questions as helpfully as possible." \ 110 | " If you don't know the answer to a question, say I don't know the answer" 111 | 112 | payload = { 113 | "inputs": [ 114 | [ 115 | {"role": "system", "content": system_prompt}, 116 | {"role": "user", "content": prompt}, 117 | ], 118 | ], 119 | "parameters": model_kwargs, 120 | } 121 | input_str = json.dumps(payload) 122 | return input_str.encode("utf-8") 123 | 124 | def transform_output(self, output: bytes) -> str: 125 | response_json = json.loads(output.read().decode("utf-8")) 126 | content = response_json[0]["generation"]["content"] 127 | return content 128 | 129 | content_handler = ContentHandler() 130 | 131 | model_kwargs = { 132 | "max_new_tokens": 256, 133 | "top_p": 0.9, 134 | "temperature": 0.6, 135 | "return_full_text": False, 136 | } 137 | 138 | llm = SagemakerEndpoint( 139 | endpoint_name=text2text_model_endpoint, 140 | region_name=region, 141 | model_kwargs=model_kwargs, 142 | endpoint_kwargs={"CustomAttributes": "accept_eula=true"}, 143 | content_handler=content_handler 144 | ) 145 | 146 | opensearch_url = f"https://{opensearch_domain_endpoint}" if not opensearch_domain_endpoint.startswith('https://') else opensearch_domain_endpoint 147 | 148 | creds = _get_credentials(opensearch_secret, region) 149 | http_auth = (creds['username'], creds['password']) 150 | 151 | opensearch_vector_search = OpenSearchVectorSearch( 152 | opensearch_url=opensearch_url, 153 | index_name=opensearch_index, 154 | embedding_function=_create_sagemaker_embeddings(embeddings_model_endpoint, region), 155 | http_auth=http_auth 156 | ) 157 | 158 | retriever = opensearch_vector_search.as_retriever(search_kwargs={"k": 3}) 159 | 160 | prompt_template = """Answer based on context:\n\n{context}\n\n{question}""" 161 | 162 | PROMPT = PromptTemplate( 163 | template=prompt_template, input_variables=["context", "question"] 164 | ) 165 | 166 | chain_type_kwargs = {"prompt": PROMPT, "verbose": True} 167 | qa = RetrievalQA.from_chain_type( 168 | llm, 169 | chain_type="stuff", 170 | retriever=retriever, 171 | chain_type_kwargs=chain_type_kwargs, 172 | return_source_documents=True, 173 | verbose=True, #DEBUG 174 | ) 175 | 176 | logger.info(f"\ntype('qa'): \"{type(qa)}\"\n") 177 | return qa 178 | 179 | 180 | def run_chain(chain, prompt: str, history=[]): 181 | result = chain(prompt, include_run_info=True) 182 | # To make it compatible with chat samples 183 | return { 184 | "answer": result['result'], 185 | "source_documents": result['source_documents'] 186 | } 187 | 188 | 189 | if __name__ == "__main__": 190 | chain = build_chain() 191 | result = run_chain(chain, "What is SageMaker model monitor? Write your answer in a nicely formatted way.") 192 | print(result['answer']) 193 | if 'source_documents' in result: 194 | print('Sources:') 195 | for d in result['source_documents']: 196 | print(d.metadata['source']) -------------------------------------------------------------------------------- /app/qa-with-llm-and-rag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/rag-with-amazon-opensearch-and-sagemaker/5d43e78f0af9a69c0306b4f1181f3cd4971b78f9/app/qa-with-llm-and-rag.png -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3>=1.26.159 2 | langchain>=0.3,<0.4 3 | langchain-aws>=0.2,<0.3 4 | langchain-community>=0.3,<0.4 5 | SQLAlchemy==2.0.28 6 | opensearch-py==2.2.0 7 | streamlit==1.37.0 8 | -------------------------------------------------------------------------------- /cdk_stacks/README.md: -------------------------------------------------------------------------------- 1 | 2 | # RAG Application CDK Python project! 3 | 4 | ![rag_with_opensearch_arch](./rag_with_opensearch_arch.svg) 5 | 6 | This is an QA application with LLMs and RAG project for CDK development with Python. 7 | 8 | The `cdk.json` file tells the CDK Toolkit how to execute your app. 9 | 10 | This project is set up like a standard Python project. The initialization 11 | process also creates a virtualenv within this project, stored under the `.venv` 12 | directory. To create the virtualenv it assumes that there is a `python3` 13 | (or `python` for Windows) executable in your path with access to the `venv` 14 | package. If for any reason the automatic creation of the virtualenv fails, 15 | you can create the virtualenv manually. 16 | 17 | To manually create a virtualenv on MacOS and Linux: 18 | 19 | ``` 20 | $ python3 -m venv .venv 21 | ``` 22 | 23 | After the init process completes and the virtualenv is created, you can use the following 24 | step to activate your virtualenv. 25 | 26 | ``` 27 | $ source .venv/bin/activate 28 | ``` 29 | 30 | If you are a Windows platform, you would activate the virtualenv like this: 31 | 32 | ``` 33 | % .venv\Scripts\activate.bat 34 | ``` 35 | 36 | Once the virtualenv is activated, you can install the required dependencies. 37 | 38 | ``` 39 | (.venv) $ pip install -r requirements.txt 40 | ``` 41 | 42 | To add additional dependencies, for example other CDK libraries, just add 43 | them to your `setup.py` file and rerun the `pip install -r requirements.txt` 44 | command. 45 | 46 | Before synthesizing the CloudFormation, you should set approperly the cdk context configuration file, `cdk.context.json`. 47 | 48 | For example: 49 | 50 |
 51 | {
 52 |   "opensearch_domain_name": "open-search-domain-name",
 53 |   "jumpstart_model_info": {
 54 |     "model_id": "meta-textgeneration-llama-2-7b-f",
 55 |     "version": "2.0.1"
 56 |   }
 57 | }
 58 | 
59 | 60 | :information_source: The `model_id`, and `version` provided by SageMaker JumpStart can be found in [**SageMaker Built-in Algorithms with pre-trained Model Table**](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html). 61 | 62 | > :warning: **Important**: Make sure you need to make sure `docker daemon` is running.
63 | > Otherwise you will encounter the following errors: 64 | 65 | ``` 66 | ERROR: Cannot connect to the Docker daemon at unix://$HOME/.docker/run/docker.sock. Is the docker daemon running? 67 | jsii.errors.JavaScriptError: 68 | Error: docker exited with status 1 69 | ``` 70 | 71 | Now this point you can now synthesize the CloudFormation template for this code. 72 | 73 | ``` 74 | (.venv) $ export CDK_DEFAULT_ACCOUNT=$(aws sts get-caller-identity --query Account --output text) 75 | (.venv) $ export CDK_DEFAULT_REGION=$(aws configure get region) 76 | (.venv) $ cdk synth --all 77 | ``` 78 | 79 | Now we will be able to deploy all the CDK stacks at once like this: 80 | 81 | ``` 82 | (.venv) $ cdk deploy --require-approval never --all 83 | ``` 84 | 85 | Or, we can provision each CDK stack one at a time like this: 86 | 87 | #### Step 1: List all CDK Stacks 88 | 89 | ``` 90 | (.venv) $ cdk list 91 | RAGVpcStack 92 | RAGOpenSearchStack 93 | RAGSageMakerStudioStack 94 | EmbeddingEndpointStack 95 | LLMEndpointStack 96 | StreamlitAppStack 97 | ``` 98 | 99 | #### Step 2: Create OpenSearch cluster 100 | 101 | ``` 102 | (.venv) $ cdk deploy --require-approval never RAGVpcStack RAGOpenSearchStack 103 | ``` 104 | 105 | #### Step 3: Create SageMaker Studio 106 | 107 | ``` 108 | (.venv) $ cdk deploy --require-approval never RAGSageMakerStudioStack 109 | ``` 110 | 111 | #### Step 4: Deploy LLM Embedding Endpoint 112 | 113 | ``` 114 | (.venv) $ cdk deploy --require-approval never EmbeddingEndpointStack 115 | ``` 116 | 117 | #### Step 5: Deploy Text Generation LLM Endpoint 118 | 119 | ``` 120 | (.venv) $ cdk deploy --require-approval never LLMEndpointStack 121 | ``` 122 | 123 | #### Step 6 (Optional): Deploy the Streamlit app on ECS Fargate 124 | 125 | :warning: Before deploy the following CDK stack, make sure Docker is runing on your machine. 126 | 127 | ``` 128 | (.venv) $ cdk deploy --require-approval never StreamlitAppStack 129 | ``` 130 | 131 | **Once all CDK stacks have been successfully created, proceed with the remaining steps of the [overall workflow](../README.md#overall-workflow).** 132 | 133 | 134 | ## Clean Up 135 | 136 | Delete the CloudFormation stacks by running the below command. 137 | 138 | ``` 139 | (.venv) $ cdk destroy --all 140 | ``` 141 | 142 | ## Useful commands 143 | 144 | * `cdk ls` list all stacks in the app 145 | * `cdk synth` emits the synthesized CloudFormation template 146 | * `cdk deploy` deploy this stack to your default AWS account/region 147 | * `cdk diff` compare deployed stack with current state 148 | * `cdk docs` open CDK documentation 149 | 150 | Enjoy! 151 | 152 | ## References 153 | 154 | * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain (2023-05-25)](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/) 155 | * [Use proprietary foundation models from Amazon SageMaker JumpStart in Amazon SageMaker Studio (2023-06-27)](https://aws.amazon.com/blogs/machine-learning/use-proprietary-foundation-models-from-amazon-sagemaker-jumpstart-in-amazon-sagemaker-studio/) 156 | * [SageMaker Built-in Algorithms with pre-trained Model Table](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html) 157 | * [AWS Deep Learning Containers Images](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/deep-learning-containers-images.html) 158 | * [OpenSearch Popular APIs](https://opensearch.org/docs/latest/opensearch/popular-api/) 159 | * [Using the Amazon SageMaker Studio Image Build CLI to build container images from your Studio notebooks (2020-09-14)](https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/) 160 | -------------------------------------------------------------------------------- /cdk_stacks/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | import os 6 | 7 | from aws_cdk import App, Environment, Stack 8 | 9 | from rag_with_aos import ( 10 | VpcStack, 11 | OpenSearchStack, 12 | SageMakerStudioStack, 13 | EmbeddingEndpointStack, 14 | LLMEndpointStack, 15 | StreamlitAppStack 16 | ) 17 | 18 | APP_ENV = Environment( 19 | account=os.environ["CDK_DEFAULT_ACCOUNT"], 20 | region=os.environ["CDK_DEFAULT_REGION"] 21 | ) 22 | 23 | app = App() 24 | 25 | vpc_stack = VpcStack(app, 'RAGVpcStack', 26 | env=APP_ENV) 27 | 28 | ops_stack = OpenSearchStack(app, 'RAGOpenSearchStack', 29 | vpc_stack.vpc, 30 | env=APP_ENV 31 | ) 32 | ops_stack.add_dependency(vpc_stack) 33 | 34 | sm_studio_stack = SageMakerStudioStack(app, 'RAGSageMakerStudioStack', 35 | vpc_stack.vpc, 36 | env=APP_ENV 37 | ) 38 | sm_studio_stack.add_dependency(ops_stack) 39 | 40 | sm_embedding_endpoint = EmbeddingEndpointStack(app, 'EmbeddingEndpointStack', 41 | env=APP_ENV 42 | ) 43 | sm_embedding_endpoint.add_dependency(sm_studio_stack) 44 | 45 | sm_llm_endpoint = LLMEndpointStack(app, 'LLMEndpointStack', 46 | env=APP_ENV 47 | ) 48 | sm_llm_endpoint.add_dependency(sm_embedding_endpoint) 49 | 50 | ecs_app = StreamlitAppStack(app, "StreamlitAppStack", 51 | vpc_stack.vpc, 52 | ops_stack.master_user_secret, 53 | ops_stack.opensearch_domain, 54 | sm_llm_endpoint.llm_endpoint, 55 | sm_embedding_endpoint.embedding_endpoint, 56 | env=APP_ENV 57 | ) 58 | ecs_app.add_dependency(sm_llm_endpoint) 59 | 60 | app.synth() 61 | -------------------------------------------------------------------------------- /cdk_stacks/cdk.context.json: -------------------------------------------------------------------------------- 1 | { 2 | "opensearch_domain_name": "rag-vectordb", 3 | "jumpstart_model_info": { 4 | "model_id": "meta-textgeneration-llama-2-7b-f", 5 | "version": "2.0.1" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /cdk_stacks/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "requirements*.txt", 11 | "source.bat", 12 | "**/__init__.py", 13 | "python/__pycache__", 14 | "tests" 15 | ] 16 | }, 17 | "context": { 18 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 19 | "@aws-cdk/core:checkSecretUsage": true, 20 | "@aws-cdk/core:target-partitions": [ 21 | "aws", 22 | "aws-cn" 23 | ], 24 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 25 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 26 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 27 | "@aws-cdk/aws-iam:minimizePolicies": true, 28 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 29 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 30 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 31 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 32 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 33 | "@aws-cdk/core:enablePartitionLiterals": true, 34 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 35 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 36 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 37 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 38 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 39 | "@aws-cdk/aws-route53-patters:useCertificate": true, 40 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 41 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 42 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 43 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 44 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 45 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 46 | "@aws-cdk/aws-redshift:columnId": true, 47 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 48 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 49 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 50 | "@aws-cdk/aws-kms:aliasNameRef": true, 51 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/__init__.py: -------------------------------------------------------------------------------- 1 | from .vpc import VpcStack 2 | from .ops import OpenSearchStack 3 | from .sm_studio import SageMakerStudioStack 4 | from .sm_custom_embedding_endpoint import SageMakerEmbeddingEndpointStack as EmbeddingEndpointStack 5 | from .sm_jumpstart_llm_endpoint import SageMakerJumpStartLLMEndpointStack as LLMEndpointStack 6 | from .ecs_streamlit_app import StreamlitAppStack -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/ecs_streamlit_app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | from aws_cdk import ( 6 | CfnOutput, 7 | Stack, 8 | aws_ec2 as ec2, 9 | aws_ecs as ecs, 10 | aws_iam as iam, 11 | aws_ecs_patterns as ecs_patterns, 12 | ) 13 | from constructs import Construct 14 | 15 | 16 | class StreamlitAppStack(Stack): 17 | def __init__( 18 | self, 19 | scope: Construct, 20 | id: str, 21 | vpc, 22 | opensearch_master_user_secret, 23 | opensearch_domain, 24 | llm_endpoint, 25 | embedding_endpoint, 26 | **kwargs) -> None: 27 | 28 | super().__init__(scope, id, **kwargs) 29 | 30 | container_port = self.node.try_get_context("streamlit_container_port") or 8501 31 | 32 | # Create an IAM service role for ECS task execution 33 | execution_role = iam.Role( 34 | self, "ECSExecutionRole", 35 | assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 36 | managed_policies=[ 37 | iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AmazonECSTaskExecutionRolePolicy") 38 | ] 39 | ) 40 | 41 | # Create an IAM role for the Fargate task 42 | task_role = iam.Role( 43 | self, "ECSTaskRole", 44 | assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com"), 45 | description="Role for ECS Tasks to access Secrets Manager and SageMaker" 46 | ) 47 | 48 | # Policy to access Secrets Manager secrets 49 | secrets_policy = iam.PolicyStatement( 50 | actions=["secretsmanager:GetSecretValue"], 51 | resources=[ 52 | opensearch_master_user_secret.secret_arn 53 | ], 54 | effect=iam.Effect.ALLOW 55 | ) 56 | 57 | # Policy to invoke SageMaker endpoints 58 | sagemaker_policy = iam.PolicyStatement( 59 | actions=["sagemaker:InvokeEndpoint"], 60 | resources=[ 61 | embedding_endpoint.endpoint_arn, 62 | llm_endpoint.endpoint_arn 63 | ], 64 | effect=iam.Effect.ALLOW 65 | ) 66 | 67 | # Attach policies to the role 68 | task_role.add_to_policy(secrets_policy) 69 | task_role.add_to_policy(sagemaker_policy) 70 | 71 | # Set up ECS cluster and networking 72 | cluster = ecs.Cluster( 73 | self, "Cluster", 74 | vpc=vpc 75 | ) 76 | 77 | security_group = ec2.SecurityGroup( 78 | self, "SecurityGroup", 79 | vpc=vpc, 80 | description="Allow traffic on container port", 81 | allow_all_outbound=True 82 | ) 83 | 84 | security_group.add_ingress_rule( 85 | ec2.Peer.any_ipv4(), 86 | ec2.Port.tcp(container_port), 87 | "Allow inbound traffic" 88 | ) 89 | 90 | # Set up Fargate task definition and service 91 | task_definition = ecs.FargateTaskDefinition( 92 | self, "TaskDef", 93 | memory_limit_mib=512, 94 | execution_role=execution_role, 95 | task_role=task_role, 96 | cpu=256 97 | ) 98 | 99 | container = task_definition.add_container( 100 | "WebContainer", 101 | image=ecs.ContainerImage.from_asset("../app/"), 102 | logging=ecs.LogDrivers.aws_logs(stream_prefix="streamlitapp"), 103 | environment={ 104 | "AWS_REGION": self.region, 105 | "OPENSEARCH_SECRET": opensearch_master_user_secret.secret_name, 106 | "OPENSEARCH_DOMAIN_ENDPOINT": f"https://{opensearch_domain.domain_endpoint}", 107 | "OPENSEARCH_INDEX": "llm_rag_embeddings", 108 | "EMBEDDING_ENDPOINT_NAME": embedding_endpoint.cfn_endpoint.endpoint_name, 109 | "TEXT2TEXT_ENDPOINT_NAME": llm_endpoint.cfn_endpoint.endpoint_name 110 | } 111 | ) 112 | container.add_port_mappings(ecs.PortMapping(container_port=container_port)) 113 | 114 | fargate_service = ecs_patterns.ApplicationLoadBalancedFargateService( 115 | self, "FargateService", 116 | cluster=cluster, 117 | task_definition=task_definition, 118 | public_load_balancer=True, 119 | assign_public_ip=True 120 | ) 121 | 122 | 123 | CfnOutput(self, 'StreamlitEndpoint', 124 | value=fargate_service.load_balancer.load_balancer_dns_name, 125 | export_name=f'{self.stack_name}-StreamlitEndpoint' 126 | ) 127 | -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | import json 6 | import random 7 | import re 8 | import string 9 | 10 | import aws_cdk as cdk 11 | 12 | from aws_cdk import ( 13 | Stack, 14 | aws_ec2, 15 | aws_opensearchservice, 16 | aws_secretsmanager 17 | ) 18 | from constructs import Construct 19 | 20 | random.seed(47) 21 | 22 | 23 | class OpenSearchStack(Stack): 24 | 25 | def __init__(self, scope: Construct, construct_id: str, vpc, **kwargs) -> None: 26 | super().__init__(scope, construct_id, **kwargs) 27 | 28 | #XXX: Amazon OpenSearch Service Domain naming restrictions 29 | # https://docs.aws.amazon.com/opensearch-service/latest/developerguide/createupdatedomains.html#createdomains 30 | OPENSEARCH_DEFAULT_DOMAIN_NAME = 'opensearch-{}'.format(''.join(random.sample((string.ascii_letters), k=5))) 31 | opensearch_domain_name = self.node.try_get_context('opensearch_domain_name') or OPENSEARCH_DEFAULT_DOMAIN_NAME 32 | assert re.fullmatch(r'([a-z][a-z0-9\-]+){3,28}?', opensearch_domain_name), 'Invalid domain name' 33 | 34 | self.master_user_secret = aws_secretsmanager.Secret(self, "OpenSearchMasterUserSecret", 35 | generate_secret_string=aws_secretsmanager.SecretStringGenerator( 36 | secret_string_template=json.dumps({"username": "admin"}), 37 | generate_string_key="password", 38 | # Master password must be at least 8 characters long and contain at least one uppercase letter, 39 | # one lowercase letter, one number, and one special character. 40 | password_length=8 41 | ) 42 | ) 43 | 44 | #XXX: aws cdk elastsearch example - https://github.com/aws/aws-cdk/issues/2873 45 | # You should camelCase the property names instead of PascalCase 46 | self.opensearch_domain = aws_opensearchservice.Domain(self, "OpenSearch", 47 | domain_name=opensearch_domain_name, 48 | #XXX: Supported versions of OpenSearch and Elasticsearch 49 | # https://docs.aws.amazon.com/opensearch-service/latest/developerguide/what-is.html#choosing-version 50 | version=aws_opensearchservice.EngineVersion.OPENSEARCH_2_11, 51 | #XXX: Amazon OpenSearch Service - Current generation instance types 52 | # https://docs.aws.amazon.com/opensearch-service/latest/developerguide/supported-instance-types.html#latest-gen 53 | # - The OR1 instance types require OpenSearch 2.11 or later. 54 | # - OR1 instances are only compatible with other Graviton instance types master nodes (C6g, M6g, R6g) 55 | capacity={ 56 | "master_nodes": 3, 57 | "master_node_instance_type": "m6g.large.search", 58 | "data_nodes": 3, 59 | "data_node_instance_type": "or1.large.search" 60 | }, 61 | ebs={ 62 | # Volume size must be between 20 and 1536 for or1.large.search instance type and version OpenSearch_2.11 63 | "volume_size": 20, 64 | "volume_type": aws_ec2.EbsDeviceVolumeType.GP3 65 | }, 66 | #XXX: az_count must be equal to vpc subnets count. 67 | zone_awareness={ 68 | "availability_zone_count": 3 69 | }, 70 | logging={ 71 | "slow_search_log_enabled": True, 72 | "app_log_enabled": True, 73 | "slow_index_log_enabled": True 74 | }, 75 | fine_grained_access_control=aws_opensearchservice.AdvancedSecurityOptions( 76 | master_user_name=self.master_user_secret.secret_value_from_json("username").unsafe_unwrap(), 77 | master_user_password=self.master_user_secret.secret_value_from_json("password") 78 | ), 79 | # Enforce HTTPS is required when fine-grained access control is enabled. 80 | enforce_https=True, 81 | # Node-to-node encryption is required when fine-grained access control is enabled 82 | node_to_node_encryption=True, 83 | # Encryption-at-rest is required when fine-grained access control is enabled. 84 | encryption_at_rest={ 85 | "enabled": True 86 | }, 87 | use_unsigned_basic_auth=True, 88 | removal_policy=cdk.RemovalPolicy.DESTROY # default: cdk.RemovalPolicy.RETAIN 89 | ) 90 | 91 | cdk.Tags.of(self.opensearch_domain).add('Name', opensearch_domain_name) 92 | 93 | cdk.CfnOutput(self, 'OpenSourceDomainArn', 94 | value=self.opensearch_domain.domain_arn, 95 | export_name=f'{self.stack_name}-OpenSourceDomainArn') 96 | cdk.CfnOutput(self, 'OpenSearchDomainEndpoint', 97 | value=f"https://{self.opensearch_domain.domain_endpoint}", 98 | export_name=f'{self.stack_name}-OpenSearchDomainEndpoint') 99 | cdk.CfnOutput(self, 'OpenSearchDashboardsURL', 100 | value=f"https://{self.opensearch_domain.domain_endpoint}/_dashboards/", 101 | export_name=f'{self.stack_name}-OpenSearchDashboardsURL') 102 | cdk.CfnOutput(self, 'OpenSearchSecret', 103 | value=self.master_user_secret.secret_name, 104 | export_name=f'{self.stack_name}-MasterUserSecretId') 105 | -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/sm_custom_embedding_endpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | import random 6 | import string 7 | 8 | import aws_cdk as cdk 9 | 10 | from aws_cdk import ( 11 | Stack 12 | ) 13 | from constructs import Construct 14 | 15 | from cdklabs.generative_ai_cdk_constructs import ( 16 | CustomSageMakerEndpoint, 17 | DeepLearningContainerImage, 18 | SageMakerInstanceType, 19 | ) 20 | 21 | random.seed(47) 22 | 23 | 24 | class SageMakerEmbeddingEndpointStack(Stack): 25 | 26 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: 27 | super().__init__(scope, construct_id, **kwargs) 28 | 29 | bucket_name = f'jumpstart-cache-prod-{cdk.Aws.REGION}' 30 | key_name = 'huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-textembedding-gpt-j-6b-fp16.tar.gz' 31 | 32 | RANDOM_GUID = ''.join(random.sample(string.digits, k=7)) 33 | endpoint_name = f"gpt-j-6b-fp16-endpoint-{RANDOM_GUID}" 34 | 35 | #XXX: https://github.com/awslabs/generative-ai-cdk-constructs/blob/main/src/patterns/gen-ai/aws-model-deployment-sagemaker/README_custom_sagemaker_endpoint.md 36 | self.embedding_endpoint = CustomSageMakerEndpoint(self, 'EmbeddingEndpoint', 37 | model_id='gpt-j-6b-fp16', 38 | instance_type=SageMakerInstanceType.ML_G5_2_XLARGE, 39 | container=DeepLearningContainerImage.from_deep_learning_container_image( 40 | 'pytorch-inference', 41 | '1.12.0-gpu-py38' 42 | ), 43 | model_data_url=f's3://{bucket_name}/{key_name}', 44 | endpoint_name=endpoint_name, 45 | instance_count=1, 46 | # volume_size_in_gb=100 47 | ) 48 | 49 | cdk.CfnOutput(self, 'EmbeddingEndpointName', 50 | value=self.embedding_endpoint.cfn_endpoint.endpoint_name, 51 | export_name=f'{self.stack_name}-EmbeddingEndpointName') 52 | cdk.CfnOutput(self, 'EmbeddingEndpointArn', 53 | value=self.embedding_endpoint.endpoint_arn, 54 | export_name=f'{self.stack_name}-EmbeddingEndpointArn') -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/sm_jumpstart_llm_endpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | import random 6 | import string 7 | 8 | import aws_cdk as cdk 9 | 10 | from aws_cdk import ( 11 | Stack 12 | ) 13 | from constructs import Construct 14 | 15 | from cdklabs.generative_ai_cdk_constructs import ( 16 | JumpStartSageMakerEndpoint, 17 | JumpStartModel, 18 | SageMakerInstanceType 19 | ) 20 | 21 | random.seed(47) 22 | 23 | 24 | def name_from_base(base, max_length=63): 25 | unique = ''.join(random.sample(string.digits, k=7)) 26 | max_length = 63 27 | trimmed_base = base[: max_length - len(unique) - 1] 28 | return "{}-{}".format(trimmed_base, unique) 29 | 30 | 31 | class SageMakerJumpStartLLMEndpointStack(Stack): 32 | 33 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: 34 | super().__init__(scope, construct_id, **kwargs) 35 | 36 | jumpstart_model = self.node.try_get_context('jumpstart_model_info') 37 | model_id, model_version = jumpstart_model.get('model_id', 'meta-textgeneration-llama-2-7b-f'), jumpstart_model.get('version', '2.0.1') 38 | model_name = f"{model_id.upper().replace('-', '_')}_{model_version.replace('.', '_')}" 39 | 40 | llm_endpoint_name = name_from_base(model_id.replace('/', '-').replace('.', '-')) 41 | 42 | #XXX: Available JumStart Model List 43 | # https://github.com/awslabs/generative-ai-cdk-constructs/blob/main/src/patterns/gen-ai/aws-model-deployment-sagemaker/jumpstart-model.ts 44 | self.llm_endpoint = JumpStartSageMakerEndpoint(self, 'LLMEndpoint', 45 | model=JumpStartModel.of(model_name), 46 | accept_eula=True, 47 | instance_type=SageMakerInstanceType.ML_G5_2_XLARGE, 48 | endpoint_name=llm_endpoint_name 49 | ) 50 | 51 | cdk.CfnOutput(self, 'LLMEndpointName', 52 | value=self.llm_endpoint.cfn_endpoint.endpoint_name, 53 | export_name=f'{self.stack_name}-LLMEndpointName') 54 | cdk.CfnOutput(self, 'LLMEndpointArn', 55 | value=self.llm_endpoint.endpoint_arn, 56 | export_name=f'{self.stack_name}-LLMEndpointArn') 57 | -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/sm_studio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | import aws_cdk as cdk 6 | 7 | from aws_cdk import ( 8 | Stack, 9 | aws_ec2, 10 | aws_iam, 11 | aws_sagemaker 12 | ) 13 | from constructs import Construct 14 | 15 | 16 | class SageMakerStudioStack(Stack): 17 | 18 | def __init__(self, scope: Construct, construct_id: str, vpc, **kwargs) -> None: 19 | super().__init__(scope, construct_id, **kwargs) 20 | 21 | sagemaker_execution_policy_doc = aws_iam.PolicyDocument() 22 | sagemaker_execution_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 23 | "effect": aws_iam.Effect.ALLOW, 24 | "resources": ["arn:aws:s3:::*"], 25 | "actions": [ 26 | "s3:GetObject", 27 | "s3:PutObject", 28 | "s3:DeleteObject", 29 | "s3:ListBucket" 30 | ] 31 | })) 32 | 33 | sagemaker_custom_access_policy_doc = aws_iam.PolicyDocument() 34 | sagemaker_custom_access_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 35 | "effect": aws_iam.Effect.ALLOW, 36 | "resources": [f"arn:aws:es:{cdk.Aws.REGION}:{cdk.Aws.ACCOUNT_ID}:domain/*"], 37 | "actions": ["es:ESHttp*"], 38 | "sid": "ReadFromOpenSearch" 39 | })) 40 | 41 | sagemaker_custom_access_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 42 | "effect": aws_iam.Effect.ALLOW, 43 | "resources": [f"arn:aws:secretsmanager:{cdk.Aws.REGION}:{cdk.Aws.ACCOUNT_ID}:secret:*"], 44 | "actions": ["secretsmanager:GetSecretValue"], 45 | "sid": "ReadSecretFromSecretsManager" 46 | })) 47 | 48 | sagemaker_docker_build_policy_doc = aws_iam.PolicyDocument() 49 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 50 | "effect": aws_iam.Effect.ALLOW, 51 | "resources": ["arn:aws:codebuild:*:*:project/sagemaker-studio*"], 52 | "actions": [ 53 | "codebuild:DeleteProject", 54 | "codebuild:CreateProject", 55 | "codebuild:BatchGetBuilds", 56 | "codebuild:StartBuild" 57 | ] 58 | })) 59 | 60 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 61 | "effect": aws_iam.Effect.ALLOW, 62 | "resources": ["arn:aws:logs:*:*:log-group:/aws/codebuild/sagemaker-studio*"], 63 | "actions": ["logs:CreateLogStream"], 64 | })) 65 | 66 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 67 | "effect": aws_iam.Effect.ALLOW, 68 | "resources": ["arn:aws:logs:*:*:log-group:/aws/codebuild/sagemaker-studio*:log-stream:*"], 69 | "actions": [ 70 | "logs:GetLogEvents", 71 | "logs:PutLogEvents" 72 | ] 73 | })) 74 | 75 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 76 | "effect": aws_iam.Effect.ALLOW, 77 | "resources": ["*"], 78 | "actions": ["logs:CreateLogGroup"] 79 | })) 80 | 81 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 82 | "effect": aws_iam.Effect.ALLOW, 83 | "resources": ["*"], 84 | "actions": [ 85 | "ecr:BatchGetImage", 86 | "ecr:BatchCheckLayerAvailability", 87 | "ecr:CompleteLayerUpload", 88 | "ecr:DescribeImages", 89 | "ecr:DescribeRepositories", 90 | "ecr:GetDownloadUrlForLayer", 91 | "ecr:InitiateLayerUpload", 92 | "ecr:ListImages", 93 | "ecr:PutImage", 94 | "ecr:UploadLayerPart", 95 | "ecr:CreateRepository", 96 | "ecr:GetAuthorizationToken", 97 | "ec2:DescribeAvailabilityZones" 98 | ], 99 | "sid": "ReadWriteFromECR" 100 | })) 101 | 102 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 103 | "effect": aws_iam.Effect.ALLOW, 104 | "resources": ["*"], 105 | "actions": ["ecr:GetAuthorizationToken"] 106 | })) 107 | 108 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 109 | "effect": aws_iam.Effect.ALLOW, 110 | "resources": ["arn:aws:s3:::sagemaker-*/*"], 111 | "actions": [ 112 | "s3:GetObject", 113 | "s3:DeleteObject", 114 | "s3:PutObject" 115 | ] 116 | })) 117 | 118 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 119 | "effect": aws_iam.Effect.ALLOW, 120 | "resources": ["arn:aws:s3:::sagemaker*"], 121 | "actions": ["s3:CreateBucket"], 122 | })) 123 | 124 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 125 | "effect": aws_iam.Effect.ALLOW, 126 | "resources": ["*"], 127 | "actions": [ 128 | "iam:GetRole", 129 | "iam:ListRoles" 130 | ] 131 | })) 132 | 133 | sagemaker_docker_build_policy_doc.add_statements(aws_iam.PolicyStatement(**{ 134 | "effect": aws_iam.Effect.ALLOW, 135 | "resources": ["arn:aws:iam::*:role/*"], 136 | "conditions": { 137 | "StringLikeIfExists": { 138 | "iam:PassedToService": [ 139 | "codebuild.amazonaws.com" 140 | ] 141 | } 142 | }, 143 | "actions": ["iam:PassRole"] 144 | })) 145 | 146 | sagemaker_execution_role = aws_iam.Role(self, 'SageMakerExecutionRole', 147 | role_name=f'AmazonSageMakerStudioExecutionRole-{self.stack_name.lower()}', 148 | assumed_by=aws_iam.ServicePrincipal('sagemaker.amazonaws.com'), 149 | path='/', 150 | inline_policies={ 151 | 'sagemaker-execution-policy': sagemaker_execution_policy_doc, 152 | 'sagemaker-custom-access-policy': sagemaker_custom_access_policy_doc, 153 | 'sagemaker-docker-build-policy': sagemaker_docker_build_policy_doc, 154 | }, 155 | managed_policies=[ 156 | aws_iam.ManagedPolicy.from_aws_managed_policy_name('AmazonSageMakerFullAccess'), 157 | aws_iam.ManagedPolicy.from_aws_managed_policy_name('AmazonSageMakerCanvasFullAccess'), 158 | aws_iam.ManagedPolicy.from_aws_managed_policy_name('AWSCloudFormationReadOnlyAccess'), 159 | ] 160 | ) 161 | 162 | #XXX: To use the sm-docker CLI, the Amazon SageMaker execution role used by the Studio notebook 163 | # environment should have a trust policy with CodeBuild 164 | sagemaker_execution_role.assume_role_policy.add_statements(aws_iam.PolicyStatement(**{ 165 | "effect": aws_iam.Effect.ALLOW, 166 | "principals": [aws_iam.ServicePrincipal('codebuild.amazonaws.com')], 167 | "actions": ["sts:AssumeRole"] 168 | })) 169 | 170 | sm_studio_user_settings = aws_sagemaker.CfnDomain.UserSettingsProperty( 171 | execution_role=sagemaker_execution_role.role_arn 172 | ) 173 | 174 | sagemaker_studio_domain = aws_sagemaker.CfnDomain(self, 'SageMakerStudioDomain', 175 | auth_mode='IAM', # [SSO | IAM] 176 | default_user_settings=sm_studio_user_settings, 177 | domain_name='llm-app-rag-workshop', 178 | subnet_ids=vpc.select_subnets(subnet_type=aws_ec2.SubnetType.PUBLIC).subnet_ids, 179 | vpc_id=vpc.vpc_id, 180 | app_network_access_type='PublicInternetOnly' # [PublicInternetOnly | VpcOnly] 181 | ) 182 | 183 | #XXX: https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jl.html#studio-jl-set 184 | sagmaker_jupyerlab_arn = self.node.try_get_context('sagmaker_jupyterlab_arn') 185 | 186 | default_user_settings = aws_sagemaker.CfnUserProfile.UserSettingsProperty( 187 | jupyter_server_app_settings=aws_sagemaker.CfnUserProfile.JupyterServerAppSettingsProperty( 188 | default_resource_spec=aws_sagemaker.CfnUserProfile.ResourceSpecProperty( 189 | #XXX: JupyterServer apps only support the system value. 190 | instance_type="system", 191 | sage_maker_image_arn=sagmaker_jupyerlab_arn 192 | ) 193 | ) 194 | ) 195 | 196 | sagemaker_user_profile = aws_sagemaker.CfnUserProfile(self, 'SageMakerStudioUserProfile', 197 | domain_id=sagemaker_studio_domain.attr_domain_id, 198 | user_profile_name='default-user', 199 | user_settings=default_user_settings 200 | ) 201 | 202 | self.sm_execution_role_arn = sagemaker_execution_role.role_arn 203 | 204 | cdk.CfnOutput(self, 'DomainUrl', 205 | value=sagemaker_studio_domain.attr_url, 206 | export_name=f'{self.stack_name}-DomainUrl') 207 | cdk.CfnOutput(self, 'DomainId', 208 | value=sagemaker_user_profile.domain_id, 209 | export_name=f'{self.stack_name}-DomainId') 210 | cdk.CfnOutput(self, 'UserProfileName', 211 | value=sagemaker_user_profile.user_profile_name, 212 | export_name=f'{self.stack_name}-UserProfileName') 213 | -------------------------------------------------------------------------------- /cdk_stacks/rag_with_aos/vpc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # vim: tabstop=2 shiftwidth=2 softtabstop=2 expandtab 4 | 5 | import os 6 | import aws_cdk as cdk 7 | 8 | from aws_cdk import ( 9 | Stack, 10 | aws_ec2, 11 | ) 12 | from constructs import Construct 13 | 14 | 15 | class VpcStack(Stack): 16 | 17 | def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: 18 | super().__init__(scope, construct_id, **kwargs) 19 | 20 | #XXX: For creating this CDK Stack in the existing VPC, 21 | # remove comments from the below codes and 22 | # comments out vpc = aws_ec2.Vpc(..) codes, 23 | # then pass -c vpc_name=your-existing-vpc to cdk command 24 | # for example, 25 | # cdk -c vpc_name=your-existing-vpc syth 26 | # 27 | if str(os.environ.get('USE_DEFAULT_VPC', 'false')).lower() == 'true': 28 | vpc_name = self.node.try_get_context('vpc_name') or 'default' 29 | self.vpc = aws_ec2.Vpc.from_lookup(self, 'ExistingVPC', 30 | is_default=True, 31 | vpc_name=vpc_name 32 | ) 33 | else: 34 | #XXX: To use more than 2 AZs, be sure to specify the account and region on your stack. 35 | #XXX: https://docs.aws.amazon.com/cdk/api/latest/python/aws_cdk.aws_ec2/Vpc.html 36 | self.vpc = aws_ec2.Vpc(self, 'RAGAppVPC', 37 | ip_addresses=aws_ec2.IpAddresses.cidr("10.0.0.0/16"), 38 | max_azs=3, 39 | 40 | # 'subnetConfiguration' specifies the "subnet groups" to create. 41 | # Every subnet group will have a subnet for each AZ, so this 42 | # configuration will create `2 groups × 3 AZs = 6` subnets. 43 | subnet_configuration=[ 44 | { 45 | "cidrMask": 20, 46 | "name": "Public", 47 | "subnetType": aws_ec2.SubnetType.PUBLIC, 48 | }, 49 | { 50 | "cidrMask": 20, 51 | "name": "Private", 52 | "subnetType": aws_ec2.SubnetType.PRIVATE_WITH_EGRESS 53 | } 54 | ], 55 | gateway_endpoints={ 56 | "S3": aws_ec2.GatewayVpcEndpointOptions( 57 | service=aws_ec2.GatewayVpcEndpointAwsService.S3 58 | ) 59 | } 60 | ) 61 | 62 | 63 | cdk.CfnOutput(self, 'VPCID', value=self.vpc.vpc_id, 64 | export_name=f'{self.stack_name}-VPCID') 65 | -------------------------------------------------------------------------------- /cdk_stacks/requirements.txt: -------------------------------------------------------------------------------- 1 | aws-cdk-lib==2.171.1 2 | constructs>=10.0.0,<11.0.0 3 | cdklabs.generative-ai-cdk-constructs==0.1.286 4 | -------------------------------------------------------------------------------- /cdk_stacks/source.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | rem The sole purpose of this script is to make the command 4 | rem 5 | rem source .venv/bin/activate 6 | rem 7 | rem (which activates a Python virtualenv on Linux or Mac OS X) work on Windows. 8 | rem On Windows, this command just runs this batch file (the argument is ignored). 9 | rem 10 | rem Now we don't need to document a Windows command for activating a virtualenv. 11 | 12 | echo Executing .venv\Scripts\activate.bat for you 13 | .venv\Scripts\activate.bat 14 | -------------------------------------------------------------------------------- /data_ingestion_to_vectordb/container/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10.13-slim 2 | 3 | # pip leaves the install caches populated which uses a 4 | # significant amount of space. These optimizations save a fair 5 | # amount of space in the image, which reduces start up time. 6 | RUN pip --no-cache-dir install -U pip 7 | RUN pip --no-cache-dir install boto3==1.33.9 \ 8 | langchain==0.1.11 \ 9 | langchain-community==0.0.29 \ 10 | SQLAlchemy==2.0.2 \ 11 | opensearch-py==2.2.0 \ 12 | beautifulsoup4==4.12.3 13 | 14 | # Include python script for retrieving credentials 15 | # from AWS SecretsManager and Sagemaker helper classes 16 | ADD credentials.py /code/ 17 | ADD sm_helper.py /code/ 18 | 19 | # Set some environment variables. PYTHONUNBUFFERED keeps Python from buffering our standard 20 | # output stream, which means that logs can be delivered to the user quickly. PYTHONDONTWRITEBYTECODE 21 | # keeps Python from writing the .pyc files which are unnecessary in this case. We also update 22 | # PATH so that the train and serve programs are found when the container is invoked. 23 | ENV PYTHONUNBUFFERED=TRUE 24 | ENV PYTHONDONTWRITEBYTECODE=TRUE 25 | -------------------------------------------------------------------------------- /data_ingestion_to_vectordb/container/credentials.py: -------------------------------------------------------------------------------- 1 | """ 2 | Retrieve credentials password for given username from AWS SecretsManager 3 | """ 4 | import json 5 | import boto3 6 | 7 | def get_credentials(secret_id: str, region_name: str) -> str: 8 | 9 | client = boto3.client('secretsmanager', region_name=region_name) 10 | response = client.get_secret_value(SecretId=secret_id) 11 | secrets_value = json.loads(response['SecretString']) 12 | 13 | return secrets_value -------------------------------------------------------------------------------- /data_ingestion_to_vectordb/container/load_data_into_opensearch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # this is needed because the credentials.py and sm_helper.py 5 | # are in /code directory of the custom container we are going 6 | # to create for Sagemaker Processing Job 7 | sys.path.insert(1, '/code') 8 | 9 | import glob 10 | import time 11 | import logging 12 | import argparse 13 | import multiprocessing as mp 14 | from itertools import repeat 15 | from functools import partial 16 | from typing import ( 17 | List, 18 | Tuple 19 | ) 20 | 21 | import numpy as np 22 | 23 | from langchain_community.document_loaders import ReadTheDocsLoader 24 | from langchain_community.vectorstores import OpenSearchVectorSearch 25 | from langchain.text_splitter import RecursiveCharacterTextSplitter 26 | 27 | from opensearchpy import ( 28 | OpenSearch, 29 | RequestsHttpConnection, 30 | AWSV4SignerAuth 31 | ) 32 | 33 | from credentials import get_credentials 34 | from sm_helper import create_sagemaker_embeddings_from_js_model 35 | 36 | 37 | # global constants 38 | MAX_OS_DOCS_PER_PUT = 500 39 | TOTAL_INDEX_CREATION_WAIT_TIME = 60 40 | PER_ITER_SLEEP_TIME = 5 41 | 42 | logger = logging.getLogger() 43 | logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr) 44 | 45 | 46 | def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch: 47 | #update the region if you're working other than us-east-1 48 | 49 | aos_client = OpenSearch( 50 | hosts = [{'host': host.rstrip('/').replace("https://", ""), 'port': 443}], 51 | http_auth = http_auth, 52 | use_ssl = True, 53 | verify_certs = True, 54 | connection_class = RequestsHttpConnection 55 | ) 56 | exists = aos_client.indices.exists(index_name) 57 | logger.info(f"index_name={index_name}, exists={exists}") 58 | return exists 59 | 60 | 61 | def process_shard(shard, embeddings_model_endpoint_name, aws_region, os_index_name, os_domain_ep, os_http_auth) -> int: 62 | logger.info(f'Starting process_shard of {len(shard)} chunks.') 63 | st = time.time() 64 | embeddings = create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name, aws_region) 65 | docsearch = OpenSearchVectorSearch(index_name=os_index_name, 66 | embedding_function=embeddings, 67 | opensearch_url=os_domain_ep, 68 | http_auth=os_http_auth) 69 | docsearch.add_documents(documents=shard) 70 | et = time.time() - st 71 | logger.info(f'Shard completed in {et} seconds.') 72 | return 0 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | 78 | parser.add_argument("--opensearch-cluster-domain", type=str, default=None) 79 | parser.add_argument("--opensearch-secretid", type=str, default=None) 80 | parser.add_argument("--opensearch-index-name", type=str, default=None) 81 | parser.add_argument("--aws-region", type=str, default="us-east-1") 82 | parser.add_argument("--embeddings-model-endpoint-name", type=str, default=None) 83 | parser.add_argument("--chunk-size-for-doc-split", type=int, default=500) 84 | parser.add_argument("--chunk-overlap-for-doc-split", type=int, default=30) 85 | parser.add_argument("--input-data-dir", type=str, default="/opt/ml/processing/input_data") 86 | parser.add_argument("--process-count", type=int, default=1) 87 | parser.add_argument("--create-index-hint-file", type=str, default="_create_index_hint") 88 | args, _ = parser.parse_known_args() 89 | 90 | logger.info("Received arguments {}".format(args)) 91 | # list all the files 92 | files = glob.glob(os.path.join(args.input_data_dir, "*.*")) 93 | logger.info(f"there are {len(files)} files to process in the {args.input_data_dir} folder") 94 | 95 | # retrieve secret to talk to opensearch 96 | creds = get_credentials(args.opensearch_secretid, args.aws_region) 97 | http_auth = (creds['username'], creds['password']) 98 | 99 | loader = ReadTheDocsLoader(args.input_data_dir) 100 | text_splitter = RecursiveCharacterTextSplitter( 101 | # Set a really small chunk size, just to show. 102 | chunk_size=args.chunk_size_for_doc_split, 103 | chunk_overlap=args.chunk_overlap_for_doc_split, 104 | length_function=len, 105 | ) 106 | 107 | # Stage one: read all the docs, split them into chunks. 108 | st = time.time() 109 | logger.info('Loading documents ...') 110 | docs = loader.load() 111 | 112 | # add a custom metadata field, such as timestamp 113 | for doc in docs: 114 | doc.metadata['timestamp'] = time.time() 115 | doc.metadata['embeddings_model'] = args.embeddings_model_endpoint_name 116 | chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs]) 117 | et = time.time() - st 118 | logger.info(f'Time taken: {et} seconds. {len(chunks)} chunks generated') 119 | 120 | 121 | db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1 122 | print(f'Loading chunks into vector store ... using {db_shards} shards') 123 | st = time.time() 124 | shards = np.array_split(chunks, db_shards) 125 | 126 | t1 = time.time() 127 | 128 | # first check if index exists, if it does then call the add_documents function 129 | # otherwise call the from_documents function which would first create the index 130 | # and then do a bulk add. Both add_documents and from_documents do a bulk add 131 | # but it is important to call from_documents first so that the index is created 132 | # correctly for K-NN 133 | index_exists = check_if_index_exists(args.opensearch_index_name, 134 | args.aws_region, 135 | args.opensearch_cluster_domain, 136 | http_auth) 137 | 138 | embeddings = create_sagemaker_embeddings_from_js_model(args.embeddings_model_endpoint_name, args.aws_region) 139 | 140 | if index_exists is False: 141 | # create an index if the create index hint file exists 142 | path = os.path.join(args.input_data_dir, args.create_index_hint_file) 143 | if os.path.isfile(path) is True: 144 | logger.info(f"index {args.opensearch_index_name} does not exist but {path} file is present so will create index") 145 | # by default langchain would create a k-NN index and the embeddings would be ingested as a k-NN vector type 146 | docsearch = OpenSearchVectorSearch.from_documents(index_name=args.opensearch_index_name, 147 | documents=shards[0], 148 | embedding=embeddings, 149 | opensearch_url=args.opensearch_cluster_domain, 150 | http_auth=http_auth) 151 | # we now need to start the loop below for the second shard 152 | shard_start_index = 1 153 | else: 154 | logger.info(f"index {args.opensearch_index_name} does not exist and {path} file is not present, " 155 | f"will wait for some other node to create the index") 156 | shard_start_index = 0 157 | # start a loop to wait for index creation by another node 158 | time_slept = 0 159 | while True: 160 | logger.info(f"index {args.opensearch_index_name} still does not exist, sleeping...") 161 | time.sleep(PER_ITER_SLEEP_TIME) 162 | index_exists = check_if_index_exists(args.opensearch_index_name, 163 | args.aws_region, 164 | args.opensearch_cluster_domain, 165 | http_auth) 166 | if index_exists is True: 167 | logger.info(f"index {args.opensearch_index_name} now exists") 168 | break 169 | time_slept += PER_ITER_SLEEP_TIME 170 | if time_slept >= TOTAL_INDEX_CREATION_WAIT_TIME: 171 | logger.error(f"time_slept={time_slept} >= {TOTAL_INDEX_CREATION_WAIT_TIME}, not waiting anymore for index creation") 172 | break 173 | 174 | else: 175 | logger.info(f"index={args.opensearch_index_name} does exists, going to call add_documents") 176 | shard_start_index = 0 177 | 178 | with mp.Pool(processes = args.process_count) as pool: 179 | results = pool.map(partial(process_shard, 180 | embeddings_model_endpoint_name=args.embeddings_model_endpoint_name, 181 | aws_region=args.aws_region, 182 | os_index_name=args.opensearch_index_name, 183 | os_domain_ep=args.opensearch_cluster_domain, 184 | os_http_auth=http_auth), 185 | shards[shard_start_index:]) 186 | 187 | t2 = time.time() 188 | logger.info(f'run time in seconds: {t2-t1:.2f}') 189 | logger.info("all done") 190 | -------------------------------------------------------------------------------- /data_ingestion_to_vectordb/container/sm_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for using Samgemaker Endpoint via langchain 3 | """ 4 | import time 5 | import json 6 | import logging 7 | from typing import List 8 | 9 | from langchain_community.embeddings import SagemakerEndpointEmbeddings 10 | from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | # extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding function 15 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 16 | def embed_documents( 17 | self, texts: List[str], chunk_size: int = 5 18 | ) -> List[List[float]]: 19 | """Compute doc embeddings using a SageMaker Inference Endpoint. 20 | 21 | Args: 22 | texts: The list of texts to embed. 23 | chunk_size: The chunk size defines how many input texts will 24 | be grouped together as request. If None, will use the 25 | chunk size specified by the class. 26 | 27 | Returns: 28 | List of embeddings, one for each text. 29 | """ 30 | results = [] 31 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 32 | st = time.time() 33 | for i in range(0, len(texts), _chunk_size): 34 | response = self._embedding_func(texts[i:i + _chunk_size]) 35 | results.extend(response) 36 | time_taken = time.time() - st 37 | logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}") 38 | return results 39 | 40 | 41 | # class for serializing/deserializing requests/responses to/from the embeddings model 42 | class ContentHandler(EmbeddingsContentHandler): 43 | content_type = "application/json" 44 | accepts = "application/json" 45 | 46 | def transform_input(self, prompt: str, model_kwargs={}) -> bytes: 47 | 48 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) 49 | return input_str.encode('utf-8') 50 | 51 | def transform_output(self, output: bytes) -> str: 52 | 53 | response_json = json.loads(output.read().decode("utf-8")) 54 | embeddings = response_json["embedding"] 55 | if len(embeddings) == 1: 56 | return [embeddings[0]] 57 | return embeddings 58 | 59 | 60 | def create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name: str, aws_region: str) -> SagemakerEndpointEmbeddingsJumpStart: 61 | # all set to create the objects for the ContentHandler and 62 | # SagemakerEndpointEmbeddingsJumpStart classes 63 | content_handler = ContentHandler() 64 | 65 | # note the name of the LLM Sagemaker endpoint, this is the model that we would 66 | # be using for generating the embeddings 67 | embeddings = SagemakerEndpointEmbeddingsJumpStart( 68 | endpoint_name=embeddings_model_endpoint_name, 69 | region_name=aws_region, 70 | content_handler=content_handler 71 | ) 72 | return embeddings 73 | 74 | -------------------------------------------------------------------------------- /data_ingestion_to_vectordb/data_ingestion_to_opensearch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "93179240-9c5f-4ba6-a1c7-3a981624f794", 6 | "metadata": {}, 7 | "source": [ 8 | "# Ingest massive amounts of data to a Vector DB (Amazon OpenSearch)\n", 9 | "**_Use of Amazon OpenSearch as a vector database for storing embeddings_**\n", 10 | "\n", 11 | "This notebook works well on `ml.t3.xlarge` instance with `Python3` kernel from **JupyterLab** or `Data Science 3.0` kernel from **SageMaker Studio Class**.\n", 12 | "\n", 13 | "Here is a list of packages that are used in this notebook.\n", 14 | "\n", 15 | "```\n", 16 | "!pip list | grep -E -w \"sagemaker|ipywidgets|langchain|opensearch-py|faiss-cpu|numpy|sh|SQLAlchemy\"\n", 17 | "---------------------------------------------------------------------------------------------------\n", 18 | "faiss-cpu 1.7.4\n", 19 | "ipywidgets 7.6.5\n", 20 | "langchain 0.1.11\n", 21 | "langchain-community 0.0.29\n", 22 | "langchain-core 0.1.23\n", 23 | "numpy 1.26.2\n", 24 | "opensearch-py 2.2.0\n", 25 | "sagemaker 2.199.0\n", 26 | "sagemaker-studio-image-build 0.6.0\n", 27 | "sh 2.0.4\n", 28 | "SQLAlchemy 2.0.2\n", 29 | "```" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "79aae52c-cd7a-4637-a07d-9c0131dc7d0a", 35 | "metadata": {}, 36 | "source": [ 37 | "## Step 1: Setup\n", 38 | "Install the required packages." 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "87e64f84-b7ac-427d-b5a8-cf98b430be9b", 45 | "metadata": { 46 | "tags": [] 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "%%capture --no-stderr\n", 51 | "\n", 52 | "!pip install -U pip\n", 53 | "!pip install -U langchain==0.1.11\n", 54 | "!pip install -U langchain-community==0.0.29\n", 55 | "!pip install -U SQLAlchemy==2.0.2\n", 56 | "!pip install -U opensearch-py==2.2.0\n", 57 | "!pip install -U faiss_cpu==1.7.4\n", 58 | "!pip install -U sh==2.0.4\n", 59 | "!pip install -U sagemaker-studio-image-build==0.6.0" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "d88757ba-7ae1-4efb-9c02-ab17ec22e79a", 66 | "metadata": { 67 | "tags": [] 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "!pip list | grep -E -w \"sagemaker|ipywidgets|langchain|opensearch-py|faiss-cpu|numpy|sh|SQLAlchemy\"" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "c017bc3f-e507-4f0c-b640-ea774c5ea9c8", 77 | "metadata": {}, 78 | "source": [ 79 | "## Step 2: Download the data from the web and upload to S3\n", 80 | "\n", 81 | "In this step we use `wget` to crawl a Python documentation style website data. All files other than `html`, `txt` and `md` are removed. **This data download would take a few minutes**." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "5c2b8c14-0ffc-4090-adf1-c2a8a1bdebaa", 88 | "metadata": { 89 | "tags": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "WEBSITE = \"https://sagemaker.readthedocs.io/en/stable/\"\n", 94 | "DOMAIN = \"sagemaker.readthedocs.io\"\n", 95 | "DATA_DIR = \"docs\"" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "0eb232ee-6b62-4718-9104-345fe7978703", 102 | "metadata": { 103 | "tags": [] 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "!python ./scripts/get_data.py --website {WEBSITE} --domain {DOMAIN} --output-dir {DATA_DIR}" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "8ee1fbb8-583a-4c41-a831-715e4250ff3c", 114 | "metadata": { 115 | "tags": [] 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "import boto3\n", 120 | "import sagemaker\n", 121 | "\n", 122 | "sagemaker_session = sagemaker.session.Session()\n", 123 | "aws_region = boto3.Session().region_name\n", 124 | "bucket = sagemaker_session.default_bucket()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "6c127969-4abc-4a31-8829-c00bee321a95", 131 | "metadata": { 132 | "tags": [] 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "CREATE_OS_INDEX_HINT_FILE = \"_create_index_hint\"\n", 137 | "app_name = 'llm-app-rag'" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "25217f27-4995-4da5-8fc4-b1b9533185b5", 144 | "metadata": { 145 | "tags": [] 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "# create a dummy file called _create_index to provide a hint for opensearch index creation\n", 150 | "# this is needed for Sagemaker Processing Job when there are multiple instance nodes\n", 151 | "# all running the same code for data ingestion but only one node needs to create the index\n", 152 | "!touch {DATA_DIR}/{CREATE_OS_INDEX_HINT_FILE}\n", 153 | "\n", 154 | "# upload this data to S3, to be used when we run the Sagemaker Processing Job\n", 155 | "!aws s3 cp --recursive {DATA_DIR}/ s3://{bucket}/{app_name}/{DOMAIN}" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "af7dc134", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "from typing import List\n", 166 | "\n", 167 | "\n", 168 | "def get_cfn_outputs(stackname: str, region_name: str) -> List:\n", 169 | " cfn = boto3.client('cloudformation', region_name=region_name)\n", 170 | " outputs = {}\n", 171 | " for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:\n", 172 | " outputs[output['OutputKey']] = output['OutputValue']\n", 173 | " return outputs" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "b616e2c8", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "CFN_STACK_NAME = 'EmbeddingEndpointStack'\n", 184 | "\n", 185 | "cfn_stack_outputs = get_cfn_outputs(CFN_STACK_NAME, aws_region)\n", 186 | "embeddings_model_endpoint_name = cfn_stack_outputs['EmbeddingEndpointName']" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "cbef08c3", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "CFN_STACK_NAME = \"RAGOpenSearchStack\"\n", 197 | "cfn_stack_outputs = get_cfn_outputs(CFN_STACK_NAME, aws_region)\n", 198 | "\n", 199 | "opensearch_domain_endpoint = cfn_stack_outputs['OpenSearchDomainEndpoint']\n", 200 | "opensearch_secretid = cfn_stack_outputs['OpenSearchSecret']\n", 201 | "opensearch_index = 'llm_rag_embeddings'" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "fa041f3f", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "CHUNK_SIZE_FOR_DOC_SPLIT = 600\n", 212 | "CHUNK_OVERLAP_FOR_DOC_SPLIT = 20" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "id": "15667be7-43b5-4954-95e0-885c9173f82c", 218 | "metadata": { 219 | "tags": [] 220 | }, 221 | "source": [ 222 | "## Step 3: Load data into OpenSearch\n", 223 | "\n", 224 | "- Option 1) Parallel loading data with SageMaker Processing Job\n", 225 | "- Option 2) Sequential loading data with Document Loader" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "id": "406c6dd1", 231 | "metadata": {}, 232 | "source": [ 233 | "### Option 1) Parallel loading data with SageMaker Processing Job\n", 234 | "\n", 235 | "We now have a working script that is able to ingest data into an OpenSearch index. But for this to work for massive amounts of data we need to scale up the processing by running this code in a distributed fashion. We will do this using Sagemkaer Processing Job. This involves the following steps:\n", 236 | "\n", 237 | "1. Create a custom container in which we will install the `langchain` and `opensearch-py` packges and then upload this container image to Amazon Elastic Container Registry (ECR).\n", 238 | "2. Use the Sagemaker `ScriptProcessor` class to create a Sagemaker Processing job that will run on multiple nodes.\n", 239 | " - The data files available in S3 are automatically distributed across in the Sagemaker Processing Job instances by setting `s3_data_distribution_type='ShardedByS3Key'` as part of the `ProcessingInput` provided to the processing job.\n", 240 | " - Each node processes a subset of the files and this brings down the overall time required to ingest the data into Opensearch.\n", 241 | " - Each node also uses Python `multiprocessing` to internally also parallelize the file processing. Thus, **there are two levels of parallelization happening, one at the cluster level where individual nodes are distributing the work (files) amongst themselves and another at the node level where the files in a node are also split between multiple processes running on the node**." 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "id": "b45d2938-994f-432d-8c50-92269f22f4b8", 247 | "metadata": {}, 248 | "source": [ 249 | "### Create custom container\n", 250 | "\n", 251 | "We will now create a container locally and push the container image to ECR. **The container creation process takes about 1 minute**.\n", 252 | "\n", 253 | "1. The container include all the Python packages we need i.e. `langchain`, `opensearch-py`, `sagemaker` and `beautifulsoup4`.\n", 254 | "1. The container also includes the `credentials.py` script for retrieving credentials from Secrets Manager and `sm_helper.py` for helping to create SageMaker endpoint classes that langchain uses." 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "01848619-8b49-48da-8cbf-c9cbbd8d1e40", 261 | "metadata": { 262 | "tags": [] 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "DOCKER_IMAGE = \"load-data-opensearch-custom\"\n", 267 | "DOCKER_IMAGE_TAG = \"latest\"" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "id": "97bd70a7-8f61-477c-a8ef-82c0b5cd7821", 274 | "metadata": { 275 | "tags": [] 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "!cd ./container && sm-docker build . --repository {DOCKER_IMAGE}:{DOCKER_IMAGE_TAG}" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "1089064e-2099-4226-82c1-c7c406203c49", 285 | "metadata": {}, 286 | "source": [ 287 | "### Create and run the Sagemaker Processing Job\n", 288 | "\n", 289 | "Now we will run the Sagemaker Processing Job to ingest the data into OpenSearch." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "afb7373c-e80f-4d1a-a8dc-0dc79fb28e8a", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "import sys\n", 300 | "import time\n", 301 | "import logging\n", 302 | "\n", 303 | "from sagemaker.processing import (\n", 304 | " ProcessingInput,\n", 305 | " ScriptProcessor,\n", 306 | ")\n", 307 | "\n", 308 | "logger = logging.getLogger()\n", 309 | "logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)\n", 310 | "\n", 311 | "\n", 312 | "# setup the parameters for the job\n", 313 | "base_job_name = f\"{app_name}-job\"\n", 314 | "tags = [{\"Key\": \"data\", \"Value\": \"embeddings-for-llm-apps\"}]\n", 315 | "\n", 316 | "account_id = boto3.client(\"sts\").get_caller_identity()[\"Account\"]\n", 317 | "aws_role = sagemaker_session.get_caller_identity_arn()\n", 318 | "\n", 319 | "# use the custom container we just created\n", 320 | "image_uri = f\"{account_id}.dkr.ecr.{aws_region}.amazonaws.com/{DOCKER_IMAGE}:{DOCKER_IMAGE_TAG}\"\n", 321 | "\n", 322 | "# instance type and count determined via trial and error: how much overall processing time\n", 323 | "# and what compute cost works best for your use-case\n", 324 | "instance_type = \"ml.m5.xlarge\"\n", 325 | "instance_count = 3\n", 326 | "logger.info(f\"base_job_name={base_job_name}, tags={tags}, image_uri={image_uri}, instance_type={instance_type}, instance_count={instance_count}\")\n", 327 | "\n", 328 | "# setup the ScriptProcessor with the above parameters\n", 329 | "processor = ScriptProcessor(base_job_name=base_job_name,\n", 330 | " image_uri=image_uri,\n", 331 | " role=aws_role,\n", 332 | " instance_type=instance_type,\n", 333 | " instance_count=instance_count,\n", 334 | " command=[\"python3\"],\n", 335 | " tags=tags)\n", 336 | "\n", 337 | "# setup input from S3, note the ShardedByS3Key, this ensures that\n", 338 | "# each instance gets a random and equal subset of the files in S3.\n", 339 | "inputs = [ProcessingInput(source=f\"s3://{bucket}/{app_name}/{DOMAIN}\",\n", 340 | " destination='/opt/ml/processing/input_data',\n", 341 | " s3_data_distribution_type='ShardedByS3Key',\n", 342 | " s3_data_type='S3Prefix')]\n", 343 | "\n", 344 | "\n", 345 | "logger.info(f\"creating an opensearch index with name={opensearch_index}\")\n", 346 | "\n", 347 | "# ready to run the processing job\n", 348 | "st = time.time()\n", 349 | "processor.run(code=\"container/load_data_into_opensearch.py\",\n", 350 | " inputs=inputs,\n", 351 | " outputs=[],\n", 352 | " arguments=[\"--opensearch-cluster-domain\", opensearch_domain_endpoint,\n", 353 | " \"--opensearch-secretid\", opensearch_secretid,\n", 354 | " \"--opensearch-index-name\", opensearch_index,\n", 355 | " \"--aws-region\", aws_region,\n", 356 | " \"--embeddings-model-endpoint-name\", embeddings_model_endpoint_name,\n", 357 | " \"--chunk-size-for-doc-split\", str(CHUNK_SIZE_FOR_DOC_SPLIT),\n", 358 | " \"--chunk-overlap-for-doc-split\", str(CHUNK_OVERLAP_FOR_DOC_SPLIT),\n", 359 | " \"--input-data-dir\", \"/opt/ml/processing/input_data\",\n", 360 | " \"--create-index-hint-file\", CREATE_OS_INDEX_HINT_FILE,\n", 361 | " \"--process-count\", \"2\"])\n", 362 | "\n", 363 | "time_taken = time.time() - st\n", 364 | "logger.info(f\"processing job completed, total time taken={time_taken}s\")\n", 365 | "\n", 366 | "preprocessing_job_description = processor.jobs[-1].describe()\n", 367 | "logger.info(preprocessing_job_description)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "bc4657c0", 373 | "metadata": {}, 374 | "source": [ 375 | "### Option 2) Sequential loading data with Document Loader" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "id": "c9ef2a96", 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "%%capture --no-stderr\n", 386 | "\n", 387 | "!pip install -Uq beautifulsoup4==4.12.3" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "00b3c8e3", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "%%time\n", 398 | "\n", 399 | "from langchain_community.document_loaders import ReadTheDocsLoader\n", 400 | "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", 401 | "import time\n", 402 | "\n", 403 | "\n", 404 | "loader = ReadTheDocsLoader(DATA_DIR)\n", 405 | "text_splitter = RecursiveCharacterTextSplitter(\n", 406 | " chunk_size=CHUNK_SIZE_FOR_DOC_SPLIT,\n", 407 | " chunk_overlap=CHUNK_OVERLAP_FOR_DOC_SPLIT,\n", 408 | " length_function=len,\n", 409 | ")\n", 410 | "\n", 411 | "\n", 412 | "docs = loader.load()\n", 413 | "\n", 414 | "# add a custom metadata field, such as timestamp\n", 415 | "for doc in docs:\n", 416 | " doc.metadata['timestamp'] = time.time()\n", 417 | " doc.metadata['embeddings_model'] = embeddings_model_endpoint_name" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "id": "deb8e52e", 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "chunks = text_splitter.create_documents(\n", 428 | " [doc.page_content for doc in docs],\n", 429 | " metadatas=[doc.metadata for doc in docs]\n", 430 | ")" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "id": "a88606e0", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "from container.sm_helper import create_sagemaker_embeddings_from_js_model\n", 441 | "\n", 442 | "\n", 443 | "embeddings = create_sagemaker_embeddings_from_js_model(\n", 444 | " embeddings_model_endpoint_name,\n", 445 | " aws_region\n", 446 | ")" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "id": "aa92edfd", 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "import numpy as np\n", 457 | "\n", 458 | "\n", 459 | "MAX_OS_DOCS_PER_PUT = 500\n", 460 | "\n", 461 | "db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1\n", 462 | "shards = np.array_split(chunks, db_shards)\n", 463 | "\n", 464 | "print(f'Loading chunks into vector store ... using {len(shards)} shards')" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "id": "ed247419", 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "from container.credentials import get_credentials\n", 475 | "\n", 476 | "\n", 477 | "creds = get_credentials(opensearch_secretid, aws_region)\n", 478 | "http_auth = (creds['username'], creds['password'])\n", 479 | "http_auth" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "id": "316802fd", 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "from langchain_community.vectorstores import OpenSearchVectorSearch\n", 490 | "\n", 491 | "\n", 492 | "docsearch = OpenSearchVectorSearch.from_documents(\n", 493 | " index_name=opensearch_index,\n", 494 | " documents=shards[0],\n", 495 | " embedding=embeddings,\n", 496 | " opensearch_url=opensearch_domain_endpoint,\n", 497 | " http_auth=http_auth\n", 498 | ")" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "id": "ac1e0ea0", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "%%time\n", 509 | "\n", 510 | "for i, shard in enumerate(shards[1:]):\n", 511 | " docsearch.add_documents(documents=shard)\n", 512 | " print(f\"[{i+1}] shard is added.\")" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "id": "1e444161-262e-44e5-ad31-e490a763be4e", 518 | "metadata": {}, 519 | "source": [ 520 | "## Step 4: Do a similarity search for user input to documents (embeddings) in OpenSearch" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "id": "294a7292-8bdb-4d11-a23d-130a4a039cd2", 527 | "metadata": { 528 | "tags": [] 529 | }, 530 | "outputs": [], 531 | "source": [ 532 | "from container.credentials import get_credentials\n", 533 | "from langchain.vectorstores import OpenSearchVectorSearch\n", 534 | "from container.sm_helper import create_sagemaker_embeddings_from_js_model\n", 535 | "\n", 536 | "creds = get_credentials(opensearch_secretid, aws_region)\n", 537 | "http_auth = (creds['username'], creds['password'])\n", 538 | "\n", 539 | "docsearch = OpenSearchVectorSearch(index_name=opensearch_index,\n", 540 | " embedding_function=create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name,\n", 541 | " aws_region),\n", 542 | " opensearch_url=opensearch_domain_endpoint,\n", 543 | " http_auth=http_auth)\n", 544 | "\n", 545 | "q = \"Which XGBoost versions does SageMaker support?\"\n", 546 | "docs = docsearch.similarity_search(q, k=3) #, search_type=\"script_scoring\", space_type=\"cosinesimil\"\n", 547 | "for doc in docs:\n", 548 | " logger.info(\"----------\")\n", 549 | " logger.info(f\"content=\\\"{doc.page_content}\\\",\\nmetadata=\\\"{doc.metadata}\\\"\")" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "id": "6e29eae5-c463-4153-9167-e4628c74d13c", 555 | "metadata": { 556 | "tags": [] 557 | }, 558 | "source": [ 559 | "## Cleanup\n", 560 | "\n", 561 | "To avoid incurring future charges, delete the resources. You can do this by deleting the CloudFormation template used to create the IAM role and SageMaker notebook." 562 | ] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "id": "59ce3fe8-bb71-4e22-a551-2475eb2d16b7", 567 | "metadata": {}, 568 | "source": [ 569 | "---\n", 570 | "\n", 571 | "## Conclusion\n", 572 | "In this notebook we were able to see how to use LLMs deployed on a SageMaker Endpoint to generate embeddings and then ingest those embeddings into OpenSearch and finally do a similarity search for user input to the documents (embeddings) stored in OpenSearch. We used langchain as an abstraction layer to talk to both the SageMaker Endpoint as well as OpenSearch." 573 | ] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "id": "53386268-0cf9-4a37-b3d0-711fba1e5585", 578 | "metadata": {}, 579 | "source": [ 580 | "---\n", 581 | "\n", 582 | "## Appendix" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": null, 588 | "id": "7332274c-586d-4cf9-838f-ea7e0cbe6c0f", 589 | "metadata": { 590 | "tags": [] 591 | }, 592 | "outputs": [], 593 | "source": [ 594 | "import numpy as np\n", 595 | "from container.sm_helper import create_sagemaker_embeddings_from_js_model\n", 596 | "\n", 597 | "CFN_STACK_NAME = 'EmbeddingEndpointStack'\n", 598 | "cfn_stack_outputs = get_cfn_outputs(CFN_STACK_NAME, aws_region)\n", 599 | "embeddings_model_endpoint_name = cfn_stack_outputs['EmbeddingEndpointName']\n", 600 | "\n", 601 | "embeddings = create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name, aws_region)\n", 602 | "\n", 603 | "text = \"This is a sample query.\"\n", 604 | "query_result = embeddings.embed_query(text)\n", 605 | "\n", 606 | "print(np.array(query_result))\n", 607 | "print(f\"length: {len(query_result)}\")" 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "id": "dd881bab", 613 | "metadata": {}, 614 | "source": [ 615 | "## References\n", 616 | "\n", 617 | " * [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/)\n", 618 | " * [Using the Amazon SageMaker Studio Image Build CLI to build container images from your Studio notebooks](https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/)\n", 619 | " * [LangChain](https://python.langchain.com/docs/get_started/introduction.html) - A framework for developing applications powered by language models." 620 | ] 621 | } 622 | ], 623 | "metadata": { 624 | "availableInstances": [ 625 | { 626 | "_defaultOrder": 0, 627 | "_isFastLaunch": true, 628 | "category": "General purpose", 629 | "gpuNum": 0, 630 | "hideHardwareSpecs": false, 631 | "memoryGiB": 4, 632 | "name": "ml.t3.medium", 633 | "vcpuNum": 2 634 | }, 635 | { 636 | "_defaultOrder": 1, 637 | "_isFastLaunch": false, 638 | "category": "General purpose", 639 | "gpuNum": 0, 640 | "hideHardwareSpecs": false, 641 | "memoryGiB": 8, 642 | "name": "ml.t3.large", 643 | "vcpuNum": 2 644 | }, 645 | { 646 | "_defaultOrder": 2, 647 | "_isFastLaunch": false, 648 | "category": "General purpose", 649 | "gpuNum": 0, 650 | "hideHardwareSpecs": false, 651 | "memoryGiB": 16, 652 | "name": "ml.t3.xlarge", 653 | "vcpuNum": 4 654 | }, 655 | { 656 | "_defaultOrder": 3, 657 | "_isFastLaunch": false, 658 | "category": "General purpose", 659 | "gpuNum": 0, 660 | "hideHardwareSpecs": false, 661 | "memoryGiB": 32, 662 | "name": "ml.t3.2xlarge", 663 | "vcpuNum": 8 664 | }, 665 | { 666 | "_defaultOrder": 4, 667 | "_isFastLaunch": true, 668 | "category": "General purpose", 669 | "gpuNum": 0, 670 | "hideHardwareSpecs": false, 671 | "memoryGiB": 8, 672 | "name": "ml.m5.large", 673 | "vcpuNum": 2 674 | }, 675 | { 676 | "_defaultOrder": 5, 677 | "_isFastLaunch": false, 678 | "category": "General purpose", 679 | "gpuNum": 0, 680 | "hideHardwareSpecs": false, 681 | "memoryGiB": 16, 682 | "name": "ml.m5.xlarge", 683 | "vcpuNum": 4 684 | }, 685 | { 686 | "_defaultOrder": 6, 687 | "_isFastLaunch": false, 688 | "category": "General purpose", 689 | "gpuNum": 0, 690 | "hideHardwareSpecs": false, 691 | "memoryGiB": 32, 692 | "name": "ml.m5.2xlarge", 693 | "vcpuNum": 8 694 | }, 695 | { 696 | "_defaultOrder": 7, 697 | "_isFastLaunch": false, 698 | "category": "General purpose", 699 | "gpuNum": 0, 700 | "hideHardwareSpecs": false, 701 | "memoryGiB": 64, 702 | "name": "ml.m5.4xlarge", 703 | "vcpuNum": 16 704 | }, 705 | { 706 | "_defaultOrder": 8, 707 | "_isFastLaunch": false, 708 | "category": "General purpose", 709 | "gpuNum": 0, 710 | "hideHardwareSpecs": false, 711 | "memoryGiB": 128, 712 | "name": "ml.m5.8xlarge", 713 | "vcpuNum": 32 714 | }, 715 | { 716 | "_defaultOrder": 9, 717 | "_isFastLaunch": false, 718 | "category": "General purpose", 719 | "gpuNum": 0, 720 | "hideHardwareSpecs": false, 721 | "memoryGiB": 192, 722 | "name": "ml.m5.12xlarge", 723 | "vcpuNum": 48 724 | }, 725 | { 726 | "_defaultOrder": 10, 727 | "_isFastLaunch": false, 728 | "category": "General purpose", 729 | "gpuNum": 0, 730 | "hideHardwareSpecs": false, 731 | "memoryGiB": 256, 732 | "name": "ml.m5.16xlarge", 733 | "vcpuNum": 64 734 | }, 735 | { 736 | "_defaultOrder": 11, 737 | "_isFastLaunch": false, 738 | "category": "General purpose", 739 | "gpuNum": 0, 740 | "hideHardwareSpecs": false, 741 | "memoryGiB": 384, 742 | "name": "ml.m5.24xlarge", 743 | "vcpuNum": 96 744 | }, 745 | { 746 | "_defaultOrder": 12, 747 | "_isFastLaunch": false, 748 | "category": "General purpose", 749 | "gpuNum": 0, 750 | "hideHardwareSpecs": false, 751 | "memoryGiB": 8, 752 | "name": "ml.m5d.large", 753 | "vcpuNum": 2 754 | }, 755 | { 756 | "_defaultOrder": 13, 757 | "_isFastLaunch": false, 758 | "category": "General purpose", 759 | "gpuNum": 0, 760 | "hideHardwareSpecs": false, 761 | "memoryGiB": 16, 762 | "name": "ml.m5d.xlarge", 763 | "vcpuNum": 4 764 | }, 765 | { 766 | "_defaultOrder": 14, 767 | "_isFastLaunch": false, 768 | "category": "General purpose", 769 | "gpuNum": 0, 770 | "hideHardwareSpecs": false, 771 | "memoryGiB": 32, 772 | "name": "ml.m5d.2xlarge", 773 | "vcpuNum": 8 774 | }, 775 | { 776 | "_defaultOrder": 15, 777 | "_isFastLaunch": false, 778 | "category": "General purpose", 779 | "gpuNum": 0, 780 | "hideHardwareSpecs": false, 781 | "memoryGiB": 64, 782 | "name": "ml.m5d.4xlarge", 783 | "vcpuNum": 16 784 | }, 785 | { 786 | "_defaultOrder": 16, 787 | "_isFastLaunch": false, 788 | "category": "General purpose", 789 | "gpuNum": 0, 790 | "hideHardwareSpecs": false, 791 | "memoryGiB": 128, 792 | "name": "ml.m5d.8xlarge", 793 | "vcpuNum": 32 794 | }, 795 | { 796 | "_defaultOrder": 17, 797 | "_isFastLaunch": false, 798 | "category": "General purpose", 799 | "gpuNum": 0, 800 | "hideHardwareSpecs": false, 801 | "memoryGiB": 192, 802 | "name": "ml.m5d.12xlarge", 803 | "vcpuNum": 48 804 | }, 805 | { 806 | "_defaultOrder": 18, 807 | "_isFastLaunch": false, 808 | "category": "General purpose", 809 | "gpuNum": 0, 810 | "hideHardwareSpecs": false, 811 | "memoryGiB": 256, 812 | "name": "ml.m5d.16xlarge", 813 | "vcpuNum": 64 814 | }, 815 | { 816 | "_defaultOrder": 19, 817 | "_isFastLaunch": false, 818 | "category": "General purpose", 819 | "gpuNum": 0, 820 | "hideHardwareSpecs": false, 821 | "memoryGiB": 384, 822 | "name": "ml.m5d.24xlarge", 823 | "vcpuNum": 96 824 | }, 825 | { 826 | "_defaultOrder": 20, 827 | "_isFastLaunch": false, 828 | "category": "General purpose", 829 | "gpuNum": 0, 830 | "hideHardwareSpecs": true, 831 | "memoryGiB": 0, 832 | "name": "ml.geospatial.interactive", 833 | "supportedImageNames": [ 834 | "sagemaker-geospatial-v1-0" 835 | ], 836 | "vcpuNum": 0 837 | }, 838 | { 839 | "_defaultOrder": 21, 840 | "_isFastLaunch": true, 841 | "category": "Compute optimized", 842 | "gpuNum": 0, 843 | "hideHardwareSpecs": false, 844 | "memoryGiB": 4, 845 | "name": "ml.c5.large", 846 | "vcpuNum": 2 847 | }, 848 | { 849 | "_defaultOrder": 22, 850 | "_isFastLaunch": false, 851 | "category": "Compute optimized", 852 | "gpuNum": 0, 853 | "hideHardwareSpecs": false, 854 | "memoryGiB": 8, 855 | "name": "ml.c5.xlarge", 856 | "vcpuNum": 4 857 | }, 858 | { 859 | "_defaultOrder": 23, 860 | "_isFastLaunch": false, 861 | "category": "Compute optimized", 862 | "gpuNum": 0, 863 | "hideHardwareSpecs": false, 864 | "memoryGiB": 16, 865 | "name": "ml.c5.2xlarge", 866 | "vcpuNum": 8 867 | }, 868 | { 869 | "_defaultOrder": 24, 870 | "_isFastLaunch": false, 871 | "category": "Compute optimized", 872 | "gpuNum": 0, 873 | "hideHardwareSpecs": false, 874 | "memoryGiB": 32, 875 | "name": "ml.c5.4xlarge", 876 | "vcpuNum": 16 877 | }, 878 | { 879 | "_defaultOrder": 25, 880 | "_isFastLaunch": false, 881 | "category": "Compute optimized", 882 | "gpuNum": 0, 883 | "hideHardwareSpecs": false, 884 | "memoryGiB": 72, 885 | "name": "ml.c5.9xlarge", 886 | "vcpuNum": 36 887 | }, 888 | { 889 | "_defaultOrder": 26, 890 | "_isFastLaunch": false, 891 | "category": "Compute optimized", 892 | "gpuNum": 0, 893 | "hideHardwareSpecs": false, 894 | "memoryGiB": 96, 895 | "name": "ml.c5.12xlarge", 896 | "vcpuNum": 48 897 | }, 898 | { 899 | "_defaultOrder": 27, 900 | "_isFastLaunch": false, 901 | "category": "Compute optimized", 902 | "gpuNum": 0, 903 | "hideHardwareSpecs": false, 904 | "memoryGiB": 144, 905 | "name": "ml.c5.18xlarge", 906 | "vcpuNum": 72 907 | }, 908 | { 909 | "_defaultOrder": 28, 910 | "_isFastLaunch": false, 911 | "category": "Compute optimized", 912 | "gpuNum": 0, 913 | "hideHardwareSpecs": false, 914 | "memoryGiB": 192, 915 | "name": "ml.c5.24xlarge", 916 | "vcpuNum": 96 917 | }, 918 | { 919 | "_defaultOrder": 29, 920 | "_isFastLaunch": true, 921 | "category": "Accelerated computing", 922 | "gpuNum": 1, 923 | "hideHardwareSpecs": false, 924 | "memoryGiB": 16, 925 | "name": "ml.g4dn.xlarge", 926 | "vcpuNum": 4 927 | }, 928 | { 929 | "_defaultOrder": 30, 930 | "_isFastLaunch": false, 931 | "category": "Accelerated computing", 932 | "gpuNum": 1, 933 | "hideHardwareSpecs": false, 934 | "memoryGiB": 32, 935 | "name": "ml.g4dn.2xlarge", 936 | "vcpuNum": 8 937 | }, 938 | { 939 | "_defaultOrder": 31, 940 | "_isFastLaunch": false, 941 | "category": "Accelerated computing", 942 | "gpuNum": 1, 943 | "hideHardwareSpecs": false, 944 | "memoryGiB": 64, 945 | "name": "ml.g4dn.4xlarge", 946 | "vcpuNum": 16 947 | }, 948 | { 949 | "_defaultOrder": 32, 950 | "_isFastLaunch": false, 951 | "category": "Accelerated computing", 952 | "gpuNum": 1, 953 | "hideHardwareSpecs": false, 954 | "memoryGiB": 128, 955 | "name": "ml.g4dn.8xlarge", 956 | "vcpuNum": 32 957 | }, 958 | { 959 | "_defaultOrder": 33, 960 | "_isFastLaunch": false, 961 | "category": "Accelerated computing", 962 | "gpuNum": 4, 963 | "hideHardwareSpecs": false, 964 | "memoryGiB": 192, 965 | "name": "ml.g4dn.12xlarge", 966 | "vcpuNum": 48 967 | }, 968 | { 969 | "_defaultOrder": 34, 970 | "_isFastLaunch": false, 971 | "category": "Accelerated computing", 972 | "gpuNum": 1, 973 | "hideHardwareSpecs": false, 974 | "memoryGiB": 256, 975 | "name": "ml.g4dn.16xlarge", 976 | "vcpuNum": 64 977 | }, 978 | { 979 | "_defaultOrder": 35, 980 | "_isFastLaunch": false, 981 | "category": "Accelerated computing", 982 | "gpuNum": 1, 983 | "hideHardwareSpecs": false, 984 | "memoryGiB": 61, 985 | "name": "ml.p3.2xlarge", 986 | "vcpuNum": 8 987 | }, 988 | { 989 | "_defaultOrder": 36, 990 | "_isFastLaunch": false, 991 | "category": "Accelerated computing", 992 | "gpuNum": 4, 993 | "hideHardwareSpecs": false, 994 | "memoryGiB": 244, 995 | "name": "ml.p3.8xlarge", 996 | "vcpuNum": 32 997 | }, 998 | { 999 | "_defaultOrder": 37, 1000 | "_isFastLaunch": false, 1001 | "category": "Accelerated computing", 1002 | "gpuNum": 8, 1003 | "hideHardwareSpecs": false, 1004 | "memoryGiB": 488, 1005 | "name": "ml.p3.16xlarge", 1006 | "vcpuNum": 64 1007 | }, 1008 | { 1009 | "_defaultOrder": 38, 1010 | "_isFastLaunch": false, 1011 | "category": "Accelerated computing", 1012 | "gpuNum": 8, 1013 | "hideHardwareSpecs": false, 1014 | "memoryGiB": 768, 1015 | "name": "ml.p3dn.24xlarge", 1016 | "vcpuNum": 96 1017 | }, 1018 | { 1019 | "_defaultOrder": 39, 1020 | "_isFastLaunch": false, 1021 | "category": "Memory Optimized", 1022 | "gpuNum": 0, 1023 | "hideHardwareSpecs": false, 1024 | "memoryGiB": 16, 1025 | "name": "ml.r5.large", 1026 | "vcpuNum": 2 1027 | }, 1028 | { 1029 | "_defaultOrder": 40, 1030 | "_isFastLaunch": false, 1031 | "category": "Memory Optimized", 1032 | "gpuNum": 0, 1033 | "hideHardwareSpecs": false, 1034 | "memoryGiB": 32, 1035 | "name": "ml.r5.xlarge", 1036 | "vcpuNum": 4 1037 | }, 1038 | { 1039 | "_defaultOrder": 41, 1040 | "_isFastLaunch": false, 1041 | "category": "Memory Optimized", 1042 | "gpuNum": 0, 1043 | "hideHardwareSpecs": false, 1044 | "memoryGiB": 64, 1045 | "name": "ml.r5.2xlarge", 1046 | "vcpuNum": 8 1047 | }, 1048 | { 1049 | "_defaultOrder": 42, 1050 | "_isFastLaunch": false, 1051 | "category": "Memory Optimized", 1052 | "gpuNum": 0, 1053 | "hideHardwareSpecs": false, 1054 | "memoryGiB": 128, 1055 | "name": "ml.r5.4xlarge", 1056 | "vcpuNum": 16 1057 | }, 1058 | { 1059 | "_defaultOrder": 43, 1060 | "_isFastLaunch": false, 1061 | "category": "Memory Optimized", 1062 | "gpuNum": 0, 1063 | "hideHardwareSpecs": false, 1064 | "memoryGiB": 256, 1065 | "name": "ml.r5.8xlarge", 1066 | "vcpuNum": 32 1067 | }, 1068 | { 1069 | "_defaultOrder": 44, 1070 | "_isFastLaunch": false, 1071 | "category": "Memory Optimized", 1072 | "gpuNum": 0, 1073 | "hideHardwareSpecs": false, 1074 | "memoryGiB": 384, 1075 | "name": "ml.r5.12xlarge", 1076 | "vcpuNum": 48 1077 | }, 1078 | { 1079 | "_defaultOrder": 45, 1080 | "_isFastLaunch": false, 1081 | "category": "Memory Optimized", 1082 | "gpuNum": 0, 1083 | "hideHardwareSpecs": false, 1084 | "memoryGiB": 512, 1085 | "name": "ml.r5.16xlarge", 1086 | "vcpuNum": 64 1087 | }, 1088 | { 1089 | "_defaultOrder": 46, 1090 | "_isFastLaunch": false, 1091 | "category": "Memory Optimized", 1092 | "gpuNum": 0, 1093 | "hideHardwareSpecs": false, 1094 | "memoryGiB": 768, 1095 | "name": "ml.r5.24xlarge", 1096 | "vcpuNum": 96 1097 | }, 1098 | { 1099 | "_defaultOrder": 47, 1100 | "_isFastLaunch": false, 1101 | "category": "Accelerated computing", 1102 | "gpuNum": 1, 1103 | "hideHardwareSpecs": false, 1104 | "memoryGiB": 16, 1105 | "name": "ml.g5.xlarge", 1106 | "vcpuNum": 4 1107 | }, 1108 | { 1109 | "_defaultOrder": 48, 1110 | "_isFastLaunch": false, 1111 | "category": "Accelerated computing", 1112 | "gpuNum": 1, 1113 | "hideHardwareSpecs": false, 1114 | "memoryGiB": 32, 1115 | "name": "ml.g5.2xlarge", 1116 | "vcpuNum": 8 1117 | }, 1118 | { 1119 | "_defaultOrder": 49, 1120 | "_isFastLaunch": false, 1121 | "category": "Accelerated computing", 1122 | "gpuNum": 1, 1123 | "hideHardwareSpecs": false, 1124 | "memoryGiB": 64, 1125 | "name": "ml.g5.4xlarge", 1126 | "vcpuNum": 16 1127 | }, 1128 | { 1129 | "_defaultOrder": 50, 1130 | "_isFastLaunch": false, 1131 | "category": "Accelerated computing", 1132 | "gpuNum": 1, 1133 | "hideHardwareSpecs": false, 1134 | "memoryGiB": 128, 1135 | "name": "ml.g5.8xlarge", 1136 | "vcpuNum": 32 1137 | }, 1138 | { 1139 | "_defaultOrder": 51, 1140 | "_isFastLaunch": false, 1141 | "category": "Accelerated computing", 1142 | "gpuNum": 1, 1143 | "hideHardwareSpecs": false, 1144 | "memoryGiB": 256, 1145 | "name": "ml.g5.16xlarge", 1146 | "vcpuNum": 64 1147 | }, 1148 | { 1149 | "_defaultOrder": 52, 1150 | "_isFastLaunch": false, 1151 | "category": "Accelerated computing", 1152 | "gpuNum": 4, 1153 | "hideHardwareSpecs": false, 1154 | "memoryGiB": 192, 1155 | "name": "ml.g5.12xlarge", 1156 | "vcpuNum": 48 1157 | }, 1158 | { 1159 | "_defaultOrder": 53, 1160 | "_isFastLaunch": false, 1161 | "category": "Accelerated computing", 1162 | "gpuNum": 4, 1163 | "hideHardwareSpecs": false, 1164 | "memoryGiB": 384, 1165 | "name": "ml.g5.24xlarge", 1166 | "vcpuNum": 96 1167 | }, 1168 | { 1169 | "_defaultOrder": 54, 1170 | "_isFastLaunch": false, 1171 | "category": "Accelerated computing", 1172 | "gpuNum": 8, 1173 | "hideHardwareSpecs": false, 1174 | "memoryGiB": 768, 1175 | "name": "ml.g5.48xlarge", 1176 | "vcpuNum": 192 1177 | }, 1178 | { 1179 | "_defaultOrder": 55, 1180 | "_isFastLaunch": false, 1181 | "category": "Accelerated computing", 1182 | "gpuNum": 8, 1183 | "hideHardwareSpecs": false, 1184 | "memoryGiB": 1152, 1185 | "name": "ml.p4d.24xlarge", 1186 | "vcpuNum": 96 1187 | }, 1188 | { 1189 | "_defaultOrder": 56, 1190 | "_isFastLaunch": false, 1191 | "category": "Accelerated computing", 1192 | "gpuNum": 8, 1193 | "hideHardwareSpecs": false, 1194 | "memoryGiB": 1152, 1195 | "name": "ml.p4de.24xlarge", 1196 | "vcpuNum": 96 1197 | } 1198 | ], 1199 | "instance_type": "ml.t3.medium", 1200 | "kernelspec": { 1201 | "display_name": "Python 3 (Data Science 2.0)", 1202 | "language": "python", 1203 | "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:12345678012:image/sagemaker-data-science-38" 1204 | }, 1205 | "language_info": { 1206 | "codemirror_mode": { 1207 | "name": "ipython", 1208 | "version": 3 1209 | }, 1210 | "file_extension": ".py", 1211 | "mimetype": "text/x-python", 1212 | "name": "python", 1213 | "nbconvert_exporter": "python", 1214 | "pygments_lexer": "ipython3", 1215 | "version": "3.8.13" 1216 | } 1217 | }, 1218 | "nbformat": 4, 1219 | "nbformat_minor": 5 1220 | } 1221 | -------------------------------------------------------------------------------- /data_ingestion_to_vectordb/scripts/get_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import traceback 4 | 5 | from sh import cp, find, mkdir, wget 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--domain", type=str, default="sagemaker.readthedocs.io") 12 | parser.add_argument("--website", type=str, default="https://sagemaker.readthedocs.io/en/stable/") 13 | parser.add_argument("--output-dir", type=str, default="docs") 14 | parser.add_argument("--dryrun", action='store_true') 15 | args, _ = parser.parse_known_args() 16 | 17 | WEBSITE, DOMAIN, KB_DIR = (args.website, args.domain, args.output_dir) 18 | 19 | if args.dryrun: 20 | print(f"WEBSITE={WEBSITE}, DOMAIN={DOMAIN}, OUTPUT_DIR={KB_DIR}", file=sys.stderr) 21 | sys.exit(0) 22 | 23 | mkdir('-p', KB_DIR) 24 | 25 | try: 26 | WGET_ARGUMENTS = f"-e robots=off --recursive --no-clobber --page-requisites --html-extension --convert-links --restrict-file-names=windows --domains {DOMAIN} --no-parent {WEBSITE}" 27 | wget_argument_list = WGET_ARGUMENTS.split() 28 | wget(*wget_argument_list) 29 | except Exception as ex: 30 | traceback.print_exc() 31 | 32 | results = find(DOMAIN, '-name', '*.html') 33 | html_files = results.strip('\n').split('\n') 34 | for each in html_files: 35 | flat_i = each.replace('/', '-') 36 | cp(each, f"{KB_DIR}/{flat_i}") 37 | 38 | print(f"There are {len(html_files)} files in {KB_DIR} directory", file=sys.stderr) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() --------------------------------------------------------------------------------