├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── application ├── README.md ├── __init__.py ├── bedrock.py ├── requirements.txt ├── streamlit.py └── utils │ ├── __init__.py │ ├── bedrock.py │ ├── chat.py │ ├── chunk.py │ ├── common_utils.py │ ├── opensearch_summit.py │ ├── pymupdf.py │ ├── rag_summit.py │ ├── s3.py │ ├── ssm.py │ └── text_to_report.py ├── bin └── cdk.ts ├── cdk.json ├── jest.config.js ├── lambda └── index.py ├── lib ├── customResourceStack.ts ├── ec2Stack │ ├── ec2Stack.ts │ └── userdata.sh ├── openSearchStack.ts ├── rerankerStack │ ├── RerankerStack.assets.json │ └── RerankerStack.template.json └── sagemakerNotebookStack │ ├── install_packages.sh │ ├── install_tesseract.sh │ └── sagemakerNotebookStack.ts ├── package-lock.json ├── package.json └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | /node_modules 2 | /cdk.out 3 | cdk.context.json 4 | 5 | # DS_store 6 | .DS_Store 7 | **/.DS_Store 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 4 | software and associated documentation files (the "Software"), to deal in the Software 5 | without restriction, including without limitation the rights to use, copy, modify, 6 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 7 | permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 10 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 11 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 12 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 13 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 14 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Amazon Bedrock Q&A multi-modal chatbot with advanced RAG 2 | 3 | :link: AWS Workshop [Amazon Bedrock Q&A multi-modal chatbot with advanced RAG](https://catalog.us-east-1.prod.workshops.aws/workshops/a372f3ed-e99d-4c95-93b5-ee666375a387/en-US) 4 | 5 | :telephone_receiver: Issue Report [Bailey(Sohyeon) Cho](https://www.linkedin.com/in/csbailey/) 6 | 7 | :globe_with_meridians: Langauge 8 | * [English](#English) 9 | * [한국어](#한국어) 10 | 11 | # English 12 | ## :mega: What is Advanced RAG Workshop? 13 | > In this workshop, you'll understand several advanced RAG techniques and apply them to a multi-modal chatbot Q&A. This workshop aims to increase customer interest and engagement with AWS by introducing and practicing advanced Retrieval Augmented Generation (RAG) techniques that can be utilized in a real production environment, beyond the typical demo. 14 | 15 | ## :mega: Who need this workshop? 16 | > No prior knowledge is required to perform this workshop. 17 | 18 | This workshop is available to anyone who wants to: 19 | * Acquire production-level RAG skills 20 | * Increase technical understanding through hands-on practice 21 | * Increase customer engagement with AWS services 22 | 23 | ## :mega: How to deploy CDK stacks 24 | ```bash 25 | git clone https://github.com/aws-samples/multi-modal-chatbot-with-advanced-rag.git 26 | cd multi-modal-chatbot-with-advanced-rag 27 | npm i --save-dev @types/node 28 | cdk bootstrap 29 | cdk synth 30 | cdk deploy --all --require-approval never 31 | ``` 32 | 33 | ## :mega: Workshop Guide 34 | You can see full workshop [here](https://catalog.us-east-1.prod.workshops.aws/workshops/a372f3ed-e99d-4c95-93b5-ee666375a387/en-US) 35 | 36 | # 한국어 37 | 38 | ## :mega: Advnaced RAG Workshop이란? 39 | 이 워크샵에서는 여러 Advanced RAG 기법을 이해하고, Multi Modal 챗봇 Q&A에 여러 RAG 기법을 적용해 볼 수 있습니다. 일반적인 데모를 넘어 실제 프로덕션 환경에서 활용할 수 있는 고급 검색 증강 생성(RAG) 기술을 소개하고 실습함으로써 AWS에 대한 고객의 관심과 참여를 높이는 것을 목표로 합니다. 40 | 41 | ## :mega: 누구를 위한 워크샵 인가요? 42 | > 이 워크샵을 수행하기 위해 별도의 사전 지식은 필요하지 않습니다. 43 | 44 | 다음 효과를 기대하는 모두가 이 워크샵을 사용할 수 있습니다: 45 | * 프로덕션 수준의 RAG 기술 습득 46 | * 실습을 통한 기술적 이해도 향상 47 | * AWS 서비스 경험 48 | 49 | ## :mega: CDK 스택 배포 방법 50 | ```bash 51 | git clone https://github.com/aws-samples/multi-modal-chatbot-with-advanced-rag.git 52 | cd multi-modal-chatbot-with-advanced-rag 53 | npm i --save-dev @types/node 54 | cdk bootstrap 55 | cdk synth 56 | cdk deploy --all --require-approval never 57 | ``` 58 | 59 | ## :mega: 워크샵 가이드 60 | 이 [링크](https://catalog.us-east-1.prod.workshops.aws/workshops/a372f3ed-e99d-4c95-93b5-ee666375a387/ko-KR)에서 워크샵 콘텐츠를 확인할 수 있습니다. 61 | -------------------------------------------------------------------------------- /application/README.md: -------------------------------------------------------------------------------- 1 | # How to run this application 2 | 3 | > If you clone this repo and run 'cdk deploy', you don't need to run the streamlit application yourself. 4 | 5 | ## Structure 6 | 7 | 1. `bedrock.py` 8 | 9 | - Amazon Bedrock 및 Reranker, Hybrid search, parent-document 등의 RAG 기술 구현 파일 10 | 11 | 2. `streamlit.py` 12 | 13 | - 애플리케이션의 front-end 파일, 실행 시 `bedrock.py`을 import해서 사용 14 | 15 | ## Start 16 | 1. web_ui 폴더 접근 17 | ``` 18 | cd multi-modal-chatbot-with-advanced-rag/application 19 | ``` 20 | 21 | 2. Python 종속 라이브러리 설치 22 | 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | 3. Streamlit 애플리케이션 작동 28 | 29 | ``` 30 | streamlit run streamlit.py --server.port 8080 31 | ``` 32 | 33 | 3. 접속하기 34 | 35 | - Streamlit 작동 시 표시되는 External link 혹은 EC2 public ip로 접속 36 | -------------------------------------------------------------------------------- /application/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | -------------------------------------------------------------------------------- /application/bedrock.py: -------------------------------------------------------------------------------- 1 | import os, sys, boto3 2 | # module_path = "/" # "../../.." 3 | # sys.path.append(os.path.abspath(module_path)) 4 | from utils.rag_summit import prompt_repo, OpenSearchHybridSearchRetriever, prompt_repo, qa_chain 5 | from utils.opensearch_summit import opensearch_utils 6 | from utils.ssm import parameter_store 7 | from langchain.embeddings import BedrockEmbeddings 8 | from langchain_aws import ChatBedrock 9 | from utils import bedrock 10 | from utils.bedrock import bedrock_info 11 | 12 | region = boto3.Session().region_name 13 | pm = parameter_store(region) 14 | secrets_manager = boto3.client('secretsmanager', region_name=region) 15 | 16 | # 텍스트 생성 LLM 가져오기, streaming_callback을 인자로 받아옴 17 | def get_llm(streaming_callback): 18 | boto3_bedrock = bedrock.get_bedrock_client( 19 | assumed_role=os.environ.get("BEDROCK_ASSUME_ROLE", None), 20 | endpoint_url=os.environ.get("BEDROCK_ENDPOINT_URL", None), 21 | region=os.environ.get("AWS_DEFAULT_REGION", None), 22 | ) 23 | llm = ChatBedrock( 24 | model_id=bedrock_info.get_model_id(model_name="Claude-V3-Sonnet"), 25 | client=boto3_bedrock, 26 | model_kwargs={ 27 | "max_tokens": 1024, 28 | "stop_sequences": ["\n\nHuman"], 29 | }, 30 | streaming=True, 31 | callbacks=[streaming_callback], 32 | ) 33 | return llm 34 | 35 | # 임베딩 모델 가져오기 36 | def get_embedding_model(document_type): 37 | model_id= 'amazon.titan-embed-text-v1' if document_type == 'Default' else 'amazon.titan-embed-text-v2:0' 38 | llm_emb = BedrockEmbeddings(model_id=model_id) 39 | return llm_emb 40 | 41 | # Opensearch vectorDB 가져오기 42 | def get_opensearch_client(): 43 | opensearch_domain_endpoint = pm.get_params(key='opensearch_domain_endpoint', enc=False) 44 | opensearch_user_id = pm.get_params(key='opensearch_user_id', enc=False) 45 | 46 | response = secrets_manager.get_secret_value(SecretId='opensearch_user_password') 47 | secrets_string = response.get('SecretString') 48 | secrets_dict = eval(secrets_string) 49 | opensearch_user_password = secrets_dict['pwkey'] 50 | 51 | opensearch_domain_endpoint = opensearch_domain_endpoint 52 | rag_user_name = opensearch_user_id 53 | rag_user_password = opensearch_user_password 54 | aws_region = os.environ.get("AWS_DEFAULT_REGION", None) 55 | http_auth = (rag_user_name, rag_user_password) 56 | os_client = opensearch_utils.create_aws_opensearch_client( 57 | aws_region, 58 | opensearch_domain_endpoint, 59 | http_auth 60 | ) 61 | return os_client 62 | 63 | # hybrid search retriever 만들기 64 | def get_retriever(streaming_callback, parent, reranker, hyde, ragfusion, alpha, document_type): 65 | os_client = get_opensearch_client() 66 | llm_text = get_llm(streaming_callback) 67 | llm_emb = get_embedding_model(document_type) 68 | reranker_endpoint_name = "reranker" 69 | index_name = "default_doc_index" if document_type == "Default" else "customer_doc_index" 70 | opensearch_hybrid_retriever = OpenSearchHybridSearchRetriever( 71 | os_client=os_client, 72 | index_name=index_name, 73 | llm_text=llm_text, # llm for query augmentation in both rag_fusion and HyDE 74 | llm_emb=llm_emb, # Used in semantic search based on opensearch 75 | # option for lexical 76 | minimum_should_match=0, 77 | filter=[], 78 | # option for search 79 | # ["RRF", "simple_weighted"], rank fusion 방식 정의 80 | fusion_algorithm="RRF", 81 | complex_doc=True, 82 | # [for lexical, for semantic], Lexical, Semantic search 결과에 대한 최종 반영 비율 정의 83 | ensemble_weights=[alpha, 1.0-alpha], 84 | reranker=reranker, # enable reranker with reranker model 85 | # endpoint name for reranking model 86 | reranker_endpoint_name=reranker_endpoint_name, 87 | parent_document=parent, # enable parent document 88 | rag_fusion=ragfusion, 89 | rag_fusion_prompt = prompt_repo.get_rag_fusion(), 90 | hyde=hyde, 91 | hyde_query=['web_search'], 92 | query_augmentation_size=3, 93 | # option for async search 94 | async_mode=True, 95 | # option for output 96 | k=7, # 최종 Document 수 정의 97 | verbose=True, 98 | ) 99 | return opensearch_hybrid_retriever 100 | 101 | # 모델에 query하기 102 | def formatting_output(contexts): 103 | formatted_contexts = [] 104 | for doc, score in contexts: 105 | lines = doc.page_content.split("\n") 106 | metadata = doc.metadata 107 | formatted_contexts.append((score, lines)) 108 | return formatted_contexts 109 | 110 | def invoke(query, streaming_callback, parent, reranker, hyde, ragfusion, alpha, document_type="Default"): 111 | # llm, retriever 가져오기 112 | llm_text = get_llm(streaming_callback) 113 | opensearch_hybrid_retriever = get_retriever(streaming_callback, parent, reranker, hyde, ragfusion, alpha, document_type) 114 | # context, tables, images = opensearch_hybrid_retriever._get_relevant_documents() 115 | # answer only 선택 116 | system_prompt = prompt_repo.get_system_prompt() 117 | qa = qa_chain( 118 | llm_text=llm_text, 119 | retriever=opensearch_hybrid_retriever, 120 | system_prompt=system_prompt, 121 | return_context=False, 122 | verbose=False 123 | ) 124 | response, pretty_contexts, similar_docs, augmentation = qa.invoke(query = query, complex_doc = True) 125 | print("-------> response") 126 | # print(response) 127 | print("-------> pretty_contexts -> 모든 컨텍스트 포함된 자료") 128 | 129 | def extract_elements_and_print(pretty_contexts): 130 | for context in pretty_contexts: 131 | print("context: \n") 132 | print(context) 133 | 134 | print("######### SEMANTIC #########") 135 | extract_elements_and_print(pretty_contexts[0]) 136 | print("######### KEYWORD #########") 137 | extract_elements_and_print(pretty_contexts[1]) 138 | print("######### WITHOUT_RERANKER #########") 139 | extract_elements_and_print(pretty_contexts[2]) 140 | print("######## SIMILAR_DOCS ##########") 141 | extract_elements_and_print(pretty_contexts[3]) 142 | if hyde or ragfusion: 143 | print("######## 중간답변 ##########") 144 | print(augmentation) 145 | if not hyde or ragfusion: 146 | if alpha == 0.0: 147 | pretty_contexts[0].clear() 148 | elif alpha == 1.0: 149 | pretty_contexts[1].clear() 150 | if hyde or ragfusion: 151 | return response, pretty_contexts, augmentation 152 | 153 | 154 | return response, pretty_contexts 155 | -------------------------------------------------------------------------------- /application/requirements.txt: -------------------------------------------------------------------------------- 1 | anthropic==0.38.0 2 | httpx==0.27.2 3 | tokenizers==0.20.3 4 | boto3==1.34.122 5 | botocore==1.34.122 6 | ipython==8.2.0 7 | langchain==0.2.5 8 | langchain-aws==0.1.6 9 | langchain-community==0.2.5 10 | matplotlib==3.7.0 11 | requests==2.32.3 12 | streamlit==1.35.0 13 | urllib3==1.26.18 14 | opensearch-py==2.6.0 -------------------------------------------------------------------------------- /application/streamlit.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import streamlit as st # 모든 streamlit 명령은 "st" alias로 사용할 수 있습니다. 3 | import bedrock as glib # 로컬 라이브러리 스크립트에 대한 참조 4 | from langchain.callbacks import StreamlitCallbackHandler 5 | 6 | ##################### Functions ######################## 7 | def parse_image(metadata, tag): 8 | if tag in metadata: 9 | st.image(base64.b64decode(metadata[tag])) 10 | 11 | def parse_table(metadata, tag): 12 | if tag in metadata: 13 | st.markdown(metadata[tag], unsafe_allow_html=True) 14 | 15 | def parse_metadata(metadata): 16 | # Image, Table 이 있을 경우 파싱해 출력 17 | category = "None" 18 | if "category" in metadata: 19 | category = metadata["category"] 20 | if category == "Table": 21 | # parse_table(metadata, "text_as_html") # 테이블 html은 이미지로 대체 22 | parse_image(metadata, "image_base64") 23 | elif category == "Image": 24 | parse_image(metadata, "image_base64") 25 | else: 26 | pass 27 | st.markdown(' - - - ') 28 | 29 | def show_document_info_label(): 30 | with st.container(border=True): 31 | if st.session_state.document_type == "Default": 32 | st.markdown('''📝 현재 기본 문서인 [**상록초등학교 교육 과정 문서**](https://d14ojpq4k4igb1.cloudfront.net/school_edu_guide.pdf)를 활용하고 있습니다.''') 33 | st.markdown('''다른 문서로 챗봇 서비스를 이용해보고 싶다면 왼쪽 사이드바의 Document type에서 *'Custom'* 옵션을 클릭하고, 진행자의 안내에 따라 문서를 새로 인덱싱하여 사용해보세요.''') 34 | else: 35 | st.markdown('''**💁‍♀️ 새로운 문서로 챗봇 서비스를 이용하고 싶으신가요?**''') 36 | st.markdown('''- **진행자의 안내에 따라 SageMaker Notebook에서 인덱싱 스크립트를 실행한 뒤** 이용 가능합니다.''') 37 | st.markdown('''- 기존 문서 (상록초등학교 교육 과정)로 돌아가고 싶다면 사이드바의 Document type에서 *'Default'* 옵션을 선택하면 바로 변경할 수 있습니다.''') 38 | 39 | # 'Separately' 옵션 선택 시 나오는 중간 Context를 탭 형태로 보여주는 UI 40 | def show_context_with_tab(contexts): 41 | tab_category = ["Semantic", "Keyword", "Without Reranker", "Similar Docs"] 42 | tab_contents = { 43 | tab_category[0]: [], 44 | tab_category[1]: [], 45 | tab_category[2]: [], 46 | tab_category[3]: [] 47 | } 48 | for i, contexts_by_doctype in enumerate(contexts): 49 | tab_contents[tab_category[i]].append(contexts_by_doctype) 50 | tabs = st.tabs(tab_category) 51 | for i, tab in enumerate(tabs): 52 | category = tab_category[i] 53 | with tab: 54 | for contexts_by_doctype in tab_contents[category]: 55 | for context in contexts_by_doctype: 56 | st.markdown('##### `정확도`: {}'.format(context["score"])) 57 | for line in context["lines"]: 58 | st.write(line) 59 | parse_metadata(context["meta"]) 60 | 61 | # 'All at once' 옵션 선택 시 4개의 컬럼으로 나누어 결과 표시하는 UI 62 | # TODO: HyDE, RagFusion 추가 논의 필요 63 | def show_answer_with_multi_columns(answers): 64 | col1, col2, col3, col4 = st.columns(4) 65 | with col1: 66 | st.markdown('''### `Lexical search` ''') 67 | st.write(answers[0]) 68 | with col2: 69 | st.markdown('''### `Semantic search` ''') 70 | st.write(answers[1]) 71 | with col3: 72 | st.markdown('''### + `Reranker` ''') 73 | st.write(answers[2]) 74 | with col4: 75 | st.markdown('''### + `Parent_docs` ''') 76 | st.write(answers[3]) 77 | 78 | ####################### Application ############################### 79 | st.set_page_config(layout="wide") 80 | st.title("AWS Q&A Bot with Advanced RAG!") # page 제목 81 | 82 | st.markdown('''- 이 챗봇은 Amazon Bedrock과 Claude v3 Sonnet 모델로 구현되었습니다.''') 83 | st.markdown('''- 다음과 같은 Advanced RAG 기술을 사용합니다: **Hybrid Search, ReRanker, and Parent Document, HyDE, Rag Fusion**''') 84 | st.markdown('''- 원본 데이터는 Amazon OpenSearch에 저장되어 있으며, Amazon Titan 임베딩 모델이 사용되었습니다.''') 85 | st.markdown(''' ''') 86 | 87 | # Store the initial value of widgets in session state 88 | if "document_type" not in st.session_state: 89 | st.session_state.document_type = "Default" 90 | if "showing_option" not in st.session_state: 91 | st.session_state.showing_option = "Separately" 92 | if "search_mode" not in st.session_state: 93 | st.session_state.search_mode = "Hybrid search" 94 | if "hyde_or_ragfusion" not in st.session_state: 95 | st.session_state.hyde_or_ragfusion = "None" 96 | disabled = st.session_state.showing_option=="All at once" 97 | 98 | with st.sidebar: # Sidebar 모델 옵션 99 | with st.container(border=True): 100 | st.radio( 101 | "Document type:", 102 | ["Default", "Custom"], 103 | captions = ["챗봇이 참고하는 자료로 기본 문서(상록초등학교 자료)가 사용됩니다.", "원하시는 문서를 직접 업로드해보세요."], 104 | key="document_type", 105 | ) 106 | with st.container(border=True): 107 | st.radio( 108 | "UI option:", 109 | ["Separately", "All at once"], 110 | captions = ["아래에서 설정한 파라미터 조합으로 하나의 검색 결과가 도출됩니다.", "여러 옵션들을 한 화면에서 한꺼번에 볼 수 있습니다."], 111 | key="showing_option", 112 | ) 113 | st.markdown('''### Set parameters for your Bot 👇''') 114 | with st.container(border=True): 115 | search_mode = st.radio( 116 | "Search mode:", 117 | ["Lexical search", "Semantic search", "Hybrid search"], 118 | captions = [ 119 | "키워드의 일치 여부를 기반으로 답변을 생성합니다.", 120 | "키워드의 일치 여부보다는 문맥의 의미적 유사도에 기반해 답변을 생성합니다.", 121 | "아래의 Alpha 값을 조정하여 Lexical/Semantic search의 비율을 조정합니다." 122 | ], 123 | key="search_mode", 124 | disabled=disabled 125 | ) 126 | alpha = st.slider('Alpha value for Hybrid search ⬇️', 0.0, 1.0, 0.51, 127 | disabled=st.session_state.search_mode != "Hybrid search", 128 | help="""Alpha=0.0 이면 Lexical search, \nAlpha=1.0 이면 Semantic search 입니다.""" 129 | ) 130 | if search_mode == "Lexical search": 131 | alpha = 0.0 132 | elif search_mode == "Semantic search": 133 | alpha = 1.0 134 | 135 | col1, col2 = st.columns(2) 136 | with col1: 137 | reranker = st.toggle("Reranker", 138 | help="""초기 검색 결과를 재평가하여 순위를 재조정하는 모델입니다. 139 | 문맥 정보와 질의 관련성을 고려하여 적합한 결과를 상위에 올립니다.""", 140 | disabled=disabled) 141 | with col2: 142 | parent = st.toggle("Parent Docs", 143 | help="""답변 생성 모델이 질의에 대한 답변을 생성할 때 참조한 정보의 출처를 표시하는 옵션입니다.""", 144 | disabled=disabled) 145 | 146 | with st.container(border=True): 147 | hyde_or_ragfusion = st.radio( 148 | "Choose a RAG option:", 149 | ["None", "HyDE", "RAG-Fusion"], 150 | captions = [ 151 | "", 152 | "문서와 질의 간의 의미적 유사도를 측정하기 위한 임베딩 기법입니다. 하이퍼볼릭 공간에서 거리를 계산하여 유사도를 측정합니다.", 153 | "검색과 생성을 결합한 모델로, 검색 모듈이 관련 문서를 찾고 생성 모듈이 이를 참조하여 답변을 생성합니다. 두 모듈의 출력을 융합하여 최종 답변을 도출합니다." 154 | ], 155 | key="hyde_or_ragfusion", 156 | disabled=disabled 157 | ) 158 | hyde = hyde_or_ragfusion == "HyDE" 159 | ragfusion = hyde_or_ragfusion == "RAG-Fusion" 160 | 161 | ###### 'Separately' 옵션 선택한 경우 ###### 162 | if st.session_state.showing_option == "Separately": 163 | show_document_info_label() 164 | 165 | if "messages" not in st.session_state: 166 | st.session_state["messages"] = [ 167 | {"role": "assistant", "content": "안녕하세요, 무엇이 궁금하세요?"} 168 | ] 169 | # 지난 답변 출력 170 | for msg in st.session_state.messages: 171 | # 지난 답변에 대한 컨텍스트 출력 172 | if msg["role"] == "assistant_context": 173 | with st.chat_message("assistant"): 174 | with st.expander("Context 확인하기 ⬇️"): 175 | show_context_with_tab(contexts=msg["content"]) 176 | 177 | elif msg["role"] == "hyde_or_fusion": 178 | with st.chat_message("assistant"): 179 | with st.expander("중간 답변 확인하기 ⬇️"): 180 | msg["content"] 181 | 182 | elif msg["role"] == "assistant_column": 183 | # 'Separately' 옵션일 경우 multi column 으로 보여주지 않고 첫 번째 답변만 출력 184 | st.chat_message(msg["role"]).write(msg["content"][0]) 185 | else: 186 | st.chat_message(msg["role"]).write(msg["content"]) 187 | 188 | # 유저가 쓴 chat을 query라는 변수에 담음 189 | query = st.chat_input("Search documentation") 190 | if query: 191 | # Session에 메세지 저장 192 | st.session_state.messages.append({"role": "user", "content": query}) 193 | 194 | # UI에 출력 195 | st.chat_message("user").write(query) 196 | 197 | # Streamlit callback handler로 bedrock streaming 받아오는 컨테이너 설정 198 | st_cb = StreamlitCallbackHandler( 199 | st.chat_message("assistant"), 200 | collapse_completed_thoughts=True 201 | ) 202 | # bedrock.py의 invoke 함수 사용 203 | response = glib.invoke( 204 | query=query, 205 | streaming_callback=st_cb, 206 | parent=parent, 207 | reranker=reranker, 208 | hyde = hyde, 209 | ragfusion = ragfusion, 210 | alpha = alpha, 211 | document_type=st.session_state.document_type 212 | ) 213 | # response 로 메세지, 링크, 레퍼런스(source_documents) 받아오게 설정된 것을 변수로 저장 214 | answer = response[0] 215 | contexts = response[1] 216 | if hyde or ragfusion: 217 | mid_answer = response[2] 218 | 219 | # UI 출력 220 | st.chat_message("assistant").write(answer) 221 | 222 | if hyde: 223 | with st.chat_message("assistant"): 224 | with st.expander("HyDE 중간 생성 답변 ⬇️"): 225 | mid_answer 226 | if ragfusion: 227 | with st.chat_message("assistant"): 228 | with st.expander("RAG-Fusion 중간 생성 쿼리 ⬇️"): 229 | mid_answer 230 | with st.chat_message("assistant"): 231 | with st.expander("정확도 별 컨텍스트 보기 ⬇️"): 232 | show_context_with_tab(contexts) 233 | 234 | # Session 메세지 저장 235 | st.session_state.messages.append({"role": "assistant", "content": answer}) 236 | 237 | if hyde or ragfusion: 238 | st.session_state.messages.append({"role": "hyde_or_fusion", "content": mid_answer}) 239 | 240 | st.session_state.messages.append({"role": "assistant_context", "content": contexts}) 241 | # Thinking을 complete로 수동으로 바꾸어 줌 242 | st_cb._complete_current_thought() 243 | 244 | ###### 2) 'All at once' 옵션 선택한 경우 ###### 245 | else: 246 | with st.container(border=True): 247 | st.markdown('''현재 기본 문서인 [상록초등학교 교육 과정 문서](https://file.notion.so/f/f/d82c0c1c-c239-4242-bd5e-320565fdc9d4/6057662b-2d01-4284-a65f-cc17d050a321/school_edu_guide.pdf?id=a2f7166b-f663-4740-aa06-ec559567011a&table=block&spaceId=d82c0c1c-c239-4242-bd5e-320565fdc9d4&expirationTimestamp=1718100000000&signature=wxS5AgYuK085mNvynkUZsRyqyMuqE_ucoCNfM4jRnU0&downloadName=school_edu_guide.pdf)를 활용하고 있습니다.''') 248 | st.markdown('''다른 문서로 챗봇 서비스를 이용해보고 싶다면 왼쪽 사이드바의 Document type에서 'Custom' 옵션을 클릭해 문서를 업로드해보세요.''') 249 | 250 | if "messages" not in st.session_state: 251 | st.session_state["messages"] = [ 252 | {"role": "assistant", "content": "안녕하세요, 무엇이 궁금하세요?"} 253 | ] 254 | # 지난 답변 출력 255 | for msg in st.session_state.messages: 256 | if msg["role"] == "assistant_column": 257 | answers = msg["content"] 258 | show_answer_with_multi_columns(answers) 259 | elif msg["role"] == "assistant_context": 260 | pass # 'All at once' 옵션 선택 시에는 context 로그를 출력하지 않음 261 | else: 262 | st.chat_message(msg["role"]).write(msg["content"]) 263 | 264 | # 유저가 쓴 chat을 query라는 변수에 담음 265 | query = st.chat_input("Search documentation") 266 | if query: 267 | # Session에 메세지 저장 268 | st.session_state.messages.append({"role": "user", "content": query}) 269 | 270 | # UI에 출력 271 | st.chat_message("user").write(query) 272 | 273 | col1, col2, col3, col4 = st.columns(4) 274 | with col1: 275 | st.markdown('''### `Lexical search` ''') 276 | st.markdown(":green[: Alpha 값이 0.0]으로, 키워드의 정확한 일치 여부를 판단하는 Lexical search 결과입니다.") 277 | with col2: 278 | st.markdown('''### `Semantic search` ''') 279 | st.markdown(":green[: Alpha 값이 1.0]으로, 키워드 일치 여부보다는 문맥의 의미적 유사도에 기반한 Semantic search 결과입니다.") 280 | with col3: 281 | st.markdown('''### + `Reranker` ''') 282 | st.markdown(""": 초기 검색 결과를 재평가하여 순위를 재조정하는 모델입니다. 문맥 정보와 질의 관련성을 고려하여 적합한 결과를 상위에 올립니다. 283 | :green[Alpha 값은 왼쪽 사이드바에서 설정하신 값]으로 적용됩니다.""") 284 | with col4: 285 | st.markdown('''### + `Parent Docs` ''') 286 | st.markdown(""": 질의에 대한 답변을 생성할 때 참조하는 문서 집합입니다. 답변 생성 모델이 참조할 수 있는 관련 정보의 출처가 됩니다. 287 | :green[Alpha 값은 왼쪽 사이드바에서 설정하신 값]으로 적용됩니다.""") 288 | 289 | with col1: 290 | # Streamlit callback handler로 bedrock streaming 받아오는 컨테이너 설정 291 | st_cb = StreamlitCallbackHandler( 292 | st.chat_message("assistant"), 293 | collapse_completed_thoughts=True 294 | ) 295 | answer1 = glib.invoke( 296 | query=query, 297 | streaming_callback=st_cb, 298 | parent=False, 299 | reranker=False, 300 | hyde = False, 301 | ragfusion = False, 302 | alpha = 0, # Lexical search 303 | document_type=st.session_state.document_type 304 | )[0] 305 | st.write(answer1) 306 | st_cb._complete_current_thought() # Thinking을 complete로 수동으로 바꾸어 줌 307 | with col2: 308 | st_cb = StreamlitCallbackHandler( 309 | st.chat_message("assistant"), 310 | collapse_completed_thoughts=True 311 | ) 312 | answer2 = glib.invoke( 313 | query=query, 314 | streaming_callback=st_cb, 315 | parent=False, 316 | reranker=False, 317 | hyde = False, 318 | ragfusion = False, 319 | alpha = 1.0, # Semantic search 320 | document_type=st.session_state.document_type 321 | )[0] 322 | st.write(answer2) 323 | st_cb._complete_current_thought() 324 | with col3: 325 | st_cb = StreamlitCallbackHandler( 326 | st.chat_message("assistant"), 327 | collapse_completed_thoughts=True 328 | ) 329 | answer3 = glib.invoke( 330 | query=query, 331 | streaming_callback=st_cb, 332 | parent=False, 333 | reranker=True, # Add Reranker option 334 | hyde = False, 335 | ragfusion = False, 336 | alpha = alpha, # Hybrid search 337 | document_type=st.session_state.document_type 338 | )[0] 339 | st.write(answer3) 340 | st_cb._complete_current_thought() 341 | with col4: 342 | st_cb = StreamlitCallbackHandler( 343 | st.chat_message("assistant"), 344 | collapse_completed_thoughts=True 345 | ) 346 | answer4 = glib.invoke( 347 | query=query, 348 | streaming_callback=st_cb, 349 | parent=True, # Add Parent_docs option 350 | reranker=True, # Add Reranker option 351 | hyde = False, 352 | ragfusion = False, 353 | alpha = alpha, # Hybrid search 354 | document_type=st.session_state.document_type 355 | )[0] 356 | st.write(answer4) 357 | st_cb._complete_current_thought() 358 | 359 | # Session 메세지 저장 360 | answers = [answer1, answer2, answer3, answer4] 361 | st.session_state.messages.append({"role": "assistant_column", "content": answers}) 362 | -------------------------------------------------------------------------------- /application/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """General helper utilities the workshop notebooks""" 4 | # Python Built-Ins: 5 | from io import StringIO 6 | import sys 7 | import textwrap 8 | 9 | 10 | def print_ww(*args, width: int = 100, **kwargs): 11 | """Like print(), but wraps output to `width` characters (default 100)""" 12 | buffer = StringIO() 13 | try: 14 | _stdout = sys.stdout 15 | sys.stdout = buffer 16 | print(*args, **kwargs) 17 | output = buffer.getvalue() 18 | finally: 19 | sys.stdout = _stdout 20 | for line in output.splitlines(): 21 | print("\n".join(textwrap.wrap(line, width=width))) 22 | -------------------------------------------------------------------------------- /application/utils/bedrock.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """Helper utilities for working with Amazon Bedrock from Python notebooks""" 4 | # Python Built-Ins: 5 | import os 6 | from typing import Optional 7 | 8 | # External Dependencies: 9 | import boto3 10 | from botocore.config import Config 11 | 12 | # Langchain 13 | from langchain.callbacks.base import BaseCallbackHandler 14 | 15 | def get_bedrock_client( 16 | assumed_role: Optional[str] = None, 17 | endpoint_url: Optional[str] = None, 18 | region: Optional[str] = None, 19 | ): 20 | """Create a boto3 client for Amazon Bedrock, with optional configuration overrides 21 | 22 | Parameters 23 | ---------- 24 | assumed_role : 25 | Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not 26 | specified, the current active credentials will be used. 27 | endpoint_url : 28 | Optional override for the Bedrock service API Endpoint. If setting this, it should usually 29 | include the protocol i.e. "https://..." 30 | region : 31 | Optional name of the AWS Region in which the service should be called (e.g. "us-east-1"). 32 | If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used. 33 | """ 34 | if region is None: 35 | target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")) 36 | else: 37 | target_region = region 38 | 39 | print(f"Create new client\n Using region: {target_region}") 40 | session_kwargs = {"region_name": target_region} 41 | client_kwargs = {**session_kwargs} 42 | 43 | profile_name = os.environ.get("AWS_PROFILE") 44 | print(f" Using profile: {profile_name}") 45 | if profile_name: 46 | print(f" Using profile: {profile_name}") 47 | session_kwargs["profile_name"] = profile_name 48 | 49 | retry_config = Config( 50 | region_name=target_region, 51 | retries={ 52 | "max_attempts": 10, 53 | "mode": "standard", 54 | }, 55 | ) 56 | session = boto3.Session(**session_kwargs) 57 | 58 | if assumed_role: 59 | print(f" Using role: {assumed_role}", end='') 60 | sts = session.client("sts") 61 | response = sts.assume_role( 62 | RoleArn=str(assumed_role), 63 | RoleSessionName="langchain-llm-1" 64 | ) 65 | print(" ... successful!") 66 | client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"] 67 | client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"] 68 | client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"] 69 | 70 | if endpoint_url: 71 | client_kwargs["endpoint_url"] = endpoint_url 72 | 73 | bedrock_client = session.client( 74 | service_name="bedrock-runtime", 75 | config=retry_config, 76 | **client_kwargs 77 | ) 78 | 79 | print("boto3 Bedrock client successfully created!") 80 | print(bedrock_client._endpoint) 81 | return bedrock_client 82 | 83 | 84 | class bedrock_info(): 85 | 86 | _BEDROCK_MODEL_INFO = { 87 | "Claude-Instant-V1": "anthropic.claude-instant-v1", 88 | "Claude-V1": "anthropic.claude-v1", 89 | "Claude-V2": "anthropic.claude-v2", 90 | "Claude-V2-1": "anthropic.claude-v2:1", 91 | "Claude-V3-Sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", 92 | "Claude-V3-Haiku": "anthropic.claude-3-haiku-20240307-v1:0", 93 | "Jurassic-2-Mid": "ai21.j2-mid-v1", 94 | "Jurassic-2-Ultra": "ai21.j2-ultra-v1", 95 | "Command": "cohere.command-text-v14", 96 | "Command-Light": "cohere.command-light-text-v14", 97 | "Cohere-Embeddings-En": "cohere.embed-english-v3", 98 | "Cohere-Embeddings-Multilingual": "cohere.embed-multilingual-v3", 99 | "Titan-Embeddings-G1": "amazon.titan-embed-text-v1", 100 | "Titan-Text-Embeddings-V2": "amazon.titan-embed-text-v2:0", 101 | "Titan-Text-G1": "amazon.titan-text-express-v1", 102 | "Titan-Text-G1-Light": "amazon.titan-text-lite-v1", 103 | "Titan-Text-G1-Premier": "amazon.titan-text-premier-v1:0", 104 | "Titan-Text-G1-Express": "amazon.titan-text-express-v1", 105 | "Llama2-13b-Chat": "meta.llama2-13b-chat-v1" 106 | } 107 | 108 | @classmethod 109 | def get_list_fm_models(cls, verbose=False): 110 | 111 | if verbose: 112 | bedrock = boto3.client(service_name='bedrock') 113 | model_list = bedrock.list_foundation_models() 114 | return model_list["modelSummaries"] 115 | else: 116 | return cls._BEDROCK_MODEL_INFO 117 | 118 | @classmethod 119 | def get_model_id(cls, model_name): 120 | 121 | assert model_name in cls._BEDROCK_MODEL_INFO.keys(), "Check model name" 122 | 123 | return cls._BEDROCK_MODEL_INFO[model_name] -------------------------------------------------------------------------------- /application/utils/chat.py: -------------------------------------------------------------------------------- 1 | import ipywidgets as ipw 2 | from utils import print_ww 3 | from langchain import PromptTemplate 4 | from IPython.display import display, clear_output 5 | from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory, ConversationBufferMemory 6 | 7 | class chat_utils(): 8 | 9 | MEMORY_TYPE = [ 10 | "ConversationBufferMemory", 11 | "ConversationBufferWindowMemory", 12 | "ConversationSummaryBufferMemory" 13 | ] 14 | 15 | PROMPT_SUMMARY = PromptTemplate( 16 | input_variables=['summary', 'new_lines'], 17 | template=""" 18 | \n\nHuman: Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary. 19 | 20 | EXAMPLE 21 | Current summary: 22 | The user asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good. 23 | 24 | New lines of conversation: 25 | User: Why do you think artificial intelligence is a force for good? 26 | AI: Because artificial intelligence will help users reach their full potential. 27 | 28 | New summary: 29 | The user asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help users reach their full potential. 30 | END OF EXAMPLE 31 | 32 | Current summary: 33 | {summary} 34 | 35 | New lines of conversation: 36 | {new_lines} 37 | 38 | \n\nAssistant:""" 39 | ) 40 | 41 | @classmethod 42 | def get_memory(cls, **kwargs): 43 | """ 44 | create memory for this chat session 45 | """ 46 | memory_type = kwargs["memory_type"] 47 | 48 | assert memory_type in cls.MEMORY_TYPE, f"Check Buffer Name, Lists: {cls.MEMORY_TYPE}" 49 | if memory_type == "ConversationBufferMemory": 50 | 51 | # This memory allows for storing messages and then extracts the messages in a variable. 52 | memory = ConversationBufferMemory( 53 | memory_key=kwargs.get("memory_key", "chat_history"), 54 | human_prefix=kwargs.get("human_prefix", "Human"), 55 | ai_prefix=kwargs.get("ai_prefix", "AI"), 56 | return_messages=kwargs.get("return_messages", True) 57 | ) 58 | elif memory_type == "ConversationBufferWindowMemory": 59 | 60 | # Maintains a history of previous messages 61 | # ConversationBufferWindowMemory keeps a list of the interactions of the conversation over time. 62 | # It only uses the last K interactions. 63 | # This can be useful for keeping a sliding window of the most recent interactions, 64 | # so the buffer does not get too large. 65 | # https://python.langchain.com/docs/modules/memory/types/buffer_window 66 | memory = ConversationBufferWindowMemory( 67 | k=kwargs.get("k", 5), 68 | memory_key=kwargs.get("memory_key", "chat_history"), 69 | human_prefix=kwargs.get("human_prefix", "Human"), 70 | ai_prefix=kwargs.get("ai_prefix", "AI"), 71 | return_messages=kwargs.get("return_messages", True), 72 | ) 73 | elif memory_type == "ConversationSummaryBufferMemory": 74 | 75 | # Maintains a summary of previous messages 76 | # ConversationSummaryBufferMemory combines the two ideas. 77 | # It keeps a buffer of recent interactions in memory, but rather than just completely flushing old interactions it compiles them into a summary and uses both. 78 | # It uses token length rather than number of interactions to determine when to flush interactions. 79 | 80 | assert kwargs.get("llm", None) != None, "Give your LLM" 81 | memory = ConversationSummaryBufferMemory( 82 | llm=kwargs.get("llm", None), 83 | memory_key=kwargs.get("memory_key", "chat_history"), 84 | human_prefix=kwargs.get("human_prefix", "User"), 85 | ai_prefix=kwargs.get("ai_prefix", "AI"), 86 | return_messages=kwargs.get("return_messages", True), 87 | max_token_limit=kwargs.get("max_token_limit", 1024), 88 | prompt=cls.PROMPT_SUMMARY 89 | ) 90 | 91 | return memory 92 | 93 | @classmethod 94 | def get_tokens(cls, chain, prompt): 95 | token = chain.llm.get_num_tokens(prompt) 96 | print(f'# tokens: {token}') 97 | return token 98 | 99 | @classmethod 100 | def clear_memory(cls, chain): 101 | return chain.memory.clear() 102 | 103 | class ChatUX: 104 | """ A chat UX using IPWidgets 105 | """ 106 | def __init__(self, qa, retrievalChain = False): 107 | self.qa = qa 108 | self.name = None 109 | self.b=None 110 | self.retrievalChain = retrievalChain 111 | self.out = ipw.Output() 112 | 113 | if "ConversationChain" in str(type(self.qa)): 114 | self.streaming = self.qa.llm.streaming 115 | elif "ConversationalRetrievalChain" in str(type(self.qa)): 116 | self.streaming = self.qa.combine_docs_chain.llm_chain.llm.streaming 117 | 118 | def start_chat(self): 119 | print("Starting chat bot") 120 | display(self.out) 121 | self.chat(None) 122 | 123 | def chat(self, _): 124 | if self.name is None: 125 | prompt = "" 126 | else: 127 | prompt = self.name.value 128 | if 'q' == prompt or 'quit' == prompt or 'Q' == prompt: 129 | print("Thank you , that was a nice chat !!") 130 | return 131 | elif len(prompt) > 0: 132 | with self.out: 133 | thinking = ipw.Label(value="Thinking...") 134 | display(thinking) 135 | try: 136 | if self.retrievalChain: 137 | result = self.qa.run({'question': prompt }) 138 | else: 139 | result = self.qa.run({'input': prompt }) #, 'history':chat_history}) 140 | except: 141 | result = "No answer" 142 | thinking.value="" 143 | if self.streaming: 144 | response = f"AI:{result}" 145 | else: 146 | print_ww(f"AI:{result}") 147 | self.name.disabled = True 148 | self.b.disabled = True 149 | self.name = None 150 | 151 | if self.name is None: 152 | with self.out: 153 | self.name = ipw.Text(description="You:", placeholder='q to quit') 154 | self.b = ipw.Button(description="Send") 155 | self.b.on_click(self.chat) 156 | display(ipw.Box(children=(self.name, self.b))) -------------------------------------------------------------------------------- /application/utils/chunk.py: -------------------------------------------------------------------------------- 1 | from langchain.docstore.document import Document 2 | from langchain.text_splitter import RecursiveCharacterTextSplitter 3 | 4 | 5 | class parant_documents(): 6 | 7 | @classmethod 8 | def _create_chunk(cls, docs, chunk_size, chunk_overlap): 9 | 10 | ''' 11 | docs: list of docs 12 | chunk_size: int 13 | chunk_overlap: int 14 | return: list of chunk_docs 15 | ''' 16 | 17 | text_splitter = RecursiveCharacterTextSplitter( 18 | # Set a really small chunk size, just to show. 19 | chunk_size=chunk_size, 20 | chunk_overlap=chunk_overlap, 21 | separators=["\n\n", "\n", ".", " ", ""], 22 | length_function=len, 23 | ) 24 | # print("doc: in create_chunk", docs ) 25 | chunk_docs = text_splitter.split_documents(docs) 26 | 27 | return chunk_docs 28 | 29 | @classmethod 30 | def create_parent_chunk(cls, docs, parent_id_key, family_tree_id_key, parent_chunk_size, parent_chunk_overlap): 31 | 32 | parent_chunks = cls._create_chunk(docs, parent_chunk_size, parent_chunk_overlap) 33 | for i, doc in enumerate(parent_chunks): 34 | doc.metadata[family_tree_id_key] = 'parent' 35 | doc.metadata[parent_id_key] = None 36 | 37 | return parent_chunks 38 | 39 | @classmethod 40 | def create_child_chunk(cls, child_chunk_size, child_chunk_overlap, docs, parent_ids_value, parent_id_key, family_tree_id_key): 41 | 42 | sub_docs = [] 43 | for i, doc in enumerate(docs): 44 | # print("doc: ", doc) 45 | parent_id = parent_ids_value[i] 46 | doc = [doc] 47 | _sub_docs = cls._create_chunk(doc, child_chunk_size, child_chunk_overlap) 48 | for _doc in _sub_docs: 49 | _doc.metadata[family_tree_id_key] = 'child' 50 | _doc.metadata[parent_id_key] = parent_id 51 | sub_docs.extend(_sub_docs) 52 | 53 | # if i == 0: 54 | # return sub_docs 55 | 56 | return sub_docs -------------------------------------------------------------------------------- /application/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import random 4 | import logging 5 | import functools 6 | from IPython.display import Markdown, HTML, display 7 | 8 | logging.basicConfig() 9 | logger = logging.getLogger('retry-bedrock-invocation') 10 | logger.setLevel(logging.INFO) 11 | 12 | def retry(total_try_cnt=5, sleep_in_sec=5, retryable_exceptions=()): 13 | def decorator(func): 14 | @functools.wraps(func) 15 | def wrapper(*args, **kwargs): 16 | for cnt in range(total_try_cnt): 17 | logger.info(f"trying {func.__name__}() [{cnt+1}/{total_try_cnt}]") 18 | 19 | try: 20 | result = func(*args, **kwargs) 21 | logger.info(f"in retry(), {func.__name__}() returned '{result}'") 22 | 23 | if result: return result 24 | except retryable_exceptions as e: 25 | logger.info(f"in retry(), {func.__name__}() raised retryable exception '{e}'") 26 | pass 27 | except Exception as e: 28 | logger.info(f"in retry(), {func.__name__}() raised {e}") 29 | raise e 30 | 31 | time.sleep(sleep_in_sec) 32 | logger.info(f"{func.__name__} finally has been failed") 33 | return wrapper 34 | return decorator 35 | 36 | def to_pickle(obj, path): 37 | 38 | with open(file=path, mode="wb") as f: 39 | pickle.dump(obj, f) 40 | 41 | print (f'To_PICKLE: {path}') 42 | 43 | def load_pickle(path): 44 | 45 | with open(file=path, mode="rb") as f: 46 | obj=pickle.load(f) 47 | 48 | print (f'Load from {path}') 49 | 50 | return obj 51 | 52 | def to_markdown(obj, path): 53 | 54 | with open(file=path, mode="w") as f: 55 | f.write(obj) 56 | 57 | print (f'To_Markdown: {path}') 58 | 59 | def print_html(input_html): 60 | 61 | html_string="" 62 | html_string = html_string + input_html 63 | 64 | display(HTML(html_string)) 65 | 66 | 67 | -------------------------------------------------------------------------------- /application/utils/opensearch_summit.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Tuple 3 | from opensearchpy import OpenSearch, RequestsHttpConnection 4 | class opensearch_utils(): 5 | @classmethod 6 | def create_aws_opensearch_client(cls, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch: 7 | client = OpenSearch( 8 | hosts=[ 9 | {'host': host.replace("https://", ""), 10 | 'port': 443 11 | } 12 | ], 13 | http_auth=http_auth, 14 | use_ssl=True, 15 | verify_certs=True, 16 | connection_class=RequestsHttpConnection 17 | ) 18 | return client 19 | @classmethod 20 | def create_index(cls, os_client, index_name, index_body): 21 | ''' 22 | 인덱스 생성 23 | ''' 24 | response = os_client.indices.create( 25 | index_name, 26 | body=index_body 27 | ) 28 | print('\nCreating index:') 29 | print(response) 30 | @classmethod 31 | def check_if_index_exists(cls, os_client, index_name): 32 | ''' 33 | 인덱스가 존재하는지 확인 34 | ''' 35 | exists = os_client.indices.exists(index_name) 36 | print(f"index_name={index_name}, exists={exists}") 37 | return exists 38 | @classmethod 39 | def add_doc(cls, os_client, index_name, document, id): 40 | ''' 41 | # Add a document to the index. 42 | ''' 43 | response = os_client.index( 44 | index = index_name, 45 | body = document, 46 | id = id, 47 | refresh = True 48 | ) 49 | print('\nAdding document:') 50 | print(response) 51 | @classmethod 52 | def search_document(cls, os_client, query, index_name): 53 | response = os_client.search( 54 | body=query, 55 | index=index_name 56 | ) 57 | #print('\nKeyword Search results:') 58 | return response 59 | @classmethod 60 | def delete_index(cls, os_client, index_name): 61 | response = os_client.indices.delete( 62 | index=index_name 63 | ) 64 | print('\nDeleting index:') 65 | print(response) 66 | @classmethod 67 | def parse_keyword_response(cls, response, show_size=3): 68 | ''' 69 | 키워드 검색 결과를 보여 줌. 70 | ''' 71 | length = len(response['hits']['hits']) 72 | if length >= 1: 73 | print("# of searched docs: ", length) 74 | print(f"# of display: {show_size}") 75 | print("---------------------") 76 | for idx, doc in enumerate(response['hits']['hits']): 77 | print("_id in index: " , doc['_id']) 78 | print(doc['_score']) 79 | print(doc['_source']['text']) 80 | print(doc['_source']['metadata']) 81 | print("---------------------") 82 | if idx == show_size-1: 83 | break 84 | else: 85 | print("There is no response") 86 | @classmethod 87 | def opensearch_pretty_print_documents(cls, response): 88 | ''' 89 | OpenSearch 결과인 LIST 를 파싱하는 함수 90 | ''' 91 | for doc, score in response: 92 | print(f'\nScore: {score}') 93 | print(f'Document Number: {doc.metadata["row"]}') 94 | # Split the page content into lines 95 | lines = doc.page_content.split("\n") 96 | # Extract and print each piece of information if it exists 97 | for line in lines: 98 | split_line = line.split(": ") 99 | if len(split_line) > 1: 100 | print(f'{split_line[0]}: {split_line[1]}') 101 | print("Metadata:") 102 | print(f'Type: {doc.metadata["type"]}') 103 | print(f'Source: {doc.metadata["source"]}') 104 | print('-' * 50) 105 | @classmethod 106 | def get_document(cls, os_client, doc_id, index_name): 107 | response = os_client.get( 108 | id= doc_id, 109 | index=index_name 110 | ) 111 | return response 112 | @classmethod 113 | def get_count(cls, os_client, index_name): 114 | response = os_client.count( 115 | index=index_name 116 | ) 117 | return response 118 | @classmethod 119 | def get_query(cls, **kwargs): 120 | # Reference: 121 | # OpenSearcj boolean query: 122 | # - https://opensearch.org/docs/latest/query-dsl/compound/bool/ 123 | # OpenSearch match qeury: 124 | # - https://opensearch.org/docs/latest/query-dsl/full-text/index/#match-boolean-prefix 125 | # OpenSearch Query Description (한글) 126 | # - https://esbook.kimjmin.net/05-search) 127 | search_type = kwargs.get("search_type", "lexical") 128 | if search_type == "lexical": 129 | min_shoud_match = 0 130 | if "minimum_should_match" in kwargs: 131 | min_shoud_match = kwargs["minimum_should_match"] 132 | QUERY_TEMPLATE = { 133 | "query": { 134 | "bool": { 135 | "must": [ 136 | { 137 | "match": { 138 | "text": { 139 | "query": f'{kwargs["query"]}', 140 | "minimum_should_match": f'{min_shoud_match}%', 141 | "operator": "or", 142 | # "fuzziness": "AUTO", 143 | # "fuzzy_transpositions": True, 144 | # "zero_terms_query": "none", 145 | # "lenient": False, 146 | # "prefix_length": 0, 147 | # "max_expansions": 50, 148 | # "boost": 1 149 | } 150 | } 151 | }, 152 | ], 153 | "filter": [ 154 | ] 155 | } 156 | } 157 | } 158 | if "filter" in kwargs: 159 | QUERY_TEMPLATE["query"]["bool"]["filter"].extend(kwargs["filter"]) 160 | elif search_type == "semantic": 161 | QUERY_TEMPLATE = { 162 | "query": { 163 | "bool": { 164 | "must": [ 165 | { 166 | "knn": { 167 | kwargs["vector_field"]: { 168 | "vector": kwargs["vector"], 169 | "k": kwargs["k"], 170 | } 171 | } 172 | }, 173 | ], 174 | "filter": [ 175 | ] 176 | } 177 | } 178 | } 179 | if "filter" in kwargs: 180 | QUERY_TEMPLATE["query"]["bool"]["filter"].extend(kwargs["filter"]) 181 | return QUERY_TEMPLATE 182 | @classmethod 183 | def get_filter(cls, **kwargs): 184 | BOOL_FILTER_TEMPLATE = { 185 | "bool": { 186 | "filter": [ 187 | ] 188 | } 189 | } 190 | if "filter" in kwargs: 191 | BOOL_FILTER_TEMPLATE["bool"]["filter"].extend(kwargs["filter"]) 192 | return BOOL_FILTER_TEMPLATE 193 | @staticmethod 194 | def get_documents_by_ids(os_client, ids, index_name): 195 | response = os_client.mget( 196 | body={"ids": ids}, 197 | index=index_name 198 | ) 199 | return response 200 | @staticmethod 201 | def opensearch_pretty_print_documents_with_score(doctype, response): 202 | ''' 203 | OpenSearch 결과인 LIST 를 파싱하는 함수 204 | ''' 205 | results = [] 206 | responses = copy.deepcopy(response) 207 | for doc, score in responses: 208 | result = {} 209 | result['doctype'] = doctype 210 | result['score'] = score 211 | result['lines'] = doc.page_content.split("\n") 212 | result['meta'] = doc.metadata 213 | results.append(result) 214 | return results -------------------------------------------------------------------------------- /application/utils/pymupdf.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script accepts a PDF document filename and converts it to a text file 3 | in Markdown format, compatible with the GitHub standard. 4 | 5 | It must be invoked with the filename like this: 6 | 7 | python pymupdf_rag.py input.pdf [-pages PAGES] 8 | 9 | The "PAGES" parameter is a string (containing no spaces) of comma-separated 10 | page numbers to consider. Each item is either a single page number or a 11 | number range "m-n". Use "N" to address the document's last page number. 12 | Example: "-pages 2-15,40,43-N" 13 | 14 | It will produce a markdown text file called "input.md". 15 | 16 | Text will be sorted in Western reading order. Any table will be included in 17 | the text in markdwn format as well. 18 | 19 | Use in some other script 20 | ------------------------- 21 | import fitz 22 | from to_markdown import to_markdown 23 | 24 | doc = fitz.open("input.pdf") 25 | page_list = [ list of 0-based page numbers ] 26 | md_text = to_markdown(doc, pages=page_list) 27 | 28 | Dependencies 29 | ------------- 30 | PyMuPDF v1.24.0 or later 31 | 32 | Copyright and License 33 | ---------------------- 34 | Copyright 2024 Artifex Software, Inc. 35 | License GNU Affero GPL 3.0 36 | """ 37 | 38 | import string 39 | from pprint import pprint 40 | 41 | import fitz 42 | 43 | if fitz.pymupdf_version_tuple < (1, 24, 0): 44 | raise NotImplementedError("PyMuPDF version 1.24.0 or later is needed.") 45 | 46 | 47 | def to_markdown_pymupdf(doc: fitz.Document, pages: list = None) -> str: 48 | """Process the document and return the text of its selected pages.""" 49 | SPACES = set(string.whitespace) # used to check relevance of text pieces 50 | if not pages: # use all pages if argument not given 51 | pages = range(doc.page_count) 52 | 53 | class IdentifyHeaders: 54 | """Compute data for identifying header text.""" 55 | 56 | def __init__(self, doc, pages: list = None, body_limit: float = None): 57 | """Read all text and make a dictionary of fontsizes. 58 | 59 | Args: 60 | pages: optional list of pages to consider 61 | body_limit: consider text with larger font size as some header 62 | """ 63 | if pages is None: # use all pages if omitted 64 | pages = range(doc.page_count) 65 | fontsizes = {} 66 | for pno in pages: 67 | page = doc[pno] 68 | blocks = page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] 69 | for span in [ # look at all non-empty horizontal spans 70 | s 71 | for b in blocks 72 | for l in b["lines"] 73 | for s in l["spans"] 74 | if not SPACES.issuperset(s["text"]) 75 | ]: 76 | fontsz = round(span["size"]) 77 | count = fontsizes.get(fontsz, 0) + len(span["text"].strip()) 78 | fontsizes[fontsz] = count 79 | 80 | # maps a fontsize to a string of multiple # header tag characters 81 | self.header_id = {} 82 | if body_limit is None: # body text fontsize if not provided 83 | body_limit = sorted( 84 | [(k, v) for k, v in fontsizes.items()], 85 | key=lambda i: i[1], 86 | reverse=True, 87 | )[0][0] 88 | 89 | sizes = sorted( 90 | [f for f in fontsizes.keys() if f > body_limit], reverse=True 91 | ) 92 | 93 | # make the header tag dictionary 94 | for i, size in enumerate(sizes): 95 | self.header_id[size] = "#" * (i + 1) + " " 96 | 97 | def get_header_id(self, span): 98 | """Return appropriate markdown header prefix. 99 | 100 | Given a text span from a "dict"/"radict" extraction, determine the 101 | markdown header prefix string of 0 to many concatenated '#' characters. 102 | """ 103 | fontsize = round(span["size"]) # compute fontsize 104 | hdr_id = self.header_id.get(fontsize, "") 105 | return hdr_id 106 | 107 | def resolve_links(links, span): 108 | """Accept a span bbox and return a markdown link string.""" 109 | bbox = fitz.Rect(span["bbox"]) # span bbox 110 | # a link should overlap at least 70% of the span 111 | bbox_area = 0.7 * abs(bbox) 112 | for link in links: 113 | hot = link["from"] # the hot area of the link 114 | if not abs(hot & bbox) >= bbox_area: 115 | continue # does not touch the bbox 116 | text = f'[{span["text"].strip()}]({link["uri"]})' 117 | return text 118 | 119 | def write_text(page, clip, hdr_prefix): 120 | """Output the text found inside the given clip. 121 | 122 | This is an alternative for plain text in that it outputs 123 | text enriched with markdown styling. 124 | The logic is capable of recognizing headers, body text, code blocks, 125 | inline code, bold, italic and bold-italic styling. 126 | There is also some effort for list supported (ordered / unordered) in 127 | that typical characters are replaced by respective markdown characters. 128 | """ 129 | out_string = "" 130 | code = False # mode indicator: outputting code 131 | 132 | # extract URL type links on page 133 | links = [l for l in page.get_links() if l["kind"] == 2] 134 | 135 | blocks = page.get_text( 136 | "dict", 137 | clip=clip, 138 | flags=fitz.TEXTFLAGS_TEXT, 139 | sort=True, 140 | )["blocks"] 141 | 142 | for block in blocks: # iterate textblocks 143 | previous_y = 0 144 | for line in block["lines"]: # iterate lines in block 145 | if line["dir"][1] != 0: # only consider horizontal lines 146 | continue 147 | spans = [s for s in line["spans"]] 148 | 149 | this_y = line["bbox"][3] # current bottom coord 150 | 151 | # check for still being on same line 152 | same_line = abs(this_y - previous_y) <= 3 and previous_y > 0 153 | 154 | if same_line and out_string.endswith("\n"): 155 | out_string = out_string[:-1] 156 | 157 | # are all spans in line in a mono-spaced font? 158 | all_mono = all([s["flags"] & 8 for s in spans]) 159 | 160 | # compute text of the line 161 | text = "".join([s["text"] for s in spans]) 162 | if not same_line: 163 | previous_y = this_y 164 | if not out_string.endswith("\n"): 165 | out_string += "\n" 166 | 167 | if all_mono: 168 | # compute approx. distance from left - assuming a width 169 | # of 0.5*fontsize. 170 | delta = int( 171 | (spans[0]["bbox"][0] - block["bbox"][0]) 172 | / (spans[0]["size"] * 0.5) 173 | ) 174 | if not code: # if not already in code output mode: 175 | out_string += "```" # switch on "code" mode 176 | code = True 177 | if not same_line: # new code line with left indentation 178 | out_string += "\n" + " " * delta + text + " " 179 | previous_y = this_y 180 | else: # same line, simply append 181 | out_string += text + " " 182 | continue # done with this line 183 | 184 | for i, s in enumerate(spans): # iterate spans of the line 185 | # this line is not all-mono, so switch off "code" mode 186 | if code: # still in code output mode? 187 | out_string += "```\n" # switch of code mode 188 | code = False 189 | # decode font properties 190 | mono = s["flags"] & 8 191 | bold = s["flags"] & 16 192 | italic = s["flags"] & 2 193 | 194 | if mono: 195 | # this is text in some monospaced font 196 | out_string += f"`{s['text'].strip()}` " 197 | else: # not a mono text 198 | # for first span, get header prefix string if present 199 | if i == 0: 200 | hdr_string = hdr_prefix.get_header_id(s) 201 | else: 202 | hdr_string = "" 203 | prefix = "" 204 | suffix = "" 205 | if hdr_string == "": 206 | if bold: 207 | prefix = "**" 208 | suffix += "**" 209 | if italic: 210 | prefix += "_" 211 | suffix = "_" + suffix 212 | 213 | ltext = resolve_links(links, s) 214 | if ltext: 215 | text = f"{hdr_string}{prefix}{ltext}{suffix} " 216 | else: 217 | text = f"{hdr_string}{prefix}{s['text'].strip()}{suffix} " 218 | text = ( 219 | text.replace("<", "<") 220 | .replace(">", ">") 221 | .replace(chr(0xF0B7), "-") 222 | .replace(chr(0xB7), "-") 223 | .replace(chr(8226), "-") 224 | .replace(chr(9679), "-") 225 | ) 226 | out_string += text 227 | previous_y = this_y 228 | if not code: 229 | out_string += "\n" 230 | out_string += "\n" 231 | if code: 232 | out_string += "```\n" # switch of code mode 233 | code = False 234 | return out_string.replace(" \n", "\n") 235 | 236 | hdr_prefix = IdentifyHeaders(doc, pages=pages) 237 | md_string = "" 238 | 239 | for pno in pages: 240 | page = doc[pno] 241 | # 1. first locate all tables on page 242 | tabs = page.find_tables() 243 | 244 | # 2. make a list of table boundary boxes, sort by top-left corner. 245 | # Must include the header bbox, which may be external. 246 | tab_rects = sorted( 247 | [ 248 | (fitz.Rect(t.bbox) | fitz.Rect(t.header.bbox), i) 249 | for i, t in enumerate(tabs.tables) 250 | ], 251 | key=lambda r: (r[0].y0, r[0].x0), 252 | ) 253 | 254 | # 3. final list of all text and table rectangles 255 | text_rects = [] 256 | # compute rectangles outside tables and fill final rect list 257 | for i, (r, idx) in enumerate(tab_rects): 258 | if i == 0: # compute rect above all tables 259 | tr = page.rect 260 | tr.y1 = r.y0 261 | if not tr.is_empty: 262 | text_rects.append(("text", tr, 0)) 263 | text_rects.append(("table", r, idx)) 264 | continue 265 | # read previous rectangle in final list: always a table! 266 | _, r0, idx0 = text_rects[-1] 267 | 268 | # check if a non-empty text rect is fitting in between tables 269 | tr = page.rect 270 | tr.y0 = r0.y1 271 | tr.y1 = r.y0 272 | if not tr.is_empty: # empty if two tables overlap vertically! 273 | text_rects.append(("text", tr, 0)) 274 | 275 | text_rects.append(("table", r, idx)) 276 | 277 | # there may also be text below all tables 278 | if i == len(tab_rects) - 1: 279 | tr = page.rect 280 | tr.y0 = r.y1 281 | if not tr.is_empty: 282 | text_rects.append(("text", tr, 0)) 283 | 284 | if not text_rects: # this will happen for table-free pages 285 | text_rects.append(("text", page.rect, 0)) 286 | else: 287 | rtype, r, idx = text_rects[-1] 288 | if rtype == "table": 289 | tr = page.rect 290 | tr.y0 = r.y1 291 | if not tr.is_empty: 292 | text_rects.append(("text", tr, 0)) 293 | 294 | # we have all rectangles and can start outputting their contents 295 | for rtype, r, idx in text_rects: 296 | if rtype == "text": # a text rectangle 297 | md_string += write_text(page, r, hdr_prefix) # write MD content 298 | md_string += "\n" 299 | else: # a table rect 300 | md_string += tabs[idx].to_markdown(clean=False) 301 | 302 | md_string += "\n-----\n\n" 303 | 304 | return md_string 305 | 306 | 307 | if __name__ == "__main__": 308 | import os 309 | import sys 310 | import time 311 | 312 | try: 313 | filename = sys.argv[1] 314 | except IndexError: 315 | print(f"Usage:\npython {os.path.basename(__file__)} input.pdf") 316 | sys.exit() 317 | 318 | t0 = time.perf_counter() # start a time 319 | 320 | doc = fitz.open(filename) # open input file 321 | parms = sys.argv[2:] # contains ["-pages", "PAGES"] or empty list 322 | pages = range(doc.page_count) # default page range 323 | if len(parms) == 2 and parms[0] == "-pages": # page sub-selection given 324 | pages = [] # list of desired page numbers 325 | 326 | # replace any variable "N" by page count 327 | pages_spec = parms[1].replace("N", f"{doc.page_count}") 328 | for spec in pages_spec.split(","): 329 | if "-" in spec: 330 | start, end = map(int, spec.split("-")) 331 | pages.extend(range(start - 1, end)) 332 | else: 333 | pages.append(int(spec) - 1) 334 | 335 | # make a set of invalid page numbers 336 | wrong_pages = set([n + 1 for n in pages if n >= doc.page_count][:4]) 337 | if wrong_pages != set(): # if any invalid numbers given, exit. 338 | sys.exit(f"Page number(s) {wrong_pages} not in '{doc}'.") 339 | 340 | # get the markdown string 341 | md_string = to_markdown(doc, pages=pages) 342 | 343 | # output to a text file with extension ".md" 344 | out = open(doc.name.replace(".pdf", ".md"), "w") 345 | out.write(md_string) 346 | out.close() 347 | t1 = time.perf_counter() # stop timer 348 | print(f"Markdown creation time for {doc.name=} {round(t1-t0,2)} sec.") 349 | -------------------------------------------------------------------------------- /application/utils/rag_summit.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | ############################################################ 3 | # RAG 관련 함수들 4 | ############################################################ 5 | ############################################################ 6 | 7 | import json 8 | import copy 9 | import boto3 10 | import numpy as np 11 | import pandas as pd 12 | from copy import deepcopy 13 | from pprint import pprint 14 | from operator import itemgetter 15 | from itertools import chain as ch 16 | from typing import Any, Dict, List, Optional, List, Tuple 17 | from opensearchpy import OpenSearch, RequestsHttpConnection 18 | 19 | import base64 20 | from PIL import Image 21 | from io import BytesIO 22 | import matplotlib.pyplot as plt 23 | 24 | from utils import print_ww 25 | from utils.opensearch_summit import opensearch_utils 26 | from utils.common_utils import print_html 27 | 28 | from langchain.schema import Document 29 | from langchain.chains import RetrievalQA 30 | from langchain.schema import BaseRetriever 31 | from langchain.prompts import PromptTemplate 32 | from langchain.retrievers import AmazonKendraRetriever 33 | from langchain_core.tracers import ConsoleCallbackHandler 34 | from langchain.schema.output_parser import StrOutputParser 35 | from langchain.embeddings import SagemakerEndpointEmbeddings 36 | from langchain.text_splitter import RecursiveCharacterTextSplitter 37 | from langchain.callbacks.manager import CallbackManagerForRetrieverRun 38 | from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler 39 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 40 | from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate 41 | 42 | import threading 43 | from functools import partial 44 | from multiprocessing.pool import ThreadPool 45 | #pool = ThreadPool(processes=2) 46 | #rag_fusion_pool = ThreadPool(processes=5) 47 | 48 | ############################################################ 49 | # Prompt repo 50 | ############################################################ 51 | class prompt_repo(): 52 | 53 | template_types = ["web_search", "sci_fact", "fiqa", "trec_news"] 54 | prompt_types = ["answer_only", "answer_with_ref", "original", "ko_answer_only"] 55 | 56 | 57 | #First, find the paragraphs or sentences from the context that are most relevant to answering the question, 58 | #Then, answer the question within XML tags as much as you can. 59 | # Answer the question within XML tags as much as you can. 60 | # Don't say "According to context" when answering. 61 | # Don't insert XML tag such as and when answering. 62 | 63 | 64 | @classmethod 65 | def get_system_prompt(cls, ): 66 | 67 | system_prompt = ''' 68 | You are a master answer bot designed to answer user's questions. 69 | I'm going to give you contexts which consist of texts, tables and images. 70 | Read the contexts carefully, because I'm going to ask you a question about it. 71 | ''' 72 | return system_prompt 73 | 74 | @classmethod 75 | def get_human_prompt(cls, images=None, tables=None): 76 | 77 | human_prompt = [] 78 | 79 | image_template = { 80 | "type": "image_url", 81 | "image_url": { 82 | "url": "data:image/png;base64," + "IMAGE_BASE64", 83 | }, 84 | } 85 | text_template = { 86 | "type": "text", 87 | "text": ''' 88 | Here is the contexts as texts: {contexts} 89 | TABLE_PROMPT 90 | 91 | First, find a few paragraphs or sentences from the contexts that are most relevant to answering the question. 92 | Then, answer the question as much as you can. 93 | 94 | Skip the preamble and go straight into the answer. 95 | Don't insert any XML tag such as and when answering. 96 | Answer in Korean. 97 | 98 | Here is the question: {question} 99 | 100 | If the question cannot be answered by the contexts, say "No relevant contexts". 101 | ''' 102 | } 103 | 104 | table_prompt = ''' 105 | Here is the contexts as tables (table as text): {tables_text} 106 | Here is the contexts as tables (table as html): {tables_html} 107 | ''' 108 | if tables != None: 109 | text_template["text"] = text_template["text"].replace("TABLE_PROMPT", table_prompt) 110 | for table in tables: 111 | #if table.metadata["image_base64"]: 112 | if "image_base64" in table.metadata: 113 | image_template["image_url"]["url"] = image_template["image_url"]["url"].replace("IMAGE_BASE64", table.metadata["image_base64"]) 114 | human_prompt.append(image_template) 115 | else: text_template["text"] = text_template["text"].replace("TABLE_PROMPT", "") 116 | 117 | if images != None: 118 | for image in images: 119 | image_template["image_url"]["url"] = image_template["image_url"]["url"].replace("IMAGE_BASE64", image.page_content) 120 | human_prompt.append(image_template) 121 | 122 | human_prompt.append(text_template) 123 | 124 | return human_prompt 125 | 126 | # @classmethod 127 | # def get_qa(cls, prompt_type="answer_only"): 128 | 129 | # assert prompt_type in cls.prompt_types, "Check your prompt_type" 130 | 131 | # if prompt_type == "answer_only": 132 | 133 | # prompt = """ 134 | # \n\nHuman: 135 | # You are a master answer bot designed to answer software developer's questions. 136 | # I'm going to give you a context. Read the context carefully, because I'm going to ask you a question about it. 137 | 138 | # Here is the context: {context} 139 | 140 | # First, find a few paragraphs or sentences from the context that are most relevant to answering the question. 141 | # Then, answer the question as much as you can. 142 | 143 | # Skip the preamble and go straight into the answer. 144 | # Don't insert any XML tag such as and when answering. 145 | 146 | # Here is the question: {question} 147 | 148 | # If the question cannot be answered by the context, say "No relevant context". 149 | # \n\nAssistant: Here is the answer. """ 150 | 151 | # elif prompt_type == "answer_with_ref": 152 | 153 | # prompt = """ 154 | # \n\nHuman: 155 | # You are a master answer bot designed to answer software developer's questions. 156 | # I'm going to give you a context. Read the context carefully, because I'm going to ask you a question about it. 157 | 158 | # Here is the context: {context} 159 | 160 | # First, find the paragraphs or sentences from the context that are most relevant to answering the question, and then print them in numbered order. 161 | # The format of paragraphs or sentences to the question should look like what's shown between the tags. 162 | # Make sure to follow the formatting and spacing exactly. 163 | 164 | # 165 | # [Examples of question + answer pairs using parts of the given context, with answers written exactly like how Claude’s output should be structured] 166 | # 167 | 168 | # If there are no relevant paragraphs or sentences, write "No relevant context" instead. 169 | 170 | # Then, answer the question within XML tags. 171 | # Answer as much as you can. 172 | # Skip the preamble and go straight into the answer. 173 | # Don't say "According to context" when answering. 174 | # Don't insert XML tag such as and when answering. 175 | # If needed, answer using bulleted format. 176 | # If relevant paragraphs or sentences have code block, please show us that as code block. 177 | 178 | # Here is the question: {question} 179 | 180 | # If the question cannot be answered by the context, say "No relevant context". 181 | 182 | # \n\nAssistant: Here is the most relevant sentence in the context:""" 183 | 184 | # elif prompt_type == "original": 185 | # prompt = """ 186 | # \n\nHuman: Here is the context, inside XML tags. 187 | 188 | # 189 | # {context} 190 | # 191 | 192 | # Only using the context as above, answer the following question with the rules as below: 193 | # - Don't insert XML tag such as and when answering. 194 | # - Write as much as you can 195 | # - Be courteous and polite 196 | # - Only answer the question if you can find the answer in the context with certainty. 197 | 198 | # Question: 199 | # {question} 200 | 201 | # If the answer is not in the context, just say "I don't know" 202 | # \n\nAssistant:""" 203 | 204 | # if prompt_type == "ko_answer_only": 205 | 206 | # prompt = """ 207 | # \n\nHuman: 208 | # You are a master answer bot designed to answer software developer's questions. 209 | # I'm going to give you a context. Read the context carefully, because I'm going to ask you a question about it. 210 | 211 | # Here is the context: {context} 212 | 213 | # First, find a few paragraphs or sentences from the context that are most relevant to answering the question. 214 | # Then, answer the question as much as you can. 215 | 216 | # Skip the preamble and go straight into the answer. 217 | # Don't insert any XML tag such as and when answering. 218 | 219 | # Here is the question: {question} 220 | 221 | # Answer in Korean. 222 | # If the question cannot be answered by the context, say "No relevant context". 223 | # \n\nAssistant: Here is the answer. """ 224 | 225 | # prompt_template = PromptTemplate( 226 | # template=prompt, input_variables=["context", "question"] 227 | # ) 228 | 229 | # return prompt_template 230 | 231 | @staticmethod 232 | def get_rag_fusion(): 233 | 234 | system_prompt = """ 235 | You are a helpful assistant that generates multiple search queries that is semantically simiar to a single input query. 236 | Skip the preamble and generate in Korean. 237 | """ 238 | human_prompt = """ 239 | Generate multiple search queries related to: {query} 240 | OUTPUT ({query_augmentation_size} queries): 241 | """ 242 | 243 | system_message_template = SystemMessagePromptTemplate.from_template(system_prompt) 244 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) 245 | prompt = ChatPromptTemplate.from_messages( 246 | [system_message_template, human_message_template] 247 | ) 248 | 249 | return prompt 250 | 251 | @classmethod 252 | def get_hyde(cls, template_type): 253 | 254 | assert template_type in cls.template_types, "Check your template_type" 255 | 256 | system_prompt = """ 257 | You are a master answer bot designed to answer user's questions. 258 | """ 259 | human_prompt = """ 260 | Here is the question: {query} 261 | 262 | HYDE_TEMPLATE 263 | Skip the preamble and generate in Korean. 264 | """ 265 | 266 | 267 | # There are a few different templates to choose from 268 | # These are just different ways to generate hypothetical documents 269 | hyde_template = { 270 | "web_search": "Please write a concise passage to answer the question.", 271 | "sci_fact": "Please write a concise scientific paper passage to support/refute the claim.", 272 | "fiqa": "Please write a concise financial article passage to answer the question.", 273 | "trec_news": "Please write a concise news passage about the topic." 274 | } 275 | human_prompt = human_prompt.replace("HYDE_TEMPLATE", hyde_template[template_type]) 276 | 277 | system_message_template = SystemMessagePromptTemplate.from_template(system_prompt) 278 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) 279 | prompt = ChatPromptTemplate.from_messages( 280 | [system_message_template, human_message_template] 281 | ) 282 | 283 | return prompt 284 | 285 | ############################################################ 286 | # RetrievalQA (Langchain) 287 | ############################################################ 288 | pretty_contexts = None 289 | augmentation = None 290 | 291 | class qa_chain(): 292 | 293 | def __init__(self, **kwargs): 294 | 295 | system_prompt = kwargs["system_prompt"] 296 | self.llm_text = kwargs["llm_text"] 297 | self.retriever = kwargs["retriever"] 298 | self.system_message_template = SystemMessagePromptTemplate.from_template(system_prompt) 299 | self.return_context = kwargs.get("return_context", False) 300 | self.verbose = kwargs.get("verbose", False) 301 | 302 | def invoke(self, **kwargs): 303 | global pretty_contexts 304 | global augmentation 305 | 306 | query, verbose = kwargs["query"], kwargs.get("verbose", self.verbose) 307 | tables, images = None, None 308 | if self.retriever.complex_doc: 309 | #retrieval, tables, images = self.retriever.get_relevant_documents(query) 310 | retrieval, tables, images = self.retriever.invoke(query) 311 | 312 | invoke_args = { 313 | "contexts": "\n\n".join([doc.page_content for doc in retrieval]), 314 | "tables_text": "\n\n".join([doc.page_content for doc in tables]), 315 | "tables_html": "\n\n".join([doc.metadata["text_as_html"] if "text_as_html" in doc.metadata else "" for doc in tables]), 316 | "question": query 317 | } 318 | else: 319 | #retrieval = self.retriever.get_relevant_documents(query) 320 | retrieval = self.retriever.invoke(query) 321 | invoke_args = { 322 | "contexts": "\n\n".join([doc.page_content for doc in retrieval]), 323 | "question": query 324 | } 325 | 326 | human_prompt = prompt_repo.get_human_prompt( 327 | images=images, 328 | tables=tables 329 | ) 330 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) 331 | prompt = ChatPromptTemplate.from_messages( 332 | [self.system_message_template, human_message_template] 333 | ) 334 | 335 | chain = prompt | self.llm_text | StrOutputParser() 336 | 337 | self.verbose = verbose 338 | response = chain.invoke( 339 | invoke_args, 340 | config={'callbacks': [ConsoleCallbackHandler()]} if self.verbose else {} 341 | ) 342 | pretty_contexts = tuple(pretty_contexts) 343 | 344 | return response, pretty_contexts, retrieval, augmentation 345 | 346 | def run_RetrievalQA(**kwargs): 347 | 348 | chain_types = ["stuff", "map_reduce", "refine"] 349 | 350 | assert "llm" in kwargs, "Check your llm" 351 | assert "query" in kwargs, "Check your query" 352 | assert "prompt" in kwargs, "Check your prompt" 353 | assert "vector_db" in kwargs, "Check your vector_db" 354 | assert kwargs.get("chain_type", "stuff") in chain_types, f'Check your chain_type, {chain_types}' 355 | 356 | qa = RetrievalQA.from_chain_type( 357 | llm=kwargs["llm"], 358 | chain_type=kwargs.get("chain_type", "stuff"), 359 | retriever=kwargs["vector_db"].as_retriever( 360 | search_type="similarity", 361 | search_kwargs={ 362 | "k": kwargs.get("k", 5), 363 | "boolean_filter": opensearch_utils.get_filter( 364 | filter=kwargs.get("boolean_filter", []) 365 | ), 366 | } 367 | ), 368 | return_source_documents=True, 369 | chain_type_kwargs={ 370 | "prompt": kwargs["prompt"], 371 | "verbose": kwargs.get("verbose", False), 372 | }, 373 | verbose=kwargs.get("verbose", False) 374 | ) 375 | 376 | return qa(kwargs["query"]) 377 | 378 | def run_RetrievalQA_kendra(query, llm_text, PROMPT, kendra_index_id, k, aws_region, verbose): 379 | qa = RetrievalQA.from_chain_type( 380 | llm=llm_text, 381 | chain_type="stuff", 382 | retriever=AmazonKendraRetriever( 383 | index_id=kendra_index_id, 384 | region_name=aws_region, 385 | top_k=k, 386 | attribute_filter = { 387 | "EqualsTo": { 388 | "Key": "_language_code", 389 | "Value": { 390 | "StringValue": "ko" 391 | } 392 | }, 393 | } 394 | ), 395 | return_source_documents=True, 396 | chain_type_kwargs={ 397 | "prompt": PROMPT, 398 | "verbose": verbose, 399 | }, 400 | verbose=verbose 401 | ) 402 | 403 | result = qa(query) 404 | 405 | return result 406 | 407 | ################################################################# 408 | # Document Retriever with custom function: return List(documents) 409 | ################################################################# 410 | def list_up(similar_docs_semantic, similar_docs_keyword, similar_docs_wo_reranker, similar_docs): 411 | combined_list = [] 412 | combined_list.append(similar_docs_semantic) 413 | combined_list.append(similar_docs_keyword) 414 | combined_list.append(similar_docs_wo_reranker) 415 | combined_list.append(similar_docs) 416 | 417 | return combined_list 418 | 419 | class retriever_utils(): 420 | 421 | runtime_client = boto3.Session().client('sagemaker-runtime') 422 | pool = ThreadPool(processes=2) 423 | rag_fusion_pool = ThreadPool(processes=5) 424 | hyde_pool = ThreadPool(processes=4) 425 | text_splitter = RecursiveCharacterTextSplitter( 426 | # Set a really small chunk size, just to show. 427 | chunk_size=512, 428 | chunk_overlap=0, 429 | separators=["\n\n", "\n", ".", " ", ""], 430 | length_function=len, 431 | ) 432 | token_limit = 300 433 | 434 | @classmethod 435 | # semantic search based 436 | def get_semantic_similar_docs_by_langchain(cls, **kwargs): 437 | 438 | #print(f"Thread={threading.get_ident()}, Process={os.getpid()}") 439 | search_types = ["approximate_search", "script_scoring", "painless_scripting"] 440 | space_types = ["l2", "l1", "linf", "cosinesimil", "innerproduct", "hammingbit"] 441 | 442 | assert "vector_db" in kwargs, "Check your vector_db" 443 | assert "query" in kwargs, "Check your query" 444 | assert kwargs.get("search_type", "approximate_search") in search_types, f'Check your search_type: {search_types}' 445 | assert kwargs.get("space_type", "l2") in space_types, f'Check your space_type: {space_types}' 446 | 447 | results = kwargs["vector_db"].similarity_search_with_score( 448 | query=kwargs["query"], 449 | k=kwargs.get("k", 5), 450 | search_type=kwargs.get("search_type", "approximate_search"), 451 | space_type=kwargs.get("space_type", "l2"), 452 | boolean_filter=opensearch_utils.get_filter( 453 | filter=kwargs.get("boolean_filter", []) 454 | ), 455 | ) 456 | 457 | if kwargs.get("hybrid", False) and results: 458 | max_score = results[0][1] 459 | new_results = [] 460 | for doc in results: 461 | nomalized_score = float(doc[1]/max_score) 462 | new_results.append((doc[0], nomalized_score)) 463 | results = deepcopy(new_results) 464 | 465 | return results 466 | 467 | @classmethod 468 | def control_streaming_mode(cls, llm, stream=True): 469 | 470 | if stream: 471 | llm.streaming = True 472 | llm.callbacks = [StreamingStdOutCallbackHandler()] 473 | else: 474 | llm.streaming = False 475 | llm.callbacks = None 476 | 477 | return llm 478 | 479 | @classmethod 480 | # semantic search based 481 | def get_semantic_similar_docs(cls, **kwargs): 482 | 483 | assert "query" in kwargs, "Check your query" 484 | assert "k" in kwargs, "Check your k" 485 | assert "os_client" in kwargs, "Check your os_client" 486 | assert "index_name" in kwargs, "Check your index_name" 487 | 488 | def normalize_search_results(search_results): 489 | 490 | hits = (search_results["hits"]["hits"]) 491 | max_score = float(search_results["hits"]["max_score"]) 492 | for hit in hits: 493 | hit["_score"] = float(hit["_score"]) / max_score 494 | search_results["hits"]["max_score"] = hits[0]["_score"] 495 | search_results["hits"]["hits"] = hits 496 | return search_results 497 | 498 | query = opensearch_utils.get_query( 499 | query=kwargs["query"], 500 | filter=kwargs.get("boolean_filter", []), 501 | search_type="semantic", # enable semantic search 502 | vector_field="vector_field", # for semantic search check by using index_info = os_client.indices.get(index=index_name) 503 | vector=kwargs["llm_emb"].embed_query(kwargs["query"]), 504 | k=kwargs["k"] 505 | ) 506 | query["size"] = kwargs["k"] 507 | 508 | #print ("\nsemantic search query: ") 509 | #pprint (query) 510 | 511 | search_results = opensearch_utils.search_document( 512 | os_client=kwargs["os_client"], 513 | query=query, 514 | index_name=kwargs["index_name"] 515 | ) 516 | 517 | results = [] 518 | if search_results["hits"]["hits"]: 519 | search_results = normalize_search_results(search_results) 520 | for res in search_results["hits"]["hits"]: 521 | 522 | metadata = res["_source"]["metadata"] 523 | metadata["id"] = res["_id"] 524 | 525 | doc = Document( 526 | page_content=res["_source"]["text"], 527 | metadata=metadata 528 | ) 529 | if kwargs.get("hybrid", False): 530 | results.append((doc, res["_score"])) 531 | else: 532 | results.append((doc)) 533 | 534 | return results 535 | 536 | @classmethod 537 | # lexical(keyword) search based (using Amazon OpenSearch) 538 | def get_lexical_similar_docs(cls, **kwargs): 539 | 540 | assert "query" in kwargs, "Check your query" 541 | assert "k" in kwargs, "Check your k" 542 | assert "os_client" in kwargs, "Check your os_client" 543 | assert "index_name" in kwargs, "Check your index_name" 544 | 545 | def normalize_search_results(search_results): 546 | 547 | hits = (search_results["hits"]["hits"]) 548 | max_score = float(search_results["hits"]["max_score"]) 549 | for hit in hits: 550 | hit["_score"] = float(hit["_score"]) / max_score 551 | search_results["hits"]["max_score"] = hits[0]["_score"] 552 | search_results["hits"]["hits"] = hits 553 | return search_results 554 | 555 | query = opensearch_utils.get_query( 556 | query=kwargs["query"], 557 | minimum_should_match=kwargs.get("minimum_should_match", 0), 558 | filter=kwargs["filter"] 559 | ) 560 | query["size"] = kwargs["k"] 561 | 562 | #print ("\nlexical search query: ") 563 | #pprint (query) 564 | 565 | search_results = opensearch_utils.search_document( 566 | os_client=kwargs["os_client"], 567 | query=query, 568 | index_name=kwargs["index_name"] 569 | ) 570 | 571 | results = [] 572 | if search_results["hits"]["hits"]: 573 | search_results = normalize_search_results(search_results) 574 | for res in search_results["hits"]["hits"]: 575 | 576 | metadata = res["_source"]["metadata"] 577 | metadata["id"] = res["_id"] 578 | 579 | doc = Document( 580 | page_content=res["_source"]["text"], 581 | metadata=metadata 582 | ) 583 | if kwargs.get("hybrid", False): 584 | results.append((doc, res["_score"])) 585 | else: 586 | results.append((doc)) 587 | 588 | return results 589 | 590 | @classmethod 591 | # rag-fusion based 592 | def get_rag_fusion_similar_docs(cls, **kwargs): 593 | global augmentation 594 | 595 | search_types = ["approximate_search", "script_scoring", "painless_scripting"] 596 | space_types = ["l2", "l1", "linf", "cosinesimil", "innerproduct", "hammingbit"] 597 | 598 | assert "llm_emb" in kwargs, "Check your llm_emb" 599 | assert "query" in kwargs, "Check your query" 600 | assert "query_transformation_prompt" in kwargs, "Check your query_transformation_prompt" 601 | assert kwargs.get("search_type", "approximate_search") in search_types, f'Check your search_type: {search_types}' 602 | assert kwargs.get("space_type", "l2") in space_types, f'Check your space_type: {space_types}' 603 | assert kwargs.get("llm_text", None) != None, "Check your llm_text" 604 | 605 | llm_text = kwargs["llm_text"] 606 | query_augmentation_size = kwargs["query_augmentation_size"] 607 | query_transformation_prompt = kwargs["query_transformation_prompt"] 608 | 609 | llm_text = cls.control_streaming_mode(llm_text, stream=False) ## trun off llm streaming 610 | generate_queries = query_transformation_prompt | llm_text | StrOutputParser() | (lambda x: x.split("\n")) 611 | 612 | rag_fusion_query = generate_queries.invoke( 613 | { 614 | "query": kwargs["query"], 615 | "query_augmentation_size": kwargs["query_augmentation_size"] 616 | } 617 | ) 618 | 619 | rag_fusion_query = [query for query in rag_fusion_query if query != ""] 620 | if len(rag_fusion_query) > query_augmentation_size: rag_fusion_query = rag_fusion_query[-query_augmentation_size:] 621 | rag_fusion_query.insert(0, kwargs["query"]) 622 | augmentation = rag_fusion_query 623 | 624 | if kwargs["verbose"]: 625 | print("\n") 626 | print("===== RAG-Fusion Queries =====") 627 | print(rag_fusion_query) 628 | 629 | llm_text = cls.control_streaming_mode(llm_text, stream=True)## trun on llm streaming 630 | 631 | tasks = [] 632 | for query in rag_fusion_query: 633 | semantic_search = partial( 634 | cls.get_semantic_similar_docs, 635 | os_client=kwargs["os_client"], 636 | index_name=kwargs["index_name"], 637 | query=query, 638 | k=kwargs["k"], 639 | boolean_filter=kwargs.get("boolean_filter", []), 640 | llm_emb=kwargs["llm_emb"], 641 | hybrid=True 642 | ) 643 | tasks.append(cls.rag_fusion_pool.apply_async(semantic_search,)) 644 | rag_fusion_docs = [task.get() for task in tasks] 645 | 646 | similar_docs = cls.get_ensemble_results( 647 | doc_lists=rag_fusion_docs, 648 | weights=[1/(query_augmentation_size+1)]*(query_augmentation_size+1), #query_augmentation_size + original query 649 | algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 650 | c=60, 651 | k=kwargs["k"], 652 | ) 653 | 654 | return similar_docs 655 | 656 | @classmethod 657 | # HyDE based 658 | def get_hyde_similar_docs(cls, **kwargs): 659 | global augmentation 660 | 661 | def _get_hyde_response(query, prompt, llm_text): 662 | 663 | chain = prompt | llm_text | StrOutputParser() 664 | 665 | return chain.invoke({"query": query}) 666 | 667 | search_types = ["approximate_search", "script_scoring", "painless_scripting"] 668 | space_types = ["l2", "l1", "linf", "cosinesimil", "innerproduct", "hammingbit"] 669 | 670 | assert "llm_emb" in kwargs, "Check your llm_emb" 671 | assert "query" in kwargs, "Check your query" 672 | assert "hyde_query" in kwargs, "Check your hyde_query" 673 | assert kwargs.get("search_type", "approximate_search") in search_types, f'Check your search_type: {search_types}' 674 | assert kwargs.get("space_type", "l2") in space_types, f'Check your space_type: {space_types}' 675 | assert kwargs.get("llm_text", None) != None, "Check your llm_text" 676 | 677 | query = kwargs["query"] 678 | llm_text = kwargs["llm_text"] 679 | hyde_query = kwargs["hyde_query"] 680 | 681 | tasks = [] 682 | llm_text = cls.control_streaming_mode(llm_text, stream=False) ## trun off llm streaming 683 | for template_type in hyde_query: 684 | hyde_response = partial( 685 | _get_hyde_response, 686 | query=query, 687 | prompt=prompt_repo.get_hyde(template_type), 688 | llm_text=llm_text 689 | ) 690 | tasks.append(cls.hyde_pool.apply_async(hyde_response,)) 691 | hyde_answers = [task.get() for task in tasks] 692 | hyde_answers.insert(0, query) 693 | 694 | tasks = [] 695 | llm_text = cls.control_streaming_mode(llm_text, stream=True) ## trun on llm streaming 696 | for hyde_answer in hyde_answers: 697 | semantic_search = partial( 698 | cls.get_semantic_similar_docs, 699 | os_client=kwargs["os_client"], 700 | index_name=kwargs["index_name"], 701 | query=hyde_answer, 702 | k=kwargs["k"], 703 | boolean_filter=kwargs.get("boolean_filter", []), 704 | llm_emb=kwargs["llm_emb"], 705 | hybrid=True 706 | ) 707 | tasks.append(cls.hyde_pool.apply_async(semantic_search,)) 708 | hyde_docs = [task.get() for task in tasks] 709 | hyde_doc_size = len(hyde_docs) 710 | 711 | similar_docs = cls.get_ensemble_results( 712 | doc_lists=hyde_docs, 713 | weights=[1/(hyde_doc_size)]*(hyde_doc_size), #query_augmentation_size + original query 714 | algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 715 | c=60, 716 | k=kwargs["k"], 717 | ) 718 | augmentation = hyde_answers[1] 719 | if kwargs["verbose"]: 720 | print("\n") 721 | print("===== HyDE Answers =====") 722 | print(hyde_answers) 723 | 724 | return similar_docs 725 | 726 | @classmethod 727 | # ParentDocument based 728 | def get_parent_document_similar_docs(cls, **kwargs): 729 | 730 | child_search_results = kwargs["similar_docs"] 731 | 732 | parent_info, similar_docs = {}, [] 733 | for rank, (doc, score) in enumerate(child_search_results): 734 | parent_id = doc.metadata["parent_id"] 735 | if parent_id != "NA": ## For Tables and Images 736 | if parent_id not in parent_info: 737 | parent_info[parent_id] = (rank+1, score) 738 | else: 739 | if kwargs["hybrid"]: 740 | similar_docs.append((doc, score)) 741 | else: 742 | similar_docs.append((doc)) 743 | 744 | parent_ids = sorted(parent_info.items(), key=lambda x: x[1], reverse=False) 745 | parent_ids = list(map(lambda x:x[0], parent_ids)) 746 | 747 | if parent_ids: 748 | parent_docs = opensearch_utils.get_documents_by_ids( 749 | os_client=kwargs["os_client"], 750 | ids=parent_ids, 751 | index_name=kwargs["index_name"], 752 | ) 753 | 754 | if parent_docs["docs"]: 755 | for res in parent_docs["docs"]: 756 | doc_id = res["_id"] 757 | doc = Document( 758 | page_content=res["_source"]["text"], 759 | metadata=res["_source"]["metadata"] 760 | ) 761 | if kwargs["hybrid"]: 762 | similar_docs.append((doc, parent_info[doc_id][1])) 763 | else: 764 | similar_docs.append((doc)) 765 | 766 | if kwargs["hybrid"]: 767 | similar_docs = sorted( 768 | similar_docs, 769 | key=lambda x: x[1], 770 | reverse=True 771 | ) 772 | 773 | if kwargs["verbose"]: 774 | print("===== ParentDocument =====") 775 | print (f'filter: {kwargs["boolean_filter"]}') 776 | print (f'# child_docs: {len(child_search_results)}') 777 | print (f'# parent docs: {len(similar_docs)}') 778 | print (f'# duplicates: {len(child_search_results)-len(similar_docs)}') 779 | 780 | return similar_docs 781 | 782 | @classmethod 783 | def get_rerank_docs(cls, **kwargs): 784 | 785 | assert "reranker_endpoint_name" in kwargs, "Check your reranker_endpoint_name" 786 | assert "k" in kwargs, "Check your k" 787 | 788 | contexts, query, llm_text, rerank_queries = kwargs["context"], kwargs["query"], kwargs["llm_text"], {"inputs":[]} 789 | 790 | exceed_info = [] 791 | for idx, (context, score) in enumerate(contexts): 792 | page_content = context.page_content 793 | token_size = llm_text.get_num_tokens(query+page_content) 794 | exceed_flag = False 795 | 796 | if token_size > cls.token_limit: 797 | exceed_flag = True 798 | splited_docs = cls.text_splitter.split_documents([context]) 799 | if kwargs["verbose"]: 800 | print(f"\n[Exeeds ReRanker token limit] Number of chunk_docs after split and chunking= {len(splited_docs)}\n") 801 | 802 | partial_set, length = [], [] 803 | for splited_doc in splited_docs: 804 | rerank_queries["inputs"].append({"text": query, "text_pair": splited_doc.page_content}) 805 | length.append(llm_text.get_num_tokens(splited_doc.page_content)) 806 | partial_set.append(len(rerank_queries["inputs"])-1) 807 | else: 808 | rerank_queries["inputs"].append({"text": query, "text_pair": page_content}) 809 | 810 | if exceed_flag: 811 | exceed_info.append([idx, exceed_flag, partial_set, length]) 812 | else: 813 | exceed_info.append([idx, exceed_flag, len(rerank_queries["inputs"])-1, None]) 814 | 815 | rerank_queries = json.dumps(rerank_queries) 816 | 817 | response = cls.runtime_client.invoke_endpoint( 818 | EndpointName=kwargs["reranker_endpoint_name"], 819 | ContentType="application/json", 820 | Accept="application/json", 821 | Body=rerank_queries 822 | ) 823 | outs = json.loads(response['Body'].read().decode()) ## for json 824 | 825 | rerank_contexts = [] 826 | for idx, exceed_flag, partial_set, length in exceed_info: 827 | if not exceed_flag: 828 | rerank_contexts.append((contexts[idx][0], outs[partial_set]["score"])) 829 | else: 830 | partial_scores = [outs[partial_idx]["score"] for partial_idx in partial_set] 831 | partial_scores = np.average(partial_scores, axis=0, weights=length) 832 | rerank_contexts.append((contexts[idx][0], partial_scores)) 833 | 834 | #rerank_contexts = [(contexts[idx][0], out["score"]) for idx, out in enumerate(outs)] 835 | rerank_contexts = sorted( 836 | rerank_contexts, 837 | key=lambda x: x[1], 838 | reverse=True 839 | ) 840 | 841 | return rerank_contexts[:kwargs["k"]] 842 | 843 | @classmethod 844 | def get_element(cls, **kwargs): 845 | 846 | similar_docs = copy.deepcopy(kwargs["similar_docs"]) 847 | tables, images = [], [] 848 | 849 | for doc in similar_docs: 850 | 851 | category = doc.metadata.get("category", None) 852 | if category != None: 853 | if category == "Table": 854 | doc.page_content = doc.metadata["origin_table"] 855 | tables.append(doc) 856 | elif category == "Image": 857 | doc.page_content = doc.metadata["image_base64"] 858 | images.append(doc) 859 | 860 | return tables, images 861 | 862 | 863 | @classmethod 864 | # hybrid (lexical + semantic) search based 865 | def search_hybrid(cls, **kwargs): 866 | 867 | assert "query" in kwargs, "Check your query" 868 | assert "llm_emb" in kwargs, "Check your llm_emb" 869 | assert "index_name" in kwargs, "Check your index_name" 870 | assert "os_client" in kwargs, "Check your os_client" 871 | 872 | rag_fusion = kwargs.get("rag_fusion", False) 873 | hyde = kwargs.get("hyde", False) 874 | parent_document = kwargs.get("parent_document", False) 875 | hybrid_search_debugger = kwargs.get("hybrid_search_debugger", "None") 876 | 877 | 878 | 879 | assert (rag_fusion + hyde) <= 1, "choose only one between RAG-FUSION and HyDE" 880 | if rag_fusion: 881 | assert "query_augmentation_size" in kwargs, "if you use RAG-FUSION, Check your query_augmentation_size" 882 | if hyde: 883 | assert "hyde_query" in kwargs, "if you use HyDE, Check your hyde_query" 884 | 885 | verbose = kwargs.get("verbose", False) 886 | async_mode = kwargs.get("async_mode", True) 887 | reranker = kwargs.get("reranker", False) 888 | complex_doc = kwargs.get("complex_doc", False) 889 | search_filter = deepcopy(kwargs.get("filter", [])) 890 | 891 | #search_filter.append({"term": {"metadata.family_tree": "child"}}) 892 | if parent_document: 893 | parent_doc_filter = { 894 | "bool":{ 895 | "should":[ ## or condition 896 | {"term": {"metadata.family_tree": "child"}}, 897 | {"term": {"metadata.family_tree": "parent_table"}}, 898 | {"term": {"metadata.family_tree": "parent_image"}}, 899 | ] 900 | } 901 | } 902 | search_filter.append(parent_doc_filter) 903 | else: 904 | parent_doc_filter = { 905 | "bool":{ 906 | "should":[ ## or condition 907 | {"term": {"metadata.family_tree": "child"}}, 908 | {"term": {"metadata.family_tree": "parent_table"}}, 909 | {"term": {"metadata.family_tree": "parent_image"}}, 910 | ] 911 | } 912 | } 913 | search_filter.append(parent_doc_filter) 914 | 915 | 916 | def do_sync(): 917 | 918 | if rag_fusion: 919 | similar_docs_semantic = cls.get_rag_fusion_similar_docs( 920 | index_name=kwargs["index_name"], 921 | os_client=kwargs["os_client"], 922 | llm_emb=kwargs["llm_emb"], 923 | 924 | query=kwargs["query"], 925 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 926 | boolean_filter=search_filter, 927 | hybrid=True, 928 | 929 | llm_text=kwargs.get("llm_text", None), 930 | query_augmentation_size=kwargs["query_augmentation_size"], 931 | query_transformation_prompt=kwargs.get("query_transformation_prompt", None), 932 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 933 | 934 | verbose=kwargs.get("verbose", False), 935 | ) 936 | elif hyde: 937 | similar_docs_semantic = cls.get_hyde_similar_docs( 938 | index_name=kwargs["index_name"], 939 | os_client=kwargs["os_client"], 940 | llm_emb=kwargs["llm_emb"], 941 | 942 | query=kwargs["query"], 943 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 944 | boolean_filter=search_filter, 945 | hybrid=True, 946 | 947 | llm_text=kwargs.get("llm_text", None), 948 | hyde_query=kwargs["hyde_query"], 949 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 950 | 951 | verbose=kwargs.get("verbose", False), 952 | ) 953 | else: 954 | similar_docs_semantic = cls.get_semantic_similar_docs( 955 | index_name=kwargs["index_name"], 956 | os_client=kwargs["os_client"], 957 | llm_emb=kwargs["llm_emb"], 958 | 959 | query=kwargs["query"], 960 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 961 | boolean_filter=search_filter, 962 | hybrid=True 963 | ) 964 | 965 | similar_docs_keyword = cls.get_lexical_similar_docs( 966 | index_name=kwargs["index_name"], 967 | os_client=kwargs["os_client"], 968 | 969 | query=kwargs["query"], 970 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 971 | minimum_should_match=kwargs.get("minimum_should_match", 0), 972 | filter=search_filter, 973 | hybrid=True 974 | ) 975 | 976 | if hybrid_search_debugger == "semantic": similar_docs_keyword = [] 977 | elif hybrid_search_debugger == "lexical": similar_docs_semantic = [] 978 | 979 | return similar_docs_semantic, similar_docs_keyword 980 | 981 | def do_async(): 982 | 983 | if rag_fusion: 984 | semantic_search = partial( 985 | cls.get_rag_fusion_similar_docs, 986 | index_name=kwargs["index_name"], 987 | os_client=kwargs["os_client"], 988 | llm_emb=kwargs["llm_emb"], 989 | 990 | query=kwargs["query"], 991 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 992 | boolean_filter=search_filter, 993 | hybrid=True, 994 | 995 | llm_text=kwargs.get("llm_text", None), 996 | query_augmentation_size=kwargs["query_augmentation_size"], 997 | query_transformation_prompt=kwargs.get("query_transformation_prompt", None), 998 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 999 | 1000 | verbose=kwargs.get("verbose", False), 1001 | ) 1002 | elif hyde: 1003 | semantic_search = partial( 1004 | cls.get_hyde_similar_docs, 1005 | index_name=kwargs["index_name"], 1006 | os_client=kwargs["os_client"], 1007 | llm_emb=kwargs["llm_emb"], 1008 | 1009 | query=kwargs["query"], 1010 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 1011 | boolean_filter=search_filter, 1012 | hybrid=True, 1013 | 1014 | llm_text=kwargs.get("llm_text", None), 1015 | hyde_query=kwargs["hyde_query"], 1016 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 1017 | 1018 | verbose=kwargs.get("verbose", False), 1019 | ) 1020 | else: 1021 | semantic_search = partial( 1022 | cls.get_semantic_similar_docs, 1023 | index_name=kwargs["index_name"], 1024 | os_client=kwargs["os_client"], 1025 | llm_emb=kwargs["llm_emb"], 1026 | 1027 | query=kwargs["query"], 1028 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 1029 | boolean_filter=search_filter, 1030 | hybrid=True 1031 | ) 1032 | 1033 | lexical_search = partial( 1034 | cls.get_lexical_similar_docs, 1035 | index_name=kwargs["index_name"], 1036 | os_client=kwargs["os_client"], 1037 | 1038 | query=kwargs["query"], 1039 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 1040 | minimum_should_match=kwargs.get("minimum_should_match", 0), 1041 | filter=search_filter, 1042 | hybrid=True 1043 | ) 1044 | semantic_pool = cls.pool.apply_async(semantic_search,) 1045 | lexical_pool = cls.pool.apply_async(lexical_search,) 1046 | similar_docs_semantic, similar_docs_keyword = semantic_pool.get(), lexical_pool.get() 1047 | 1048 | if hybrid_search_debugger == "semantic": similar_docs_keyword = [] 1049 | elif hybrid_search_debugger == "lexical": similar_docs_semantic = [] 1050 | 1051 | return similar_docs_semantic, similar_docs_keyword 1052 | 1053 | if async_mode: 1054 | similar_docs_semantic, similar_docs_keyword = do_async() 1055 | else: 1056 | similar_docs_semantic, similar_docs_keyword = do_sync() 1057 | 1058 | similar_docs = cls.get_ensemble_results( 1059 | doc_lists=[similar_docs_semantic, similar_docs_keyword], 1060 | weights=kwargs.get("ensemble_weights", [.51, .49]), 1061 | algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"] 1062 | c=60, 1063 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5), 1064 | ) 1065 | #print (len(similar_docs_keyword), len(similar_docs_semantic), len(similar_docs)) 1066 | #print ("1-similar_docs") 1067 | #for i, doc in enumerate(similar_docs): print (i, doc) 1068 | 1069 | if verbose: 1070 | similar_docs_wo_reranker = copy.deepcopy(similar_docs) 1071 | 1072 | if reranker: 1073 | reranker_endpoint_name = kwargs["reranker_endpoint_name"] 1074 | similar_docs = cls.get_rerank_docs( 1075 | llm_text=kwargs["llm_text"], 1076 | query=kwargs["query"], 1077 | context=similar_docs, 1078 | k=kwargs.get("k", 5), 1079 | reranker_endpoint_name=reranker_endpoint_name, 1080 | verbose=verbose 1081 | ) 1082 | 1083 | #print ("2-similar_docs") 1084 | #for i, doc in enumerate(similar_docs): print (i, doc) 1085 | 1086 | if parent_document: 1087 | similar_docs = cls.get_parent_document_similar_docs( 1088 | index_name=kwargs["index_name"], 1089 | os_client=kwargs["os_client"], 1090 | similar_docs=similar_docs, 1091 | hybrid=True, 1092 | boolean_filter=search_filter, 1093 | verbose=verbose 1094 | ) 1095 | 1096 | if complex_doc: 1097 | tables, images = cls.get_element( 1098 | similar_docs=list(map(lambda x:x[0], similar_docs)) 1099 | ) 1100 | 1101 | if verbose: 1102 | similar_docs_semantic_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("semantic", similar_docs_semantic) 1103 | similar_docs_keyword_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("keyword", similar_docs_keyword) 1104 | similar_docs_wo_reranker_pretty = [] 1105 | if reranker: 1106 | similar_docs_wo_reranker_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("wo_reranker", similar_docs_wo_reranker) 1107 | 1108 | similar_docs_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("similar_docs", similar_docs) 1109 | 1110 | similar_docs = list(map(lambda x:x[0], similar_docs)) 1111 | global pretty_contexts 1112 | pretty_contexts = list_up(similar_docs_semantic_pretty, similar_docs_keyword_pretty, similar_docs_wo_reranker_pretty, similar_docs_pretty) 1113 | 1114 | #if complex_doc: return similar_docs, tables, images 1115 | #else: return similar_docs 1116 | 1117 | if complex_doc: return similar_docs, tables, images 1118 | else: 1119 | similar_docs_filtered = [] 1120 | for doc in similar_docs: 1121 | category = "None" 1122 | if "category" in doc.metadata: 1123 | category = doc.metadata["category"] 1124 | 1125 | if category not in {"Table", "Image"}: 1126 | similar_docs_filtered.append(doc) 1127 | return similar_docs_filtered 1128 | 1129 | 1130 | 1131 | @classmethod 1132 | # Score fusion and re-rank (lexical + semantic) 1133 | def get_ensemble_results(cls, doc_lists: List[List[Document]], weights, algorithm="RRF", c=60, k=5) -> List[Document]: 1134 | 1135 | assert algorithm in ["RRF", "simple_weighted"] 1136 | 1137 | # Create a union of all unique documents in the input doc_lists 1138 | all_documents = set() 1139 | 1140 | for doc_list in doc_lists: 1141 | for (doc, _) in doc_list: 1142 | all_documents.add(doc.page_content) 1143 | 1144 | # Initialize the score dictionary for each document 1145 | hybrid_score_dic = {doc: 0.0 for doc in all_documents} 1146 | 1147 | # Calculate RRF scores for each document 1148 | for doc_list, weight in zip(doc_lists, weights): 1149 | for rank, (doc, score) in enumerate(doc_list, start=1): 1150 | if algorithm == "RRF": # RRF (Reciprocal Rank Fusion) 1151 | score = weight * (1 / (rank + c)) 1152 | elif algorithm == "simple_weighted": 1153 | score *= weight 1154 | hybrid_score_dic[doc.page_content] += score 1155 | 1156 | # Sort documents by their scores in descending order 1157 | sorted_documents = sorted( 1158 | hybrid_score_dic.items(), key=lambda x: x[1], reverse=True 1159 | ) 1160 | 1161 | # Map the sorted page_content back to the original document objects 1162 | page_content_to_doc_map = { 1163 | doc.page_content: doc for doc_list in doc_lists for (doc, orig_score) in doc_list 1164 | } 1165 | 1166 | sorted_docs = [ 1167 | (page_content_to_doc_map[page_content], hybrid_score) for (page_content, hybrid_score) in sorted_documents 1168 | ] 1169 | 1170 | return sorted_docs[:k] 1171 | 1172 | 1173 | ################################################################# 1174 | # Document Retriever with Langchain(BaseRetriever): return List(documents) 1175 | ################################################################# 1176 | 1177 | # lexical(keyword) search based (using Amazon OpenSearch) 1178 | class OpenSearchLexicalSearchRetriever(BaseRetriever): 1179 | 1180 | os_client: Any 1181 | index_name: str 1182 | k = 3 1183 | minimum_should_match = 0 1184 | filter = [] 1185 | 1186 | def normalize_search_results(self, search_results): 1187 | 1188 | hits = (search_results["hits"]["hits"]) 1189 | max_score = float(search_results["hits"]["max_score"]) 1190 | for hit in hits: 1191 | hit["_score"] = float(hit["_score"]) / max_score 1192 | search_results["hits"]["max_score"] = hits[0]["_score"] 1193 | search_results["hits"]["hits"] = hits 1194 | return search_results 1195 | 1196 | def update_search_params(self, **kwargs): 1197 | 1198 | self.k = kwargs.get("k", 3) 1199 | self.minimum_should_match = kwargs.get("minimum_should_match", 0) 1200 | self.filter = kwargs.get("filter", []) 1201 | self.index_name = kwargs.get("index_name", self.index_name) 1202 | 1203 | def _reset_search_params(self, ): 1204 | 1205 | self.k = 3 1206 | self.minimum_should_match = 0 1207 | self.filter = [] 1208 | 1209 | def _get_relevant_documents( 1210 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: 1211 | 1212 | query = opensearch_utils.get_query( 1213 | query=query, 1214 | minimum_should_match=self.minimum_should_match, 1215 | filter=self.filter 1216 | ) 1217 | query["size"] = self.k 1218 | 1219 | print ("lexical search query: ") 1220 | pprint(query) 1221 | 1222 | search_results = opensearch_utils.search_document( 1223 | os_client=self.os_client, 1224 | query=query, 1225 | index_name=self.index_name 1226 | ) 1227 | 1228 | results = [] 1229 | if search_results["hits"]["hits"]: 1230 | search_results = self.normalize_search_results(search_results) 1231 | for res in search_results["hits"]["hits"]: 1232 | 1233 | metadata = res["_source"]["metadata"] 1234 | metadata["id"] = res["_id"] 1235 | 1236 | doc = Document( 1237 | page_content=res["_source"]["text"], 1238 | metadata=metadata 1239 | ) 1240 | results.append((doc)) 1241 | 1242 | self._reset_search_params() 1243 | 1244 | return results[:self.k] 1245 | 1246 | # hybrid (lexical + semantic) search based 1247 | class OpenSearchHybridSearchRetriever(BaseRetriever): 1248 | 1249 | os_client: Any 1250 | vector_db: Any 1251 | index_name: str 1252 | k = 3 1253 | minimum_should_match = 0 1254 | filter = [] 1255 | fusion_algorithm: str 1256 | ensemble_weights = [0.51, 0.49] 1257 | verbose = False 1258 | async_mode = True 1259 | reranker = False 1260 | reranker_endpoint_name = "" 1261 | rag_fusion = False 1262 | query_augmentation_size: Any 1263 | rag_fusion_prompt = prompt_repo.get_rag_fusion() 1264 | llm_text: Any 1265 | llm_emb: Any 1266 | hyde = False 1267 | hyde_query: Any 1268 | parent_document = False 1269 | complex_doc = False 1270 | hybrid_search_debugger = "None" 1271 | 1272 | def update_search_params(self, **kwargs): 1273 | 1274 | self.k = kwargs.get("k", 3) 1275 | self.minimum_should_match = kwargs.get("minimum_should_match", 0) 1276 | self.filter = kwargs.get("filter", []) 1277 | self.index_name = kwargs.get("index_name", self.index_name) 1278 | self.fusion_algorithm = kwargs.get("fusion_algorithm", self.fusion_algorithm) 1279 | self.ensemble_weights = kwargs.get("ensemble_weights", self.ensemble_weights) 1280 | self.verbose = kwargs.get("verbose", self.verbose) 1281 | self.async_mode = kwargs.get("async_mode", self.async_mode) 1282 | self.reranker = kwargs.get("reranker", self.reranker) 1283 | self.reranker_endpoint_name = kwargs.get("reranker_endpoint_name", self.reranker_endpoint_name) 1284 | self.rag_fusion = kwargs.get("rag_fusion", self.rag_fusion) 1285 | self.query_augmentation_size = kwargs.get("query_augmentation_size", 3) 1286 | self.hyde = kwargs.get("hyde", self.hyde) 1287 | self.hyde_query = kwargs.get("hyde_query", ["web_search"]) 1288 | self.parent_document = kwargs.get("parent_document", self.parent_document) 1289 | self.complex_doc = kwargs.get("complex_doc", self.complex_doc) 1290 | self.hybrid_search_debugger = kwargs.get("hybrid_search_debugger", self.hybrid_search_debugger) 1291 | 1292 | def _reset_search_params(self, ): 1293 | 1294 | self.k = 3 1295 | self.minimum_should_match = 0 1296 | self.filter = [] 1297 | 1298 | def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: 1299 | 1300 | ''' 1301 | It can be called by "retriever.invoke" statements 1302 | ''' 1303 | search_hybrid_result = retriever_utils.search_hybrid( 1304 | query=query, 1305 | k=self.k, 1306 | index_name=self.index_name, 1307 | os_client=self.os_client, 1308 | filter=self.filter, 1309 | minimum_should_match=self.minimum_should_match, 1310 | fusion_algorithm=self.fusion_algorithm, # ["RRF", "simple_weighted"] 1311 | ensemble_weights=self.ensemble_weights, # 시멘트 서치에 가중치 0.5 , 키워드 서치 가중치 0.5 부여. 1312 | async_mode=self.async_mode, 1313 | reranker=self.reranker, 1314 | reranker_endpoint_name=self.reranker_endpoint_name, 1315 | rag_fusion=self.rag_fusion, 1316 | query_augmentation_size=self.query_augmentation_size, 1317 | query_transformation_prompt=self.rag_fusion_prompt if self.rag_fusion else "", 1318 | hyde=self.hyde, 1319 | hyde_query=self.hyde_query if self.hyde else [], 1320 | parent_document = self.parent_document, 1321 | complex_doc = self.complex_doc, 1322 | llm_text=self.llm_text, 1323 | llm_emb=self.llm_emb, 1324 | verbose=self.verbose, 1325 | hybrid_search_debugger=self.hybrid_search_debugger 1326 | ) 1327 | #self._reset_search_params() 1328 | 1329 | return search_hybrid_result 1330 | 1331 | ################################################################# 1332 | # Document visualization 1333 | ################################################################# 1334 | def show_context_used(context_list, limit=10): 1335 | 1336 | context_list = copy.deepcopy(context_list) 1337 | 1338 | if type(context_list) == tuple: context_list=context_list[0] 1339 | for idx, context in enumerate(context_list): 1340 | 1341 | if idx < limit: 1342 | 1343 | category = "None" 1344 | if "category" in context.metadata: 1345 | category = context.metadata["category"] 1346 | 1347 | print("\n-----------------------------------------------") 1348 | if category != "None": 1349 | print(f"{idx+1}. Category: {category}, Chunk: {len(context.page_content)} Characters") 1350 | else: 1351 | print(f"{idx+1}. Chunk: {len(context.page_content)} Characters") 1352 | print("-----------------------------------------------") 1353 | 1354 | if category == "Image" or (category == "Table" and "image_base64" in context.metadata): 1355 | img = Image.open(BytesIO(base64.b64decode(context.metadata["image_base64"]))) 1356 | plt.imshow(img) 1357 | plt.show() 1358 | context.metadata["image_base64"], context.metadata["origin_image"] = "", "" 1359 | 1360 | context.metadata["orig_elements"] = "" 1361 | print_ww(context.page_content) 1362 | if "text_as_html" in context.metadata: print_html(context.metadata["text_as_html"]) 1363 | print_ww("metadata: \n", context.metadata) 1364 | else: 1365 | break 1366 | 1367 | def show_chunk_stat(documents): 1368 | 1369 | doc_len_list = [len(doc.page_content) for doc in documents] 1370 | print(pd.DataFrame(doc_len_list).describe()) 1371 | avg_doc_length = lambda documents: sum([len(doc.page_content) for doc in documents])//len(documents) 1372 | avg_char_count_pre = avg_doc_length(documents) 1373 | print(f'Average length among {len(documents)} documents loaded is {avg_char_count_pre} characters.') 1374 | 1375 | max_idx = doc_len_list.index(max(doc_len_list)) 1376 | print("\nShow document at maximum size") 1377 | print(documents[max_idx].page_content) 1378 | 1379 | ################################################################# 1380 | # JumpStart Embeddings 1381 | ################################################################# 1382 | 1383 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings): 1384 | def embed_documents(self, texts: List[str], chunk_size: int=1) -> List[List[float]]: 1385 | """Compute doc embeddings using a SageMaker Inference Endpoint. 1386 | 1387 | Args: 1388 | texts: The list of texts to embed. 1389 | chunk_size: The chunk size defines how many input texts will 1390 | be grouped together as request. If None, will use the 1391 | chunk size specified by the class. 1392 | 1393 | Returns: 1394 | List of embeddings, one for each text. 1395 | """ 1396 | results = [] 1397 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size 1398 | 1399 | print("text size: ", len(texts)) 1400 | print("_chunk_size: ", _chunk_size) 1401 | 1402 | for i in range(0, len(texts), _chunk_size): 1403 | 1404 | #print (i, texts[i : i + _chunk_size]) 1405 | response = self._embedding_func(texts[i : i + _chunk_size]) 1406 | #print (i, response, len(response[0].shape)) 1407 | 1408 | results.extend(response) 1409 | return results 1410 | 1411 | class KoSimCSERobertaContentHandler(EmbeddingsContentHandler): 1412 | 1413 | content_type = "application/json" 1414 | accepts = "application/json" 1415 | 1416 | def transform_input(self, prompt: str, model_kwargs={}) -> bytes: 1417 | 1418 | input_str = json.dumps({"inputs": prompt, **model_kwargs}) 1419 | 1420 | return input_str.encode("utf-8") 1421 | 1422 | def transform_output(self, output: bytes) -> str: 1423 | 1424 | response_json = json.loads(output.read().decode("utf-8")) 1425 | ndim = np.array(response_json).ndim 1426 | 1427 | if ndim == 4: 1428 | # Original shape (1, 1, n, 768) 1429 | emb = response_json[0][0][0] 1430 | emb = np.expand_dims(emb, axis=0).tolist() 1431 | elif ndim == 2: 1432 | # Original shape (n, 1) 1433 | emb = [] 1434 | for ele in response_json: 1435 | e = ele[0][0] 1436 | emb.append(e) 1437 | else: 1438 | print(f"Other # of dimension: {ndim}") 1439 | emb = None 1440 | return emb 1441 | 1442 | -------------------------------------------------------------------------------- /application/utils/s3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import boto3 3 | import logging 4 | from sagemaker.s3 import S3Uploader 5 | from botocore.exceptions import ClientError 6 | 7 | class s3_handler(): 8 | 9 | def __init__(self, region_name=None): 10 | 11 | self.region_name = region_name 12 | #self.resource = boto3.resource('s3', region_name=self.region_name) 13 | #self.client = boto3.client('s3', region_name=self.region_name) 14 | 15 | self.resource = boto3.resource('s3') 16 | self.client = boto3.client('s3') 17 | 18 | print (f"This is a S3 handler with [{self.region_name}] region.") 19 | 20 | def create_bucket(self, bucket_name): 21 | """Create an S3 bucket in a specified region 22 | 23 | If a region is not specified, the bucket is created in the S3 default 24 | region (us-east-1). 25 | 26 | :param bucket_name: Bucket to create 27 | :return: True if bucket created, else False 28 | """ 29 | 30 | try: 31 | if self.region_name is None: 32 | self.client.create_bucket(Bucket=bucket_name) 33 | else: 34 | location = {'LocationConstraint': self.region_name} 35 | self.client.create_bucket( 36 | Bucket=bucket_name, 37 | CreateBucketConfiguration=location 38 | ) 39 | print (f"CREATE:[{bucket_name}] Bucket was created successfully") 40 | 41 | except ClientError as e: 42 | logging.error(e) 43 | print (f"ERROR: {e}") 44 | return False 45 | 46 | return True 47 | 48 | def copy_object(self, source_obj, target_bucket, target_obj): 49 | 50 | ''' 51 | Copy S3 to S3 52 | ''' 53 | 54 | try: 55 | response = self.client.copy_object( 56 | Bucket=target_bucket,#'destinationbucket', 57 | CopySource=source_obj,#'/sourcebucket/HappyFacejpg', 58 | Key=target_obj,#'HappyFaceCopyjpg', 59 | ) 60 | 61 | except ClientError as e: 62 | logging.error(e) 63 | print (f"ERROR: {e}") 64 | return False 65 | 66 | def download_obj(self, source_bucket, source_obj, target_file): 67 | 68 | ''' 69 | Copy S3 to Locl 70 | ''' 71 | 72 | self.client.download_file(source_bucket, source_obj, target_file) 73 | 74 | def upload_dir(self, source_dir, target_bucket, target_dir): 75 | 76 | inputs = S3Uploader.upload(source_dir, "s3://{}/{}".format(target_bucket, target_dir)) 77 | 78 | print (f"Upload:[{source_dir}] was uploaded to [{inputs}]successfully") 79 | 80 | def upload_file(self, source_file, target_bucket, target_obj=None): 81 | """Upload a file to an S3 bucket 82 | 83 | :param file_name: File to upload 84 | :param bucket: Bucket to upload to 85 | :param object_name: S3 object name. If not specified then file_name is used 86 | :return: True if file was uploaded, else False 87 | """ 88 | 89 | # If S3 object_name was not specified, use file_name 90 | if target_obj is None: 91 | target_obj = os.path.basename(source_file) 92 | 93 | # Upload the file 94 | #s3_client = boto3.client('s3') 95 | try: 96 | response = self.client.upload_file(source_file, target_bucket, target_obj) 97 | except ClientError as e: 98 | logging.error(e) 99 | return False 100 | 101 | obj_s3_path = f"s3://{target_bucket}/{target_obj}" 102 | 103 | return obj_s3_path 104 | 105 | def delete_bucket(self, bucket_name): 106 | 107 | try: 108 | self._delete_all_object(bucket_name=bucket_name) 109 | response = self.client.delete_bucket( 110 | Bucket=bucket_name, 111 | ) 112 | 113 | print (f"DELETE: [{bucket_name}] Bucket was deleted successfully") 114 | 115 | except ClientError as e: 116 | logging.error(e) 117 | print (f"ERROR: {e}") 118 | return False 119 | 120 | return True 121 | 122 | def _delete_all_object(self, bucket_name): 123 | 124 | bucket = self.resource.Bucket(bucket_name) 125 | bucket.object_versions.delete() ## delete versioning 126 | bucket.objects.all().delete() ## delete all objects in the bucket -------------------------------------------------------------------------------- /application/utils/ssm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import boto3 3 | 4 | class parameter_store(): 5 | 6 | def __init__(self, region_name="ap-northeast-2"): 7 | 8 | self.ssm = boto3.client('ssm', region_name=region_name) 9 | 10 | def put_params(self, key, value, dtype="String", overwrite=False, enc=False): 11 | 12 | # Specify the parameter name, value, and type 13 | if enc: dtype="SecureString" 14 | 15 | try: 16 | # Put the parameter 17 | response = self.ssm.put_parameter( 18 | Name=key, 19 | Value=value, 20 | Type=dtype, 21 | Overwrite=overwrite # Set to True if you want to overwrite an existing parameter 22 | ) 23 | 24 | # Print the response 25 | print('Parameter stored successfully.') 26 | #print(response) 27 | 28 | except Exception as e: 29 | print('Error storing parameter:', str(e)) 30 | 31 | def get_params(self, key, enc=False): 32 | 33 | if enc: WithDecryption = True 34 | else: WithDecryption = False 35 | response = self.ssm.get_parameters( 36 | Names=[key,], 37 | WithDecryption=WithDecryption 38 | ) 39 | 40 | return response['Parameters'][0]['Value'] 41 | 42 | def get_all_params(self, ): 43 | 44 | response = self.ssm.describe_parameters(MaxResults=50) 45 | 46 | return [dicParam["Name"] for dicParam in response["Parameters"]] 47 | 48 | def delete_param(self, listParams): 49 | 50 | response = self.ssm.delete_parameters( 51 | Names=listParams 52 | ) 53 | print (f" parameters: {listParams} is deleted successfully") 54 | 55 | 56 | -------------------------------------------------------------------------------- /application/utils/text_to_report.py: -------------------------------------------------------------------------------- 1 | import re 2 | from textwrap import dedent 3 | from langchain_core.tracers import ConsoleCallbackHandler 4 | from langchain.schema.output_parser import StrOutputParser 5 | from langchain_experimental.tools.python.tool import PythonAstREPLTool 6 | from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, MessagesPlaceholder 7 | 8 | class prompt_repo(): 9 | 10 | @classmethod 11 | def get_system_prompt(cls, role=None): 12 | 13 | if role == "text2chart": 14 | 15 | system_prompt = dedent( 16 | ''' 17 | You are a pandas master bot designed to generate Python code for plotting a chart based on the given dataset and user question. 18 | I'm going to give you dataset. 19 | Read the given dataset carefully, because I'm going to ask you a question about it. 20 | ''' 21 | ) 22 | else: 23 | system_prompt = "" 24 | 25 | return system_prompt 26 | 27 | @classmethod 28 | def get_human_prompt(cls, role=None): 29 | 30 | if role == "text2chart": 31 | 32 | human_prompt = dedent( 33 | ''' 34 | This is the result of `print(df.head())`: {dataset} 35 | You should execute code as commanded to either provide information to answer the question or to do the transformations required. 36 | You should not assign any variables; you should return a one-liner in Pandas. 37 | 38 | Update this initial code: 39 | 40 | ```python 41 | # TODO: import the required dependencies 42 | import pandas as pd 43 | 44 | # Write code here 45 | 46 | ``` 47 | 48 | Here is the question: {question} 49 | 50 | Variable `df: pd.DataFrame` is already declared. 51 | At the end, declare "result" variable as a dictionary of type and value. 52 | If you are asked to plot a chart, use "matplotlib" for charts, save as "results.png". 53 | Expaination with Koren. 54 | Do not use legend and title in plot in Korean. 55 | 56 | Generate python code and return full updated code within : 57 | 58 | 59 | ''' 60 | ) 61 | 62 | else: 63 | human_prompt = "" 64 | 65 | return human_prompt 66 | 67 | class text2chart_chain(): 68 | 69 | def __init__(self, **kwargs): 70 | 71 | system_prompt = kwargs["system_prompt"] 72 | self.llm_text = kwargs["llm_text"] 73 | self.system_message_template = SystemMessagePromptTemplate.from_template(system_prompt) 74 | self.num_rows = kwargs["num_rows"] 75 | #self.return_context = kwargs.get("return_context", False) 76 | self.verbose = kwargs.get("verbose", False) 77 | self.parsing_pattern = kwargs["parsing_pattern"] 78 | self.show_chart = kwargs.get("verbose", False) 79 | 80 | def query(self, **kwargs): 81 | 82 | df, query, verbose = kwargs["df"], kwargs["query"], kwargs.get("verbose", self.verbose) 83 | show_chart = kwargs.get("show_chart", self.show_chart) 84 | 85 | if len(df) < self.num_rows: dataset = str(df.to_csv()) 86 | else: dataset = str(df.sample(self.num_rows, random_state=0).to_csv()) 87 | 88 | invoke_args = { 89 | "dataset": dataset, 90 | "question": query 91 | } 92 | 93 | human_prompt = prompt_repo.get_human_prompt( 94 | role="text2chart" 95 | ) 96 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) 97 | prompt = ChatPromptTemplate.from_messages( 98 | [self.system_message_template, human_message_template] 99 | ) 100 | 101 | code_generation_chain = prompt | self.llm_text | StrOutputParser() 102 | 103 | self.verbose = verbose 104 | response = code_generation_chain.invoke( 105 | invoke_args, 106 | config={'callbacks': [ConsoleCallbackHandler()]} if self.verbose else {} 107 | ) 108 | 109 | if show_chart: 110 | results = self.code_execution( 111 | df=df, 112 | response=response 113 | ) 114 | return results 115 | else: 116 | return response 117 | 118 | def code_execution(self, **kwargs): 119 | 120 | df, code = kwargs["df"], self._code_parser(response=kwargs["response"]) 121 | tool = PythonAstREPLTool(locals={"df": df}) 122 | 123 | results = tool.invoke(code) 124 | 125 | return results 126 | 127 | def _code_parser(self, **kwargs): 128 | 129 | parsed_code, response = "", kwargs["response"] 130 | match = re.search(self.parsing_pattern, response, re.DOTALL) 131 | 132 | if match: parsed_code = match.group(1) 133 | else: print("No match found.") 134 | 135 | return parsed_code 136 | -------------------------------------------------------------------------------- /bin/cdk.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | import "source-map-support/register"; 3 | import { App } from "aws-cdk-lib"; 4 | import { EC2Stack } from "../lib/ec2Stack/ec2Stack"; 5 | import { OpensearchStack } from "../lib/openSearchStack"; 6 | import { CustomResourceStack } from "../lib/customResourceStack"; 7 | import { SagemakerNotebookStack } from "../lib/sagemakerNotebookStack/sagemakerNotebookStack"; 8 | import { CfnInclude } from 'aws-cdk-lib/cloudformation-include'; 9 | 10 | const STACK_PREFIX = "AdvancedRAG"; 11 | const DEFAULT_REGION = "us-west-2"; 12 | const envSetting = { 13 | env: { 14 | account: process.env.CDK_DEPLOY_ACCOUNT || process.env.CDK_DEFAULT_ACCOUNT, 15 | region: DEFAULT_REGION, 16 | }, 17 | }; 18 | 19 | const app = new App(); 20 | 21 | // Deploy Sagemaker stack 22 | const sagemakerNotebookStack = new SagemakerNotebookStack(app, `${STACK_PREFIX}-SagemakerNotebookStack`, envSetting); 23 | 24 | // Deploy OpenSearch stack 25 | const opensearchStack = new OpensearchStack(app, `${STACK_PREFIX}-OpensearchStack`, envSetting); 26 | opensearchStack.addDependency(sagemakerNotebookStack); 27 | 28 | // Deploy Reranker stack using cloudformation template 29 | const rerankerStack = new CfnInclude(opensearchStack, `${STACK_PREFIX}-RerankerStack`, { 30 | templateFile: 'lib/rerankerStack/RerankerStack.template.json' 31 | }); 32 | 33 | const customResourceStack = new CustomResourceStack(app, `${STACK_PREFIX}-CustomResourceStack`, envSetting) 34 | customResourceStack.addDependency(opensearchStack); 35 | 36 | // Deploy EC2 stack 37 | const ec2Stack = new EC2Stack(app, `${STACK_PREFIX}-EC2Stack`, envSetting); 38 | ec2Stack.node.addDependency(customResourceStack); 39 | 40 | app.synth(); 41 | -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/cdk.ts", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "**/*.d.ts", 11 | "**/*.js", 12 | "tsconfig.json", 13 | "package*.json", 14 | "yarn.lock", 15 | "node_modules", 16 | "test" 17 | ] 18 | }, 19 | "context": { 20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 21 | "@aws-cdk/core:checkSecretUsage": true, 22 | "@aws-cdk/core:target-partitions": [ 23 | "aws", 24 | "aws-cn" 25 | ], 26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 29 | "@aws-cdk/aws-iam:minimizePolicies": true, 30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 35 | "@aws-cdk/core:enablePartitionLiterals": true, 36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 41 | "@aws-cdk/aws-route53-patters:useCertificate": true, 42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 48 | "@aws-cdk/aws-redshift:columnId": true, 49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 52 | "@aws-cdk/aws-kms:aliasNameRef": true, 53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 55 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 56 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 57 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 58 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true, 59 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true, 60 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true, 61 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true, 62 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true, 63 | "@aws-cdk/aws-cloudwatch-actions:changeLambdaPermissionLogicalIdForLambdaAction": true, 64 | "@aws-cdk/aws-codepipeline:crossAccountKeysDefaultValueToFalse": true, 65 | "@aws-cdk/aws-codepipeline:defaultPipelineTypeToV2": true, 66 | "@aws-cdk/aws-kms:reduceCrossAccountRegionPolicyScope": true, 67 | "@aws-cdk/aws-eks:nodegroupNameAttribute": true, 68 | "@aws-cdk/aws-ec2:ebsDefaultGp3Volume": true 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | testEnvironment: 'node', 3 | roots: ['/test'], 4 | testMatch: ['**/*.test.ts'], 5 | transform: { 6 | '^.+\\.tsx?$': 'ts-jest' 7 | } 8 | }; 9 | -------------------------------------------------------------------------------- /lambda/index.py: -------------------------------------------------------------------------------- 1 | import json 2 | import boto3 3 | import logging 4 | import time 5 | 6 | logger = logging.getLogger() 7 | logger.setLevel(logging.INFO) 8 | 9 | opensearch = boto3.client('opensearch') 10 | 11 | def on_event(event, context): 12 | print(event) 13 | request_type = event['RequestType'] 14 | if request_type == 'Create': return on_create(event) 15 | if request_type == 'Update': return on_update(event) 16 | if request_type == 'Delete': return on_delete(event) 17 | raise Exception("Invalid request type: %s" % request_type) 18 | 19 | def on_create(event): 20 | props = event["ResourceProperties"] 21 | print("create new resource with props %s" % props) 22 | 23 | domain_arn = props['DomainArn'] 24 | domain_name = domain_arn.split('/')[-1] 25 | DEFAULT_REGION = props['DEFAULT_REGION'] 26 | VERSION = props['VERSION'] 27 | 28 | nori_pkg_id = { 29 | 'us-east-1': { 30 | '2.3': 'G196105221', 31 | '2.5': 'G240285063', 32 | '2.7': 'G16029449', 33 | '2.9': 'G60209291', 34 | '2.11': 'G181660338' 35 | }, 36 | 'us-west-2': { 37 | '2.3': 'G94047474', 38 | '2.5': 'G138227316', 39 | '2.7': 'G182407158', 40 | '2.9': 'G226587000', 41 | '2.11': 'G79602591' 42 | } 43 | } 44 | 45 | package_id = nori_pkg_id[DEFAULT_REGION][VERSION] 46 | print(domain_arn, domain_name, package_id) 47 | 48 | try: 49 | response = opensearch.associate_package( 50 | PackageID=package_id, 51 | DomainName=domain_name 52 | ) 53 | print(f"Successfully initiated association of package {package_id} with domain {domain_name}") 54 | except opensearch.exceptions.BaseException as e: 55 | logger.error(f"Failed to associate package: {e}") 56 | raise e 57 | 58 | physical_id = f"AssociatePackage-{domain_name}-{package_id}" 59 | return { 'PhysicalResourceId': physical_id } 60 | 61 | def on_update(event): 62 | physical_id = event["PhysicalResourceId"] 63 | props = event["ResourceProperties"] 64 | print("update resource %s with props %s" % (physical_id, props)) 65 | return { 'PhysicalResourceId': physical_id } 66 | 67 | def on_delete(event): 68 | physical_id = event["PhysicalResourceId"] 69 | print("delete resource %s" % physical_id) 70 | # Optionally add dissociation logic if required 71 | return { 'PhysicalResourceId': physical_id } 72 | 73 | """ 74 | def is_complete(event, context): 75 | props = event["ResourceProperties"] 76 | domain_arn = props['DomainArn'] 77 | domain_name = domain_arn.split('/')[-1] 78 | DEFAULT_REGION = props['DEFAULT_REGION'] 79 | VERSION = props['VERSION'] 80 | 81 | nori_pkg_id = { 82 | 'us-east-1': { 83 | '2.3': 'G196105221', 84 | '2.5': 'G240285063', 85 | '2.7': 'G16029449', 86 | '2.9': 'G60209291', 87 | '2.11': 'G181660338' 88 | }, 89 | 'us-west-2': { 90 | '2.3': 'G94047474', 91 | '2.5': 'G138227316', 92 | '2.7': 'G182407158', 93 | '2.9': 'G226587000', 94 | '2.11': 'G79602591' 95 | } 96 | } 97 | 98 | package_id = nori_pkg_id[DEFAULT_REGION][VERSION] 99 | print(f"Checking association status for package {package_id} on domain {domain_name}") 100 | 101 | response = opensearch.list_packages_for_domain( 102 | DomainName=domain_name, 103 | MaxResults=1 104 | ) 105 | 106 | if response['DomainPackageDetailsList'][0]['DomainPackageStatus'] == "ACTIVE": 107 | is_ready = True 108 | else: 109 | in_ready = False 110 | 111 | print(f"Is package {package_id} active on domain {domain_name}? {is_ready}") 112 | 113 | return { 'IsComplete': is_ready } 114 | """ -------------------------------------------------------------------------------- /lib/customResourceStack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import * as lambda from 'aws-cdk-lib/aws-lambda'; 3 | import * as cr from 'aws-cdk-lib/custom-resources'; 4 | import * as iam from 'aws-cdk-lib/aws-iam'; 5 | import { Construct } from 'constructs'; 6 | 7 | interface CustomResourceStackProps extends cdk.StackProps { 8 | env: { 9 | account: string | undefined; 10 | region: string; 11 | }; 12 | } 13 | 14 | export class CustomResourceStack extends cdk.Stack { 15 | constructor(scope: Construct, id: string, props: CustomResourceStackProps) { 16 | super(scope, id, props); 17 | 18 | const domainArn = cdk.Fn.importValue('DomainArn'); 19 | //const DEFAULT_REGION = this.node.tryGetContext('DEFAULT_REGION'); 20 | const DEFAULT_REGION = props.env.region; 21 | 22 | const lambdaRole = new iam.Role(this, 'LambdaRole', { 23 | assumedBy: new iam.ServicePrincipal('lambda.amazonaws.com'), 24 | }); 25 | 26 | lambdaRole.addManagedPolicy( 27 | iam.ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaBasicExecutionRole') 28 | ); 29 | 30 | lambdaRole.addToPolicy(new iam.PolicyStatement({ 31 | actions: ['es:AssociatePackage', 'es:DescribePackages', 'es:DescribeDomain', 'logs:CreateLogGroup', 'logs:CreateLogStream', 'logs:PutLogEvents'], 32 | resources: ["*"], // domainArn 33 | })); 34 | 35 | 36 | const customResourceLambda = new lambda.Function(this, 'CustomResourceLambda', { 37 | runtime: lambda.Runtime.PYTHON_3_9, 38 | handler: 'index.on_event', 39 | code: lambda.Code.fromAsset('lambda'), 40 | timeout: cdk.Duration.minutes(15), 41 | role: lambdaRole 42 | }); 43 | 44 | //const isCompleteLambda = new lambda.Function(this, 'IsCompleteLambda', { 45 | // runtime: lambda.Runtime.PYTHON_3_9, 46 | // handler: 'index.is_complete', 47 | // code: lambda.Code.fromAsset('lambda'), 48 | // timeout: cdk.Duration.minutes(15), 49 | // role: lambdaRole 50 | //}); 51 | 52 | 53 | const customResourceProvider = new cr.Provider(this, 'CustomResourceProvider', { 54 | onEventHandler: customResourceLambda, 55 | //isCompleteHandler: isCompleteLambda, 56 | //queryInterval: cdk.Duration.minutes(1), 57 | //totalTimeout: cdk.Duration.hours(1), 58 | }); 59 | 60 | 61 | const customResource = new cdk.CustomResource(this, 'AssociateNoriPackage', { 62 | serviceToken: customResourceProvider.serviceToken, 63 | properties: { 64 | DomainArn: domainArn, 65 | DEFAULT_REGION: DEFAULT_REGION, 66 | VERSION: "2.11", 67 | }, 68 | }); 69 | } 70 | } -------------------------------------------------------------------------------- /lib/ec2Stack/ec2Stack.ts: -------------------------------------------------------------------------------- 1 | import { Stack, StackProps, RemovalPolicy, aws_s3 as s3, } from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as cdk from 'aws-cdk-lib'; 4 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 5 | import * as iam from 'aws-cdk-lib/aws-iam'; 6 | import * as fs from 'fs'; 7 | import * as path from 'path'; 8 | 9 | export class EC2Stack extends Stack { 10 | constructor(scope: Construct, id: string, props?: StackProps) { 11 | super(scope, id, props); 12 | 13 | // IAM Role to access EC2 14 | const instanceRole = new iam.Role(this, 'InstanceRole', { 15 | assumedBy: new iam.ServicePrincipal('ec2.amazonaws.com'), 16 | managedPolicies: [ 17 | iam.ManagedPolicy.fromAwsManagedPolicyName('AdministratorAccess'), 18 | ], 19 | }); 20 | 21 | // Network setting for EC2 22 | const defaultVpc = ec2.Vpc.fromLookup(this, 'VPC', { 23 | isDefault: true, 24 | }); 25 | 26 | const chatbotAppSecurityGroup = new ec2.SecurityGroup(this, 'chatbotAppSecurityGroup', { 27 | vpc: defaultVpc, 28 | }); 29 | chatbotAppSecurityGroup.addIngressRule( 30 | ec2.Peer.anyIpv4(), 31 | ec2.Port.tcp(80), 32 | 'httpIpv4', 33 | ); 34 | chatbotAppSecurityGroup.addIngressRule( 35 | ec2.Peer.anyIpv4(), 36 | ec2.Port.tcp(22), 37 | 'sshIpv4', 38 | ); 39 | 40 | // set AMI 41 | const machineImage = ec2.MachineImage.fromSsmParameter( 42 | '/aws/service/canonical/ubuntu/server/focal/stable/current/amd64/hvm/ebs-gp2/ami-id' 43 | ); 44 | 45 | // set User Data 46 | const userData = ec2.UserData.forLinux(); 47 | const userDataScript = fs.readFileSync(path.join(__dirname, 'userdata.sh'), 'utf8'); 48 | userData.addCommands(userDataScript); 49 | 50 | // EC2 instance 51 | const chatbotAppInstance = new ec2.Instance(this, 'chatbotAppInstance', { 52 | instanceType: new ec2.InstanceType('m5.large'), 53 | machineImage: machineImage, 54 | vpc: defaultVpc, 55 | securityGroup: chatbotAppSecurityGroup, 56 | role: instanceRole, 57 | userData: userData, 58 | }); 59 | 60 | new cdk.CfnOutput(this, 'chatbotAppUrl', { 61 | value: `http://${chatbotAppInstance.instancePublicIp}/`, 62 | description: 'The URL of chatbot instance generated by AWS Advanced RAG Workshop', 63 | exportName: 'chatbotAppUrl', 64 | }); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /lib/ec2Stack/userdata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Update packages 4 | sudo apt-get update -y 5 | 6 | # Install required packages 7 | sudo apt-get install -y ec2-instance-connect 8 | sudo apt-get install -y git 9 | sudo apt-get install -y python3-pip 10 | sudo apt-get install -y python3-venv 11 | 12 | # Clone repository 13 | cd /home/ubuntu 14 | sudo git clone https://github.com/aws-samples/multi-modal-chatbot-with-advanced-rag.git 15 | 16 | # Create virtual environment 17 | sudo python3 -m venv --copies /home/ubuntu/my_env 18 | sudo chown -R ubuntu:ubuntu /home/ubuntu/my_env 19 | source /home/ubuntu/my_env/bin/activate 20 | 21 | # Install dependencies 22 | cd multi-modal-chatbot-with-advanced-rag/application 23 | sudo apt install -y cargo 24 | pip3 install -r requirements.txt 25 | 26 | # Create systemd service 27 | sudo sh -c "cat < /etc/systemd/system/streamlit.service 28 | [Unit] 29 | Description=Streamlit App 30 | After=network.target 31 | 32 | [Service] 33 | User=ubuntu 34 | Environment='AWS_DEFAULT_REGION=us-west-2' 35 | WorkingDirectory=/home/ubuntu/multi-modal-chatbot-with-advanced-rag/application 36 | ExecStartPre=/bin/bash -c 'sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8501' 37 | ExecStart=/bin/bash -c 'source /home/ubuntu/my_env/bin/activate && streamlit run streamlit.py --server.port 8501' 38 | Restart=always 39 | 40 | [Install] 41 | WantedBy=multi-user.target 42 | EOF" 43 | 44 | # Reload systemd daemon and start the service 45 | sudo systemctl daemon-reload 46 | sudo systemctl enable streamlit 47 | sudo systemctl start streamlit -------------------------------------------------------------------------------- /lib/openSearchStack.ts: -------------------------------------------------------------------------------- 1 | import { Duration, Stack, StackProps, SecretValue } from "aws-cdk-lib"; 2 | import { Construct } from "constructs"; 3 | 4 | import * as fs from "fs"; 5 | import * as cdk from "aws-cdk-lib"; 6 | import * as opensearch from "aws-cdk-lib/aws-opensearchservice"; 7 | import * as ec2 from "aws-cdk-lib/aws-ec2"; 8 | import * as ssm from "aws-cdk-lib/aws-ssm"; 9 | import { PolicyStatement, AnyPrincipal } from "aws-cdk-lib/aws-iam"; 10 | import * as secretsmanager from "aws-cdk-lib/aws-secretsmanager"; 11 | 12 | export class OpensearchStack extends Stack { 13 | constructor(scope: Construct, id: string, props?: StackProps) { 14 | super(scope, id, props); 15 | 16 | const domainName = `rag-hol-mydomain`; 17 | 18 | const opensearch_user_id = "raguser"; 19 | 20 | const user_id_pm = new ssm.StringParameter(this, "opensearch_user_id", { 21 | parameterName: "opensearch_user_id", 22 | stringValue: "raguser", 23 | }); 24 | 25 | const opensearch_user_password = "pwkey"; 26 | 27 | const secret = new secretsmanager.Secret(this, "domain-creds", { 28 | generateSecretString: { 29 | secretStringTemplate: JSON.stringify({ 30 | "es.net.http.auth.user": opensearch_user_id, 31 | }), 32 | generateStringKey: opensearch_user_password, 33 | excludeCharacters: '"\'', 34 | }, 35 | secretName: "opensearch_user_password", 36 | }); 37 | 38 | const domain = new opensearch.Domain(this, "Domain", { 39 | version: opensearch.EngineVersion.OPENSEARCH_2_11, 40 | domainName: domainName, 41 | capacity: { 42 | masterNodes: 2, 43 | multiAzWithStandbyEnabled: false, 44 | }, 45 | ebs: { 46 | volumeSize: 100, 47 | volumeType: ec2.EbsDeviceVolumeType.GP3, 48 | enabled: true, 49 | }, 50 | enforceHttps: true, 51 | nodeToNodeEncryption: true, 52 | encryptionAtRest: { enabled: true }, 53 | fineGrainedAccessControl: { 54 | masterUserName: opensearch_user_id, 55 | masterUserPassword: secret.secretValueFromJson( 56 | opensearch_user_password 57 | ), 58 | }, 59 | }); 60 | 61 | domain.addAccessPolicies( 62 | new PolicyStatement({ 63 | actions: ["es:*"], 64 | principals: [new AnyPrincipal()], 65 | resources: [domain.domainArn + "/*"], 66 | }) 67 | ); 68 | 69 | const domain_endpoint_pm = new ssm.StringParameter( 70 | this, 71 | "opensearch_domain_endpoint", 72 | { 73 | parameterName: "opensearch_domain_endpoint", 74 | stringValue: domain.domainEndpoint, 75 | } 76 | ); 77 | 78 | new cdk.CfnOutput(this, "OpensearchDomainEndpoint", { 79 | value: domain.domainEndpoint, 80 | description: "OpenSearch Domain Endpoint", 81 | }); 82 | 83 | new cdk.CfnOutput(this, "parameter store user id", { 84 | value: user_id_pm.parameterArn, 85 | description: "parameter store user id", 86 | }); 87 | 88 | new cdk.CfnOutput(this, "secrets manager user pw", { 89 | value: secret.secretName, 90 | description: "secrets manager user pw", 91 | }); 92 | 93 | new cdk.CfnOutput(this, 'DomainArn', { 94 | value: domain.domainArn, 95 | exportName: 'DomainArn' 96 | }); 97 | 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /lib/rerankerStack/RerankerStack.assets.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "36.0.0", 3 | "files": { 4 | "6eee82909dec1440393d6ae4b11ef1c1191320d48cb02287c623f05e28bf7caa": { 5 | "source": { 6 | "path": "RerankerStack.template.json", 7 | "packaging": "file" 8 | }, 9 | "destinations": { 10 | "current_account-current_region": { 11 | "bucketName": "cdk-hnb659fds-assets-${AWS::AccountId}-${AWS::Region}", 12 | "objectKey": "6eee82909dec1440393d6ae4b11ef1c1191320d48cb02287c623f05e28bf7caa.json", 13 | "assumeRoleArn": "arn:${AWS::Partition}:iam::${AWS::AccountId}:role/cdk-hnb659fds-file-publishing-role-${AWS::AccountId}-${AWS::Region}" 14 | } 15 | } 16 | } 17 | }, 18 | "dockerImages": {} 19 | } -------------------------------------------------------------------------------- /lib/rerankerStack/RerankerStack.template.json: -------------------------------------------------------------------------------- 1 | { 2 | "Description": "Description: (uksb-1tupboc45) (version:0.1.198) (tag:C1:0,C2:0,C3:0,C4:0,C5:0,C6:1,C7:0,C8:0) ", 3 | "Resources": { 4 | "RerankerRoleDAAED19A": { 5 | "Type": "AWS::IAM::Role", 6 | "Properties": { 7 | "AssumeRolePolicyDocument": { 8 | "Statement": [ 9 | { 10 | "Action": "sts:AssumeRole", 11 | "Effect": "Allow", 12 | "Principal": { 13 | "Service": "sagemaker.amazonaws.com" 14 | } 15 | } 16 | ], 17 | "Version": "2012-10-17" 18 | }, 19 | "ManagedPolicyArns": [ 20 | { 21 | "Fn::Join": [ 22 | "", 23 | [ 24 | "arn:", 25 | { 26 | "Ref": "AWS::Partition" 27 | }, 28 | ":iam::aws:policy/AmazonSageMakerFullAccess" 29 | ] 30 | ] 31 | } 32 | ] 33 | }, 34 | "Metadata": { 35 | "aws:cdk:path": "RerankerStack/Reranker/Role/Resource" 36 | } 37 | }, 38 | "RerankerRoleDefaultPolicy6BB7CA84": { 39 | "Type": "AWS::IAM::Policy", 40 | "Properties": { 41 | "PolicyDocument": { 42 | "Statement": [ 43 | { 44 | "Action": [ 45 | "ecr:BatchCheckLayerAvailability", 46 | "ecr:BatchGetImage", 47 | "ecr:GetDownloadUrlForLayer" 48 | ], 49 | "Effect": "Allow", 50 | "Resource": { 51 | "Fn::Join": [ 52 | "", 53 | [ 54 | "arn:", 55 | { 56 | "Ref": "AWS::Partition" 57 | }, 58 | ":ecr:", 59 | { 60 | "Ref": "AWS::Region" 61 | }, 62 | ":", 63 | { 64 | "Fn::FindInMap": [ 65 | "DlcRepositoryAccountMap", 66 | { 67 | "Ref": "AWS::Region" 68 | }, 69 | "value" 70 | ] 71 | }, 72 | ":repository/huggingface-pytorch-inference" 73 | ] 74 | ] 75 | } 76 | }, 77 | { 78 | "Action": "ecr:GetAuthorizationToken", 79 | "Effect": "Allow", 80 | "Resource": "*" 81 | } 82 | ], 83 | "Version": "2012-10-17" 84 | }, 85 | "PolicyName": "RerankerRoleDefaultPolicy6BB7CA84", 86 | "Roles": [ 87 | { 88 | "Ref": "RerankerRoleDAAED19A" 89 | } 90 | ] 91 | }, 92 | "Metadata": { 93 | "aws:cdk:path": "RerankerStack/Reranker/Role/DefaultPolicy/Resource" 94 | } 95 | }, 96 | "DongjinkrkorerankermodelReranker": { 97 | "Type": "AWS::SageMaker::Model", 98 | "Properties": { 99 | "ExecutionRoleArn": { 100 | "Fn::GetAtt": [ 101 | "RerankerRoleDAAED19A", 102 | "Arn" 103 | ] 104 | }, 105 | "PrimaryContainer": { 106 | "Environment": { 107 | "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", 108 | "SAGEMAKER_REGION": { 109 | "Ref": "AWS::Region" 110 | }, 111 | "HF_MODEL_ID": "Dongjin-kr/ko-reranker", 112 | "HF_TASK": "text-classification" 113 | }, 114 | "Image": { 115 | "Fn::Join": [ 116 | "", 117 | [ 118 | { 119 | "Fn::FindInMap": [ 120 | "DlcRepositoryAccountMap", 121 | { 122 | "Ref": "AWS::Region" 123 | }, 124 | "value" 125 | ] 126 | }, 127 | ".dkr.ecr.", 128 | { 129 | "Ref": "AWS::Region" 130 | }, 131 | ".", 132 | { 133 | "Ref": "AWS::URLSuffix" 134 | }, 135 | "/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" 136 | ] 137 | ] 138 | }, 139 | "Mode": "SingleModel" 140 | }, 141 | "Tags": [ 142 | { 143 | "Key": "modelId", 144 | "Value": "Dongjin-kr/ko-reranker" 145 | } 146 | ] 147 | }, 148 | "Metadata": { 149 | "aws:cdk:path": "RerankerStack/Dongjin-kr-ko-reranker-model-Reranker" 150 | } 151 | }, 152 | "EndpointConfigReranker": { 153 | "Type": "AWS::SageMaker::EndpointConfig", 154 | "Properties": { 155 | "ProductionVariants": [ 156 | { 157 | "ContainerStartupHealthCheckTimeoutInSeconds": 600, 158 | "InitialInstanceCount": 1, 159 | "InitialVariantWeight": 1, 160 | "InstanceType": "ml.g5.xlarge", 161 | "ModelName": { 162 | "Fn::GetAtt": [ 163 | "DongjinkrkorerankermodelReranker", 164 | "ModelName" 165 | ] 166 | }, 167 | "VariantName": "AllTraffic" 168 | } 169 | ] 170 | }, 171 | "DependsOn": [ 172 | "DongjinkrkorerankermodelReranker" 173 | ], 174 | "Metadata": { 175 | "aws:cdk:path": "RerankerStack/EndpointConfig-Reranker" 176 | } 177 | }, 178 | "DongjinkrkorerankerendpointReranker": { 179 | "Type": "AWS::SageMaker::Endpoint", 180 | "Properties": { 181 | "EndpointConfigName": { 182 | "Fn::GetAtt": [ 183 | "EndpointConfigReranker", 184 | "EndpointConfigName" 185 | ] 186 | }, 187 | "EndpointName": "reranker", 188 | "Tags": [ 189 | { 190 | "Key": "modelId", 191 | "Value": "Dongjin-kr/ko-reranker" 192 | } 193 | ] 194 | }, 195 | "DependsOn": [ 196 | "EndpointConfigReranker" 197 | ], 198 | "Metadata": { 199 | "aws:cdk:path": "RerankerStack/Dongjin-kr-ko-reranker-endpoint-Reranker" 200 | } 201 | } 202 | }, 203 | "Mappings": { 204 | "DlcRepositoryAccountMap": { 205 | "ap-east-1": { 206 | "value": "871362719292" 207 | }, 208 | "ap-northeast-1": { 209 | "value": "763104351884" 210 | }, 211 | "ap-northeast-2": { 212 | "value": "763104351884" 213 | }, 214 | "ap-south-1": { 215 | "value": "763104351884" 216 | }, 217 | "ap-south-2": { 218 | "value": "772153158452" 219 | }, 220 | "ap-southeast-1": { 221 | "value": "763104351884" 222 | }, 223 | "ap-southeast-2": { 224 | "value": "763104351884" 225 | }, 226 | "ap-southeast-3": { 227 | "value": "907027046896" 228 | }, 229 | "ap-southeast-4": { 230 | "value": "457447274322" 231 | }, 232 | "ca-central-1": { 233 | "value": "763104351884" 234 | }, 235 | "cn-north-1": { 236 | "value": "727897471807" 237 | }, 238 | "cn-northwest-1": { 239 | "value": "727897471807" 240 | }, 241 | "eu-central-1": { 242 | "value": "763104351884" 243 | }, 244 | "eu-central-2": { 245 | "value": "380420809688" 246 | }, 247 | "eu-north-1": { 248 | "value": "763104351884" 249 | }, 250 | "eu-south-1": { 251 | "value": "692866216735" 252 | }, 253 | "eu-south-2": { 254 | "value": "503227376785" 255 | }, 256 | "eu-west-1": { 257 | "value": "763104351884" 258 | }, 259 | "eu-west-2": { 260 | "value": "763104351884" 261 | }, 262 | "eu-west-3": { 263 | "value": "763104351884" 264 | }, 265 | "me-central-1": { 266 | "value": "914824155844" 267 | }, 268 | "me-south-1": { 269 | "value": "217643126080" 270 | }, 271 | "sa-east-1": { 272 | "value": "763104351884" 273 | }, 274 | "us-east-1": { 275 | "value": "763104351884" 276 | }, 277 | "us-east-2": { 278 | "value": "763104351884" 279 | }, 280 | "us-west-1": { 281 | "value": "763104351884" 282 | }, 283 | "us-west-2": { 284 | "value": "763104351884" 285 | } 286 | } 287 | }, 288 | "Conditions": { 289 | "CDKMetadataAvailable": { 290 | "Fn::Or": [ 291 | { 292 | "Fn::Or": [ 293 | { 294 | "Fn::Equals": [ 295 | { 296 | "Ref": "AWS::Region" 297 | }, 298 | "af-south-1" 299 | ] 300 | }, 301 | { 302 | "Fn::Equals": [ 303 | { 304 | "Ref": "AWS::Region" 305 | }, 306 | "ap-east-1" 307 | ] 308 | }, 309 | { 310 | "Fn::Equals": [ 311 | { 312 | "Ref": "AWS::Region" 313 | }, 314 | "ap-northeast-1" 315 | ] 316 | }, 317 | { 318 | "Fn::Equals": [ 319 | { 320 | "Ref": "AWS::Region" 321 | }, 322 | "ap-northeast-2" 323 | ] 324 | }, 325 | { 326 | "Fn::Equals": [ 327 | { 328 | "Ref": "AWS::Region" 329 | }, 330 | "ap-south-1" 331 | ] 332 | }, 333 | { 334 | "Fn::Equals": [ 335 | { 336 | "Ref": "AWS::Region" 337 | }, 338 | "ap-southeast-1" 339 | ] 340 | }, 341 | { 342 | "Fn::Equals": [ 343 | { 344 | "Ref": "AWS::Region" 345 | }, 346 | "ap-southeast-2" 347 | ] 348 | }, 349 | { 350 | "Fn::Equals": [ 351 | { 352 | "Ref": "AWS::Region" 353 | }, 354 | "ca-central-1" 355 | ] 356 | }, 357 | { 358 | "Fn::Equals": [ 359 | { 360 | "Ref": "AWS::Region" 361 | }, 362 | "cn-north-1" 363 | ] 364 | }, 365 | { 366 | "Fn::Equals": [ 367 | { 368 | "Ref": "AWS::Region" 369 | }, 370 | "cn-northwest-1" 371 | ] 372 | } 373 | ] 374 | }, 375 | { 376 | "Fn::Or": [ 377 | { 378 | "Fn::Equals": [ 379 | { 380 | "Ref": "AWS::Region" 381 | }, 382 | "eu-central-1" 383 | ] 384 | }, 385 | { 386 | "Fn::Equals": [ 387 | { 388 | "Ref": "AWS::Region" 389 | }, 390 | "eu-north-1" 391 | ] 392 | }, 393 | { 394 | "Fn::Equals": [ 395 | { 396 | "Ref": "AWS::Region" 397 | }, 398 | "eu-south-1" 399 | ] 400 | }, 401 | { 402 | "Fn::Equals": [ 403 | { 404 | "Ref": "AWS::Region" 405 | }, 406 | "eu-west-1" 407 | ] 408 | }, 409 | { 410 | "Fn::Equals": [ 411 | { 412 | "Ref": "AWS::Region" 413 | }, 414 | "eu-west-2" 415 | ] 416 | }, 417 | { 418 | "Fn::Equals": [ 419 | { 420 | "Ref": "AWS::Region" 421 | }, 422 | "eu-west-3" 423 | ] 424 | }, 425 | { 426 | "Fn::Equals": [ 427 | { 428 | "Ref": "AWS::Region" 429 | }, 430 | "il-central-1" 431 | ] 432 | }, 433 | { 434 | "Fn::Equals": [ 435 | { 436 | "Ref": "AWS::Region" 437 | }, 438 | "me-central-1" 439 | ] 440 | }, 441 | { 442 | "Fn::Equals": [ 443 | { 444 | "Ref": "AWS::Region" 445 | }, 446 | "me-south-1" 447 | ] 448 | }, 449 | { 450 | "Fn::Equals": [ 451 | { 452 | "Ref": "AWS::Region" 453 | }, 454 | "sa-east-1" 455 | ] 456 | } 457 | ] 458 | }, 459 | { 460 | "Fn::Or": [ 461 | { 462 | "Fn::Equals": [ 463 | { 464 | "Ref": "AWS::Region" 465 | }, 466 | "us-east-1" 467 | ] 468 | }, 469 | { 470 | "Fn::Equals": [ 471 | { 472 | "Ref": "AWS::Region" 473 | }, 474 | "us-east-2" 475 | ] 476 | }, 477 | { 478 | "Fn::Equals": [ 479 | { 480 | "Ref": "AWS::Region" 481 | }, 482 | "us-west-1" 483 | ] 484 | }, 485 | { 486 | "Fn::Equals": [ 487 | { 488 | "Ref": "AWS::Region" 489 | }, 490 | "us-west-2" 491 | ] 492 | } 493 | ] 494 | } 495 | ] 496 | } 497 | } 498 | } -------------------------------------------------------------------------------- /lib/sagemakerNotebookStack/install_packages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | 6 | sudo -u ec2-user -i <<'EOF' 7 | 8 | #source /home/ec2-user/anaconda3/bin/deactivate 9 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pip 10 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U awscli==1.33.16 11 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U botocore==1.34.134 12 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U boto3==1.34.134 13 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U sagemaker==2.224.1 14 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U langchain==0.2.6 15 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U langchain-community==0.2.6 16 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U langchain_aws==0.1.8 17 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U termcolor==2.4.0 18 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U transformers==4.41.2 19 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U librosa==0.10.2.post1 20 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U opensearch-py==2.6.0 21 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U sqlalchemy #==2.0.1 22 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pypdf==4.2.0 23 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U ipython==8.25.0 24 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U ipywidgets==8.1.3 25 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U anthropic==0.30.0 26 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U faiss-cpu==1.8.0.post1 27 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U jq==1.7.0 28 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pydantic==2.7.4 29 | 30 | sudo rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm 31 | sudo yum -y update 32 | sudo yum install -y poppler-utils 33 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U lxml==5.2.2 34 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U kaleido==0.2.1 35 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U uvicorn==0.30.1 36 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pandas==2.2.2 37 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U numexpr==2.10.1 38 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pdf2image==1.17.0 39 | 40 | sudo amazon-linux-extras install libreoffice -y 41 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U "unstructured[all-docs]==0.13.2" 42 | 43 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U python-dotenv==1.0.1 44 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U llama-parse==0.4.4 45 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pymupdf==1.24.7 46 | 47 | # Uninstall nltk 3.8.2 and install 3.8.1 48 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip uninstall -y nltk 49 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U nltk==3.8.1 50 | 51 | EOF 52 | -------------------------------------------------------------------------------- /lib/sagemakerNotebookStack/install_tesseract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | yum-config-manager --disable centos-extras 5 | 6 | nohup bash <<'EOF' & 7 | 8 | echo "## Step 1" 9 | yum -y update 10 | yum -y upgrade 11 | yum install clang -y 12 | yum install libpng-devel libtiff-devel zlib-devel libwebp-devel libjpeg-turbo-devel wget tar gzip -y 13 | wget https://github.com/DanBloomberg/leptonica/releases/download/1.84.1/leptonica-1.84.1.tar.gz 14 | tar -zxvf leptonica-1.84.1.tar.gz 15 | cd leptonica-1.84.1 16 | ./configure 17 | make 18 | make install 19 | 20 | echo "## Step 2" 21 | cd ~ 22 | yum install git-core libtool pkgconfig -y 23 | wget https://github.com/tesseract-ocr/tesseract/archive/5.3.1.tar.gz 24 | tar xzvf 5.3.1.tar.gz 25 | cd tesseract-5.3.1 26 | #git clone --depth 1 https://github.com/tesseract-ocr/tesseract.git tesseract-ocr 27 | #cd tesseract-ocr 28 | export PKG_CONFIG_PATH=/usr/local/lib/pkgconfig 29 | ./autogen.sh 30 | ./configure 31 | make 32 | make install 33 | ldconfig 34 | 35 | echo "## Step 3" 36 | cd /usr/local/share/tessdata 37 | wget https://github.com/tesseract-ocr/tessdata/raw/main/osd.traineddata 38 | wget https://github.com/tesseract-ocr/tessdata/raw/main/eng.traineddata 39 | wget https://github.com/tesseract-ocr/tessdata/raw/main/hin.traineddata 40 | wget https://github.com/tesseract-ocr/tessdata/raw/main/kor.traineddata 41 | wget https://github.com/tesseract-ocr/tessdata/raw/main/kor_vert.traineddata 42 | #wget https://github.com/tesseract-ocr/tessdata_best/raw/main/kor.traineddata 43 | #wget https://github.com/tesseract-ocr/tessdata_best/raw/main/kor_vert.traineddata 44 | 45 | echo "## Step 4" 46 | echo "export TESSDATA_PREFIX=/usr/local/share/tessdata" >> ~/.bash_profile 47 | echo "export TESSDATA_PREFIX=/usr/local/share/tessdata" >> /home/ec2-user/.bash_profile 48 | 49 | EOF -------------------------------------------------------------------------------- /lib/sagemakerNotebookStack/sagemakerNotebookStack.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as iam from 'aws-cdk-lib/aws-iam'; 4 | import * as sagemaker from 'aws-cdk-lib/aws-sagemaker'; 5 | import * as fs from 'fs'; 6 | import * as path from 'path'; 7 | 8 | 9 | export class SagemakerNotebookStack extends cdk.Stack { 10 | constructor(scope: Construct, id: string, props?: cdk.StackProps) { 11 | super(scope, id, props); 12 | 13 | // The code that defines your stack goes here 14 | 15 | // IAM Role 16 | const SageMakerNotebookinstanceRole = new iam.Role(this, 'SageMakerNotebookInstanceRole', { 17 | assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), 18 | managedPolicies: [ 19 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonBedrockFullAccess'), 20 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonOpenSearchServiceFullAccess'), 21 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), 22 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSSMFullAccess'), 23 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonS3FullAccess'), 24 | iam.ManagedPolicy.fromAwsManagedPolicyName('SecretsManagerReadWrite') 25 | ], 26 | }); 27 | 28 | 29 | // SageMaker Notebook Instance Lifecycle Configuration 30 | 31 | const onCreateScriptPath1 = path.join(__dirname, 'install_packages.sh') 32 | const onCreateScriptPath2 = path.join(__dirname, 'install_tesseract.sh') 33 | const onCreateScriptContent1 = fs.readFileSync(onCreateScriptPath1, 'utf-8') 34 | const onCreateScriptContent2 = fs.readFileSync(onCreateScriptPath2, 'utf-8') 35 | 36 | const combinedScriptContent = `${onCreateScriptContent1}\n${onCreateScriptContent2}`; 37 | 38 | const cfnNotebookInstanceLifecycleConfig = new sagemaker.CfnNotebookInstanceLifecycleConfig(this, 'MyCfnNotebookInstanceLifecycleConfig', /* all optional props */ { 39 | notebookInstanceLifecycleConfigName: 'notebookInstanceLifecycleConfig', 40 | onCreate: [{ 41 | content: cdk.Fn.base64(combinedScriptContent), 42 | }], 43 | onStart: [], 44 | }); 45 | 46 | 47 | // SageMaker Notebook Instance 48 | 49 | const cfnNotebookInstance = new sagemaker.CfnNotebookInstance(this, 'MyCfnNotebookInstance', { 50 | instanceType: 'ml.m5.xlarge', 51 | roleArn: SageMakerNotebookinstanceRole.roleArn, 52 | 53 | // the properties below are optional 54 | //acceleratorTypes: ['acceleratorTypes'], 55 | //additionalCodeRepositories: ['additionalCodeRepositories'], 56 | defaultCodeRepository: 'https://github.com/Jiyu-Kim/advanced-rag-workshop.git', 57 | directInternetAccess: 'Enabled', 58 | //instanceMetadataServiceConfiguration: { 59 | // minimumInstanceMetadataServiceVersion: 'minimumInstanceMetadataServiceVersion', 60 | //}, 61 | //kmsKeyId: 'kmsKeyId', 62 | lifecycleConfigName: 'notebookInstanceLifecycleConfig', 63 | notebookInstanceName: 'advanced-rag-workshop-notebook-instance', 64 | //platformIdentifier: 'platformIdentifier', 65 | //rootAccess: 'rootAccess', 66 | //securityGroupIds: ['securityGroupIds'], 67 | //subnetId: 'subnetId', 68 | //tags: [{ 69 | // key: 'key', 70 | // value: 'value', 71 | //}], 72 | volumeSizeInGb: 10, 73 | }); 74 | 75 | 76 | } 77 | } 78 | 79 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "cdk", 3 | "version": "0.1.0", 4 | "bin": { 5 | "cdk": "bin/cdk.js" 6 | }, 7 | "scripts": { 8 | "build": "tsc", 9 | "watch": "tsc -w", 10 | "test": "jest", 11 | "cdk": "cdk" 12 | }, 13 | "devDependencies": { 14 | "@types/jest": "^29.5.12", 15 | "@types/node": "^20.17.1", 16 | "aws-cdk": "2.91.0", 17 | "jest": "^29.7.0", 18 | "ts-jest": "^29.1.2", 19 | "ts-node": "^10.9.2", 20 | "typescript": "~5.4.5" 21 | }, 22 | "dependencies": { 23 | "aws-cdk-lib": "^2.91.0", 24 | "constructs": "^10.0.0", 25 | "source-map-support": "^0.5.21" 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "module": "commonjs", 5 | "lib": [ 6 | "es2020", 7 | "dom" 8 | ], 9 | "declaration": true, 10 | "strict": true, 11 | "noImplicitAny": true, 12 | "strictNullChecks": true, 13 | "noImplicitThis": true, 14 | "alwaysStrict": true, 15 | "noUnusedLocals": false, 16 | "noUnusedParameters": false, 17 | "noImplicitReturns": true, 18 | "noFallthroughCasesInSwitch": false, 19 | "inlineSourceMap": true, 20 | "inlineSources": true, 21 | "experimentalDecorators": true, 22 | "strictPropertyInitialization": false, 23 | "typeRoots": [ 24 | "./node_modules/@types" 25 | ] 26 | }, 27 | "exclude": [ 28 | "node_modules", 29 | "cdk.out" 30 | ] 31 | } 32 | --------------------------------------------------------------------------------