├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── application
├── README.md
├── __init__.py
├── bedrock.py
├── requirements.txt
├── streamlit.py
└── utils
│ ├── __init__.py
│ ├── bedrock.py
│ ├── chat.py
│ ├── chunk.py
│ ├── common_utils.py
│ ├── opensearch_summit.py
│ ├── pymupdf.py
│ ├── rag_summit.py
│ ├── s3.py
│ ├── ssm.py
│ └── text_to_report.py
├── bin
└── cdk.ts
├── cdk.json
├── jest.config.js
├── lambda
└── index.py
├── lib
├── customResourceStack.ts
├── ec2Stack
│ ├── ec2Stack.ts
│ └── userdata.sh
├── openSearchStack.ts
├── rerankerStack
│ ├── RerankerStack.assets.json
│ └── RerankerStack.template.json
└── sagemakerNotebookStack
│ ├── install_packages.sh
│ ├── install_tesseract.sh
│ └── sagemakerNotebookStack.ts
├── package-lock.json
├── package.json
└── tsconfig.json
/.gitignore:
--------------------------------------------------------------------------------
1 | /node_modules
2 | /cdk.out
3 | cdk.context.json
4 |
5 | # DS_store
6 | .DS_Store
7 | **/.DS_Store
8 |
9 | .ipynb_checkpoints
10 | */.ipynb_checkpoints/*
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this
4 | software and associated documentation files (the "Software"), to deal in the Software
5 | without restriction, including without limitation the rights to use, copy, modify,
6 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
7 | permit persons to whom the Software is furnished to do so.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
10 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
11 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
12 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
13 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
14 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Amazon Bedrock Q&A multi-modal chatbot with advanced RAG
2 |
3 | :link: AWS Workshop [Amazon Bedrock Q&A multi-modal chatbot with advanced RAG](https://catalog.us-east-1.prod.workshops.aws/workshops/a372f3ed-e99d-4c95-93b5-ee666375a387/en-US)
4 |
5 | :telephone_receiver: Issue Report [Bailey(Sohyeon) Cho](https://www.linkedin.com/in/csbailey/)
6 |
7 | :globe_with_meridians: Langauge
8 | * [English](#English)
9 | * [한국어](#한국어)
10 |
11 | # English
12 | ## :mega: What is Advanced RAG Workshop?
13 | > In this workshop, you'll understand several advanced RAG techniques and apply them to a multi-modal chatbot Q&A. This workshop aims to increase customer interest and engagement with AWS by introducing and practicing advanced Retrieval Augmented Generation (RAG) techniques that can be utilized in a real production environment, beyond the typical demo.
14 |
15 | ## :mega: Who need this workshop?
16 | > No prior knowledge is required to perform this workshop.
17 |
18 | This workshop is available to anyone who wants to:
19 | * Acquire production-level RAG skills
20 | * Increase technical understanding through hands-on practice
21 | * Increase customer engagement with AWS services
22 |
23 | ## :mega: How to deploy CDK stacks
24 | ```bash
25 | git clone https://github.com/aws-samples/multi-modal-chatbot-with-advanced-rag.git
26 | cd multi-modal-chatbot-with-advanced-rag
27 | npm i --save-dev @types/node
28 | cdk bootstrap
29 | cdk synth
30 | cdk deploy --all --require-approval never
31 | ```
32 |
33 | ## :mega: Workshop Guide
34 | You can see full workshop [here](https://catalog.us-east-1.prod.workshops.aws/workshops/a372f3ed-e99d-4c95-93b5-ee666375a387/en-US)
35 |
36 | # 한국어
37 |
38 | ## :mega: Advnaced RAG Workshop이란?
39 | 이 워크샵에서는 여러 Advanced RAG 기법을 이해하고, Multi Modal 챗봇 Q&A에 여러 RAG 기법을 적용해 볼 수 있습니다. 일반적인 데모를 넘어 실제 프로덕션 환경에서 활용할 수 있는 고급 검색 증강 생성(RAG) 기술을 소개하고 실습함으로써 AWS에 대한 고객의 관심과 참여를 높이는 것을 목표로 합니다.
40 |
41 | ## :mega: 누구를 위한 워크샵 인가요?
42 | > 이 워크샵을 수행하기 위해 별도의 사전 지식은 필요하지 않습니다.
43 |
44 | 다음 효과를 기대하는 모두가 이 워크샵을 사용할 수 있습니다:
45 | * 프로덕션 수준의 RAG 기술 습득
46 | * 실습을 통한 기술적 이해도 향상
47 | * AWS 서비스 경험
48 |
49 | ## :mega: CDK 스택 배포 방법
50 | ```bash
51 | git clone https://github.com/aws-samples/multi-modal-chatbot-with-advanced-rag.git
52 | cd multi-modal-chatbot-with-advanced-rag
53 | npm i --save-dev @types/node
54 | cdk bootstrap
55 | cdk synth
56 | cdk deploy --all --require-approval never
57 | ```
58 |
59 | ## :mega: 워크샵 가이드
60 | 이 [링크](https://catalog.us-east-1.prod.workshops.aws/workshops/a372f3ed-e99d-4c95-93b5-ee666375a387/ko-KR)에서 워크샵 콘텐츠를 확인할 수 있습니다.
61 |
--------------------------------------------------------------------------------
/application/README.md:
--------------------------------------------------------------------------------
1 | # How to run this application
2 |
3 | > If you clone this repo and run 'cdk deploy', you don't need to run the streamlit application yourself.
4 |
5 | ## Structure
6 |
7 | 1. `bedrock.py`
8 |
9 | - Amazon Bedrock 및 Reranker, Hybrid search, parent-document 등의 RAG 기술 구현 파일
10 |
11 | 2. `streamlit.py`
12 |
13 | - 애플리케이션의 front-end 파일, 실행 시 `bedrock.py`을 import해서 사용
14 |
15 | ## Start
16 | 1. web_ui 폴더 접근
17 | ```
18 | cd multi-modal-chatbot-with-advanced-rag/application
19 | ```
20 |
21 | 2. Python 종속 라이브러리 설치
22 |
23 | ```
24 | pip install -r requirements.txt
25 | ```
26 |
27 | 3. Streamlit 애플리케이션 작동
28 |
29 | ```
30 | streamlit run streamlit.py --server.port 8080
31 | ```
32 |
33 | 3. 접속하기
34 |
35 | - Streamlit 작동 시 표시되는 External link 혹은 EC2 public ip로 접속
36 |
--------------------------------------------------------------------------------
/application/__init__.py:
--------------------------------------------------------------------------------
1 | # __init__.py
2 |
--------------------------------------------------------------------------------
/application/bedrock.py:
--------------------------------------------------------------------------------
1 | import os, sys, boto3
2 | # module_path = "/" # "../../.."
3 | # sys.path.append(os.path.abspath(module_path))
4 | from utils.rag_summit import prompt_repo, OpenSearchHybridSearchRetriever, prompt_repo, qa_chain
5 | from utils.opensearch_summit import opensearch_utils
6 | from utils.ssm import parameter_store
7 | from langchain.embeddings import BedrockEmbeddings
8 | from langchain_aws import ChatBedrock
9 | from utils import bedrock
10 | from utils.bedrock import bedrock_info
11 |
12 | region = boto3.Session().region_name
13 | pm = parameter_store(region)
14 | secrets_manager = boto3.client('secretsmanager', region_name=region)
15 |
16 | # 텍스트 생성 LLM 가져오기, streaming_callback을 인자로 받아옴
17 | def get_llm(streaming_callback):
18 | boto3_bedrock = bedrock.get_bedrock_client(
19 | assumed_role=os.environ.get("BEDROCK_ASSUME_ROLE", None),
20 | endpoint_url=os.environ.get("BEDROCK_ENDPOINT_URL", None),
21 | region=os.environ.get("AWS_DEFAULT_REGION", None),
22 | )
23 | llm = ChatBedrock(
24 | model_id=bedrock_info.get_model_id(model_name="Claude-V3-Sonnet"),
25 | client=boto3_bedrock,
26 | model_kwargs={
27 | "max_tokens": 1024,
28 | "stop_sequences": ["\n\nHuman"],
29 | },
30 | streaming=True,
31 | callbacks=[streaming_callback],
32 | )
33 | return llm
34 |
35 | # 임베딩 모델 가져오기
36 | def get_embedding_model(document_type):
37 | model_id= 'amazon.titan-embed-text-v1' if document_type == 'Default' else 'amazon.titan-embed-text-v2:0'
38 | llm_emb = BedrockEmbeddings(model_id=model_id)
39 | return llm_emb
40 |
41 | # Opensearch vectorDB 가져오기
42 | def get_opensearch_client():
43 | opensearch_domain_endpoint = pm.get_params(key='opensearch_domain_endpoint', enc=False)
44 | opensearch_user_id = pm.get_params(key='opensearch_user_id', enc=False)
45 |
46 | response = secrets_manager.get_secret_value(SecretId='opensearch_user_password')
47 | secrets_string = response.get('SecretString')
48 | secrets_dict = eval(secrets_string)
49 | opensearch_user_password = secrets_dict['pwkey']
50 |
51 | opensearch_domain_endpoint = opensearch_domain_endpoint
52 | rag_user_name = opensearch_user_id
53 | rag_user_password = opensearch_user_password
54 | aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
55 | http_auth = (rag_user_name, rag_user_password)
56 | os_client = opensearch_utils.create_aws_opensearch_client(
57 | aws_region,
58 | opensearch_domain_endpoint,
59 | http_auth
60 | )
61 | return os_client
62 |
63 | # hybrid search retriever 만들기
64 | def get_retriever(streaming_callback, parent, reranker, hyde, ragfusion, alpha, document_type):
65 | os_client = get_opensearch_client()
66 | llm_text = get_llm(streaming_callback)
67 | llm_emb = get_embedding_model(document_type)
68 | reranker_endpoint_name = "reranker"
69 | index_name = "default_doc_index" if document_type == "Default" else "customer_doc_index"
70 | opensearch_hybrid_retriever = OpenSearchHybridSearchRetriever(
71 | os_client=os_client,
72 | index_name=index_name,
73 | llm_text=llm_text, # llm for query augmentation in both rag_fusion and HyDE
74 | llm_emb=llm_emb, # Used in semantic search based on opensearch
75 | # option for lexical
76 | minimum_should_match=0,
77 | filter=[],
78 | # option for search
79 | # ["RRF", "simple_weighted"], rank fusion 방식 정의
80 | fusion_algorithm="RRF",
81 | complex_doc=True,
82 | # [for lexical, for semantic], Lexical, Semantic search 결과에 대한 최종 반영 비율 정의
83 | ensemble_weights=[alpha, 1.0-alpha],
84 | reranker=reranker, # enable reranker with reranker model
85 | # endpoint name for reranking model
86 | reranker_endpoint_name=reranker_endpoint_name,
87 | parent_document=parent, # enable parent document
88 | rag_fusion=ragfusion,
89 | rag_fusion_prompt = prompt_repo.get_rag_fusion(),
90 | hyde=hyde,
91 | hyde_query=['web_search'],
92 | query_augmentation_size=3,
93 | # option for async search
94 | async_mode=True,
95 | # option for output
96 | k=7, # 최종 Document 수 정의
97 | verbose=True,
98 | )
99 | return opensearch_hybrid_retriever
100 |
101 | # 모델에 query하기
102 | def formatting_output(contexts):
103 | formatted_contexts = []
104 | for doc, score in contexts:
105 | lines = doc.page_content.split("\n")
106 | metadata = doc.metadata
107 | formatted_contexts.append((score, lines))
108 | return formatted_contexts
109 |
110 | def invoke(query, streaming_callback, parent, reranker, hyde, ragfusion, alpha, document_type="Default"):
111 | # llm, retriever 가져오기
112 | llm_text = get_llm(streaming_callback)
113 | opensearch_hybrid_retriever = get_retriever(streaming_callback, parent, reranker, hyde, ragfusion, alpha, document_type)
114 | # context, tables, images = opensearch_hybrid_retriever._get_relevant_documents()
115 | # answer only 선택
116 | system_prompt = prompt_repo.get_system_prompt()
117 | qa = qa_chain(
118 | llm_text=llm_text,
119 | retriever=opensearch_hybrid_retriever,
120 | system_prompt=system_prompt,
121 | return_context=False,
122 | verbose=False
123 | )
124 | response, pretty_contexts, similar_docs, augmentation = qa.invoke(query = query, complex_doc = True)
125 | print("-------> response")
126 | # print(response)
127 | print("-------> pretty_contexts -> 모든 컨텍스트 포함된 자료")
128 |
129 | def extract_elements_and_print(pretty_contexts):
130 | for context in pretty_contexts:
131 | print("context: \n")
132 | print(context)
133 |
134 | print("######### SEMANTIC #########")
135 | extract_elements_and_print(pretty_contexts[0])
136 | print("######### KEYWORD #########")
137 | extract_elements_and_print(pretty_contexts[1])
138 | print("######### WITHOUT_RERANKER #########")
139 | extract_elements_and_print(pretty_contexts[2])
140 | print("######## SIMILAR_DOCS ##########")
141 | extract_elements_and_print(pretty_contexts[3])
142 | if hyde or ragfusion:
143 | print("######## 중간답변 ##########")
144 | print(augmentation)
145 | if not hyde or ragfusion:
146 | if alpha == 0.0:
147 | pretty_contexts[0].clear()
148 | elif alpha == 1.0:
149 | pretty_contexts[1].clear()
150 | if hyde or ragfusion:
151 | return response, pretty_contexts, augmentation
152 |
153 |
154 | return response, pretty_contexts
155 |
--------------------------------------------------------------------------------
/application/requirements.txt:
--------------------------------------------------------------------------------
1 | anthropic==0.38.0
2 | httpx==0.27.2
3 | tokenizers==0.20.3
4 | boto3==1.34.122
5 | botocore==1.34.122
6 | ipython==8.2.0
7 | langchain==0.2.5
8 | langchain-aws==0.1.6
9 | langchain-community==0.2.5
10 | matplotlib==3.7.0
11 | requests==2.32.3
12 | streamlit==1.35.0
13 | urllib3==1.26.18
14 | opensearch-py==2.6.0
--------------------------------------------------------------------------------
/application/streamlit.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import streamlit as st # 모든 streamlit 명령은 "st" alias로 사용할 수 있습니다.
3 | import bedrock as glib # 로컬 라이브러리 스크립트에 대한 참조
4 | from langchain.callbacks import StreamlitCallbackHandler
5 |
6 | ##################### Functions ########################
7 | def parse_image(metadata, tag):
8 | if tag in metadata:
9 | st.image(base64.b64decode(metadata[tag]))
10 |
11 | def parse_table(metadata, tag):
12 | if tag in metadata:
13 | st.markdown(metadata[tag], unsafe_allow_html=True)
14 |
15 | def parse_metadata(metadata):
16 | # Image, Table 이 있을 경우 파싱해 출력
17 | category = "None"
18 | if "category" in metadata:
19 | category = metadata["category"]
20 | if category == "Table":
21 | # parse_table(metadata, "text_as_html") # 테이블 html은 이미지로 대체
22 | parse_image(metadata, "image_base64")
23 | elif category == "Image":
24 | parse_image(metadata, "image_base64")
25 | else:
26 | pass
27 | st.markdown(' - - - ')
28 |
29 | def show_document_info_label():
30 | with st.container(border=True):
31 | if st.session_state.document_type == "Default":
32 | st.markdown('''📝 현재 기본 문서인 [**상록초등학교 교육 과정 문서**](https://d14ojpq4k4igb1.cloudfront.net/school_edu_guide.pdf)를 활용하고 있습니다.''')
33 | st.markdown('''다른 문서로 챗봇 서비스를 이용해보고 싶다면 왼쪽 사이드바의 Document type에서 *'Custom'* 옵션을 클릭하고, 진행자의 안내에 따라 문서를 새로 인덱싱하여 사용해보세요.''')
34 | else:
35 | st.markdown('''**💁♀️ 새로운 문서로 챗봇 서비스를 이용하고 싶으신가요?**''')
36 | st.markdown('''- **진행자의 안내에 따라 SageMaker Notebook에서 인덱싱 스크립트를 실행한 뒤** 이용 가능합니다.''')
37 | st.markdown('''- 기존 문서 (상록초등학교 교육 과정)로 돌아가고 싶다면 사이드바의 Document type에서 *'Default'* 옵션을 선택하면 바로 변경할 수 있습니다.''')
38 |
39 | # 'Separately' 옵션 선택 시 나오는 중간 Context를 탭 형태로 보여주는 UI
40 | def show_context_with_tab(contexts):
41 | tab_category = ["Semantic", "Keyword", "Without Reranker", "Similar Docs"]
42 | tab_contents = {
43 | tab_category[0]: [],
44 | tab_category[1]: [],
45 | tab_category[2]: [],
46 | tab_category[3]: []
47 | }
48 | for i, contexts_by_doctype in enumerate(contexts):
49 | tab_contents[tab_category[i]].append(contexts_by_doctype)
50 | tabs = st.tabs(tab_category)
51 | for i, tab in enumerate(tabs):
52 | category = tab_category[i]
53 | with tab:
54 | for contexts_by_doctype in tab_contents[category]:
55 | for context in contexts_by_doctype:
56 | st.markdown('##### `정확도`: {}'.format(context["score"]))
57 | for line in context["lines"]:
58 | st.write(line)
59 | parse_metadata(context["meta"])
60 |
61 | # 'All at once' 옵션 선택 시 4개의 컬럼으로 나누어 결과 표시하는 UI
62 | # TODO: HyDE, RagFusion 추가 논의 필요
63 | def show_answer_with_multi_columns(answers):
64 | col1, col2, col3, col4 = st.columns(4)
65 | with col1:
66 | st.markdown('''### `Lexical search` ''')
67 | st.write(answers[0])
68 | with col2:
69 | st.markdown('''### `Semantic search` ''')
70 | st.write(answers[1])
71 | with col3:
72 | st.markdown('''### + `Reranker` ''')
73 | st.write(answers[2])
74 | with col4:
75 | st.markdown('''### + `Parent_docs` ''')
76 | st.write(answers[3])
77 |
78 | ####################### Application ###############################
79 | st.set_page_config(layout="wide")
80 | st.title("AWS Q&A Bot with Advanced RAG!") # page 제목
81 |
82 | st.markdown('''- 이 챗봇은 Amazon Bedrock과 Claude v3 Sonnet 모델로 구현되었습니다.''')
83 | st.markdown('''- 다음과 같은 Advanced RAG 기술을 사용합니다: **Hybrid Search, ReRanker, and Parent Document, HyDE, Rag Fusion**''')
84 | st.markdown('''- 원본 데이터는 Amazon OpenSearch에 저장되어 있으며, Amazon Titan 임베딩 모델이 사용되었습니다.''')
85 | st.markdown(''' ''')
86 |
87 | # Store the initial value of widgets in session state
88 | if "document_type" not in st.session_state:
89 | st.session_state.document_type = "Default"
90 | if "showing_option" not in st.session_state:
91 | st.session_state.showing_option = "Separately"
92 | if "search_mode" not in st.session_state:
93 | st.session_state.search_mode = "Hybrid search"
94 | if "hyde_or_ragfusion" not in st.session_state:
95 | st.session_state.hyde_or_ragfusion = "None"
96 | disabled = st.session_state.showing_option=="All at once"
97 |
98 | with st.sidebar: # Sidebar 모델 옵션
99 | with st.container(border=True):
100 | st.radio(
101 | "Document type:",
102 | ["Default", "Custom"],
103 | captions = ["챗봇이 참고하는 자료로 기본 문서(상록초등학교 자료)가 사용됩니다.", "원하시는 문서를 직접 업로드해보세요."],
104 | key="document_type",
105 | )
106 | with st.container(border=True):
107 | st.radio(
108 | "UI option:",
109 | ["Separately", "All at once"],
110 | captions = ["아래에서 설정한 파라미터 조합으로 하나의 검색 결과가 도출됩니다.", "여러 옵션들을 한 화면에서 한꺼번에 볼 수 있습니다."],
111 | key="showing_option",
112 | )
113 | st.markdown('''### Set parameters for your Bot 👇''')
114 | with st.container(border=True):
115 | search_mode = st.radio(
116 | "Search mode:",
117 | ["Lexical search", "Semantic search", "Hybrid search"],
118 | captions = [
119 | "키워드의 일치 여부를 기반으로 답변을 생성합니다.",
120 | "키워드의 일치 여부보다는 문맥의 의미적 유사도에 기반해 답변을 생성합니다.",
121 | "아래의 Alpha 값을 조정하여 Lexical/Semantic search의 비율을 조정합니다."
122 | ],
123 | key="search_mode",
124 | disabled=disabled
125 | )
126 | alpha = st.slider('Alpha value for Hybrid search ⬇️', 0.0, 1.0, 0.51,
127 | disabled=st.session_state.search_mode != "Hybrid search",
128 | help="""Alpha=0.0 이면 Lexical search, \nAlpha=1.0 이면 Semantic search 입니다."""
129 | )
130 | if search_mode == "Lexical search":
131 | alpha = 0.0
132 | elif search_mode == "Semantic search":
133 | alpha = 1.0
134 |
135 | col1, col2 = st.columns(2)
136 | with col1:
137 | reranker = st.toggle("Reranker",
138 | help="""초기 검색 결과를 재평가하여 순위를 재조정하는 모델입니다.
139 | 문맥 정보와 질의 관련성을 고려하여 적합한 결과를 상위에 올립니다.""",
140 | disabled=disabled)
141 | with col2:
142 | parent = st.toggle("Parent Docs",
143 | help="""답변 생성 모델이 질의에 대한 답변을 생성할 때 참조한 정보의 출처를 표시하는 옵션입니다.""",
144 | disabled=disabled)
145 |
146 | with st.container(border=True):
147 | hyde_or_ragfusion = st.radio(
148 | "Choose a RAG option:",
149 | ["None", "HyDE", "RAG-Fusion"],
150 | captions = [
151 | "",
152 | "문서와 질의 간의 의미적 유사도를 측정하기 위한 임베딩 기법입니다. 하이퍼볼릭 공간에서 거리를 계산하여 유사도를 측정합니다.",
153 | "검색과 생성을 결합한 모델로, 검색 모듈이 관련 문서를 찾고 생성 모듈이 이를 참조하여 답변을 생성합니다. 두 모듈의 출력을 융합하여 최종 답변을 도출합니다."
154 | ],
155 | key="hyde_or_ragfusion",
156 | disabled=disabled
157 | )
158 | hyde = hyde_or_ragfusion == "HyDE"
159 | ragfusion = hyde_or_ragfusion == "RAG-Fusion"
160 |
161 | ###### 'Separately' 옵션 선택한 경우 ######
162 | if st.session_state.showing_option == "Separately":
163 | show_document_info_label()
164 |
165 | if "messages" not in st.session_state:
166 | st.session_state["messages"] = [
167 | {"role": "assistant", "content": "안녕하세요, 무엇이 궁금하세요?"}
168 | ]
169 | # 지난 답변 출력
170 | for msg in st.session_state.messages:
171 | # 지난 답변에 대한 컨텍스트 출력
172 | if msg["role"] == "assistant_context":
173 | with st.chat_message("assistant"):
174 | with st.expander("Context 확인하기 ⬇️"):
175 | show_context_with_tab(contexts=msg["content"])
176 |
177 | elif msg["role"] == "hyde_or_fusion":
178 | with st.chat_message("assistant"):
179 | with st.expander("중간 답변 확인하기 ⬇️"):
180 | msg["content"]
181 |
182 | elif msg["role"] == "assistant_column":
183 | # 'Separately' 옵션일 경우 multi column 으로 보여주지 않고 첫 번째 답변만 출력
184 | st.chat_message(msg["role"]).write(msg["content"][0])
185 | else:
186 | st.chat_message(msg["role"]).write(msg["content"])
187 |
188 | # 유저가 쓴 chat을 query라는 변수에 담음
189 | query = st.chat_input("Search documentation")
190 | if query:
191 | # Session에 메세지 저장
192 | st.session_state.messages.append({"role": "user", "content": query})
193 |
194 | # UI에 출력
195 | st.chat_message("user").write(query)
196 |
197 | # Streamlit callback handler로 bedrock streaming 받아오는 컨테이너 설정
198 | st_cb = StreamlitCallbackHandler(
199 | st.chat_message("assistant"),
200 | collapse_completed_thoughts=True
201 | )
202 | # bedrock.py의 invoke 함수 사용
203 | response = glib.invoke(
204 | query=query,
205 | streaming_callback=st_cb,
206 | parent=parent,
207 | reranker=reranker,
208 | hyde = hyde,
209 | ragfusion = ragfusion,
210 | alpha = alpha,
211 | document_type=st.session_state.document_type
212 | )
213 | # response 로 메세지, 링크, 레퍼런스(source_documents) 받아오게 설정된 것을 변수로 저장
214 | answer = response[0]
215 | contexts = response[1]
216 | if hyde or ragfusion:
217 | mid_answer = response[2]
218 |
219 | # UI 출력
220 | st.chat_message("assistant").write(answer)
221 |
222 | if hyde:
223 | with st.chat_message("assistant"):
224 | with st.expander("HyDE 중간 생성 답변 ⬇️"):
225 | mid_answer
226 | if ragfusion:
227 | with st.chat_message("assistant"):
228 | with st.expander("RAG-Fusion 중간 생성 쿼리 ⬇️"):
229 | mid_answer
230 | with st.chat_message("assistant"):
231 | with st.expander("정확도 별 컨텍스트 보기 ⬇️"):
232 | show_context_with_tab(contexts)
233 |
234 | # Session 메세지 저장
235 | st.session_state.messages.append({"role": "assistant", "content": answer})
236 |
237 | if hyde or ragfusion:
238 | st.session_state.messages.append({"role": "hyde_or_fusion", "content": mid_answer})
239 |
240 | st.session_state.messages.append({"role": "assistant_context", "content": contexts})
241 | # Thinking을 complete로 수동으로 바꾸어 줌
242 | st_cb._complete_current_thought()
243 |
244 | ###### 2) 'All at once' 옵션 선택한 경우 ######
245 | else:
246 | with st.container(border=True):
247 | st.markdown('''현재 기본 문서인 [상록초등학교 교육 과정 문서](https://file.notion.so/f/f/d82c0c1c-c239-4242-bd5e-320565fdc9d4/6057662b-2d01-4284-a65f-cc17d050a321/school_edu_guide.pdf?id=a2f7166b-f663-4740-aa06-ec559567011a&table=block&spaceId=d82c0c1c-c239-4242-bd5e-320565fdc9d4&expirationTimestamp=1718100000000&signature=wxS5AgYuK085mNvynkUZsRyqyMuqE_ucoCNfM4jRnU0&downloadName=school_edu_guide.pdf)를 활용하고 있습니다.''')
248 | st.markdown('''다른 문서로 챗봇 서비스를 이용해보고 싶다면 왼쪽 사이드바의 Document type에서 'Custom' 옵션을 클릭해 문서를 업로드해보세요.''')
249 |
250 | if "messages" not in st.session_state:
251 | st.session_state["messages"] = [
252 | {"role": "assistant", "content": "안녕하세요, 무엇이 궁금하세요?"}
253 | ]
254 | # 지난 답변 출력
255 | for msg in st.session_state.messages:
256 | if msg["role"] == "assistant_column":
257 | answers = msg["content"]
258 | show_answer_with_multi_columns(answers)
259 | elif msg["role"] == "assistant_context":
260 | pass # 'All at once' 옵션 선택 시에는 context 로그를 출력하지 않음
261 | else:
262 | st.chat_message(msg["role"]).write(msg["content"])
263 |
264 | # 유저가 쓴 chat을 query라는 변수에 담음
265 | query = st.chat_input("Search documentation")
266 | if query:
267 | # Session에 메세지 저장
268 | st.session_state.messages.append({"role": "user", "content": query})
269 |
270 | # UI에 출력
271 | st.chat_message("user").write(query)
272 |
273 | col1, col2, col3, col4 = st.columns(4)
274 | with col1:
275 | st.markdown('''### `Lexical search` ''')
276 | st.markdown(":green[: Alpha 값이 0.0]으로, 키워드의 정확한 일치 여부를 판단하는 Lexical search 결과입니다.")
277 | with col2:
278 | st.markdown('''### `Semantic search` ''')
279 | st.markdown(":green[: Alpha 값이 1.0]으로, 키워드 일치 여부보다는 문맥의 의미적 유사도에 기반한 Semantic search 결과입니다.")
280 | with col3:
281 | st.markdown('''### + `Reranker` ''')
282 | st.markdown(""": 초기 검색 결과를 재평가하여 순위를 재조정하는 모델입니다. 문맥 정보와 질의 관련성을 고려하여 적합한 결과를 상위에 올립니다.
283 | :green[Alpha 값은 왼쪽 사이드바에서 설정하신 값]으로 적용됩니다.""")
284 | with col4:
285 | st.markdown('''### + `Parent Docs` ''')
286 | st.markdown(""": 질의에 대한 답변을 생성할 때 참조하는 문서 집합입니다. 답변 생성 모델이 참조할 수 있는 관련 정보의 출처가 됩니다.
287 | :green[Alpha 값은 왼쪽 사이드바에서 설정하신 값]으로 적용됩니다.""")
288 |
289 | with col1:
290 | # Streamlit callback handler로 bedrock streaming 받아오는 컨테이너 설정
291 | st_cb = StreamlitCallbackHandler(
292 | st.chat_message("assistant"),
293 | collapse_completed_thoughts=True
294 | )
295 | answer1 = glib.invoke(
296 | query=query,
297 | streaming_callback=st_cb,
298 | parent=False,
299 | reranker=False,
300 | hyde = False,
301 | ragfusion = False,
302 | alpha = 0, # Lexical search
303 | document_type=st.session_state.document_type
304 | )[0]
305 | st.write(answer1)
306 | st_cb._complete_current_thought() # Thinking을 complete로 수동으로 바꾸어 줌
307 | with col2:
308 | st_cb = StreamlitCallbackHandler(
309 | st.chat_message("assistant"),
310 | collapse_completed_thoughts=True
311 | )
312 | answer2 = glib.invoke(
313 | query=query,
314 | streaming_callback=st_cb,
315 | parent=False,
316 | reranker=False,
317 | hyde = False,
318 | ragfusion = False,
319 | alpha = 1.0, # Semantic search
320 | document_type=st.session_state.document_type
321 | )[0]
322 | st.write(answer2)
323 | st_cb._complete_current_thought()
324 | with col3:
325 | st_cb = StreamlitCallbackHandler(
326 | st.chat_message("assistant"),
327 | collapse_completed_thoughts=True
328 | )
329 | answer3 = glib.invoke(
330 | query=query,
331 | streaming_callback=st_cb,
332 | parent=False,
333 | reranker=True, # Add Reranker option
334 | hyde = False,
335 | ragfusion = False,
336 | alpha = alpha, # Hybrid search
337 | document_type=st.session_state.document_type
338 | )[0]
339 | st.write(answer3)
340 | st_cb._complete_current_thought()
341 | with col4:
342 | st_cb = StreamlitCallbackHandler(
343 | st.chat_message("assistant"),
344 | collapse_completed_thoughts=True
345 | )
346 | answer4 = glib.invoke(
347 | query=query,
348 | streaming_callback=st_cb,
349 | parent=True, # Add Parent_docs option
350 | reranker=True, # Add Reranker option
351 | hyde = False,
352 | ragfusion = False,
353 | alpha = alpha, # Hybrid search
354 | document_type=st.session_state.document_type
355 | )[0]
356 | st.write(answer4)
357 | st_cb._complete_current_thought()
358 |
359 | # Session 메세지 저장
360 | answers = [answer1, answer2, answer3, answer4]
361 | st.session_state.messages.append({"role": "assistant_column", "content": answers})
362 |
--------------------------------------------------------------------------------
/application/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """General helper utilities the workshop notebooks"""
4 | # Python Built-Ins:
5 | from io import StringIO
6 | import sys
7 | import textwrap
8 |
9 |
10 | def print_ww(*args, width: int = 100, **kwargs):
11 | """Like print(), but wraps output to `width` characters (default 100)"""
12 | buffer = StringIO()
13 | try:
14 | _stdout = sys.stdout
15 | sys.stdout = buffer
16 | print(*args, **kwargs)
17 | output = buffer.getvalue()
18 | finally:
19 | sys.stdout = _stdout
20 | for line in output.splitlines():
21 | print("\n".join(textwrap.wrap(line, width=width)))
22 |
--------------------------------------------------------------------------------
/application/utils/bedrock.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 | """Helper utilities for working with Amazon Bedrock from Python notebooks"""
4 | # Python Built-Ins:
5 | import os
6 | from typing import Optional
7 |
8 | # External Dependencies:
9 | import boto3
10 | from botocore.config import Config
11 |
12 | # Langchain
13 | from langchain.callbacks.base import BaseCallbackHandler
14 |
15 | def get_bedrock_client(
16 | assumed_role: Optional[str] = None,
17 | endpoint_url: Optional[str] = None,
18 | region: Optional[str] = None,
19 | ):
20 | """Create a boto3 client for Amazon Bedrock, with optional configuration overrides
21 |
22 | Parameters
23 | ----------
24 | assumed_role :
25 | Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not
26 | specified, the current active credentials will be used.
27 | endpoint_url :
28 | Optional override for the Bedrock service API Endpoint. If setting this, it should usually
29 | include the protocol i.e. "https://..."
30 | region :
31 | Optional name of the AWS Region in which the service should be called (e.g. "us-east-1").
32 | If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used.
33 | """
34 | if region is None:
35 | target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
36 | else:
37 | target_region = region
38 |
39 | print(f"Create new client\n Using region: {target_region}")
40 | session_kwargs = {"region_name": target_region}
41 | client_kwargs = {**session_kwargs}
42 |
43 | profile_name = os.environ.get("AWS_PROFILE")
44 | print(f" Using profile: {profile_name}")
45 | if profile_name:
46 | print(f" Using profile: {profile_name}")
47 | session_kwargs["profile_name"] = profile_name
48 |
49 | retry_config = Config(
50 | region_name=target_region,
51 | retries={
52 | "max_attempts": 10,
53 | "mode": "standard",
54 | },
55 | )
56 | session = boto3.Session(**session_kwargs)
57 |
58 | if assumed_role:
59 | print(f" Using role: {assumed_role}", end='')
60 | sts = session.client("sts")
61 | response = sts.assume_role(
62 | RoleArn=str(assumed_role),
63 | RoleSessionName="langchain-llm-1"
64 | )
65 | print(" ... successful!")
66 | client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
67 | client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"]
68 | client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
69 |
70 | if endpoint_url:
71 | client_kwargs["endpoint_url"] = endpoint_url
72 |
73 | bedrock_client = session.client(
74 | service_name="bedrock-runtime",
75 | config=retry_config,
76 | **client_kwargs
77 | )
78 |
79 | print("boto3 Bedrock client successfully created!")
80 | print(bedrock_client._endpoint)
81 | return bedrock_client
82 |
83 |
84 | class bedrock_info():
85 |
86 | _BEDROCK_MODEL_INFO = {
87 | "Claude-Instant-V1": "anthropic.claude-instant-v1",
88 | "Claude-V1": "anthropic.claude-v1",
89 | "Claude-V2": "anthropic.claude-v2",
90 | "Claude-V2-1": "anthropic.claude-v2:1",
91 | "Claude-V3-Sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
92 | "Claude-V3-Haiku": "anthropic.claude-3-haiku-20240307-v1:0",
93 | "Jurassic-2-Mid": "ai21.j2-mid-v1",
94 | "Jurassic-2-Ultra": "ai21.j2-ultra-v1",
95 | "Command": "cohere.command-text-v14",
96 | "Command-Light": "cohere.command-light-text-v14",
97 | "Cohere-Embeddings-En": "cohere.embed-english-v3",
98 | "Cohere-Embeddings-Multilingual": "cohere.embed-multilingual-v3",
99 | "Titan-Embeddings-G1": "amazon.titan-embed-text-v1",
100 | "Titan-Text-Embeddings-V2": "amazon.titan-embed-text-v2:0",
101 | "Titan-Text-G1": "amazon.titan-text-express-v1",
102 | "Titan-Text-G1-Light": "amazon.titan-text-lite-v1",
103 | "Titan-Text-G1-Premier": "amazon.titan-text-premier-v1:0",
104 | "Titan-Text-G1-Express": "amazon.titan-text-express-v1",
105 | "Llama2-13b-Chat": "meta.llama2-13b-chat-v1"
106 | }
107 |
108 | @classmethod
109 | def get_list_fm_models(cls, verbose=False):
110 |
111 | if verbose:
112 | bedrock = boto3.client(service_name='bedrock')
113 | model_list = bedrock.list_foundation_models()
114 | return model_list["modelSummaries"]
115 | else:
116 | return cls._BEDROCK_MODEL_INFO
117 |
118 | @classmethod
119 | def get_model_id(cls, model_name):
120 |
121 | assert model_name in cls._BEDROCK_MODEL_INFO.keys(), "Check model name"
122 |
123 | return cls._BEDROCK_MODEL_INFO[model_name]
--------------------------------------------------------------------------------
/application/utils/chat.py:
--------------------------------------------------------------------------------
1 | import ipywidgets as ipw
2 | from utils import print_ww
3 | from langchain import PromptTemplate
4 | from IPython.display import display, clear_output
5 | from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory, ConversationBufferMemory
6 |
7 | class chat_utils():
8 |
9 | MEMORY_TYPE = [
10 | "ConversationBufferMemory",
11 | "ConversationBufferWindowMemory",
12 | "ConversationSummaryBufferMemory"
13 | ]
14 |
15 | PROMPT_SUMMARY = PromptTemplate(
16 | input_variables=['summary', 'new_lines'],
17 | template="""
18 | \n\nHuman: Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
19 |
20 | EXAMPLE
21 | Current summary:
22 | The user asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
23 |
24 | New lines of conversation:
25 | User: Why do you think artificial intelligence is a force for good?
26 | AI: Because artificial intelligence will help users reach their full potential.
27 |
28 | New summary:
29 | The user asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help users reach their full potential.
30 | END OF EXAMPLE
31 |
32 | Current summary:
33 | {summary}
34 |
35 | New lines of conversation:
36 | {new_lines}
37 |
38 | \n\nAssistant:"""
39 | )
40 |
41 | @classmethod
42 | def get_memory(cls, **kwargs):
43 | """
44 | create memory for this chat session
45 | """
46 | memory_type = kwargs["memory_type"]
47 |
48 | assert memory_type in cls.MEMORY_TYPE, f"Check Buffer Name, Lists: {cls.MEMORY_TYPE}"
49 | if memory_type == "ConversationBufferMemory":
50 |
51 | # This memory allows for storing messages and then extracts the messages in a variable.
52 | memory = ConversationBufferMemory(
53 | memory_key=kwargs.get("memory_key", "chat_history"),
54 | human_prefix=kwargs.get("human_prefix", "Human"),
55 | ai_prefix=kwargs.get("ai_prefix", "AI"),
56 | return_messages=kwargs.get("return_messages", True)
57 | )
58 | elif memory_type == "ConversationBufferWindowMemory":
59 |
60 | # Maintains a history of previous messages
61 | # ConversationBufferWindowMemory keeps a list of the interactions of the conversation over time.
62 | # It only uses the last K interactions.
63 | # This can be useful for keeping a sliding window of the most recent interactions,
64 | # so the buffer does not get too large.
65 | # https://python.langchain.com/docs/modules/memory/types/buffer_window
66 | memory = ConversationBufferWindowMemory(
67 | k=kwargs.get("k", 5),
68 | memory_key=kwargs.get("memory_key", "chat_history"),
69 | human_prefix=kwargs.get("human_prefix", "Human"),
70 | ai_prefix=kwargs.get("ai_prefix", "AI"),
71 | return_messages=kwargs.get("return_messages", True),
72 | )
73 | elif memory_type == "ConversationSummaryBufferMemory":
74 |
75 | # Maintains a summary of previous messages
76 | # ConversationSummaryBufferMemory combines the two ideas.
77 | # It keeps a buffer of recent interactions in memory, but rather than just completely flushing old interactions it compiles them into a summary and uses both.
78 | # It uses token length rather than number of interactions to determine when to flush interactions.
79 |
80 | assert kwargs.get("llm", None) != None, "Give your LLM"
81 | memory = ConversationSummaryBufferMemory(
82 | llm=kwargs.get("llm", None),
83 | memory_key=kwargs.get("memory_key", "chat_history"),
84 | human_prefix=kwargs.get("human_prefix", "User"),
85 | ai_prefix=kwargs.get("ai_prefix", "AI"),
86 | return_messages=kwargs.get("return_messages", True),
87 | max_token_limit=kwargs.get("max_token_limit", 1024),
88 | prompt=cls.PROMPT_SUMMARY
89 | )
90 |
91 | return memory
92 |
93 | @classmethod
94 | def get_tokens(cls, chain, prompt):
95 | token = chain.llm.get_num_tokens(prompt)
96 | print(f'# tokens: {token}')
97 | return token
98 |
99 | @classmethod
100 | def clear_memory(cls, chain):
101 | return chain.memory.clear()
102 |
103 | class ChatUX:
104 | """ A chat UX using IPWidgets
105 | """
106 | def __init__(self, qa, retrievalChain = False):
107 | self.qa = qa
108 | self.name = None
109 | self.b=None
110 | self.retrievalChain = retrievalChain
111 | self.out = ipw.Output()
112 |
113 | if "ConversationChain" in str(type(self.qa)):
114 | self.streaming = self.qa.llm.streaming
115 | elif "ConversationalRetrievalChain" in str(type(self.qa)):
116 | self.streaming = self.qa.combine_docs_chain.llm_chain.llm.streaming
117 |
118 | def start_chat(self):
119 | print("Starting chat bot")
120 | display(self.out)
121 | self.chat(None)
122 |
123 | def chat(self, _):
124 | if self.name is None:
125 | prompt = ""
126 | else:
127 | prompt = self.name.value
128 | if 'q' == prompt or 'quit' == prompt or 'Q' == prompt:
129 | print("Thank you , that was a nice chat !!")
130 | return
131 | elif len(prompt) > 0:
132 | with self.out:
133 | thinking = ipw.Label(value="Thinking...")
134 | display(thinking)
135 | try:
136 | if self.retrievalChain:
137 | result = self.qa.run({'question': prompt })
138 | else:
139 | result = self.qa.run({'input': prompt }) #, 'history':chat_history})
140 | except:
141 | result = "No answer"
142 | thinking.value=""
143 | if self.streaming:
144 | response = f"AI:{result}"
145 | else:
146 | print_ww(f"AI:{result}")
147 | self.name.disabled = True
148 | self.b.disabled = True
149 | self.name = None
150 |
151 | if self.name is None:
152 | with self.out:
153 | self.name = ipw.Text(description="You:", placeholder='q to quit')
154 | self.b = ipw.Button(description="Send")
155 | self.b.on_click(self.chat)
156 | display(ipw.Box(children=(self.name, self.b)))
--------------------------------------------------------------------------------
/application/utils/chunk.py:
--------------------------------------------------------------------------------
1 | from langchain.docstore.document import Document
2 | from langchain.text_splitter import RecursiveCharacterTextSplitter
3 |
4 |
5 | class parant_documents():
6 |
7 | @classmethod
8 | def _create_chunk(cls, docs, chunk_size, chunk_overlap):
9 |
10 | '''
11 | docs: list of docs
12 | chunk_size: int
13 | chunk_overlap: int
14 | return: list of chunk_docs
15 | '''
16 |
17 | text_splitter = RecursiveCharacterTextSplitter(
18 | # Set a really small chunk size, just to show.
19 | chunk_size=chunk_size,
20 | chunk_overlap=chunk_overlap,
21 | separators=["\n\n", "\n", ".", " ", ""],
22 | length_function=len,
23 | )
24 | # print("doc: in create_chunk", docs )
25 | chunk_docs = text_splitter.split_documents(docs)
26 |
27 | return chunk_docs
28 |
29 | @classmethod
30 | def create_parent_chunk(cls, docs, parent_id_key, family_tree_id_key, parent_chunk_size, parent_chunk_overlap):
31 |
32 | parent_chunks = cls._create_chunk(docs, parent_chunk_size, parent_chunk_overlap)
33 | for i, doc in enumerate(parent_chunks):
34 | doc.metadata[family_tree_id_key] = 'parent'
35 | doc.metadata[parent_id_key] = None
36 |
37 | return parent_chunks
38 |
39 | @classmethod
40 | def create_child_chunk(cls, child_chunk_size, child_chunk_overlap, docs, parent_ids_value, parent_id_key, family_tree_id_key):
41 |
42 | sub_docs = []
43 | for i, doc in enumerate(docs):
44 | # print("doc: ", doc)
45 | parent_id = parent_ids_value[i]
46 | doc = [doc]
47 | _sub_docs = cls._create_chunk(doc, child_chunk_size, child_chunk_overlap)
48 | for _doc in _sub_docs:
49 | _doc.metadata[family_tree_id_key] = 'child'
50 | _doc.metadata[parent_id_key] = parent_id
51 | sub_docs.extend(_sub_docs)
52 |
53 | # if i == 0:
54 | # return sub_docs
55 |
56 | return sub_docs
--------------------------------------------------------------------------------
/application/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import pickle
3 | import random
4 | import logging
5 | import functools
6 | from IPython.display import Markdown, HTML, display
7 |
8 | logging.basicConfig()
9 | logger = logging.getLogger('retry-bedrock-invocation')
10 | logger.setLevel(logging.INFO)
11 |
12 | def retry(total_try_cnt=5, sleep_in_sec=5, retryable_exceptions=()):
13 | def decorator(func):
14 | @functools.wraps(func)
15 | def wrapper(*args, **kwargs):
16 | for cnt in range(total_try_cnt):
17 | logger.info(f"trying {func.__name__}() [{cnt+1}/{total_try_cnt}]")
18 |
19 | try:
20 | result = func(*args, **kwargs)
21 | logger.info(f"in retry(), {func.__name__}() returned '{result}'")
22 |
23 | if result: return result
24 | except retryable_exceptions as e:
25 | logger.info(f"in retry(), {func.__name__}() raised retryable exception '{e}'")
26 | pass
27 | except Exception as e:
28 | logger.info(f"in retry(), {func.__name__}() raised {e}")
29 | raise e
30 |
31 | time.sleep(sleep_in_sec)
32 | logger.info(f"{func.__name__} finally has been failed")
33 | return wrapper
34 | return decorator
35 |
36 | def to_pickle(obj, path):
37 |
38 | with open(file=path, mode="wb") as f:
39 | pickle.dump(obj, f)
40 |
41 | print (f'To_PICKLE: {path}')
42 |
43 | def load_pickle(path):
44 |
45 | with open(file=path, mode="rb") as f:
46 | obj=pickle.load(f)
47 |
48 | print (f'Load from {path}')
49 |
50 | return obj
51 |
52 | def to_markdown(obj, path):
53 |
54 | with open(file=path, mode="w") as f:
55 | f.write(obj)
56 |
57 | print (f'To_Markdown: {path}')
58 |
59 | def print_html(input_html):
60 |
61 | html_string=""
62 | html_string = html_string + input_html
63 |
64 | display(HTML(html_string))
65 |
66 |
67 |
--------------------------------------------------------------------------------
/application/utils/opensearch_summit.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import List, Tuple
3 | from opensearchpy import OpenSearch, RequestsHttpConnection
4 | class opensearch_utils():
5 | @classmethod
6 | def create_aws_opensearch_client(cls, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch:
7 | client = OpenSearch(
8 | hosts=[
9 | {'host': host.replace("https://", ""),
10 | 'port': 443
11 | }
12 | ],
13 | http_auth=http_auth,
14 | use_ssl=True,
15 | verify_certs=True,
16 | connection_class=RequestsHttpConnection
17 | )
18 | return client
19 | @classmethod
20 | def create_index(cls, os_client, index_name, index_body):
21 | '''
22 | 인덱스 생성
23 | '''
24 | response = os_client.indices.create(
25 | index_name,
26 | body=index_body
27 | )
28 | print('\nCreating index:')
29 | print(response)
30 | @classmethod
31 | def check_if_index_exists(cls, os_client, index_name):
32 | '''
33 | 인덱스가 존재하는지 확인
34 | '''
35 | exists = os_client.indices.exists(index_name)
36 | print(f"index_name={index_name}, exists={exists}")
37 | return exists
38 | @classmethod
39 | def add_doc(cls, os_client, index_name, document, id):
40 | '''
41 | # Add a document to the index.
42 | '''
43 | response = os_client.index(
44 | index = index_name,
45 | body = document,
46 | id = id,
47 | refresh = True
48 | )
49 | print('\nAdding document:')
50 | print(response)
51 | @classmethod
52 | def search_document(cls, os_client, query, index_name):
53 | response = os_client.search(
54 | body=query,
55 | index=index_name
56 | )
57 | #print('\nKeyword Search results:')
58 | return response
59 | @classmethod
60 | def delete_index(cls, os_client, index_name):
61 | response = os_client.indices.delete(
62 | index=index_name
63 | )
64 | print('\nDeleting index:')
65 | print(response)
66 | @classmethod
67 | def parse_keyword_response(cls, response, show_size=3):
68 | '''
69 | 키워드 검색 결과를 보여 줌.
70 | '''
71 | length = len(response['hits']['hits'])
72 | if length >= 1:
73 | print("# of searched docs: ", length)
74 | print(f"# of display: {show_size}")
75 | print("---------------------")
76 | for idx, doc in enumerate(response['hits']['hits']):
77 | print("_id in index: " , doc['_id'])
78 | print(doc['_score'])
79 | print(doc['_source']['text'])
80 | print(doc['_source']['metadata'])
81 | print("---------------------")
82 | if idx == show_size-1:
83 | break
84 | else:
85 | print("There is no response")
86 | @classmethod
87 | def opensearch_pretty_print_documents(cls, response):
88 | '''
89 | OpenSearch 결과인 LIST 를 파싱하는 함수
90 | '''
91 | for doc, score in response:
92 | print(f'\nScore: {score}')
93 | print(f'Document Number: {doc.metadata["row"]}')
94 | # Split the page content into lines
95 | lines = doc.page_content.split("\n")
96 | # Extract and print each piece of information if it exists
97 | for line in lines:
98 | split_line = line.split(": ")
99 | if len(split_line) > 1:
100 | print(f'{split_line[0]}: {split_line[1]}')
101 | print("Metadata:")
102 | print(f'Type: {doc.metadata["type"]}')
103 | print(f'Source: {doc.metadata["source"]}')
104 | print('-' * 50)
105 | @classmethod
106 | def get_document(cls, os_client, doc_id, index_name):
107 | response = os_client.get(
108 | id= doc_id,
109 | index=index_name
110 | )
111 | return response
112 | @classmethod
113 | def get_count(cls, os_client, index_name):
114 | response = os_client.count(
115 | index=index_name
116 | )
117 | return response
118 | @classmethod
119 | def get_query(cls, **kwargs):
120 | # Reference:
121 | # OpenSearcj boolean query:
122 | # - https://opensearch.org/docs/latest/query-dsl/compound/bool/
123 | # OpenSearch match qeury:
124 | # - https://opensearch.org/docs/latest/query-dsl/full-text/index/#match-boolean-prefix
125 | # OpenSearch Query Description (한글)
126 | # - https://esbook.kimjmin.net/05-search)
127 | search_type = kwargs.get("search_type", "lexical")
128 | if search_type == "lexical":
129 | min_shoud_match = 0
130 | if "minimum_should_match" in kwargs:
131 | min_shoud_match = kwargs["minimum_should_match"]
132 | QUERY_TEMPLATE = {
133 | "query": {
134 | "bool": {
135 | "must": [
136 | {
137 | "match": {
138 | "text": {
139 | "query": f'{kwargs["query"]}',
140 | "minimum_should_match": f'{min_shoud_match}%',
141 | "operator": "or",
142 | # "fuzziness": "AUTO",
143 | # "fuzzy_transpositions": True,
144 | # "zero_terms_query": "none",
145 | # "lenient": False,
146 | # "prefix_length": 0,
147 | # "max_expansions": 50,
148 | # "boost": 1
149 | }
150 | }
151 | },
152 | ],
153 | "filter": [
154 | ]
155 | }
156 | }
157 | }
158 | if "filter" in kwargs:
159 | QUERY_TEMPLATE["query"]["bool"]["filter"].extend(kwargs["filter"])
160 | elif search_type == "semantic":
161 | QUERY_TEMPLATE = {
162 | "query": {
163 | "bool": {
164 | "must": [
165 | {
166 | "knn": {
167 | kwargs["vector_field"]: {
168 | "vector": kwargs["vector"],
169 | "k": kwargs["k"],
170 | }
171 | }
172 | },
173 | ],
174 | "filter": [
175 | ]
176 | }
177 | }
178 | }
179 | if "filter" in kwargs:
180 | QUERY_TEMPLATE["query"]["bool"]["filter"].extend(kwargs["filter"])
181 | return QUERY_TEMPLATE
182 | @classmethod
183 | def get_filter(cls, **kwargs):
184 | BOOL_FILTER_TEMPLATE = {
185 | "bool": {
186 | "filter": [
187 | ]
188 | }
189 | }
190 | if "filter" in kwargs:
191 | BOOL_FILTER_TEMPLATE["bool"]["filter"].extend(kwargs["filter"])
192 | return BOOL_FILTER_TEMPLATE
193 | @staticmethod
194 | def get_documents_by_ids(os_client, ids, index_name):
195 | response = os_client.mget(
196 | body={"ids": ids},
197 | index=index_name
198 | )
199 | return response
200 | @staticmethod
201 | def opensearch_pretty_print_documents_with_score(doctype, response):
202 | '''
203 | OpenSearch 결과인 LIST 를 파싱하는 함수
204 | '''
205 | results = []
206 | responses = copy.deepcopy(response)
207 | for doc, score in responses:
208 | result = {}
209 | result['doctype'] = doctype
210 | result['score'] = score
211 | result['lines'] = doc.page_content.split("\n")
212 | result['meta'] = doc.metadata
213 | results.append(result)
214 | return results
--------------------------------------------------------------------------------
/application/utils/pymupdf.py:
--------------------------------------------------------------------------------
1 | """
2 | This script accepts a PDF document filename and converts it to a text file
3 | in Markdown format, compatible with the GitHub standard.
4 |
5 | It must be invoked with the filename like this:
6 |
7 | python pymupdf_rag.py input.pdf [-pages PAGES]
8 |
9 | The "PAGES" parameter is a string (containing no spaces) of comma-separated
10 | page numbers to consider. Each item is either a single page number or a
11 | number range "m-n". Use "N" to address the document's last page number.
12 | Example: "-pages 2-15,40,43-N"
13 |
14 | It will produce a markdown text file called "input.md".
15 |
16 | Text will be sorted in Western reading order. Any table will be included in
17 | the text in markdwn format as well.
18 |
19 | Use in some other script
20 | -------------------------
21 | import fitz
22 | from to_markdown import to_markdown
23 |
24 | doc = fitz.open("input.pdf")
25 | page_list = [ list of 0-based page numbers ]
26 | md_text = to_markdown(doc, pages=page_list)
27 |
28 | Dependencies
29 | -------------
30 | PyMuPDF v1.24.0 or later
31 |
32 | Copyright and License
33 | ----------------------
34 | Copyright 2024 Artifex Software, Inc.
35 | License GNU Affero GPL 3.0
36 | """
37 |
38 | import string
39 | from pprint import pprint
40 |
41 | import fitz
42 |
43 | if fitz.pymupdf_version_tuple < (1, 24, 0):
44 | raise NotImplementedError("PyMuPDF version 1.24.0 or later is needed.")
45 |
46 |
47 | def to_markdown_pymupdf(doc: fitz.Document, pages: list = None) -> str:
48 | """Process the document and return the text of its selected pages."""
49 | SPACES = set(string.whitespace) # used to check relevance of text pieces
50 | if not pages: # use all pages if argument not given
51 | pages = range(doc.page_count)
52 |
53 | class IdentifyHeaders:
54 | """Compute data for identifying header text."""
55 |
56 | def __init__(self, doc, pages: list = None, body_limit: float = None):
57 | """Read all text and make a dictionary of fontsizes.
58 |
59 | Args:
60 | pages: optional list of pages to consider
61 | body_limit: consider text with larger font size as some header
62 | """
63 | if pages is None: # use all pages if omitted
64 | pages = range(doc.page_count)
65 | fontsizes = {}
66 | for pno in pages:
67 | page = doc[pno]
68 | blocks = page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
69 | for span in [ # look at all non-empty horizontal spans
70 | s
71 | for b in blocks
72 | for l in b["lines"]
73 | for s in l["spans"]
74 | if not SPACES.issuperset(s["text"])
75 | ]:
76 | fontsz = round(span["size"])
77 | count = fontsizes.get(fontsz, 0) + len(span["text"].strip())
78 | fontsizes[fontsz] = count
79 |
80 | # maps a fontsize to a string of multiple # header tag characters
81 | self.header_id = {}
82 | if body_limit is None: # body text fontsize if not provided
83 | body_limit = sorted(
84 | [(k, v) for k, v in fontsizes.items()],
85 | key=lambda i: i[1],
86 | reverse=True,
87 | )[0][0]
88 |
89 | sizes = sorted(
90 | [f for f in fontsizes.keys() if f > body_limit], reverse=True
91 | )
92 |
93 | # make the header tag dictionary
94 | for i, size in enumerate(sizes):
95 | self.header_id[size] = "#" * (i + 1) + " "
96 |
97 | def get_header_id(self, span):
98 | """Return appropriate markdown header prefix.
99 |
100 | Given a text span from a "dict"/"radict" extraction, determine the
101 | markdown header prefix string of 0 to many concatenated '#' characters.
102 | """
103 | fontsize = round(span["size"]) # compute fontsize
104 | hdr_id = self.header_id.get(fontsize, "")
105 | return hdr_id
106 |
107 | def resolve_links(links, span):
108 | """Accept a span bbox and return a markdown link string."""
109 | bbox = fitz.Rect(span["bbox"]) # span bbox
110 | # a link should overlap at least 70% of the span
111 | bbox_area = 0.7 * abs(bbox)
112 | for link in links:
113 | hot = link["from"] # the hot area of the link
114 | if not abs(hot & bbox) >= bbox_area:
115 | continue # does not touch the bbox
116 | text = f'[{span["text"].strip()}]({link["uri"]})'
117 | return text
118 |
119 | def write_text(page, clip, hdr_prefix):
120 | """Output the text found inside the given clip.
121 |
122 | This is an alternative for plain text in that it outputs
123 | text enriched with markdown styling.
124 | The logic is capable of recognizing headers, body text, code blocks,
125 | inline code, bold, italic and bold-italic styling.
126 | There is also some effort for list supported (ordered / unordered) in
127 | that typical characters are replaced by respective markdown characters.
128 | """
129 | out_string = ""
130 | code = False # mode indicator: outputting code
131 |
132 | # extract URL type links on page
133 | links = [l for l in page.get_links() if l["kind"] == 2]
134 |
135 | blocks = page.get_text(
136 | "dict",
137 | clip=clip,
138 | flags=fitz.TEXTFLAGS_TEXT,
139 | sort=True,
140 | )["blocks"]
141 |
142 | for block in blocks: # iterate textblocks
143 | previous_y = 0
144 | for line in block["lines"]: # iterate lines in block
145 | if line["dir"][1] != 0: # only consider horizontal lines
146 | continue
147 | spans = [s for s in line["spans"]]
148 |
149 | this_y = line["bbox"][3] # current bottom coord
150 |
151 | # check for still being on same line
152 | same_line = abs(this_y - previous_y) <= 3 and previous_y > 0
153 |
154 | if same_line and out_string.endswith("\n"):
155 | out_string = out_string[:-1]
156 |
157 | # are all spans in line in a mono-spaced font?
158 | all_mono = all([s["flags"] & 8 for s in spans])
159 |
160 | # compute text of the line
161 | text = "".join([s["text"] for s in spans])
162 | if not same_line:
163 | previous_y = this_y
164 | if not out_string.endswith("\n"):
165 | out_string += "\n"
166 |
167 | if all_mono:
168 | # compute approx. distance from left - assuming a width
169 | # of 0.5*fontsize.
170 | delta = int(
171 | (spans[0]["bbox"][0] - block["bbox"][0])
172 | / (spans[0]["size"] * 0.5)
173 | )
174 | if not code: # if not already in code output mode:
175 | out_string += "```" # switch on "code" mode
176 | code = True
177 | if not same_line: # new code line with left indentation
178 | out_string += "\n" + " " * delta + text + " "
179 | previous_y = this_y
180 | else: # same line, simply append
181 | out_string += text + " "
182 | continue # done with this line
183 |
184 | for i, s in enumerate(spans): # iterate spans of the line
185 | # this line is not all-mono, so switch off "code" mode
186 | if code: # still in code output mode?
187 | out_string += "```\n" # switch of code mode
188 | code = False
189 | # decode font properties
190 | mono = s["flags"] & 8
191 | bold = s["flags"] & 16
192 | italic = s["flags"] & 2
193 |
194 | if mono:
195 | # this is text in some monospaced font
196 | out_string += f"`{s['text'].strip()}` "
197 | else: # not a mono text
198 | # for first span, get header prefix string if present
199 | if i == 0:
200 | hdr_string = hdr_prefix.get_header_id(s)
201 | else:
202 | hdr_string = ""
203 | prefix = ""
204 | suffix = ""
205 | if hdr_string == "":
206 | if bold:
207 | prefix = "**"
208 | suffix += "**"
209 | if italic:
210 | prefix += "_"
211 | suffix = "_" + suffix
212 |
213 | ltext = resolve_links(links, s)
214 | if ltext:
215 | text = f"{hdr_string}{prefix}{ltext}{suffix} "
216 | else:
217 | text = f"{hdr_string}{prefix}{s['text'].strip()}{suffix} "
218 | text = (
219 | text.replace("<", "<")
220 | .replace(">", ">")
221 | .replace(chr(0xF0B7), "-")
222 | .replace(chr(0xB7), "-")
223 | .replace(chr(8226), "-")
224 | .replace(chr(9679), "-")
225 | )
226 | out_string += text
227 | previous_y = this_y
228 | if not code:
229 | out_string += "\n"
230 | out_string += "\n"
231 | if code:
232 | out_string += "```\n" # switch of code mode
233 | code = False
234 | return out_string.replace(" \n", "\n")
235 |
236 | hdr_prefix = IdentifyHeaders(doc, pages=pages)
237 | md_string = ""
238 |
239 | for pno in pages:
240 | page = doc[pno]
241 | # 1. first locate all tables on page
242 | tabs = page.find_tables()
243 |
244 | # 2. make a list of table boundary boxes, sort by top-left corner.
245 | # Must include the header bbox, which may be external.
246 | tab_rects = sorted(
247 | [
248 | (fitz.Rect(t.bbox) | fitz.Rect(t.header.bbox), i)
249 | for i, t in enumerate(tabs.tables)
250 | ],
251 | key=lambda r: (r[0].y0, r[0].x0),
252 | )
253 |
254 | # 3. final list of all text and table rectangles
255 | text_rects = []
256 | # compute rectangles outside tables and fill final rect list
257 | for i, (r, idx) in enumerate(tab_rects):
258 | if i == 0: # compute rect above all tables
259 | tr = page.rect
260 | tr.y1 = r.y0
261 | if not tr.is_empty:
262 | text_rects.append(("text", tr, 0))
263 | text_rects.append(("table", r, idx))
264 | continue
265 | # read previous rectangle in final list: always a table!
266 | _, r0, idx0 = text_rects[-1]
267 |
268 | # check if a non-empty text rect is fitting in between tables
269 | tr = page.rect
270 | tr.y0 = r0.y1
271 | tr.y1 = r.y0
272 | if not tr.is_empty: # empty if two tables overlap vertically!
273 | text_rects.append(("text", tr, 0))
274 |
275 | text_rects.append(("table", r, idx))
276 |
277 | # there may also be text below all tables
278 | if i == len(tab_rects) - 1:
279 | tr = page.rect
280 | tr.y0 = r.y1
281 | if not tr.is_empty:
282 | text_rects.append(("text", tr, 0))
283 |
284 | if not text_rects: # this will happen for table-free pages
285 | text_rects.append(("text", page.rect, 0))
286 | else:
287 | rtype, r, idx = text_rects[-1]
288 | if rtype == "table":
289 | tr = page.rect
290 | tr.y0 = r.y1
291 | if not tr.is_empty:
292 | text_rects.append(("text", tr, 0))
293 |
294 | # we have all rectangles and can start outputting their contents
295 | for rtype, r, idx in text_rects:
296 | if rtype == "text": # a text rectangle
297 | md_string += write_text(page, r, hdr_prefix) # write MD content
298 | md_string += "\n"
299 | else: # a table rect
300 | md_string += tabs[idx].to_markdown(clean=False)
301 |
302 | md_string += "\n-----\n\n"
303 |
304 | return md_string
305 |
306 |
307 | if __name__ == "__main__":
308 | import os
309 | import sys
310 | import time
311 |
312 | try:
313 | filename = sys.argv[1]
314 | except IndexError:
315 | print(f"Usage:\npython {os.path.basename(__file__)} input.pdf")
316 | sys.exit()
317 |
318 | t0 = time.perf_counter() # start a time
319 |
320 | doc = fitz.open(filename) # open input file
321 | parms = sys.argv[2:] # contains ["-pages", "PAGES"] or empty list
322 | pages = range(doc.page_count) # default page range
323 | if len(parms) == 2 and parms[0] == "-pages": # page sub-selection given
324 | pages = [] # list of desired page numbers
325 |
326 | # replace any variable "N" by page count
327 | pages_spec = parms[1].replace("N", f"{doc.page_count}")
328 | for spec in pages_spec.split(","):
329 | if "-" in spec:
330 | start, end = map(int, spec.split("-"))
331 | pages.extend(range(start - 1, end))
332 | else:
333 | pages.append(int(spec) - 1)
334 |
335 | # make a set of invalid page numbers
336 | wrong_pages = set([n + 1 for n in pages if n >= doc.page_count][:4])
337 | if wrong_pages != set(): # if any invalid numbers given, exit.
338 | sys.exit(f"Page number(s) {wrong_pages} not in '{doc}'.")
339 |
340 | # get the markdown string
341 | md_string = to_markdown(doc, pages=pages)
342 |
343 | # output to a text file with extension ".md"
344 | out = open(doc.name.replace(".pdf", ".md"), "w")
345 | out.write(md_string)
346 | out.close()
347 | t1 = time.perf_counter() # stop timer
348 | print(f"Markdown creation time for {doc.name=} {round(t1-t0,2)} sec.")
349 |
--------------------------------------------------------------------------------
/application/utils/rag_summit.py:
--------------------------------------------------------------------------------
1 | ############################################################
2 | ############################################################
3 | # RAG 관련 함수들
4 | ############################################################
5 | ############################################################
6 |
7 | import json
8 | import copy
9 | import boto3
10 | import numpy as np
11 | import pandas as pd
12 | from copy import deepcopy
13 | from pprint import pprint
14 | from operator import itemgetter
15 | from itertools import chain as ch
16 | from typing import Any, Dict, List, Optional, List, Tuple
17 | from opensearchpy import OpenSearch, RequestsHttpConnection
18 |
19 | import base64
20 | from PIL import Image
21 | from io import BytesIO
22 | import matplotlib.pyplot as plt
23 |
24 | from utils import print_ww
25 | from utils.opensearch_summit import opensearch_utils
26 | from utils.common_utils import print_html
27 |
28 | from langchain.schema import Document
29 | from langchain.chains import RetrievalQA
30 | from langchain.schema import BaseRetriever
31 | from langchain.prompts import PromptTemplate
32 | from langchain.retrievers import AmazonKendraRetriever
33 | from langchain_core.tracers import ConsoleCallbackHandler
34 | from langchain.schema.output_parser import StrOutputParser
35 | from langchain.embeddings import SagemakerEndpointEmbeddings
36 | from langchain.text_splitter import RecursiveCharacterTextSplitter
37 | from langchain.callbacks.manager import CallbackManagerForRetrieverRun
38 | from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
39 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
40 | from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
41 |
42 | import threading
43 | from functools import partial
44 | from multiprocessing.pool import ThreadPool
45 | #pool = ThreadPool(processes=2)
46 | #rag_fusion_pool = ThreadPool(processes=5)
47 |
48 | ############################################################
49 | # Prompt repo
50 | ############################################################
51 | class prompt_repo():
52 |
53 | template_types = ["web_search", "sci_fact", "fiqa", "trec_news"]
54 | prompt_types = ["answer_only", "answer_with_ref", "original", "ko_answer_only"]
55 |
56 |
57 | #First, find the paragraphs or sentences from the context that are most relevant to answering the question,
58 | #Then, answer the question within XML tags as much as you can.
59 | # Answer the question within XML tags as much as you can.
60 | # Don't say "According to context" when answering.
61 | # Don't insert XML tag such as and when answering.
62 |
63 |
64 | @classmethod
65 | def get_system_prompt(cls, ):
66 |
67 | system_prompt = '''
68 | You are a master answer bot designed to answer user's questions.
69 | I'm going to give you contexts which consist of texts, tables and images.
70 | Read the contexts carefully, because I'm going to ask you a question about it.
71 | '''
72 | return system_prompt
73 |
74 | @classmethod
75 | def get_human_prompt(cls, images=None, tables=None):
76 |
77 | human_prompt = []
78 |
79 | image_template = {
80 | "type": "image_url",
81 | "image_url": {
82 | "url": "data:image/png;base64," + "IMAGE_BASE64",
83 | },
84 | }
85 | text_template = {
86 | "type": "text",
87 | "text": '''
88 | Here is the contexts as texts: {contexts}
89 | TABLE_PROMPT
90 |
91 | First, find a few paragraphs or sentences from the contexts that are most relevant to answering the question.
92 | Then, answer the question as much as you can.
93 |
94 | Skip the preamble and go straight into the answer.
95 | Don't insert any XML tag such as and when answering.
96 | Answer in Korean.
97 |
98 | Here is the question: {question}
99 |
100 | If the question cannot be answered by the contexts, say "No relevant contexts".
101 | '''
102 | }
103 |
104 | table_prompt = '''
105 | Here is the contexts as tables (table as text): {tables_text}
106 | Here is the contexts as tables (table as html): {tables_html}
107 | '''
108 | if tables != None:
109 | text_template["text"] = text_template["text"].replace("TABLE_PROMPT", table_prompt)
110 | for table in tables:
111 | #if table.metadata["image_base64"]:
112 | if "image_base64" in table.metadata:
113 | image_template["image_url"]["url"] = image_template["image_url"]["url"].replace("IMAGE_BASE64", table.metadata["image_base64"])
114 | human_prompt.append(image_template)
115 | else: text_template["text"] = text_template["text"].replace("TABLE_PROMPT", "")
116 |
117 | if images != None:
118 | for image in images:
119 | image_template["image_url"]["url"] = image_template["image_url"]["url"].replace("IMAGE_BASE64", image.page_content)
120 | human_prompt.append(image_template)
121 |
122 | human_prompt.append(text_template)
123 |
124 | return human_prompt
125 |
126 | # @classmethod
127 | # def get_qa(cls, prompt_type="answer_only"):
128 |
129 | # assert prompt_type in cls.prompt_types, "Check your prompt_type"
130 |
131 | # if prompt_type == "answer_only":
132 |
133 | # prompt = """
134 | # \n\nHuman:
135 | # You are a master answer bot designed to answer software developer's questions.
136 | # I'm going to give you a context. Read the context carefully, because I'm going to ask you a question about it.
137 |
138 | # Here is the context: {context}
139 |
140 | # First, find a few paragraphs or sentences from the context that are most relevant to answering the question.
141 | # Then, answer the question as much as you can.
142 |
143 | # Skip the preamble and go straight into the answer.
144 | # Don't insert any XML tag such as and when answering.
145 |
146 | # Here is the question: {question}
147 |
148 | # If the question cannot be answered by the context, say "No relevant context".
149 | # \n\nAssistant: Here is the answer. """
150 |
151 | # elif prompt_type == "answer_with_ref":
152 |
153 | # prompt = """
154 | # \n\nHuman:
155 | # You are a master answer bot designed to answer software developer's questions.
156 | # I'm going to give you a context. Read the context carefully, because I'm going to ask you a question about it.
157 |
158 | # Here is the context: {context}
159 |
160 | # First, find the paragraphs or sentences from the context that are most relevant to answering the question, and then print them in numbered order.
161 | # The format of paragraphs or sentences to the question should look like what's shown between the tags.
162 | # Make sure to follow the formatting and spacing exactly.
163 |
164 | #
165 | # [Examples of question + answer pairs using parts of the given context, with answers written exactly like how Claude’s output should be structured]
166 | #
167 |
168 | # If there are no relevant paragraphs or sentences, write "No relevant context" instead.
169 |
170 | # Then, answer the question within XML tags.
171 | # Answer as much as you can.
172 | # Skip the preamble and go straight into the answer.
173 | # Don't say "According to context" when answering.
174 | # Don't insert XML tag such as and when answering.
175 | # If needed, answer using bulleted format.
176 | # If relevant paragraphs or sentences have code block, please show us that as code block.
177 |
178 | # Here is the question: {question}
179 |
180 | # If the question cannot be answered by the context, say "No relevant context".
181 |
182 | # \n\nAssistant: Here is the most relevant sentence in the context:"""
183 |
184 | # elif prompt_type == "original":
185 | # prompt = """
186 | # \n\nHuman: Here is the context, inside XML tags.
187 |
188 | #
189 | # {context}
190 | #
191 |
192 | # Only using the context as above, answer the following question with the rules as below:
193 | # - Don't insert XML tag such as and when answering.
194 | # - Write as much as you can
195 | # - Be courteous and polite
196 | # - Only answer the question if you can find the answer in the context with certainty.
197 |
198 | # Question:
199 | # {question}
200 |
201 | # If the answer is not in the context, just say "I don't know"
202 | # \n\nAssistant:"""
203 |
204 | # if prompt_type == "ko_answer_only":
205 |
206 | # prompt = """
207 | # \n\nHuman:
208 | # You are a master answer bot designed to answer software developer's questions.
209 | # I'm going to give you a context. Read the context carefully, because I'm going to ask you a question about it.
210 |
211 | # Here is the context: {context}
212 |
213 | # First, find a few paragraphs or sentences from the context that are most relevant to answering the question.
214 | # Then, answer the question as much as you can.
215 |
216 | # Skip the preamble and go straight into the answer.
217 | # Don't insert any XML tag such as and when answering.
218 |
219 | # Here is the question: {question}
220 |
221 | # Answer in Korean.
222 | # If the question cannot be answered by the context, say "No relevant context".
223 | # \n\nAssistant: Here is the answer. """
224 |
225 | # prompt_template = PromptTemplate(
226 | # template=prompt, input_variables=["context", "question"]
227 | # )
228 |
229 | # return prompt_template
230 |
231 | @staticmethod
232 | def get_rag_fusion():
233 |
234 | system_prompt = """
235 | You are a helpful assistant that generates multiple search queries that is semantically simiar to a single input query.
236 | Skip the preamble and generate in Korean.
237 | """
238 | human_prompt = """
239 | Generate multiple search queries related to: {query}
240 | OUTPUT ({query_augmentation_size} queries):
241 | """
242 |
243 | system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)
244 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
245 | prompt = ChatPromptTemplate.from_messages(
246 | [system_message_template, human_message_template]
247 | )
248 |
249 | return prompt
250 |
251 | @classmethod
252 | def get_hyde(cls, template_type):
253 |
254 | assert template_type in cls.template_types, "Check your template_type"
255 |
256 | system_prompt = """
257 | You are a master answer bot designed to answer user's questions.
258 | """
259 | human_prompt = """
260 | Here is the question: {query}
261 |
262 | HYDE_TEMPLATE
263 | Skip the preamble and generate in Korean.
264 | """
265 |
266 |
267 | # There are a few different templates to choose from
268 | # These are just different ways to generate hypothetical documents
269 | hyde_template = {
270 | "web_search": "Please write a concise passage to answer the question.",
271 | "sci_fact": "Please write a concise scientific paper passage to support/refute the claim.",
272 | "fiqa": "Please write a concise financial article passage to answer the question.",
273 | "trec_news": "Please write a concise news passage about the topic."
274 | }
275 | human_prompt = human_prompt.replace("HYDE_TEMPLATE", hyde_template[template_type])
276 |
277 | system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)
278 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
279 | prompt = ChatPromptTemplate.from_messages(
280 | [system_message_template, human_message_template]
281 | )
282 |
283 | return prompt
284 |
285 | ############################################################
286 | # RetrievalQA (Langchain)
287 | ############################################################
288 | pretty_contexts = None
289 | augmentation = None
290 |
291 | class qa_chain():
292 |
293 | def __init__(self, **kwargs):
294 |
295 | system_prompt = kwargs["system_prompt"]
296 | self.llm_text = kwargs["llm_text"]
297 | self.retriever = kwargs["retriever"]
298 | self.system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)
299 | self.return_context = kwargs.get("return_context", False)
300 | self.verbose = kwargs.get("verbose", False)
301 |
302 | def invoke(self, **kwargs):
303 | global pretty_contexts
304 | global augmentation
305 |
306 | query, verbose = kwargs["query"], kwargs.get("verbose", self.verbose)
307 | tables, images = None, None
308 | if self.retriever.complex_doc:
309 | #retrieval, tables, images = self.retriever.get_relevant_documents(query)
310 | retrieval, tables, images = self.retriever.invoke(query)
311 |
312 | invoke_args = {
313 | "contexts": "\n\n".join([doc.page_content for doc in retrieval]),
314 | "tables_text": "\n\n".join([doc.page_content for doc in tables]),
315 | "tables_html": "\n\n".join([doc.metadata["text_as_html"] if "text_as_html" in doc.metadata else "" for doc in tables]),
316 | "question": query
317 | }
318 | else:
319 | #retrieval = self.retriever.get_relevant_documents(query)
320 | retrieval = self.retriever.invoke(query)
321 | invoke_args = {
322 | "contexts": "\n\n".join([doc.page_content for doc in retrieval]),
323 | "question": query
324 | }
325 |
326 | human_prompt = prompt_repo.get_human_prompt(
327 | images=images,
328 | tables=tables
329 | )
330 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
331 | prompt = ChatPromptTemplate.from_messages(
332 | [self.system_message_template, human_message_template]
333 | )
334 |
335 | chain = prompt | self.llm_text | StrOutputParser()
336 |
337 | self.verbose = verbose
338 | response = chain.invoke(
339 | invoke_args,
340 | config={'callbacks': [ConsoleCallbackHandler()]} if self.verbose else {}
341 | )
342 | pretty_contexts = tuple(pretty_contexts)
343 |
344 | return response, pretty_contexts, retrieval, augmentation
345 |
346 | def run_RetrievalQA(**kwargs):
347 |
348 | chain_types = ["stuff", "map_reduce", "refine"]
349 |
350 | assert "llm" in kwargs, "Check your llm"
351 | assert "query" in kwargs, "Check your query"
352 | assert "prompt" in kwargs, "Check your prompt"
353 | assert "vector_db" in kwargs, "Check your vector_db"
354 | assert kwargs.get("chain_type", "stuff") in chain_types, f'Check your chain_type, {chain_types}'
355 |
356 | qa = RetrievalQA.from_chain_type(
357 | llm=kwargs["llm"],
358 | chain_type=kwargs.get("chain_type", "stuff"),
359 | retriever=kwargs["vector_db"].as_retriever(
360 | search_type="similarity",
361 | search_kwargs={
362 | "k": kwargs.get("k", 5),
363 | "boolean_filter": opensearch_utils.get_filter(
364 | filter=kwargs.get("boolean_filter", [])
365 | ),
366 | }
367 | ),
368 | return_source_documents=True,
369 | chain_type_kwargs={
370 | "prompt": kwargs["prompt"],
371 | "verbose": kwargs.get("verbose", False),
372 | },
373 | verbose=kwargs.get("verbose", False)
374 | )
375 |
376 | return qa(kwargs["query"])
377 |
378 | def run_RetrievalQA_kendra(query, llm_text, PROMPT, kendra_index_id, k, aws_region, verbose):
379 | qa = RetrievalQA.from_chain_type(
380 | llm=llm_text,
381 | chain_type="stuff",
382 | retriever=AmazonKendraRetriever(
383 | index_id=kendra_index_id,
384 | region_name=aws_region,
385 | top_k=k,
386 | attribute_filter = {
387 | "EqualsTo": {
388 | "Key": "_language_code",
389 | "Value": {
390 | "StringValue": "ko"
391 | }
392 | },
393 | }
394 | ),
395 | return_source_documents=True,
396 | chain_type_kwargs={
397 | "prompt": PROMPT,
398 | "verbose": verbose,
399 | },
400 | verbose=verbose
401 | )
402 |
403 | result = qa(query)
404 |
405 | return result
406 |
407 | #################################################################
408 | # Document Retriever with custom function: return List(documents)
409 | #################################################################
410 | def list_up(similar_docs_semantic, similar_docs_keyword, similar_docs_wo_reranker, similar_docs):
411 | combined_list = []
412 | combined_list.append(similar_docs_semantic)
413 | combined_list.append(similar_docs_keyword)
414 | combined_list.append(similar_docs_wo_reranker)
415 | combined_list.append(similar_docs)
416 |
417 | return combined_list
418 |
419 | class retriever_utils():
420 |
421 | runtime_client = boto3.Session().client('sagemaker-runtime')
422 | pool = ThreadPool(processes=2)
423 | rag_fusion_pool = ThreadPool(processes=5)
424 | hyde_pool = ThreadPool(processes=4)
425 | text_splitter = RecursiveCharacterTextSplitter(
426 | # Set a really small chunk size, just to show.
427 | chunk_size=512,
428 | chunk_overlap=0,
429 | separators=["\n\n", "\n", ".", " ", ""],
430 | length_function=len,
431 | )
432 | token_limit = 300
433 |
434 | @classmethod
435 | # semantic search based
436 | def get_semantic_similar_docs_by_langchain(cls, **kwargs):
437 |
438 | #print(f"Thread={threading.get_ident()}, Process={os.getpid()}")
439 | search_types = ["approximate_search", "script_scoring", "painless_scripting"]
440 | space_types = ["l2", "l1", "linf", "cosinesimil", "innerproduct", "hammingbit"]
441 |
442 | assert "vector_db" in kwargs, "Check your vector_db"
443 | assert "query" in kwargs, "Check your query"
444 | assert kwargs.get("search_type", "approximate_search") in search_types, f'Check your search_type: {search_types}'
445 | assert kwargs.get("space_type", "l2") in space_types, f'Check your space_type: {space_types}'
446 |
447 | results = kwargs["vector_db"].similarity_search_with_score(
448 | query=kwargs["query"],
449 | k=kwargs.get("k", 5),
450 | search_type=kwargs.get("search_type", "approximate_search"),
451 | space_type=kwargs.get("space_type", "l2"),
452 | boolean_filter=opensearch_utils.get_filter(
453 | filter=kwargs.get("boolean_filter", [])
454 | ),
455 | )
456 |
457 | if kwargs.get("hybrid", False) and results:
458 | max_score = results[0][1]
459 | new_results = []
460 | for doc in results:
461 | nomalized_score = float(doc[1]/max_score)
462 | new_results.append((doc[0], nomalized_score))
463 | results = deepcopy(new_results)
464 |
465 | return results
466 |
467 | @classmethod
468 | def control_streaming_mode(cls, llm, stream=True):
469 |
470 | if stream:
471 | llm.streaming = True
472 | llm.callbacks = [StreamingStdOutCallbackHandler()]
473 | else:
474 | llm.streaming = False
475 | llm.callbacks = None
476 |
477 | return llm
478 |
479 | @classmethod
480 | # semantic search based
481 | def get_semantic_similar_docs(cls, **kwargs):
482 |
483 | assert "query" in kwargs, "Check your query"
484 | assert "k" in kwargs, "Check your k"
485 | assert "os_client" in kwargs, "Check your os_client"
486 | assert "index_name" in kwargs, "Check your index_name"
487 |
488 | def normalize_search_results(search_results):
489 |
490 | hits = (search_results["hits"]["hits"])
491 | max_score = float(search_results["hits"]["max_score"])
492 | for hit in hits:
493 | hit["_score"] = float(hit["_score"]) / max_score
494 | search_results["hits"]["max_score"] = hits[0]["_score"]
495 | search_results["hits"]["hits"] = hits
496 | return search_results
497 |
498 | query = opensearch_utils.get_query(
499 | query=kwargs["query"],
500 | filter=kwargs.get("boolean_filter", []),
501 | search_type="semantic", # enable semantic search
502 | vector_field="vector_field", # for semantic search check by using index_info = os_client.indices.get(index=index_name)
503 | vector=kwargs["llm_emb"].embed_query(kwargs["query"]),
504 | k=kwargs["k"]
505 | )
506 | query["size"] = kwargs["k"]
507 |
508 | #print ("\nsemantic search query: ")
509 | #pprint (query)
510 |
511 | search_results = opensearch_utils.search_document(
512 | os_client=kwargs["os_client"],
513 | query=query,
514 | index_name=kwargs["index_name"]
515 | )
516 |
517 | results = []
518 | if search_results["hits"]["hits"]:
519 | search_results = normalize_search_results(search_results)
520 | for res in search_results["hits"]["hits"]:
521 |
522 | metadata = res["_source"]["metadata"]
523 | metadata["id"] = res["_id"]
524 |
525 | doc = Document(
526 | page_content=res["_source"]["text"],
527 | metadata=metadata
528 | )
529 | if kwargs.get("hybrid", False):
530 | results.append((doc, res["_score"]))
531 | else:
532 | results.append((doc))
533 |
534 | return results
535 |
536 | @classmethod
537 | # lexical(keyword) search based (using Amazon OpenSearch)
538 | def get_lexical_similar_docs(cls, **kwargs):
539 |
540 | assert "query" in kwargs, "Check your query"
541 | assert "k" in kwargs, "Check your k"
542 | assert "os_client" in kwargs, "Check your os_client"
543 | assert "index_name" in kwargs, "Check your index_name"
544 |
545 | def normalize_search_results(search_results):
546 |
547 | hits = (search_results["hits"]["hits"])
548 | max_score = float(search_results["hits"]["max_score"])
549 | for hit in hits:
550 | hit["_score"] = float(hit["_score"]) / max_score
551 | search_results["hits"]["max_score"] = hits[0]["_score"]
552 | search_results["hits"]["hits"] = hits
553 | return search_results
554 |
555 | query = opensearch_utils.get_query(
556 | query=kwargs["query"],
557 | minimum_should_match=kwargs.get("minimum_should_match", 0),
558 | filter=kwargs["filter"]
559 | )
560 | query["size"] = kwargs["k"]
561 |
562 | #print ("\nlexical search query: ")
563 | #pprint (query)
564 |
565 | search_results = opensearch_utils.search_document(
566 | os_client=kwargs["os_client"],
567 | query=query,
568 | index_name=kwargs["index_name"]
569 | )
570 |
571 | results = []
572 | if search_results["hits"]["hits"]:
573 | search_results = normalize_search_results(search_results)
574 | for res in search_results["hits"]["hits"]:
575 |
576 | metadata = res["_source"]["metadata"]
577 | metadata["id"] = res["_id"]
578 |
579 | doc = Document(
580 | page_content=res["_source"]["text"],
581 | metadata=metadata
582 | )
583 | if kwargs.get("hybrid", False):
584 | results.append((doc, res["_score"]))
585 | else:
586 | results.append((doc))
587 |
588 | return results
589 |
590 | @classmethod
591 | # rag-fusion based
592 | def get_rag_fusion_similar_docs(cls, **kwargs):
593 | global augmentation
594 |
595 | search_types = ["approximate_search", "script_scoring", "painless_scripting"]
596 | space_types = ["l2", "l1", "linf", "cosinesimil", "innerproduct", "hammingbit"]
597 |
598 | assert "llm_emb" in kwargs, "Check your llm_emb"
599 | assert "query" in kwargs, "Check your query"
600 | assert "query_transformation_prompt" in kwargs, "Check your query_transformation_prompt"
601 | assert kwargs.get("search_type", "approximate_search") in search_types, f'Check your search_type: {search_types}'
602 | assert kwargs.get("space_type", "l2") in space_types, f'Check your space_type: {space_types}'
603 | assert kwargs.get("llm_text", None) != None, "Check your llm_text"
604 |
605 | llm_text = kwargs["llm_text"]
606 | query_augmentation_size = kwargs["query_augmentation_size"]
607 | query_transformation_prompt = kwargs["query_transformation_prompt"]
608 |
609 | llm_text = cls.control_streaming_mode(llm_text, stream=False) ## trun off llm streaming
610 | generate_queries = query_transformation_prompt | llm_text | StrOutputParser() | (lambda x: x.split("\n"))
611 |
612 | rag_fusion_query = generate_queries.invoke(
613 | {
614 | "query": kwargs["query"],
615 | "query_augmentation_size": kwargs["query_augmentation_size"]
616 | }
617 | )
618 |
619 | rag_fusion_query = [query for query in rag_fusion_query if query != ""]
620 | if len(rag_fusion_query) > query_augmentation_size: rag_fusion_query = rag_fusion_query[-query_augmentation_size:]
621 | rag_fusion_query.insert(0, kwargs["query"])
622 | augmentation = rag_fusion_query
623 |
624 | if kwargs["verbose"]:
625 | print("\n")
626 | print("===== RAG-Fusion Queries =====")
627 | print(rag_fusion_query)
628 |
629 | llm_text = cls.control_streaming_mode(llm_text, stream=True)## trun on llm streaming
630 |
631 | tasks = []
632 | for query in rag_fusion_query:
633 | semantic_search = partial(
634 | cls.get_semantic_similar_docs,
635 | os_client=kwargs["os_client"],
636 | index_name=kwargs["index_name"],
637 | query=query,
638 | k=kwargs["k"],
639 | boolean_filter=kwargs.get("boolean_filter", []),
640 | llm_emb=kwargs["llm_emb"],
641 | hybrid=True
642 | )
643 | tasks.append(cls.rag_fusion_pool.apply_async(semantic_search,))
644 | rag_fusion_docs = [task.get() for task in tasks]
645 |
646 | similar_docs = cls.get_ensemble_results(
647 | doc_lists=rag_fusion_docs,
648 | weights=[1/(query_augmentation_size+1)]*(query_augmentation_size+1), #query_augmentation_size + original query
649 | algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
650 | c=60,
651 | k=kwargs["k"],
652 | )
653 |
654 | return similar_docs
655 |
656 | @classmethod
657 | # HyDE based
658 | def get_hyde_similar_docs(cls, **kwargs):
659 | global augmentation
660 |
661 | def _get_hyde_response(query, prompt, llm_text):
662 |
663 | chain = prompt | llm_text | StrOutputParser()
664 |
665 | return chain.invoke({"query": query})
666 |
667 | search_types = ["approximate_search", "script_scoring", "painless_scripting"]
668 | space_types = ["l2", "l1", "linf", "cosinesimil", "innerproduct", "hammingbit"]
669 |
670 | assert "llm_emb" in kwargs, "Check your llm_emb"
671 | assert "query" in kwargs, "Check your query"
672 | assert "hyde_query" in kwargs, "Check your hyde_query"
673 | assert kwargs.get("search_type", "approximate_search") in search_types, f'Check your search_type: {search_types}'
674 | assert kwargs.get("space_type", "l2") in space_types, f'Check your space_type: {space_types}'
675 | assert kwargs.get("llm_text", None) != None, "Check your llm_text"
676 |
677 | query = kwargs["query"]
678 | llm_text = kwargs["llm_text"]
679 | hyde_query = kwargs["hyde_query"]
680 |
681 | tasks = []
682 | llm_text = cls.control_streaming_mode(llm_text, stream=False) ## trun off llm streaming
683 | for template_type in hyde_query:
684 | hyde_response = partial(
685 | _get_hyde_response,
686 | query=query,
687 | prompt=prompt_repo.get_hyde(template_type),
688 | llm_text=llm_text
689 | )
690 | tasks.append(cls.hyde_pool.apply_async(hyde_response,))
691 | hyde_answers = [task.get() for task in tasks]
692 | hyde_answers.insert(0, query)
693 |
694 | tasks = []
695 | llm_text = cls.control_streaming_mode(llm_text, stream=True) ## trun on llm streaming
696 | for hyde_answer in hyde_answers:
697 | semantic_search = partial(
698 | cls.get_semantic_similar_docs,
699 | os_client=kwargs["os_client"],
700 | index_name=kwargs["index_name"],
701 | query=hyde_answer,
702 | k=kwargs["k"],
703 | boolean_filter=kwargs.get("boolean_filter", []),
704 | llm_emb=kwargs["llm_emb"],
705 | hybrid=True
706 | )
707 | tasks.append(cls.hyde_pool.apply_async(semantic_search,))
708 | hyde_docs = [task.get() for task in tasks]
709 | hyde_doc_size = len(hyde_docs)
710 |
711 | similar_docs = cls.get_ensemble_results(
712 | doc_lists=hyde_docs,
713 | weights=[1/(hyde_doc_size)]*(hyde_doc_size), #query_augmentation_size + original query
714 | algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
715 | c=60,
716 | k=kwargs["k"],
717 | )
718 | augmentation = hyde_answers[1]
719 | if kwargs["verbose"]:
720 | print("\n")
721 | print("===== HyDE Answers =====")
722 | print(hyde_answers)
723 |
724 | return similar_docs
725 |
726 | @classmethod
727 | # ParentDocument based
728 | def get_parent_document_similar_docs(cls, **kwargs):
729 |
730 | child_search_results = kwargs["similar_docs"]
731 |
732 | parent_info, similar_docs = {}, []
733 | for rank, (doc, score) in enumerate(child_search_results):
734 | parent_id = doc.metadata["parent_id"]
735 | if parent_id != "NA": ## For Tables and Images
736 | if parent_id not in parent_info:
737 | parent_info[parent_id] = (rank+1, score)
738 | else:
739 | if kwargs["hybrid"]:
740 | similar_docs.append((doc, score))
741 | else:
742 | similar_docs.append((doc))
743 |
744 | parent_ids = sorted(parent_info.items(), key=lambda x: x[1], reverse=False)
745 | parent_ids = list(map(lambda x:x[0], parent_ids))
746 |
747 | if parent_ids:
748 | parent_docs = opensearch_utils.get_documents_by_ids(
749 | os_client=kwargs["os_client"],
750 | ids=parent_ids,
751 | index_name=kwargs["index_name"],
752 | )
753 |
754 | if parent_docs["docs"]:
755 | for res in parent_docs["docs"]:
756 | doc_id = res["_id"]
757 | doc = Document(
758 | page_content=res["_source"]["text"],
759 | metadata=res["_source"]["metadata"]
760 | )
761 | if kwargs["hybrid"]:
762 | similar_docs.append((doc, parent_info[doc_id][1]))
763 | else:
764 | similar_docs.append((doc))
765 |
766 | if kwargs["hybrid"]:
767 | similar_docs = sorted(
768 | similar_docs,
769 | key=lambda x: x[1],
770 | reverse=True
771 | )
772 |
773 | if kwargs["verbose"]:
774 | print("===== ParentDocument =====")
775 | print (f'filter: {kwargs["boolean_filter"]}')
776 | print (f'# child_docs: {len(child_search_results)}')
777 | print (f'# parent docs: {len(similar_docs)}')
778 | print (f'# duplicates: {len(child_search_results)-len(similar_docs)}')
779 |
780 | return similar_docs
781 |
782 | @classmethod
783 | def get_rerank_docs(cls, **kwargs):
784 |
785 | assert "reranker_endpoint_name" in kwargs, "Check your reranker_endpoint_name"
786 | assert "k" in kwargs, "Check your k"
787 |
788 | contexts, query, llm_text, rerank_queries = kwargs["context"], kwargs["query"], kwargs["llm_text"], {"inputs":[]}
789 |
790 | exceed_info = []
791 | for idx, (context, score) in enumerate(contexts):
792 | page_content = context.page_content
793 | token_size = llm_text.get_num_tokens(query+page_content)
794 | exceed_flag = False
795 |
796 | if token_size > cls.token_limit:
797 | exceed_flag = True
798 | splited_docs = cls.text_splitter.split_documents([context])
799 | if kwargs["verbose"]:
800 | print(f"\n[Exeeds ReRanker token limit] Number of chunk_docs after split and chunking= {len(splited_docs)}\n")
801 |
802 | partial_set, length = [], []
803 | for splited_doc in splited_docs:
804 | rerank_queries["inputs"].append({"text": query, "text_pair": splited_doc.page_content})
805 | length.append(llm_text.get_num_tokens(splited_doc.page_content))
806 | partial_set.append(len(rerank_queries["inputs"])-1)
807 | else:
808 | rerank_queries["inputs"].append({"text": query, "text_pair": page_content})
809 |
810 | if exceed_flag:
811 | exceed_info.append([idx, exceed_flag, partial_set, length])
812 | else:
813 | exceed_info.append([idx, exceed_flag, len(rerank_queries["inputs"])-1, None])
814 |
815 | rerank_queries = json.dumps(rerank_queries)
816 |
817 | response = cls.runtime_client.invoke_endpoint(
818 | EndpointName=kwargs["reranker_endpoint_name"],
819 | ContentType="application/json",
820 | Accept="application/json",
821 | Body=rerank_queries
822 | )
823 | outs = json.loads(response['Body'].read().decode()) ## for json
824 |
825 | rerank_contexts = []
826 | for idx, exceed_flag, partial_set, length in exceed_info:
827 | if not exceed_flag:
828 | rerank_contexts.append((contexts[idx][0], outs[partial_set]["score"]))
829 | else:
830 | partial_scores = [outs[partial_idx]["score"] for partial_idx in partial_set]
831 | partial_scores = np.average(partial_scores, axis=0, weights=length)
832 | rerank_contexts.append((contexts[idx][0], partial_scores))
833 |
834 | #rerank_contexts = [(contexts[idx][0], out["score"]) for idx, out in enumerate(outs)]
835 | rerank_contexts = sorted(
836 | rerank_contexts,
837 | key=lambda x: x[1],
838 | reverse=True
839 | )
840 |
841 | return rerank_contexts[:kwargs["k"]]
842 |
843 | @classmethod
844 | def get_element(cls, **kwargs):
845 |
846 | similar_docs = copy.deepcopy(kwargs["similar_docs"])
847 | tables, images = [], []
848 |
849 | for doc in similar_docs:
850 |
851 | category = doc.metadata.get("category", None)
852 | if category != None:
853 | if category == "Table":
854 | doc.page_content = doc.metadata["origin_table"]
855 | tables.append(doc)
856 | elif category == "Image":
857 | doc.page_content = doc.metadata["image_base64"]
858 | images.append(doc)
859 |
860 | return tables, images
861 |
862 |
863 | @classmethod
864 | # hybrid (lexical + semantic) search based
865 | def search_hybrid(cls, **kwargs):
866 |
867 | assert "query" in kwargs, "Check your query"
868 | assert "llm_emb" in kwargs, "Check your llm_emb"
869 | assert "index_name" in kwargs, "Check your index_name"
870 | assert "os_client" in kwargs, "Check your os_client"
871 |
872 | rag_fusion = kwargs.get("rag_fusion", False)
873 | hyde = kwargs.get("hyde", False)
874 | parent_document = kwargs.get("parent_document", False)
875 | hybrid_search_debugger = kwargs.get("hybrid_search_debugger", "None")
876 |
877 |
878 |
879 | assert (rag_fusion + hyde) <= 1, "choose only one between RAG-FUSION and HyDE"
880 | if rag_fusion:
881 | assert "query_augmentation_size" in kwargs, "if you use RAG-FUSION, Check your query_augmentation_size"
882 | if hyde:
883 | assert "hyde_query" in kwargs, "if you use HyDE, Check your hyde_query"
884 |
885 | verbose = kwargs.get("verbose", False)
886 | async_mode = kwargs.get("async_mode", True)
887 | reranker = kwargs.get("reranker", False)
888 | complex_doc = kwargs.get("complex_doc", False)
889 | search_filter = deepcopy(kwargs.get("filter", []))
890 |
891 | #search_filter.append({"term": {"metadata.family_tree": "child"}})
892 | if parent_document:
893 | parent_doc_filter = {
894 | "bool":{
895 | "should":[ ## or condition
896 | {"term": {"metadata.family_tree": "child"}},
897 | {"term": {"metadata.family_tree": "parent_table"}},
898 | {"term": {"metadata.family_tree": "parent_image"}},
899 | ]
900 | }
901 | }
902 | search_filter.append(parent_doc_filter)
903 | else:
904 | parent_doc_filter = {
905 | "bool":{
906 | "should":[ ## or condition
907 | {"term": {"metadata.family_tree": "child"}},
908 | {"term": {"metadata.family_tree": "parent_table"}},
909 | {"term": {"metadata.family_tree": "parent_image"}},
910 | ]
911 | }
912 | }
913 | search_filter.append(parent_doc_filter)
914 |
915 |
916 | def do_sync():
917 |
918 | if rag_fusion:
919 | similar_docs_semantic = cls.get_rag_fusion_similar_docs(
920 | index_name=kwargs["index_name"],
921 | os_client=kwargs["os_client"],
922 | llm_emb=kwargs["llm_emb"],
923 |
924 | query=kwargs["query"],
925 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
926 | boolean_filter=search_filter,
927 | hybrid=True,
928 |
929 | llm_text=kwargs.get("llm_text", None),
930 | query_augmentation_size=kwargs["query_augmentation_size"],
931 | query_transformation_prompt=kwargs.get("query_transformation_prompt", None),
932 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
933 |
934 | verbose=kwargs.get("verbose", False),
935 | )
936 | elif hyde:
937 | similar_docs_semantic = cls.get_hyde_similar_docs(
938 | index_name=kwargs["index_name"],
939 | os_client=kwargs["os_client"],
940 | llm_emb=kwargs["llm_emb"],
941 |
942 | query=kwargs["query"],
943 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
944 | boolean_filter=search_filter,
945 | hybrid=True,
946 |
947 | llm_text=kwargs.get("llm_text", None),
948 | hyde_query=kwargs["hyde_query"],
949 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
950 |
951 | verbose=kwargs.get("verbose", False),
952 | )
953 | else:
954 | similar_docs_semantic = cls.get_semantic_similar_docs(
955 | index_name=kwargs["index_name"],
956 | os_client=kwargs["os_client"],
957 | llm_emb=kwargs["llm_emb"],
958 |
959 | query=kwargs["query"],
960 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
961 | boolean_filter=search_filter,
962 | hybrid=True
963 | )
964 |
965 | similar_docs_keyword = cls.get_lexical_similar_docs(
966 | index_name=kwargs["index_name"],
967 | os_client=kwargs["os_client"],
968 |
969 | query=kwargs["query"],
970 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
971 | minimum_should_match=kwargs.get("minimum_should_match", 0),
972 | filter=search_filter,
973 | hybrid=True
974 | )
975 |
976 | if hybrid_search_debugger == "semantic": similar_docs_keyword = []
977 | elif hybrid_search_debugger == "lexical": similar_docs_semantic = []
978 |
979 | return similar_docs_semantic, similar_docs_keyword
980 |
981 | def do_async():
982 |
983 | if rag_fusion:
984 | semantic_search = partial(
985 | cls.get_rag_fusion_similar_docs,
986 | index_name=kwargs["index_name"],
987 | os_client=kwargs["os_client"],
988 | llm_emb=kwargs["llm_emb"],
989 |
990 | query=kwargs["query"],
991 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
992 | boolean_filter=search_filter,
993 | hybrid=True,
994 |
995 | llm_text=kwargs.get("llm_text", None),
996 | query_augmentation_size=kwargs["query_augmentation_size"],
997 | query_transformation_prompt=kwargs.get("query_transformation_prompt", None),
998 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
999 |
1000 | verbose=kwargs.get("verbose", False),
1001 | )
1002 | elif hyde:
1003 | semantic_search = partial(
1004 | cls.get_hyde_similar_docs,
1005 | index_name=kwargs["index_name"],
1006 | os_client=kwargs["os_client"],
1007 | llm_emb=kwargs["llm_emb"],
1008 |
1009 | query=kwargs["query"],
1010 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
1011 | boolean_filter=search_filter,
1012 | hybrid=True,
1013 |
1014 | llm_text=kwargs.get("llm_text", None),
1015 | hyde_query=kwargs["hyde_query"],
1016 | fusion_algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
1017 |
1018 | verbose=kwargs.get("verbose", False),
1019 | )
1020 | else:
1021 | semantic_search = partial(
1022 | cls.get_semantic_similar_docs,
1023 | index_name=kwargs["index_name"],
1024 | os_client=kwargs["os_client"],
1025 | llm_emb=kwargs["llm_emb"],
1026 |
1027 | query=kwargs["query"],
1028 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
1029 | boolean_filter=search_filter,
1030 | hybrid=True
1031 | )
1032 |
1033 | lexical_search = partial(
1034 | cls.get_lexical_similar_docs,
1035 | index_name=kwargs["index_name"],
1036 | os_client=kwargs["os_client"],
1037 |
1038 | query=kwargs["query"],
1039 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
1040 | minimum_should_match=kwargs.get("minimum_should_match", 0),
1041 | filter=search_filter,
1042 | hybrid=True
1043 | )
1044 | semantic_pool = cls.pool.apply_async(semantic_search,)
1045 | lexical_pool = cls.pool.apply_async(lexical_search,)
1046 | similar_docs_semantic, similar_docs_keyword = semantic_pool.get(), lexical_pool.get()
1047 |
1048 | if hybrid_search_debugger == "semantic": similar_docs_keyword = []
1049 | elif hybrid_search_debugger == "lexical": similar_docs_semantic = []
1050 |
1051 | return similar_docs_semantic, similar_docs_keyword
1052 |
1053 | if async_mode:
1054 | similar_docs_semantic, similar_docs_keyword = do_async()
1055 | else:
1056 | similar_docs_semantic, similar_docs_keyword = do_sync()
1057 |
1058 | similar_docs = cls.get_ensemble_results(
1059 | doc_lists=[similar_docs_semantic, similar_docs_keyword],
1060 | weights=kwargs.get("ensemble_weights", [.51, .49]),
1061 | algorithm=kwargs.get("fusion_algorithm", "RRF"), # ["RRF", "simple_weighted"]
1062 | c=60,
1063 | k=kwargs.get("k", 5) if not reranker else int(kwargs["k"]*1.5),
1064 | )
1065 | #print (len(similar_docs_keyword), len(similar_docs_semantic), len(similar_docs))
1066 | #print ("1-similar_docs")
1067 | #for i, doc in enumerate(similar_docs): print (i, doc)
1068 |
1069 | if verbose:
1070 | similar_docs_wo_reranker = copy.deepcopy(similar_docs)
1071 |
1072 | if reranker:
1073 | reranker_endpoint_name = kwargs["reranker_endpoint_name"]
1074 | similar_docs = cls.get_rerank_docs(
1075 | llm_text=kwargs["llm_text"],
1076 | query=kwargs["query"],
1077 | context=similar_docs,
1078 | k=kwargs.get("k", 5),
1079 | reranker_endpoint_name=reranker_endpoint_name,
1080 | verbose=verbose
1081 | )
1082 |
1083 | #print ("2-similar_docs")
1084 | #for i, doc in enumerate(similar_docs): print (i, doc)
1085 |
1086 | if parent_document:
1087 | similar_docs = cls.get_parent_document_similar_docs(
1088 | index_name=kwargs["index_name"],
1089 | os_client=kwargs["os_client"],
1090 | similar_docs=similar_docs,
1091 | hybrid=True,
1092 | boolean_filter=search_filter,
1093 | verbose=verbose
1094 | )
1095 |
1096 | if complex_doc:
1097 | tables, images = cls.get_element(
1098 | similar_docs=list(map(lambda x:x[0], similar_docs))
1099 | )
1100 |
1101 | if verbose:
1102 | similar_docs_semantic_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("semantic", similar_docs_semantic)
1103 | similar_docs_keyword_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("keyword", similar_docs_keyword)
1104 | similar_docs_wo_reranker_pretty = []
1105 | if reranker:
1106 | similar_docs_wo_reranker_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("wo_reranker", similar_docs_wo_reranker)
1107 |
1108 | similar_docs_pretty = opensearch_utils.opensearch_pretty_print_documents_with_score("similar_docs", similar_docs)
1109 |
1110 | similar_docs = list(map(lambda x:x[0], similar_docs))
1111 | global pretty_contexts
1112 | pretty_contexts = list_up(similar_docs_semantic_pretty, similar_docs_keyword_pretty, similar_docs_wo_reranker_pretty, similar_docs_pretty)
1113 |
1114 | #if complex_doc: return similar_docs, tables, images
1115 | #else: return similar_docs
1116 |
1117 | if complex_doc: return similar_docs, tables, images
1118 | else:
1119 | similar_docs_filtered = []
1120 | for doc in similar_docs:
1121 | category = "None"
1122 | if "category" in doc.metadata:
1123 | category = doc.metadata["category"]
1124 |
1125 | if category not in {"Table", "Image"}:
1126 | similar_docs_filtered.append(doc)
1127 | return similar_docs_filtered
1128 |
1129 |
1130 |
1131 | @classmethod
1132 | # Score fusion and re-rank (lexical + semantic)
1133 | def get_ensemble_results(cls, doc_lists: List[List[Document]], weights, algorithm="RRF", c=60, k=5) -> List[Document]:
1134 |
1135 | assert algorithm in ["RRF", "simple_weighted"]
1136 |
1137 | # Create a union of all unique documents in the input doc_lists
1138 | all_documents = set()
1139 |
1140 | for doc_list in doc_lists:
1141 | for (doc, _) in doc_list:
1142 | all_documents.add(doc.page_content)
1143 |
1144 | # Initialize the score dictionary for each document
1145 | hybrid_score_dic = {doc: 0.0 for doc in all_documents}
1146 |
1147 | # Calculate RRF scores for each document
1148 | for doc_list, weight in zip(doc_lists, weights):
1149 | for rank, (doc, score) in enumerate(doc_list, start=1):
1150 | if algorithm == "RRF": # RRF (Reciprocal Rank Fusion)
1151 | score = weight * (1 / (rank + c))
1152 | elif algorithm == "simple_weighted":
1153 | score *= weight
1154 | hybrid_score_dic[doc.page_content] += score
1155 |
1156 | # Sort documents by their scores in descending order
1157 | sorted_documents = sorted(
1158 | hybrid_score_dic.items(), key=lambda x: x[1], reverse=True
1159 | )
1160 |
1161 | # Map the sorted page_content back to the original document objects
1162 | page_content_to_doc_map = {
1163 | doc.page_content: doc for doc_list in doc_lists for (doc, orig_score) in doc_list
1164 | }
1165 |
1166 | sorted_docs = [
1167 | (page_content_to_doc_map[page_content], hybrid_score) for (page_content, hybrid_score) in sorted_documents
1168 | ]
1169 |
1170 | return sorted_docs[:k]
1171 |
1172 |
1173 | #################################################################
1174 | # Document Retriever with Langchain(BaseRetriever): return List(documents)
1175 | #################################################################
1176 |
1177 | # lexical(keyword) search based (using Amazon OpenSearch)
1178 | class OpenSearchLexicalSearchRetriever(BaseRetriever):
1179 |
1180 | os_client: Any
1181 | index_name: str
1182 | k = 3
1183 | minimum_should_match = 0
1184 | filter = []
1185 |
1186 | def normalize_search_results(self, search_results):
1187 |
1188 | hits = (search_results["hits"]["hits"])
1189 | max_score = float(search_results["hits"]["max_score"])
1190 | for hit in hits:
1191 | hit["_score"] = float(hit["_score"]) / max_score
1192 | search_results["hits"]["max_score"] = hits[0]["_score"]
1193 | search_results["hits"]["hits"] = hits
1194 | return search_results
1195 |
1196 | def update_search_params(self, **kwargs):
1197 |
1198 | self.k = kwargs.get("k", 3)
1199 | self.minimum_should_match = kwargs.get("minimum_should_match", 0)
1200 | self.filter = kwargs.get("filter", [])
1201 | self.index_name = kwargs.get("index_name", self.index_name)
1202 |
1203 | def _reset_search_params(self, ):
1204 |
1205 | self.k = 3
1206 | self.minimum_should_match = 0
1207 | self.filter = []
1208 |
1209 | def _get_relevant_documents(
1210 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
1211 |
1212 | query = opensearch_utils.get_query(
1213 | query=query,
1214 | minimum_should_match=self.minimum_should_match,
1215 | filter=self.filter
1216 | )
1217 | query["size"] = self.k
1218 |
1219 | print ("lexical search query: ")
1220 | pprint(query)
1221 |
1222 | search_results = opensearch_utils.search_document(
1223 | os_client=self.os_client,
1224 | query=query,
1225 | index_name=self.index_name
1226 | )
1227 |
1228 | results = []
1229 | if search_results["hits"]["hits"]:
1230 | search_results = self.normalize_search_results(search_results)
1231 | for res in search_results["hits"]["hits"]:
1232 |
1233 | metadata = res["_source"]["metadata"]
1234 | metadata["id"] = res["_id"]
1235 |
1236 | doc = Document(
1237 | page_content=res["_source"]["text"],
1238 | metadata=metadata
1239 | )
1240 | results.append((doc))
1241 |
1242 | self._reset_search_params()
1243 |
1244 | return results[:self.k]
1245 |
1246 | # hybrid (lexical + semantic) search based
1247 | class OpenSearchHybridSearchRetriever(BaseRetriever):
1248 |
1249 | os_client: Any
1250 | vector_db: Any
1251 | index_name: str
1252 | k = 3
1253 | minimum_should_match = 0
1254 | filter = []
1255 | fusion_algorithm: str
1256 | ensemble_weights = [0.51, 0.49]
1257 | verbose = False
1258 | async_mode = True
1259 | reranker = False
1260 | reranker_endpoint_name = ""
1261 | rag_fusion = False
1262 | query_augmentation_size: Any
1263 | rag_fusion_prompt = prompt_repo.get_rag_fusion()
1264 | llm_text: Any
1265 | llm_emb: Any
1266 | hyde = False
1267 | hyde_query: Any
1268 | parent_document = False
1269 | complex_doc = False
1270 | hybrid_search_debugger = "None"
1271 |
1272 | def update_search_params(self, **kwargs):
1273 |
1274 | self.k = kwargs.get("k", 3)
1275 | self.minimum_should_match = kwargs.get("minimum_should_match", 0)
1276 | self.filter = kwargs.get("filter", [])
1277 | self.index_name = kwargs.get("index_name", self.index_name)
1278 | self.fusion_algorithm = kwargs.get("fusion_algorithm", self.fusion_algorithm)
1279 | self.ensemble_weights = kwargs.get("ensemble_weights", self.ensemble_weights)
1280 | self.verbose = kwargs.get("verbose", self.verbose)
1281 | self.async_mode = kwargs.get("async_mode", self.async_mode)
1282 | self.reranker = kwargs.get("reranker", self.reranker)
1283 | self.reranker_endpoint_name = kwargs.get("reranker_endpoint_name", self.reranker_endpoint_name)
1284 | self.rag_fusion = kwargs.get("rag_fusion", self.rag_fusion)
1285 | self.query_augmentation_size = kwargs.get("query_augmentation_size", 3)
1286 | self.hyde = kwargs.get("hyde", self.hyde)
1287 | self.hyde_query = kwargs.get("hyde_query", ["web_search"])
1288 | self.parent_document = kwargs.get("parent_document", self.parent_document)
1289 | self.complex_doc = kwargs.get("complex_doc", self.complex_doc)
1290 | self.hybrid_search_debugger = kwargs.get("hybrid_search_debugger", self.hybrid_search_debugger)
1291 |
1292 | def _reset_search_params(self, ):
1293 |
1294 | self.k = 3
1295 | self.minimum_should_match = 0
1296 | self.filter = []
1297 |
1298 | def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
1299 |
1300 | '''
1301 | It can be called by "retriever.invoke" statements
1302 | '''
1303 | search_hybrid_result = retriever_utils.search_hybrid(
1304 | query=query,
1305 | k=self.k,
1306 | index_name=self.index_name,
1307 | os_client=self.os_client,
1308 | filter=self.filter,
1309 | minimum_should_match=self.minimum_should_match,
1310 | fusion_algorithm=self.fusion_algorithm, # ["RRF", "simple_weighted"]
1311 | ensemble_weights=self.ensemble_weights, # 시멘트 서치에 가중치 0.5 , 키워드 서치 가중치 0.5 부여.
1312 | async_mode=self.async_mode,
1313 | reranker=self.reranker,
1314 | reranker_endpoint_name=self.reranker_endpoint_name,
1315 | rag_fusion=self.rag_fusion,
1316 | query_augmentation_size=self.query_augmentation_size,
1317 | query_transformation_prompt=self.rag_fusion_prompt if self.rag_fusion else "",
1318 | hyde=self.hyde,
1319 | hyde_query=self.hyde_query if self.hyde else [],
1320 | parent_document = self.parent_document,
1321 | complex_doc = self.complex_doc,
1322 | llm_text=self.llm_text,
1323 | llm_emb=self.llm_emb,
1324 | verbose=self.verbose,
1325 | hybrid_search_debugger=self.hybrid_search_debugger
1326 | )
1327 | #self._reset_search_params()
1328 |
1329 | return search_hybrid_result
1330 |
1331 | #################################################################
1332 | # Document visualization
1333 | #################################################################
1334 | def show_context_used(context_list, limit=10):
1335 |
1336 | context_list = copy.deepcopy(context_list)
1337 |
1338 | if type(context_list) == tuple: context_list=context_list[0]
1339 | for idx, context in enumerate(context_list):
1340 |
1341 | if idx < limit:
1342 |
1343 | category = "None"
1344 | if "category" in context.metadata:
1345 | category = context.metadata["category"]
1346 |
1347 | print("\n-----------------------------------------------")
1348 | if category != "None":
1349 | print(f"{idx+1}. Category: {category}, Chunk: {len(context.page_content)} Characters")
1350 | else:
1351 | print(f"{idx+1}. Chunk: {len(context.page_content)} Characters")
1352 | print("-----------------------------------------------")
1353 |
1354 | if category == "Image" or (category == "Table" and "image_base64" in context.metadata):
1355 | img = Image.open(BytesIO(base64.b64decode(context.metadata["image_base64"])))
1356 | plt.imshow(img)
1357 | plt.show()
1358 | context.metadata["image_base64"], context.metadata["origin_image"] = "", ""
1359 |
1360 | context.metadata["orig_elements"] = ""
1361 | print_ww(context.page_content)
1362 | if "text_as_html" in context.metadata: print_html(context.metadata["text_as_html"])
1363 | print_ww("metadata: \n", context.metadata)
1364 | else:
1365 | break
1366 |
1367 | def show_chunk_stat(documents):
1368 |
1369 | doc_len_list = [len(doc.page_content) for doc in documents]
1370 | print(pd.DataFrame(doc_len_list).describe())
1371 | avg_doc_length = lambda documents: sum([len(doc.page_content) for doc in documents])//len(documents)
1372 | avg_char_count_pre = avg_doc_length(documents)
1373 | print(f'Average length among {len(documents)} documents loaded is {avg_char_count_pre} characters.')
1374 |
1375 | max_idx = doc_len_list.index(max(doc_len_list))
1376 | print("\nShow document at maximum size")
1377 | print(documents[max_idx].page_content)
1378 |
1379 | #################################################################
1380 | # JumpStart Embeddings
1381 | #################################################################
1382 |
1383 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
1384 | def embed_documents(self, texts: List[str], chunk_size: int=1) -> List[List[float]]:
1385 | """Compute doc embeddings using a SageMaker Inference Endpoint.
1386 |
1387 | Args:
1388 | texts: The list of texts to embed.
1389 | chunk_size: The chunk size defines how many input texts will
1390 | be grouped together as request. If None, will use the
1391 | chunk size specified by the class.
1392 |
1393 | Returns:
1394 | List of embeddings, one for each text.
1395 | """
1396 | results = []
1397 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
1398 |
1399 | print("text size: ", len(texts))
1400 | print("_chunk_size: ", _chunk_size)
1401 |
1402 | for i in range(0, len(texts), _chunk_size):
1403 |
1404 | #print (i, texts[i : i + _chunk_size])
1405 | response = self._embedding_func(texts[i : i + _chunk_size])
1406 | #print (i, response, len(response[0].shape))
1407 |
1408 | results.extend(response)
1409 | return results
1410 |
1411 | class KoSimCSERobertaContentHandler(EmbeddingsContentHandler):
1412 |
1413 | content_type = "application/json"
1414 | accepts = "application/json"
1415 |
1416 | def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
1417 |
1418 | input_str = json.dumps({"inputs": prompt, **model_kwargs})
1419 |
1420 | return input_str.encode("utf-8")
1421 |
1422 | def transform_output(self, output: bytes) -> str:
1423 |
1424 | response_json = json.loads(output.read().decode("utf-8"))
1425 | ndim = np.array(response_json).ndim
1426 |
1427 | if ndim == 4:
1428 | # Original shape (1, 1, n, 768)
1429 | emb = response_json[0][0][0]
1430 | emb = np.expand_dims(emb, axis=0).tolist()
1431 | elif ndim == 2:
1432 | # Original shape (n, 1)
1433 | emb = []
1434 | for ele in response_json:
1435 | e = ele[0][0]
1436 | emb.append(e)
1437 | else:
1438 | print(f"Other # of dimension: {ndim}")
1439 | emb = None
1440 | return emb
1441 |
1442 |
--------------------------------------------------------------------------------
/application/utils/s3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import boto3
3 | import logging
4 | from sagemaker.s3 import S3Uploader
5 | from botocore.exceptions import ClientError
6 |
7 | class s3_handler():
8 |
9 | def __init__(self, region_name=None):
10 |
11 | self.region_name = region_name
12 | #self.resource = boto3.resource('s3', region_name=self.region_name)
13 | #self.client = boto3.client('s3', region_name=self.region_name)
14 |
15 | self.resource = boto3.resource('s3')
16 | self.client = boto3.client('s3')
17 |
18 | print (f"This is a S3 handler with [{self.region_name}] region.")
19 |
20 | def create_bucket(self, bucket_name):
21 | """Create an S3 bucket in a specified region
22 |
23 | If a region is not specified, the bucket is created in the S3 default
24 | region (us-east-1).
25 |
26 | :param bucket_name: Bucket to create
27 | :return: True if bucket created, else False
28 | """
29 |
30 | try:
31 | if self.region_name is None:
32 | self.client.create_bucket(Bucket=bucket_name)
33 | else:
34 | location = {'LocationConstraint': self.region_name}
35 | self.client.create_bucket(
36 | Bucket=bucket_name,
37 | CreateBucketConfiguration=location
38 | )
39 | print (f"CREATE:[{bucket_name}] Bucket was created successfully")
40 |
41 | except ClientError as e:
42 | logging.error(e)
43 | print (f"ERROR: {e}")
44 | return False
45 |
46 | return True
47 |
48 | def copy_object(self, source_obj, target_bucket, target_obj):
49 |
50 | '''
51 | Copy S3 to S3
52 | '''
53 |
54 | try:
55 | response = self.client.copy_object(
56 | Bucket=target_bucket,#'destinationbucket',
57 | CopySource=source_obj,#'/sourcebucket/HappyFacejpg',
58 | Key=target_obj,#'HappyFaceCopyjpg',
59 | )
60 |
61 | except ClientError as e:
62 | logging.error(e)
63 | print (f"ERROR: {e}")
64 | return False
65 |
66 | def download_obj(self, source_bucket, source_obj, target_file):
67 |
68 | '''
69 | Copy S3 to Locl
70 | '''
71 |
72 | self.client.download_file(source_bucket, source_obj, target_file)
73 |
74 | def upload_dir(self, source_dir, target_bucket, target_dir):
75 |
76 | inputs = S3Uploader.upload(source_dir, "s3://{}/{}".format(target_bucket, target_dir))
77 |
78 | print (f"Upload:[{source_dir}] was uploaded to [{inputs}]successfully")
79 |
80 | def upload_file(self, source_file, target_bucket, target_obj=None):
81 | """Upload a file to an S3 bucket
82 |
83 | :param file_name: File to upload
84 | :param bucket: Bucket to upload to
85 | :param object_name: S3 object name. If not specified then file_name is used
86 | :return: True if file was uploaded, else False
87 | """
88 |
89 | # If S3 object_name was not specified, use file_name
90 | if target_obj is None:
91 | target_obj = os.path.basename(source_file)
92 |
93 | # Upload the file
94 | #s3_client = boto3.client('s3')
95 | try:
96 | response = self.client.upload_file(source_file, target_bucket, target_obj)
97 | except ClientError as e:
98 | logging.error(e)
99 | return False
100 |
101 | obj_s3_path = f"s3://{target_bucket}/{target_obj}"
102 |
103 | return obj_s3_path
104 |
105 | def delete_bucket(self, bucket_name):
106 |
107 | try:
108 | self._delete_all_object(bucket_name=bucket_name)
109 | response = self.client.delete_bucket(
110 | Bucket=bucket_name,
111 | )
112 |
113 | print (f"DELETE: [{bucket_name}] Bucket was deleted successfully")
114 |
115 | except ClientError as e:
116 | logging.error(e)
117 | print (f"ERROR: {e}")
118 | return False
119 |
120 | return True
121 |
122 | def _delete_all_object(self, bucket_name):
123 |
124 | bucket = self.resource.Bucket(bucket_name)
125 | bucket.object_versions.delete() ## delete versioning
126 | bucket.objects.all().delete() ## delete all objects in the bucket
--------------------------------------------------------------------------------
/application/utils/ssm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import boto3
3 |
4 | class parameter_store():
5 |
6 | def __init__(self, region_name="ap-northeast-2"):
7 |
8 | self.ssm = boto3.client('ssm', region_name=region_name)
9 |
10 | def put_params(self, key, value, dtype="String", overwrite=False, enc=False):
11 |
12 | # Specify the parameter name, value, and type
13 | if enc: dtype="SecureString"
14 |
15 | try:
16 | # Put the parameter
17 | response = self.ssm.put_parameter(
18 | Name=key,
19 | Value=value,
20 | Type=dtype,
21 | Overwrite=overwrite # Set to True if you want to overwrite an existing parameter
22 | )
23 |
24 | # Print the response
25 | print('Parameter stored successfully.')
26 | #print(response)
27 |
28 | except Exception as e:
29 | print('Error storing parameter:', str(e))
30 |
31 | def get_params(self, key, enc=False):
32 |
33 | if enc: WithDecryption = True
34 | else: WithDecryption = False
35 | response = self.ssm.get_parameters(
36 | Names=[key,],
37 | WithDecryption=WithDecryption
38 | )
39 |
40 | return response['Parameters'][0]['Value']
41 |
42 | def get_all_params(self, ):
43 |
44 | response = self.ssm.describe_parameters(MaxResults=50)
45 |
46 | return [dicParam["Name"] for dicParam in response["Parameters"]]
47 |
48 | def delete_param(self, listParams):
49 |
50 | response = self.ssm.delete_parameters(
51 | Names=listParams
52 | )
53 | print (f" parameters: {listParams} is deleted successfully")
54 |
55 |
56 |
--------------------------------------------------------------------------------
/application/utils/text_to_report.py:
--------------------------------------------------------------------------------
1 | import re
2 | from textwrap import dedent
3 | from langchain_core.tracers import ConsoleCallbackHandler
4 | from langchain.schema.output_parser import StrOutputParser
5 | from langchain_experimental.tools.python.tool import PythonAstREPLTool
6 | from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, MessagesPlaceholder
7 |
8 | class prompt_repo():
9 |
10 | @classmethod
11 | def get_system_prompt(cls, role=None):
12 |
13 | if role == "text2chart":
14 |
15 | system_prompt = dedent(
16 | '''
17 | You are a pandas master bot designed to generate Python code for plotting a chart based on the given dataset and user question.
18 | I'm going to give you dataset.
19 | Read the given dataset carefully, because I'm going to ask you a question about it.
20 | '''
21 | )
22 | else:
23 | system_prompt = ""
24 |
25 | return system_prompt
26 |
27 | @classmethod
28 | def get_human_prompt(cls, role=None):
29 |
30 | if role == "text2chart":
31 |
32 | human_prompt = dedent(
33 | '''
34 | This is the result of `print(df.head())`: {dataset}
35 | You should execute code as commanded to either provide information to answer the question or to do the transformations required.
36 | You should not assign any variables; you should return a one-liner in Pandas.
37 |
38 | Update this initial code:
39 |
40 | ```python
41 | # TODO: import the required dependencies
42 | import pandas as pd
43 |
44 | # Write code here
45 |
46 | ```
47 |
48 | Here is the question: {question}
49 |
50 | Variable `df: pd.DataFrame` is already declared.
51 | At the end, declare "result" variable as a dictionary of type and value.
52 | If you are asked to plot a chart, use "matplotlib" for charts, save as "results.png".
53 | Expaination with Koren.
54 | Do not use legend and title in plot in Korean.
55 |
56 | Generate python code and return full updated code within :
57 |
58 |
59 | '''
60 | )
61 |
62 | else:
63 | human_prompt = ""
64 |
65 | return human_prompt
66 |
67 | class text2chart_chain():
68 |
69 | def __init__(self, **kwargs):
70 |
71 | system_prompt = kwargs["system_prompt"]
72 | self.llm_text = kwargs["llm_text"]
73 | self.system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)
74 | self.num_rows = kwargs["num_rows"]
75 | #self.return_context = kwargs.get("return_context", False)
76 | self.verbose = kwargs.get("verbose", False)
77 | self.parsing_pattern = kwargs["parsing_pattern"]
78 | self.show_chart = kwargs.get("verbose", False)
79 |
80 | def query(self, **kwargs):
81 |
82 | df, query, verbose = kwargs["df"], kwargs["query"], kwargs.get("verbose", self.verbose)
83 | show_chart = kwargs.get("show_chart", self.show_chart)
84 |
85 | if len(df) < self.num_rows: dataset = str(df.to_csv())
86 | else: dataset = str(df.sample(self.num_rows, random_state=0).to_csv())
87 |
88 | invoke_args = {
89 | "dataset": dataset,
90 | "question": query
91 | }
92 |
93 | human_prompt = prompt_repo.get_human_prompt(
94 | role="text2chart"
95 | )
96 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
97 | prompt = ChatPromptTemplate.from_messages(
98 | [self.system_message_template, human_message_template]
99 | )
100 |
101 | code_generation_chain = prompt | self.llm_text | StrOutputParser()
102 |
103 | self.verbose = verbose
104 | response = code_generation_chain.invoke(
105 | invoke_args,
106 | config={'callbacks': [ConsoleCallbackHandler()]} if self.verbose else {}
107 | )
108 |
109 | if show_chart:
110 | results = self.code_execution(
111 | df=df,
112 | response=response
113 | )
114 | return results
115 | else:
116 | return response
117 |
118 | def code_execution(self, **kwargs):
119 |
120 | df, code = kwargs["df"], self._code_parser(response=kwargs["response"])
121 | tool = PythonAstREPLTool(locals={"df": df})
122 |
123 | results = tool.invoke(code)
124 |
125 | return results
126 |
127 | def _code_parser(self, **kwargs):
128 |
129 | parsed_code, response = "", kwargs["response"]
130 | match = re.search(self.parsing_pattern, response, re.DOTALL)
131 |
132 | if match: parsed_code = match.group(1)
133 | else: print("No match found.")
134 |
135 | return parsed_code
136 |
--------------------------------------------------------------------------------
/bin/cdk.ts:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | import "source-map-support/register";
3 | import { App } from "aws-cdk-lib";
4 | import { EC2Stack } from "../lib/ec2Stack/ec2Stack";
5 | import { OpensearchStack } from "../lib/openSearchStack";
6 | import { CustomResourceStack } from "../lib/customResourceStack";
7 | import { SagemakerNotebookStack } from "../lib/sagemakerNotebookStack/sagemakerNotebookStack";
8 | import { CfnInclude } from 'aws-cdk-lib/cloudformation-include';
9 |
10 | const STACK_PREFIX = "AdvancedRAG";
11 | const DEFAULT_REGION = "us-west-2";
12 | const envSetting = {
13 | env: {
14 | account: process.env.CDK_DEPLOY_ACCOUNT || process.env.CDK_DEFAULT_ACCOUNT,
15 | region: DEFAULT_REGION,
16 | },
17 | };
18 |
19 | const app = new App();
20 |
21 | // Deploy Sagemaker stack
22 | const sagemakerNotebookStack = new SagemakerNotebookStack(app, `${STACK_PREFIX}-SagemakerNotebookStack`, envSetting);
23 |
24 | // Deploy OpenSearch stack
25 | const opensearchStack = new OpensearchStack(app, `${STACK_PREFIX}-OpensearchStack`, envSetting);
26 | opensearchStack.addDependency(sagemakerNotebookStack);
27 |
28 | // Deploy Reranker stack using cloudformation template
29 | const rerankerStack = new CfnInclude(opensearchStack, `${STACK_PREFIX}-RerankerStack`, {
30 | templateFile: 'lib/rerankerStack/RerankerStack.template.json'
31 | });
32 |
33 | const customResourceStack = new CustomResourceStack(app, `${STACK_PREFIX}-CustomResourceStack`, envSetting)
34 | customResourceStack.addDependency(opensearchStack);
35 |
36 | // Deploy EC2 stack
37 | const ec2Stack = new EC2Stack(app, `${STACK_PREFIX}-EC2Stack`, envSetting);
38 | ec2Stack.node.addDependency(customResourceStack);
39 |
40 | app.synth();
41 |
--------------------------------------------------------------------------------
/cdk.json:
--------------------------------------------------------------------------------
1 | {
2 | "app": "npx ts-node --prefer-ts-exts bin/cdk.ts",
3 | "watch": {
4 | "include": [
5 | "**"
6 | ],
7 | "exclude": [
8 | "README.md",
9 | "cdk*.json",
10 | "**/*.d.ts",
11 | "**/*.js",
12 | "tsconfig.json",
13 | "package*.json",
14 | "yarn.lock",
15 | "node_modules",
16 | "test"
17 | ]
18 | },
19 | "context": {
20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true,
21 | "@aws-cdk/core:checkSecretUsage": true,
22 | "@aws-cdk/core:target-partitions": [
23 | "aws",
24 | "aws-cn"
25 | ],
26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true,
27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true,
28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true,
29 | "@aws-cdk/aws-iam:minimizePolicies": true,
30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true,
31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true,
32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true,
33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true,
34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true,
35 | "@aws-cdk/core:enablePartitionLiterals": true,
36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true,
37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true,
38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true,
39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true,
40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true,
41 | "@aws-cdk/aws-route53-patters:useCertificate": true,
42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false,
43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true,
44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true,
45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true,
46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true,
47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true,
48 | "@aws-cdk/aws-redshift:columnId": true,
49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true,
50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true,
51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true,
52 | "@aws-cdk/aws-kms:aliasNameRef": true,
53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true,
54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true,
55 | "@aws-cdk/aws-efs:denyAnonymousAccess": true,
56 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true,
57 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true,
58 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true,
59 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true,
60 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true,
61 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true,
62 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true,
63 | "@aws-cdk/aws-cloudwatch-actions:changeLambdaPermissionLogicalIdForLambdaAction": true,
64 | "@aws-cdk/aws-codepipeline:crossAccountKeysDefaultValueToFalse": true,
65 | "@aws-cdk/aws-codepipeline:defaultPipelineTypeToV2": true,
66 | "@aws-cdk/aws-kms:reduceCrossAccountRegionPolicyScope": true,
67 | "@aws-cdk/aws-eks:nodegroupNameAttribute": true,
68 | "@aws-cdk/aws-ec2:ebsDefaultGp3Volume": true
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/jest.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | testEnvironment: 'node',
3 | roots: ['/test'],
4 | testMatch: ['**/*.test.ts'],
5 | transform: {
6 | '^.+\\.tsx?$': 'ts-jest'
7 | }
8 | };
9 |
--------------------------------------------------------------------------------
/lambda/index.py:
--------------------------------------------------------------------------------
1 | import json
2 | import boto3
3 | import logging
4 | import time
5 |
6 | logger = logging.getLogger()
7 | logger.setLevel(logging.INFO)
8 |
9 | opensearch = boto3.client('opensearch')
10 |
11 | def on_event(event, context):
12 | print(event)
13 | request_type = event['RequestType']
14 | if request_type == 'Create': return on_create(event)
15 | if request_type == 'Update': return on_update(event)
16 | if request_type == 'Delete': return on_delete(event)
17 | raise Exception("Invalid request type: %s" % request_type)
18 |
19 | def on_create(event):
20 | props = event["ResourceProperties"]
21 | print("create new resource with props %s" % props)
22 |
23 | domain_arn = props['DomainArn']
24 | domain_name = domain_arn.split('/')[-1]
25 | DEFAULT_REGION = props['DEFAULT_REGION']
26 | VERSION = props['VERSION']
27 |
28 | nori_pkg_id = {
29 | 'us-east-1': {
30 | '2.3': 'G196105221',
31 | '2.5': 'G240285063',
32 | '2.7': 'G16029449',
33 | '2.9': 'G60209291',
34 | '2.11': 'G181660338'
35 | },
36 | 'us-west-2': {
37 | '2.3': 'G94047474',
38 | '2.5': 'G138227316',
39 | '2.7': 'G182407158',
40 | '2.9': 'G226587000',
41 | '2.11': 'G79602591'
42 | }
43 | }
44 |
45 | package_id = nori_pkg_id[DEFAULT_REGION][VERSION]
46 | print(domain_arn, domain_name, package_id)
47 |
48 | try:
49 | response = opensearch.associate_package(
50 | PackageID=package_id,
51 | DomainName=domain_name
52 | )
53 | print(f"Successfully initiated association of package {package_id} with domain {domain_name}")
54 | except opensearch.exceptions.BaseException as e:
55 | logger.error(f"Failed to associate package: {e}")
56 | raise e
57 |
58 | physical_id = f"AssociatePackage-{domain_name}-{package_id}"
59 | return { 'PhysicalResourceId': physical_id }
60 |
61 | def on_update(event):
62 | physical_id = event["PhysicalResourceId"]
63 | props = event["ResourceProperties"]
64 | print("update resource %s with props %s" % (physical_id, props))
65 | return { 'PhysicalResourceId': physical_id }
66 |
67 | def on_delete(event):
68 | physical_id = event["PhysicalResourceId"]
69 | print("delete resource %s" % physical_id)
70 | # Optionally add dissociation logic if required
71 | return { 'PhysicalResourceId': physical_id }
72 |
73 | """
74 | def is_complete(event, context):
75 | props = event["ResourceProperties"]
76 | domain_arn = props['DomainArn']
77 | domain_name = domain_arn.split('/')[-1]
78 | DEFAULT_REGION = props['DEFAULT_REGION']
79 | VERSION = props['VERSION']
80 |
81 | nori_pkg_id = {
82 | 'us-east-1': {
83 | '2.3': 'G196105221',
84 | '2.5': 'G240285063',
85 | '2.7': 'G16029449',
86 | '2.9': 'G60209291',
87 | '2.11': 'G181660338'
88 | },
89 | 'us-west-2': {
90 | '2.3': 'G94047474',
91 | '2.5': 'G138227316',
92 | '2.7': 'G182407158',
93 | '2.9': 'G226587000',
94 | '2.11': 'G79602591'
95 | }
96 | }
97 |
98 | package_id = nori_pkg_id[DEFAULT_REGION][VERSION]
99 | print(f"Checking association status for package {package_id} on domain {domain_name}")
100 |
101 | response = opensearch.list_packages_for_domain(
102 | DomainName=domain_name,
103 | MaxResults=1
104 | )
105 |
106 | if response['DomainPackageDetailsList'][0]['DomainPackageStatus'] == "ACTIVE":
107 | is_ready = True
108 | else:
109 | in_ready = False
110 |
111 | print(f"Is package {package_id} active on domain {domain_name}? {is_ready}")
112 |
113 | return { 'IsComplete': is_ready }
114 | """
--------------------------------------------------------------------------------
/lib/customResourceStack.ts:
--------------------------------------------------------------------------------
1 | import * as cdk from 'aws-cdk-lib';
2 | import * as lambda from 'aws-cdk-lib/aws-lambda';
3 | import * as cr from 'aws-cdk-lib/custom-resources';
4 | import * as iam from 'aws-cdk-lib/aws-iam';
5 | import { Construct } from 'constructs';
6 |
7 | interface CustomResourceStackProps extends cdk.StackProps {
8 | env: {
9 | account: string | undefined;
10 | region: string;
11 | };
12 | }
13 |
14 | export class CustomResourceStack extends cdk.Stack {
15 | constructor(scope: Construct, id: string, props: CustomResourceStackProps) {
16 | super(scope, id, props);
17 |
18 | const domainArn = cdk.Fn.importValue('DomainArn');
19 | //const DEFAULT_REGION = this.node.tryGetContext('DEFAULT_REGION');
20 | const DEFAULT_REGION = props.env.region;
21 |
22 | const lambdaRole = new iam.Role(this, 'LambdaRole', {
23 | assumedBy: new iam.ServicePrincipal('lambda.amazonaws.com'),
24 | });
25 |
26 | lambdaRole.addManagedPolicy(
27 | iam.ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaBasicExecutionRole')
28 | );
29 |
30 | lambdaRole.addToPolicy(new iam.PolicyStatement({
31 | actions: ['es:AssociatePackage', 'es:DescribePackages', 'es:DescribeDomain', 'logs:CreateLogGroup', 'logs:CreateLogStream', 'logs:PutLogEvents'],
32 | resources: ["*"], // domainArn
33 | }));
34 |
35 |
36 | const customResourceLambda = new lambda.Function(this, 'CustomResourceLambda', {
37 | runtime: lambda.Runtime.PYTHON_3_9,
38 | handler: 'index.on_event',
39 | code: lambda.Code.fromAsset('lambda'),
40 | timeout: cdk.Duration.minutes(15),
41 | role: lambdaRole
42 | });
43 |
44 | //const isCompleteLambda = new lambda.Function(this, 'IsCompleteLambda', {
45 | // runtime: lambda.Runtime.PYTHON_3_9,
46 | // handler: 'index.is_complete',
47 | // code: lambda.Code.fromAsset('lambda'),
48 | // timeout: cdk.Duration.minutes(15),
49 | // role: lambdaRole
50 | //});
51 |
52 |
53 | const customResourceProvider = new cr.Provider(this, 'CustomResourceProvider', {
54 | onEventHandler: customResourceLambda,
55 | //isCompleteHandler: isCompleteLambda,
56 | //queryInterval: cdk.Duration.minutes(1),
57 | //totalTimeout: cdk.Duration.hours(1),
58 | });
59 |
60 |
61 | const customResource = new cdk.CustomResource(this, 'AssociateNoriPackage', {
62 | serviceToken: customResourceProvider.serviceToken,
63 | properties: {
64 | DomainArn: domainArn,
65 | DEFAULT_REGION: DEFAULT_REGION,
66 | VERSION: "2.11",
67 | },
68 | });
69 | }
70 | }
--------------------------------------------------------------------------------
/lib/ec2Stack/ec2Stack.ts:
--------------------------------------------------------------------------------
1 | import { Stack, StackProps, RemovalPolicy, aws_s3 as s3, } from 'aws-cdk-lib';
2 | import { Construct } from 'constructs';
3 | import * as cdk from 'aws-cdk-lib';
4 | import * as ec2 from 'aws-cdk-lib/aws-ec2';
5 | import * as iam from 'aws-cdk-lib/aws-iam';
6 | import * as fs from 'fs';
7 | import * as path from 'path';
8 |
9 | export class EC2Stack extends Stack {
10 | constructor(scope: Construct, id: string, props?: StackProps) {
11 | super(scope, id, props);
12 |
13 | // IAM Role to access EC2
14 | const instanceRole = new iam.Role(this, 'InstanceRole', {
15 | assumedBy: new iam.ServicePrincipal('ec2.amazonaws.com'),
16 | managedPolicies: [
17 | iam.ManagedPolicy.fromAwsManagedPolicyName('AdministratorAccess'),
18 | ],
19 | });
20 |
21 | // Network setting for EC2
22 | const defaultVpc = ec2.Vpc.fromLookup(this, 'VPC', {
23 | isDefault: true,
24 | });
25 |
26 | const chatbotAppSecurityGroup = new ec2.SecurityGroup(this, 'chatbotAppSecurityGroup', {
27 | vpc: defaultVpc,
28 | });
29 | chatbotAppSecurityGroup.addIngressRule(
30 | ec2.Peer.anyIpv4(),
31 | ec2.Port.tcp(80),
32 | 'httpIpv4',
33 | );
34 | chatbotAppSecurityGroup.addIngressRule(
35 | ec2.Peer.anyIpv4(),
36 | ec2.Port.tcp(22),
37 | 'sshIpv4',
38 | );
39 |
40 | // set AMI
41 | const machineImage = ec2.MachineImage.fromSsmParameter(
42 | '/aws/service/canonical/ubuntu/server/focal/stable/current/amd64/hvm/ebs-gp2/ami-id'
43 | );
44 |
45 | // set User Data
46 | const userData = ec2.UserData.forLinux();
47 | const userDataScript = fs.readFileSync(path.join(__dirname, 'userdata.sh'), 'utf8');
48 | userData.addCommands(userDataScript);
49 |
50 | // EC2 instance
51 | const chatbotAppInstance = new ec2.Instance(this, 'chatbotAppInstance', {
52 | instanceType: new ec2.InstanceType('m5.large'),
53 | machineImage: machineImage,
54 | vpc: defaultVpc,
55 | securityGroup: chatbotAppSecurityGroup,
56 | role: instanceRole,
57 | userData: userData,
58 | });
59 |
60 | new cdk.CfnOutput(this, 'chatbotAppUrl', {
61 | value: `http://${chatbotAppInstance.instancePublicIp}/`,
62 | description: 'The URL of chatbot instance generated by AWS Advanced RAG Workshop',
63 | exportName: 'chatbotAppUrl',
64 | });
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/lib/ec2Stack/userdata.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Update packages
4 | sudo apt-get update -y
5 |
6 | # Install required packages
7 | sudo apt-get install -y ec2-instance-connect
8 | sudo apt-get install -y git
9 | sudo apt-get install -y python3-pip
10 | sudo apt-get install -y python3-venv
11 |
12 | # Clone repository
13 | cd /home/ubuntu
14 | sudo git clone https://github.com/aws-samples/multi-modal-chatbot-with-advanced-rag.git
15 |
16 | # Create virtual environment
17 | sudo python3 -m venv --copies /home/ubuntu/my_env
18 | sudo chown -R ubuntu:ubuntu /home/ubuntu/my_env
19 | source /home/ubuntu/my_env/bin/activate
20 |
21 | # Install dependencies
22 | cd multi-modal-chatbot-with-advanced-rag/application
23 | sudo apt install -y cargo
24 | pip3 install -r requirements.txt
25 |
26 | # Create systemd service
27 | sudo sh -c "cat < /etc/systemd/system/streamlit.service
28 | [Unit]
29 | Description=Streamlit App
30 | After=network.target
31 |
32 | [Service]
33 | User=ubuntu
34 | Environment='AWS_DEFAULT_REGION=us-west-2'
35 | WorkingDirectory=/home/ubuntu/multi-modal-chatbot-with-advanced-rag/application
36 | ExecStartPre=/bin/bash -c 'sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8501'
37 | ExecStart=/bin/bash -c 'source /home/ubuntu/my_env/bin/activate && streamlit run streamlit.py --server.port 8501'
38 | Restart=always
39 |
40 | [Install]
41 | WantedBy=multi-user.target
42 | EOF"
43 |
44 | # Reload systemd daemon and start the service
45 | sudo systemctl daemon-reload
46 | sudo systemctl enable streamlit
47 | sudo systemctl start streamlit
--------------------------------------------------------------------------------
/lib/openSearchStack.ts:
--------------------------------------------------------------------------------
1 | import { Duration, Stack, StackProps, SecretValue } from "aws-cdk-lib";
2 | import { Construct } from "constructs";
3 |
4 | import * as fs from "fs";
5 | import * as cdk from "aws-cdk-lib";
6 | import * as opensearch from "aws-cdk-lib/aws-opensearchservice";
7 | import * as ec2 from "aws-cdk-lib/aws-ec2";
8 | import * as ssm from "aws-cdk-lib/aws-ssm";
9 | import { PolicyStatement, AnyPrincipal } from "aws-cdk-lib/aws-iam";
10 | import * as secretsmanager from "aws-cdk-lib/aws-secretsmanager";
11 |
12 | export class OpensearchStack extends Stack {
13 | constructor(scope: Construct, id: string, props?: StackProps) {
14 | super(scope, id, props);
15 |
16 | const domainName = `rag-hol-mydomain`;
17 |
18 | const opensearch_user_id = "raguser";
19 |
20 | const user_id_pm = new ssm.StringParameter(this, "opensearch_user_id", {
21 | parameterName: "opensearch_user_id",
22 | stringValue: "raguser",
23 | });
24 |
25 | const opensearch_user_password = "pwkey";
26 |
27 | const secret = new secretsmanager.Secret(this, "domain-creds", {
28 | generateSecretString: {
29 | secretStringTemplate: JSON.stringify({
30 | "es.net.http.auth.user": opensearch_user_id,
31 | }),
32 | generateStringKey: opensearch_user_password,
33 | excludeCharacters: '"\'',
34 | },
35 | secretName: "opensearch_user_password",
36 | });
37 |
38 | const domain = new opensearch.Domain(this, "Domain", {
39 | version: opensearch.EngineVersion.OPENSEARCH_2_11,
40 | domainName: domainName,
41 | capacity: {
42 | masterNodes: 2,
43 | multiAzWithStandbyEnabled: false,
44 | },
45 | ebs: {
46 | volumeSize: 100,
47 | volumeType: ec2.EbsDeviceVolumeType.GP3,
48 | enabled: true,
49 | },
50 | enforceHttps: true,
51 | nodeToNodeEncryption: true,
52 | encryptionAtRest: { enabled: true },
53 | fineGrainedAccessControl: {
54 | masterUserName: opensearch_user_id,
55 | masterUserPassword: secret.secretValueFromJson(
56 | opensearch_user_password
57 | ),
58 | },
59 | });
60 |
61 | domain.addAccessPolicies(
62 | new PolicyStatement({
63 | actions: ["es:*"],
64 | principals: [new AnyPrincipal()],
65 | resources: [domain.domainArn + "/*"],
66 | })
67 | );
68 |
69 | const domain_endpoint_pm = new ssm.StringParameter(
70 | this,
71 | "opensearch_domain_endpoint",
72 | {
73 | parameterName: "opensearch_domain_endpoint",
74 | stringValue: domain.domainEndpoint,
75 | }
76 | );
77 |
78 | new cdk.CfnOutput(this, "OpensearchDomainEndpoint", {
79 | value: domain.domainEndpoint,
80 | description: "OpenSearch Domain Endpoint",
81 | });
82 |
83 | new cdk.CfnOutput(this, "parameter store user id", {
84 | value: user_id_pm.parameterArn,
85 | description: "parameter store user id",
86 | });
87 |
88 | new cdk.CfnOutput(this, "secrets manager user pw", {
89 | value: secret.secretName,
90 | description: "secrets manager user pw",
91 | });
92 |
93 | new cdk.CfnOutput(this, 'DomainArn', {
94 | value: domain.domainArn,
95 | exportName: 'DomainArn'
96 | });
97 |
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/lib/rerankerStack/RerankerStack.assets.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "36.0.0",
3 | "files": {
4 | "6eee82909dec1440393d6ae4b11ef1c1191320d48cb02287c623f05e28bf7caa": {
5 | "source": {
6 | "path": "RerankerStack.template.json",
7 | "packaging": "file"
8 | },
9 | "destinations": {
10 | "current_account-current_region": {
11 | "bucketName": "cdk-hnb659fds-assets-${AWS::AccountId}-${AWS::Region}",
12 | "objectKey": "6eee82909dec1440393d6ae4b11ef1c1191320d48cb02287c623f05e28bf7caa.json",
13 | "assumeRoleArn": "arn:${AWS::Partition}:iam::${AWS::AccountId}:role/cdk-hnb659fds-file-publishing-role-${AWS::AccountId}-${AWS::Region}"
14 | }
15 | }
16 | }
17 | },
18 | "dockerImages": {}
19 | }
--------------------------------------------------------------------------------
/lib/rerankerStack/RerankerStack.template.json:
--------------------------------------------------------------------------------
1 | {
2 | "Description": "Description: (uksb-1tupboc45) (version:0.1.198) (tag:C1:0,C2:0,C3:0,C4:0,C5:0,C6:1,C7:0,C8:0) ",
3 | "Resources": {
4 | "RerankerRoleDAAED19A": {
5 | "Type": "AWS::IAM::Role",
6 | "Properties": {
7 | "AssumeRolePolicyDocument": {
8 | "Statement": [
9 | {
10 | "Action": "sts:AssumeRole",
11 | "Effect": "Allow",
12 | "Principal": {
13 | "Service": "sagemaker.amazonaws.com"
14 | }
15 | }
16 | ],
17 | "Version": "2012-10-17"
18 | },
19 | "ManagedPolicyArns": [
20 | {
21 | "Fn::Join": [
22 | "",
23 | [
24 | "arn:",
25 | {
26 | "Ref": "AWS::Partition"
27 | },
28 | ":iam::aws:policy/AmazonSageMakerFullAccess"
29 | ]
30 | ]
31 | }
32 | ]
33 | },
34 | "Metadata": {
35 | "aws:cdk:path": "RerankerStack/Reranker/Role/Resource"
36 | }
37 | },
38 | "RerankerRoleDefaultPolicy6BB7CA84": {
39 | "Type": "AWS::IAM::Policy",
40 | "Properties": {
41 | "PolicyDocument": {
42 | "Statement": [
43 | {
44 | "Action": [
45 | "ecr:BatchCheckLayerAvailability",
46 | "ecr:BatchGetImage",
47 | "ecr:GetDownloadUrlForLayer"
48 | ],
49 | "Effect": "Allow",
50 | "Resource": {
51 | "Fn::Join": [
52 | "",
53 | [
54 | "arn:",
55 | {
56 | "Ref": "AWS::Partition"
57 | },
58 | ":ecr:",
59 | {
60 | "Ref": "AWS::Region"
61 | },
62 | ":",
63 | {
64 | "Fn::FindInMap": [
65 | "DlcRepositoryAccountMap",
66 | {
67 | "Ref": "AWS::Region"
68 | },
69 | "value"
70 | ]
71 | },
72 | ":repository/huggingface-pytorch-inference"
73 | ]
74 | ]
75 | }
76 | },
77 | {
78 | "Action": "ecr:GetAuthorizationToken",
79 | "Effect": "Allow",
80 | "Resource": "*"
81 | }
82 | ],
83 | "Version": "2012-10-17"
84 | },
85 | "PolicyName": "RerankerRoleDefaultPolicy6BB7CA84",
86 | "Roles": [
87 | {
88 | "Ref": "RerankerRoleDAAED19A"
89 | }
90 | ]
91 | },
92 | "Metadata": {
93 | "aws:cdk:path": "RerankerStack/Reranker/Role/DefaultPolicy/Resource"
94 | }
95 | },
96 | "DongjinkrkorerankermodelReranker": {
97 | "Type": "AWS::SageMaker::Model",
98 | "Properties": {
99 | "ExecutionRoleArn": {
100 | "Fn::GetAtt": [
101 | "RerankerRoleDAAED19A",
102 | "Arn"
103 | ]
104 | },
105 | "PrimaryContainer": {
106 | "Environment": {
107 | "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
108 | "SAGEMAKER_REGION": {
109 | "Ref": "AWS::Region"
110 | },
111 | "HF_MODEL_ID": "Dongjin-kr/ko-reranker",
112 | "HF_TASK": "text-classification"
113 | },
114 | "Image": {
115 | "Fn::Join": [
116 | "",
117 | [
118 | {
119 | "Fn::FindInMap": [
120 | "DlcRepositoryAccountMap",
121 | {
122 | "Ref": "AWS::Region"
123 | },
124 | "value"
125 | ]
126 | },
127 | ".dkr.ecr.",
128 | {
129 | "Ref": "AWS::Region"
130 | },
131 | ".",
132 | {
133 | "Ref": "AWS::URLSuffix"
134 | },
135 | "/huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
136 | ]
137 | ]
138 | },
139 | "Mode": "SingleModel"
140 | },
141 | "Tags": [
142 | {
143 | "Key": "modelId",
144 | "Value": "Dongjin-kr/ko-reranker"
145 | }
146 | ]
147 | },
148 | "Metadata": {
149 | "aws:cdk:path": "RerankerStack/Dongjin-kr-ko-reranker-model-Reranker"
150 | }
151 | },
152 | "EndpointConfigReranker": {
153 | "Type": "AWS::SageMaker::EndpointConfig",
154 | "Properties": {
155 | "ProductionVariants": [
156 | {
157 | "ContainerStartupHealthCheckTimeoutInSeconds": 600,
158 | "InitialInstanceCount": 1,
159 | "InitialVariantWeight": 1,
160 | "InstanceType": "ml.g5.xlarge",
161 | "ModelName": {
162 | "Fn::GetAtt": [
163 | "DongjinkrkorerankermodelReranker",
164 | "ModelName"
165 | ]
166 | },
167 | "VariantName": "AllTraffic"
168 | }
169 | ]
170 | },
171 | "DependsOn": [
172 | "DongjinkrkorerankermodelReranker"
173 | ],
174 | "Metadata": {
175 | "aws:cdk:path": "RerankerStack/EndpointConfig-Reranker"
176 | }
177 | },
178 | "DongjinkrkorerankerendpointReranker": {
179 | "Type": "AWS::SageMaker::Endpoint",
180 | "Properties": {
181 | "EndpointConfigName": {
182 | "Fn::GetAtt": [
183 | "EndpointConfigReranker",
184 | "EndpointConfigName"
185 | ]
186 | },
187 | "EndpointName": "reranker",
188 | "Tags": [
189 | {
190 | "Key": "modelId",
191 | "Value": "Dongjin-kr/ko-reranker"
192 | }
193 | ]
194 | },
195 | "DependsOn": [
196 | "EndpointConfigReranker"
197 | ],
198 | "Metadata": {
199 | "aws:cdk:path": "RerankerStack/Dongjin-kr-ko-reranker-endpoint-Reranker"
200 | }
201 | }
202 | },
203 | "Mappings": {
204 | "DlcRepositoryAccountMap": {
205 | "ap-east-1": {
206 | "value": "871362719292"
207 | },
208 | "ap-northeast-1": {
209 | "value": "763104351884"
210 | },
211 | "ap-northeast-2": {
212 | "value": "763104351884"
213 | },
214 | "ap-south-1": {
215 | "value": "763104351884"
216 | },
217 | "ap-south-2": {
218 | "value": "772153158452"
219 | },
220 | "ap-southeast-1": {
221 | "value": "763104351884"
222 | },
223 | "ap-southeast-2": {
224 | "value": "763104351884"
225 | },
226 | "ap-southeast-3": {
227 | "value": "907027046896"
228 | },
229 | "ap-southeast-4": {
230 | "value": "457447274322"
231 | },
232 | "ca-central-1": {
233 | "value": "763104351884"
234 | },
235 | "cn-north-1": {
236 | "value": "727897471807"
237 | },
238 | "cn-northwest-1": {
239 | "value": "727897471807"
240 | },
241 | "eu-central-1": {
242 | "value": "763104351884"
243 | },
244 | "eu-central-2": {
245 | "value": "380420809688"
246 | },
247 | "eu-north-1": {
248 | "value": "763104351884"
249 | },
250 | "eu-south-1": {
251 | "value": "692866216735"
252 | },
253 | "eu-south-2": {
254 | "value": "503227376785"
255 | },
256 | "eu-west-1": {
257 | "value": "763104351884"
258 | },
259 | "eu-west-2": {
260 | "value": "763104351884"
261 | },
262 | "eu-west-3": {
263 | "value": "763104351884"
264 | },
265 | "me-central-1": {
266 | "value": "914824155844"
267 | },
268 | "me-south-1": {
269 | "value": "217643126080"
270 | },
271 | "sa-east-1": {
272 | "value": "763104351884"
273 | },
274 | "us-east-1": {
275 | "value": "763104351884"
276 | },
277 | "us-east-2": {
278 | "value": "763104351884"
279 | },
280 | "us-west-1": {
281 | "value": "763104351884"
282 | },
283 | "us-west-2": {
284 | "value": "763104351884"
285 | }
286 | }
287 | },
288 | "Conditions": {
289 | "CDKMetadataAvailable": {
290 | "Fn::Or": [
291 | {
292 | "Fn::Or": [
293 | {
294 | "Fn::Equals": [
295 | {
296 | "Ref": "AWS::Region"
297 | },
298 | "af-south-1"
299 | ]
300 | },
301 | {
302 | "Fn::Equals": [
303 | {
304 | "Ref": "AWS::Region"
305 | },
306 | "ap-east-1"
307 | ]
308 | },
309 | {
310 | "Fn::Equals": [
311 | {
312 | "Ref": "AWS::Region"
313 | },
314 | "ap-northeast-1"
315 | ]
316 | },
317 | {
318 | "Fn::Equals": [
319 | {
320 | "Ref": "AWS::Region"
321 | },
322 | "ap-northeast-2"
323 | ]
324 | },
325 | {
326 | "Fn::Equals": [
327 | {
328 | "Ref": "AWS::Region"
329 | },
330 | "ap-south-1"
331 | ]
332 | },
333 | {
334 | "Fn::Equals": [
335 | {
336 | "Ref": "AWS::Region"
337 | },
338 | "ap-southeast-1"
339 | ]
340 | },
341 | {
342 | "Fn::Equals": [
343 | {
344 | "Ref": "AWS::Region"
345 | },
346 | "ap-southeast-2"
347 | ]
348 | },
349 | {
350 | "Fn::Equals": [
351 | {
352 | "Ref": "AWS::Region"
353 | },
354 | "ca-central-1"
355 | ]
356 | },
357 | {
358 | "Fn::Equals": [
359 | {
360 | "Ref": "AWS::Region"
361 | },
362 | "cn-north-1"
363 | ]
364 | },
365 | {
366 | "Fn::Equals": [
367 | {
368 | "Ref": "AWS::Region"
369 | },
370 | "cn-northwest-1"
371 | ]
372 | }
373 | ]
374 | },
375 | {
376 | "Fn::Or": [
377 | {
378 | "Fn::Equals": [
379 | {
380 | "Ref": "AWS::Region"
381 | },
382 | "eu-central-1"
383 | ]
384 | },
385 | {
386 | "Fn::Equals": [
387 | {
388 | "Ref": "AWS::Region"
389 | },
390 | "eu-north-1"
391 | ]
392 | },
393 | {
394 | "Fn::Equals": [
395 | {
396 | "Ref": "AWS::Region"
397 | },
398 | "eu-south-1"
399 | ]
400 | },
401 | {
402 | "Fn::Equals": [
403 | {
404 | "Ref": "AWS::Region"
405 | },
406 | "eu-west-1"
407 | ]
408 | },
409 | {
410 | "Fn::Equals": [
411 | {
412 | "Ref": "AWS::Region"
413 | },
414 | "eu-west-2"
415 | ]
416 | },
417 | {
418 | "Fn::Equals": [
419 | {
420 | "Ref": "AWS::Region"
421 | },
422 | "eu-west-3"
423 | ]
424 | },
425 | {
426 | "Fn::Equals": [
427 | {
428 | "Ref": "AWS::Region"
429 | },
430 | "il-central-1"
431 | ]
432 | },
433 | {
434 | "Fn::Equals": [
435 | {
436 | "Ref": "AWS::Region"
437 | },
438 | "me-central-1"
439 | ]
440 | },
441 | {
442 | "Fn::Equals": [
443 | {
444 | "Ref": "AWS::Region"
445 | },
446 | "me-south-1"
447 | ]
448 | },
449 | {
450 | "Fn::Equals": [
451 | {
452 | "Ref": "AWS::Region"
453 | },
454 | "sa-east-1"
455 | ]
456 | }
457 | ]
458 | },
459 | {
460 | "Fn::Or": [
461 | {
462 | "Fn::Equals": [
463 | {
464 | "Ref": "AWS::Region"
465 | },
466 | "us-east-1"
467 | ]
468 | },
469 | {
470 | "Fn::Equals": [
471 | {
472 | "Ref": "AWS::Region"
473 | },
474 | "us-east-2"
475 | ]
476 | },
477 | {
478 | "Fn::Equals": [
479 | {
480 | "Ref": "AWS::Region"
481 | },
482 | "us-west-1"
483 | ]
484 | },
485 | {
486 | "Fn::Equals": [
487 | {
488 | "Ref": "AWS::Region"
489 | },
490 | "us-west-2"
491 | ]
492 | }
493 | ]
494 | }
495 | ]
496 | }
497 | }
498 | }
--------------------------------------------------------------------------------
/lib/sagemakerNotebookStack/install_packages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 |
6 | sudo -u ec2-user -i <<'EOF'
7 |
8 | #source /home/ec2-user/anaconda3/bin/deactivate
9 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pip
10 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U awscli==1.33.16
11 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U botocore==1.34.134
12 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U boto3==1.34.134
13 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U sagemaker==2.224.1
14 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U langchain==0.2.6
15 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U langchain-community==0.2.6
16 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U langchain_aws==0.1.8
17 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U termcolor==2.4.0
18 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U transformers==4.41.2
19 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U librosa==0.10.2.post1
20 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U opensearch-py==2.6.0
21 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U sqlalchemy #==2.0.1
22 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pypdf==4.2.0
23 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U ipython==8.25.0
24 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U ipywidgets==8.1.3
25 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U anthropic==0.30.0
26 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U faiss-cpu==1.8.0.post1
27 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U jq==1.7.0
28 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pydantic==2.7.4
29 |
30 | sudo rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
31 | sudo yum -y update
32 | sudo yum install -y poppler-utils
33 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U lxml==5.2.2
34 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U kaleido==0.2.1
35 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U uvicorn==0.30.1
36 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pandas==2.2.2
37 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U numexpr==2.10.1
38 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pdf2image==1.17.0
39 |
40 | sudo amazon-linux-extras install libreoffice -y
41 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U "unstructured[all-docs]==0.13.2"
42 |
43 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U python-dotenv==1.0.1
44 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U llama-parse==0.4.4
45 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U pymupdf==1.24.7
46 |
47 | # Uninstall nltk 3.8.2 and install 3.8.1
48 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip uninstall -y nltk
49 | /home/ec2-user/anaconda3/envs/python3/bin/python -m pip install -U nltk==3.8.1
50 |
51 | EOF
52 |
--------------------------------------------------------------------------------
/lib/sagemakerNotebookStack/install_tesseract.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | yum-config-manager --disable centos-extras
5 |
6 | nohup bash <<'EOF' &
7 |
8 | echo "## Step 1"
9 | yum -y update
10 | yum -y upgrade
11 | yum install clang -y
12 | yum install libpng-devel libtiff-devel zlib-devel libwebp-devel libjpeg-turbo-devel wget tar gzip -y
13 | wget https://github.com/DanBloomberg/leptonica/releases/download/1.84.1/leptonica-1.84.1.tar.gz
14 | tar -zxvf leptonica-1.84.1.tar.gz
15 | cd leptonica-1.84.1
16 | ./configure
17 | make
18 | make install
19 |
20 | echo "## Step 2"
21 | cd ~
22 | yum install git-core libtool pkgconfig -y
23 | wget https://github.com/tesseract-ocr/tesseract/archive/5.3.1.tar.gz
24 | tar xzvf 5.3.1.tar.gz
25 | cd tesseract-5.3.1
26 | #git clone --depth 1 https://github.com/tesseract-ocr/tesseract.git tesseract-ocr
27 | #cd tesseract-ocr
28 | export PKG_CONFIG_PATH=/usr/local/lib/pkgconfig
29 | ./autogen.sh
30 | ./configure
31 | make
32 | make install
33 | ldconfig
34 |
35 | echo "## Step 3"
36 | cd /usr/local/share/tessdata
37 | wget https://github.com/tesseract-ocr/tessdata/raw/main/osd.traineddata
38 | wget https://github.com/tesseract-ocr/tessdata/raw/main/eng.traineddata
39 | wget https://github.com/tesseract-ocr/tessdata/raw/main/hin.traineddata
40 | wget https://github.com/tesseract-ocr/tessdata/raw/main/kor.traineddata
41 | wget https://github.com/tesseract-ocr/tessdata/raw/main/kor_vert.traineddata
42 | #wget https://github.com/tesseract-ocr/tessdata_best/raw/main/kor.traineddata
43 | #wget https://github.com/tesseract-ocr/tessdata_best/raw/main/kor_vert.traineddata
44 |
45 | echo "## Step 4"
46 | echo "export TESSDATA_PREFIX=/usr/local/share/tessdata" >> ~/.bash_profile
47 | echo "export TESSDATA_PREFIX=/usr/local/share/tessdata" >> /home/ec2-user/.bash_profile
48 |
49 | EOF
--------------------------------------------------------------------------------
/lib/sagemakerNotebookStack/sagemakerNotebookStack.ts:
--------------------------------------------------------------------------------
1 | import * as cdk from 'aws-cdk-lib';
2 | import { Construct } from 'constructs';
3 | import * as iam from 'aws-cdk-lib/aws-iam';
4 | import * as sagemaker from 'aws-cdk-lib/aws-sagemaker';
5 | import * as fs from 'fs';
6 | import * as path from 'path';
7 |
8 |
9 | export class SagemakerNotebookStack extends cdk.Stack {
10 | constructor(scope: Construct, id: string, props?: cdk.StackProps) {
11 | super(scope, id, props);
12 |
13 | // The code that defines your stack goes here
14 |
15 | // IAM Role
16 | const SageMakerNotebookinstanceRole = new iam.Role(this, 'SageMakerNotebookInstanceRole', {
17 | assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
18 | managedPolicies: [
19 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonBedrockFullAccess'),
20 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonOpenSearchServiceFullAccess'),
21 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'),
22 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSSMFullAccess'),
23 | iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonS3FullAccess'),
24 | iam.ManagedPolicy.fromAwsManagedPolicyName('SecretsManagerReadWrite')
25 | ],
26 | });
27 |
28 |
29 | // SageMaker Notebook Instance Lifecycle Configuration
30 |
31 | const onCreateScriptPath1 = path.join(__dirname, 'install_packages.sh')
32 | const onCreateScriptPath2 = path.join(__dirname, 'install_tesseract.sh')
33 | const onCreateScriptContent1 = fs.readFileSync(onCreateScriptPath1, 'utf-8')
34 | const onCreateScriptContent2 = fs.readFileSync(onCreateScriptPath2, 'utf-8')
35 |
36 | const combinedScriptContent = `${onCreateScriptContent1}\n${onCreateScriptContent2}`;
37 |
38 | const cfnNotebookInstanceLifecycleConfig = new sagemaker.CfnNotebookInstanceLifecycleConfig(this, 'MyCfnNotebookInstanceLifecycleConfig', /* all optional props */ {
39 | notebookInstanceLifecycleConfigName: 'notebookInstanceLifecycleConfig',
40 | onCreate: [{
41 | content: cdk.Fn.base64(combinedScriptContent),
42 | }],
43 | onStart: [],
44 | });
45 |
46 |
47 | // SageMaker Notebook Instance
48 |
49 | const cfnNotebookInstance = new sagemaker.CfnNotebookInstance(this, 'MyCfnNotebookInstance', {
50 | instanceType: 'ml.m5.xlarge',
51 | roleArn: SageMakerNotebookinstanceRole.roleArn,
52 |
53 | // the properties below are optional
54 | //acceleratorTypes: ['acceleratorTypes'],
55 | //additionalCodeRepositories: ['additionalCodeRepositories'],
56 | defaultCodeRepository: 'https://github.com/Jiyu-Kim/advanced-rag-workshop.git',
57 | directInternetAccess: 'Enabled',
58 | //instanceMetadataServiceConfiguration: {
59 | // minimumInstanceMetadataServiceVersion: 'minimumInstanceMetadataServiceVersion',
60 | //},
61 | //kmsKeyId: 'kmsKeyId',
62 | lifecycleConfigName: 'notebookInstanceLifecycleConfig',
63 | notebookInstanceName: 'advanced-rag-workshop-notebook-instance',
64 | //platformIdentifier: 'platformIdentifier',
65 | //rootAccess: 'rootAccess',
66 | //securityGroupIds: ['securityGroupIds'],
67 | //subnetId: 'subnetId',
68 | //tags: [{
69 | // key: 'key',
70 | // value: 'value',
71 | //}],
72 | volumeSizeInGb: 10,
73 | });
74 |
75 |
76 | }
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "cdk",
3 | "version": "0.1.0",
4 | "bin": {
5 | "cdk": "bin/cdk.js"
6 | },
7 | "scripts": {
8 | "build": "tsc",
9 | "watch": "tsc -w",
10 | "test": "jest",
11 | "cdk": "cdk"
12 | },
13 | "devDependencies": {
14 | "@types/jest": "^29.5.12",
15 | "@types/node": "^20.17.1",
16 | "aws-cdk": "2.91.0",
17 | "jest": "^29.7.0",
18 | "ts-jest": "^29.1.2",
19 | "ts-node": "^10.9.2",
20 | "typescript": "~5.4.5"
21 | },
22 | "dependencies": {
23 | "aws-cdk-lib": "^2.91.0",
24 | "constructs": "^10.0.0",
25 | "source-map-support": "^0.5.21"
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "module": "commonjs",
5 | "lib": [
6 | "es2020",
7 | "dom"
8 | ],
9 | "declaration": true,
10 | "strict": true,
11 | "noImplicitAny": true,
12 | "strictNullChecks": true,
13 | "noImplicitThis": true,
14 | "alwaysStrict": true,
15 | "noUnusedLocals": false,
16 | "noUnusedParameters": false,
17 | "noImplicitReturns": true,
18 | "noFallthroughCasesInSwitch": false,
19 | "inlineSourceMap": true,
20 | "inlineSources": true,
21 | "experimentalDecorators": true,
22 | "strictPropertyInitialization": false,
23 | "typeRoots": [
24 | "./node_modules/@types"
25 | ]
26 | },
27 | "exclude": [
28 | "node_modules",
29 | "cdk.out"
30 | ]
31 | }
32 |
--------------------------------------------------------------------------------