├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── img
├── gen-ai-demo-architecture.png
└── streamlit-web-ui.png
└── lab
├── app.py
├── lab-code.ipynb
└── utils
├── knowledge_base.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .ipynb_checkpoints
3 | __pycache__
4 | **.pyc
5 |
6 | .aws-sam
7 | **tmp*
8 |
9 | # Data files (same convention across exercises)
10 | data/
11 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Build Generative AI applications with Bedrock Knowledge Base and DeepSeek R1 Models
2 |
3 | This project demonstrates how to build a multi-functional chatbot that leverages AWS Bedrock Knowledge Base and multiple AI models including DeepSeek R1, Amazon Titan Nova Pro, and Claude 3.5 Sonnet.
4 |
5 |
6 | ### 1. Architecture Overview
7 |
8 | 
9 |
10 |
11 | ### 2. Application Components
12 |
13 | **2.1 Knowledge Base System**
14 | - AWS Bedrock Knowledge Base for document storage and retrieval
15 | - Amazon Titan Text Embeddings V2 for text vectorization
16 | - OpenSearch Serverless for vector search capabilities
17 |
18 | **2.2 Generation Models**
19 | - Amazon Titan Nova Pro - Supports text, images, files and knowledge base integration
20 | - Claude 3.5 Sonnet - Supports text, images, files and knowledge base integration
21 | - DeepSeek R1 Distill Qwen 1.5B - Supports text, files and knowledge base integration
22 |
23 | **2.3 Web Interface**
24 | - Built with Streamlit
25 | - Supports file uploads (images, PDFs, text files, etc)
26 | - Interactive query interface
27 | - Model parameter controls
28 |
29 | ### 3. Application Features:
30 |
31 | **3.1 Multi-Modal Input Support**
32 | - Text queries
33 | - Image uploads with automatic compression
34 | - Document uploads (PDF, TXT, DOCX, JSON, etc)
35 | - Preview capabilities
36 |
37 |
38 | **3.2 Knowledge Base Integration**
39 | - Document ingestion and chunking
40 | - Vector embeddings generation
41 | - Semantic search capabilities
42 |
43 | **3.3 Model Selection**
44 | - Choice between different AI models
45 | - Configurable model parameters
46 | - Model-specific optimizations
47 |
48 |
49 | ### 4. Setup and Usage
50 |
51 | *This Demo project is developed with AWS Sagemaker AI Studio, https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-jl.html*
52 |
53 | 4.1. Please follow [guidence steps](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-jl-user-guide.html) here, if you don't use AWS Sagemaker AI Studio.
54 |
55 | 4.2. Please clone this repo or upload the code to the juoyterlab notebook, and then starting with `lab/lab-code.ipynb`.
56 |
57 | 4.3. Run the cells of the notebook, to initialize all the related resources on your AWS account.
58 |
59 | 4.4. Once the resources are initilized, then please execute the below steps to access the Streamlit Web App:
60 |
61 | ```bash
62 | # Start a Terminal Session on Jupyterlab, then execute the below command:
63 |
64 | pip install streamlit
65 |
66 | streamlit run lab/app.py
67 |
68 | # To access the Streamlit Web Application via:
69 | # 1. Copy & paste the URL of the Sagemaker Studio Jupyterlab web URL, eg:
70 | https://xxxxxxxxxxxxx.studio.us-west-2.sagemaker.aws/jupyterlab/default/lab/.../lab-code.ipynb
71 |
72 | # 2. Update the url as below format, and access the url via a new browser tab:
73 | https://xxxxxxxxxxxxx.studio.us-west-2.sagemaker.aws/jupyterlab/default/proxy/8501/
74 |
75 | ```
76 | 4.5. Clean up via executing the last cell of the notebook.
77 |
78 |
79 | ### 5. Test the Application:
80 |
81 | 5.1. To verify if the KB is working as expected, please use any model in the APP with below query:
82 | ```
83 | Please find the next half sentence of the below sentence:
84 |
85 | In macroscopic closed systems, nonconservative forces act to change the internal energies of the system
86 | ```
87 | Expected answer should contain:
88 | ```
89 | and are often associated with the irreversible transformation of energy into heat.
90 | ```
91 |
92 |
93 | 5.2. You can disable `Use Knowledge Base` then test with general Q&A with different models.
94 |
95 | 5.3. You can upload images and/or files to test the Application.
96 |
97 | ```
98 | Please aware
99 | - with any attachment uploaded, then the APP will disable the KB by default.
100 | - DeepSeek as text based model, current does not support image session.
101 | ```
102 |
103 | 5.4 Streamlit Web UI Screenshot
104 |
105 | 
106 |
107 |
108 |
109 | ## License
110 | This library is licensed under the MIT-0 License. See the LICENSE file.
111 |
112 | ## Security
113 |
114 | See [CONTRIBUTING](https://github.com/aws-samples/generative-ai-workshop-build-a-multifunctional-chatbot-on-sagemaker/blob/main/CONTRIBUTING.md) for more information.
115 |
--------------------------------------------------------------------------------
/img/gen-ai-demo-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/generative-ai-workshop-build-a-multifunctional-chatbot-on-sagemaker/54b856b2a67d88b7f656a1bd23fcb80df8e3774c/img/gen-ai-demo-architecture.png
--------------------------------------------------------------------------------
/img/streamlit-web-ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/generative-ai-workshop-build-a-multifunctional-chatbot-on-sagemaker/54b856b2a67d88b7f656a1bd23fcb80df8e3774c/img/streamlit-web-ui.png
--------------------------------------------------------------------------------
/lab/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import boto3
3 | import json
4 | import os
5 | import base64
6 | from PIL import Image
7 | import io
8 | from datetime import datetime
9 |
10 | # Default compression settings
11 | DEFAULT_MAX_SIZE_KB = 2048
12 | DEFAULT_MAX_DIMENSION = 800
13 |
14 | class ImageProcessor:
15 | """Handle image processing operations"""
16 | @staticmethod
17 | def compress_image(image_bytes, max_size_kb=DEFAULT_MAX_SIZE_KB, max_dimension=DEFAULT_MAX_DIMENSION):
18 | """Compress image to specified size"""
19 | img = Image.open(io.BytesIO(image_bytes))
20 |
21 | # Convert to RGB if needed
22 | if img.mode in ('RGBA', 'P'):
23 | img = img.convert('RGB')
24 |
25 | # Resize image
26 | img.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
27 |
28 | # Compress with quality adjustment
29 | output = io.BytesIO()
30 | quality = 95
31 | while quality > 5:
32 | output.seek(0)
33 | output.truncate()
34 | img.save(output, format='JPEG', quality=quality, optimize=True)
35 | if len(output.getvalue()) <= max_size_kb * 1024:
36 | break
37 | quality -= 5
38 |
39 | return output.getvalue()
40 |
41 | class FileProcessor:
42 | """Handle different types of file processing"""
43 |
44 | SUPPORTED_IMAGE_TYPES = ['png', 'jpg', 'jpeg']
45 | SUPPORTED_TEXT_TYPES = ['txt', 'pdf', 'docx', 'json']
46 |
47 | @staticmethod
48 | def process_file(file, file_type):
49 | """Process different types of files"""
50 | if file_type in FileProcessor.SUPPORTED_IMAGE_TYPES:
51 | return FileProcessor.process_image(file)
52 | elif file_type in FileProcessor.SUPPORTED_TEXT_TYPES:
53 | return FileProcessor.process_text(file)
54 | else:
55 | raise ValueError(f"Unsupported file type: {file_type}")
56 |
57 | @staticmethod
58 | def process_image(file):
59 | """Process image files"""
60 | return {
61 | 'type': 'image',
62 | 'content': ImageProcessor.compress_image(file.getvalue()),
63 | 'original_name': file.name
64 | }
65 |
66 | @staticmethod
67 | def process_text(file):
68 | """Process text files"""
69 | if file.name.endswith('.pdf'):
70 | # Add PDF processing logic here
71 | import PyPDF2
72 | pdf_reader = PyPDF2.PdfReader(file)
73 | text = ""
74 | for page in pdf_reader.pages:
75 | text += page.extract_text()
76 |
77 | elif file.name.endswith('.docx'):
78 | # Add DOCX processing logic here
79 | from docx import Document
80 | document = Document(file)
81 | text = "\n".join([paragraph.text for paragraph in document.paragraphs])
82 |
83 | elif file.name.endswith('.json'):
84 | # Add JSON processing logic here
85 | text = json.load(file)
86 | text = json.dumps(text, indent=2)
87 |
88 | else: # txt files
89 | text = file.getvalue().decode('utf-8')
90 |
91 | return {
92 | 'type': 'text',
93 | 'content': text,
94 | 'original_name': file.name
95 | }
96 |
97 | class ConfigManager:
98 | """Manage application configuration"""
99 | @staticmethod
100 | def load_config():
101 | """Load configuration from file"""
102 | try:
103 | with open('lab/utils/tmp_config.json', 'r') as f:
104 | config = json.load(f)
105 | return (
106 | config['kb_id'],
107 | config['nova_pro_profile_arn'],
108 | config['nova_pro_model_id'],
109 | config['region_name'],
110 | config['sagemaker_endpoint'],
111 | config['sagemaker_ep_arn']
112 | )
113 | except Exception as e:
114 | st.error(f"Configuration error: {str(e)}")
115 | st.stop()
116 |
117 | class ModelManager:
118 | """Handle model operations and configurations"""
119 | def __init__(self):
120 | # Initialize AWS clients
121 | self.bedrock_agent = boto3.client('bedrock-agent-runtime')
122 | self.bedrock_runtime = boto3.client('bedrock-runtime')
123 | self.sagemaker_runtime = boto3.client('sagemaker-runtime', region_name='us-west-2')
124 |
125 | # Load configuration
126 | self.KB_ID, self.NOVA_PRO_PROFILE_ARN, self.NOVA_PRO_MODEL_ID, \
127 | self.REGION_NAME, self.SAGEMAKER_ENDPOINT, self.SAGEMAKER_EP_ARN = ConfigManager.load_config()
128 |
129 | # Define model configurations
130 | self.MODELS = {
131 | "Amazon Titan Nova Pro": {
132 | "model_arn": self.NOVA_PRO_PROFILE_ARN,
133 | "model_id": self.NOVA_PRO_MODEL_ID,
134 | "type": "nova",
135 | "supports_image": True
136 | },
137 | "Claude 3.5 Sonnet": {
138 | "model_arn": "arn:aws:bedrock:us-west-2:010117700078:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
139 | "model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
140 | "type": "claude",
141 | "supports_image": True
142 | },
143 | "DeepSeek R1 Distill Qwen 1.5B": {
144 | "model_arn": self.SAGEMAKER_EP_ARN,
145 | "endpoint_name": self.SAGEMAKER_ENDPOINT,
146 | "type": "deepseek",
147 | "supports_image": False,
148 | "supports_kb": True
149 | }
150 | }
151 |
152 | def generate_request(self, model_type, query, image_data=None, model_params=None):
153 | """Generate model-specific request body"""
154 | if model_type == "claude":
155 | content = [{"type": "text", "text": query}]
156 | if image_data:
157 | content.append({
158 | "type": "image",
159 | "source": {
160 | "type": "base64",
161 | "media_type": "image/jpeg",
162 | "data": image_data
163 | }
164 | })
165 | return {
166 | "anthropic_version": "bedrock-2023-05-31",
167 | "max_tokens": model_params.get("max_new_tokens", 1000),
168 | "temperature": model_params.get("temperature", 0.7),
169 | "messages": [{"role": "user", "content": content}]
170 | }
171 | elif model_type == "nova":
172 | content = []
173 | if image_data:
174 | content.append({
175 | "image": {
176 | "format": "jpeg",
177 | "source": {"bytes": image_data}
178 | }
179 | })
180 | content.append({"text": query})
181 | return {
182 | "schemaVersion": "messages-v1",
183 | "messages": [{"role": "user", "content": content}],
184 | "inferenceConfig": {
185 | "max_new_tokens": model_params.get("max_new_tokens", 1000),
186 | "temperature": model_params.get("temperature", 0.7),
187 | "top_k": model_params.get("top_k", 50),
188 | "top_p": model_params.get("top_p", 0.9)
189 | }
190 | }
191 | else: # deepseek
192 | return {
193 | "inputs": query,
194 | "parameters": {
195 | "max_new_tokens": model_params.get("max_new_tokens", 1000),
196 | "temperature": model_params.get("temperature", 0.7),
197 | "top_k": model_params.get("top_k", 50),
198 | "top_p": model_params.get("top_p", 0.9)
199 | }
200 | }
201 |
202 | def process_query(self, model_config, query, processed_files=None, use_kb=False, model_params=None):
203 | """Process query with selected model"""
204 | try:
205 | # For Deepseek model, check if there are any image files
206 | if model_config["type"] == "deepseek" and processed_files:
207 | has_image = any(file['type'] == 'image' for file in processed_files)
208 | if has_image:
209 | return ("The DeepSeek model only supports text processing. "
210 | "For queries involving images, please use either Claude or Nova Pro. "
211 | "If you wish to use DeepSeek, please upload text files only.", None)
212 |
213 | # Handle multiple files
214 | if processed_files:
215 | # For Claude model
216 | if model_config["type"] == "claude":
217 | content = [{"type": "text", "text": query}]
218 | text_contents = []
219 |
220 | for file in processed_files:
221 | if file['type'] == 'image':
222 | if not model_config["supports_image"]:
223 | continue
224 | # Convert bytes to base64 string
225 | image_base64 = base64.b64encode(file['content']).decode('utf-8')
226 | content.append({
227 | "type": "image",
228 | "source": {
229 | "type": "base64",
230 | "media_type": "image/jpeg",
231 | "data": image_base64
232 | }
233 | })
234 | else: # text files
235 | text_contents.append(f"\n=== Content from {file['original_name']} ===\n")
236 | content_preview = file['content'][:1000] + ("..." if len(file['content']) > 1000 else "")
237 | text_contents.append(content_preview)
238 |
239 | if text_contents:
240 | prompt = f"""Please analyze the following file content and answer the question:
241 |
242 | File Content:
243 | {' '.join(text_contents)}
244 |
245 | Question: {query}
246 |
247 | Please provide a clear and concise answer focusing on the main points of the content."""
248 | content[0]["text"] = prompt
249 |
250 | request_body = {
251 | "anthropic_version": "bedrock-2023-05-31",
252 | "max_tokens": model_params.get("max_new_tokens", 1000),
253 | "temperature": model_params.get("temperature", 0.7),
254 | "messages": [{"role": "user", "content": content}]
255 | }
256 |
257 | # For Nova model
258 | elif model_config["type"] == "nova":
259 | content = []
260 | text_contents = []
261 |
262 | for file in processed_files:
263 | if file['type'] == 'image':
264 | if not model_config["supports_image"]:
265 | continue
266 | # Convert bytes to base64 string for Nova
267 | image_bytes_base64 = base64.b64encode(file['content']).decode('utf-8')
268 | content.append({
269 | "image": {
270 | "format": "jpeg",
271 | "source": {"bytes": image_bytes_base64}
272 | }
273 | })
274 | else: # text files
275 | text_contents.append(f"\n=== Content from {file['original_name']} ===\n")
276 | content_preview = file['content'][:1000] + ("..." if len(file['content']) > 1000 else "")
277 | text_contents.append(content_preview)
278 |
279 | if text_contents:
280 | prompt = f"""Please analyze the following file content and answer the question:
281 |
282 | File Content:
283 | {' '.join(text_contents)}
284 |
285 | Question: {query}
286 |
287 | Please provide a clear and concise answer focusing on the main points of the content."""
288 | else:
289 | prompt = query
290 |
291 | content.append({"text": prompt})
292 | request_body = {
293 | "schemaVersion": "messages-v1",
294 | "messages": [{"role": "user", "content": content}],
295 | "inferenceConfig": {
296 | "max_new_tokens": model_params.get("max_new_tokens", 1000),
297 | "temperature": model_params.get("temperature", 0.7),
298 | "top_k": model_params.get("top_k", 50),
299 | "top_p": model_params.get("top_p", 0.9)
300 | }
301 | }
302 |
303 | # For Deepseek model
304 | else:
305 | # Combine all text content with the query
306 | text_contents = []
307 | for file in processed_files:
308 | if file['type'] == 'text':
309 | text_contents.append(f"\n=== Content from {file['original_name']} ===\n")
310 | content_preview = file['content'][:1000] + ("..." if len(file['content']) > 1000 else "")
311 | text_contents.append(content_preview)
312 |
313 | if text_contents:
314 | prompt = f"""Please analyze the following file content and answer the question:
315 |
316 | File Content:
317 | {' '.join(text_contents)}
318 |
319 | Question: {query}
320 |
321 | Please provide a clear and concise answer focusing on the main points of the content."""
322 | else:
323 | prompt = query
324 |
325 | request_body = {
326 | "inputs": prompt,
327 | "parameters": {
328 | "max_new_tokens": model_params.get("max_new_tokens", 1000),
329 | "temperature": model_params.get("temperature", 0.7),
330 | "top_k": model_params.get("top_k", 50),
331 | "top_p": model_params.get("top_p", 0.9)
332 | }
333 | }
334 |
335 | # Special handling for Deepseek with KB
336 | if model_config["type"] == "deepseek" and use_kb:
337 | try:
338 | kb_response = self.bedrock_agent.retrieve(
339 | knowledgeBaseId=self.KB_ID,
340 | retrievalQuery={"text": query},
341 | retrievalConfiguration={
342 | "vectorSearchConfiguration": {
343 | "numberOfResults": 3
344 | }
345 | }
346 | )
347 |
348 | contexts = []
349 | citations = []
350 |
351 | if 'retrievalResults' in kb_response:
352 | for result in kb_response['retrievalResults']:
353 | if 'content' in result:
354 | content = result['content']
355 | if isinstance(content, dict):
356 | text = content.get('text', '')
357 | else:
358 | text = str(content)
359 | contexts.append(text)
360 | citations.append({
361 | 'retrievedReferences': [{'content': text}]
362 | })
363 |
364 | if not contexts:
365 | return "No relevant content found in knowledge base.", None
366 |
367 | prompt = f"""Review the following AWS documentation and answer the user's question.
368 | Make sure your answer is:
369 | 1. Accurate according to the provided documentation
370 | 2. Directly addresses the user's question
371 | 3. Clear and easy to understand
372 |
373 | AWS Documentation:
374 | {' '.join(contexts)}
375 |
376 | User Question:
377 | {query}"""
378 |
379 | request_body = self.generate_request(
380 | model_config["type"],
381 | prompt,
382 | None,
383 | model_params
384 | )
385 |
386 | generated_text, _ = self.process_deepseek_query(model_config, request_body)
387 | return generated_text, citations
388 |
389 | except Exception as e:
390 | print(f"KB retrieval error: {str(e)}")
391 | raise
392 |
393 | # Process the request based on model type
394 | if model_config["type"] == "deepseek":
395 | return self.process_deepseek_query(model_config, request_body)
396 | else:
397 | return self.process_bedrock_query(model_config, request_body)
398 |
399 | # If no files, process as normal query
400 | elif use_kb and model_config.get("supports_kb", True):
401 | return self.process_kb_query(model_config, query)
402 | else:
403 | request_body = self.generate_request(
404 | model_config["type"],
405 | query,
406 | None,
407 | model_params
408 | )
409 |
410 | if model_config["type"] == "deepseek":
411 | return self.process_deepseek_query(model_config, request_body)
412 | else:
413 | return self.process_bedrock_query(model_config, request_body)
414 |
415 | except Exception as e:
416 | raise Exception(f"Query processing error: {str(e)}")
417 |
418 |
419 | def process_kb_query(self, model_config, query):
420 | """Process knowledge base query"""
421 | if model_config["type"] == "deepseek":
422 | # For Deepseek model, first retrieve from KB then process separately
423 | try:
424 | kb_response = self.bedrock_agent.retrieve(
425 | knowledgeBaseId=self.KB_ID,
426 | retrievalQuery={"text": query},
427 | retrievalConfiguration={
428 | "vectorSearchConfiguration": {
429 | "numberOfResults": 3
430 | }
431 | }
432 | )
433 |
434 | contexts = []
435 | citations = []
436 |
437 | if 'retrievalResults' in kb_response:
438 | for result in kb_response['retrievalResults']:
439 | if 'content' in result:
440 | content = result['content']
441 | if isinstance(content, dict):
442 | text = content.get('text', '')
443 | else:
444 | text = str(content)
445 | contexts.append(text)
446 | citations.append({
447 | 'retrievedReferences': [{'content': text}]
448 | })
449 |
450 | if not contexts:
451 | return "No relevant content found in knowledge base.", None
452 |
453 | prompt = f"""Review the following AWS documentation and answer the user's question.
454 | Make sure your answer is:
455 | 1. Accurate according to the provided documentation
456 | 2. Directly addresses the user's question
457 | 3. Clear and easy to understand
458 |
459 | AWS Documentation:
460 | {' '.join(contexts)}
461 |
462 | User Question:
463 | {query}"""
464 |
465 | request_body = {
466 | "inputs": prompt,
467 | "parameters": {
468 | "max_new_tokens": 1000,
469 | "temperature": 0.7,
470 | "top_k": 50,
471 | "top_p": 0.9
472 | }
473 | }
474 |
475 | generated_text, _ = self.process_deepseek_query(model_config, request_body)
476 | return generated_text, citations
477 |
478 | except Exception as e:
479 | print(f"KB retrieval error: {str(e)}")
480 | raise
481 |
482 | else:
483 | # For other models, use the original retrieve_and_generate
484 | response = self.bedrock_agent.retrieve_and_generate(
485 | input={"text": query},
486 | retrieveAndGenerateConfiguration={
487 | "type": "KNOWLEDGE_BASE",
488 | "knowledgeBaseConfiguration": {
489 | "knowledgeBaseId": self.KB_ID,
490 | "modelArn": model_config["model_arn"],
491 | "retrievalConfiguration": {
492 | "vectorSearchConfiguration": {
493 | "numberOfResults": 3
494 | }
495 | }
496 | }
497 | }
498 | )
499 | return response['output'], response.get('citations')
500 |
501 | def process_deepseek_query(self, model_config, request_body):
502 | """Process Deepseek model query"""
503 | response = self.sagemaker_runtime.invoke_endpoint(
504 | EndpointName=model_config["endpoint_name"],
505 | ContentType='application/json',
506 | Body=json.dumps(request_body)
507 | )
508 | response_text = response['Body'].read().decode('utf-8')
509 | response_body = json.loads(response_text)
510 |
511 | if isinstance(response_body, list):
512 | return response_body[0]['generated_text'], None
513 | return response_body.get('generated_text', str(response_body)), None
514 |
515 | def process_bedrock_query(self, model_config, request_body):
516 | """Process Bedrock model query"""
517 | response = self.bedrock_runtime.invoke_model(
518 | modelId=model_config["model_arn"],
519 | body=json.dumps(request_body)
520 | )
521 | response_body = json.loads(response['body'].read())
522 |
523 | if model_config["type"] == "claude":
524 | return response_body['content'][0]['text'], None
525 | return response_body["output"]["message"]["content"][0]["text"], None
526 |
527 | class UI:
528 | """Handle UI components and interactions"""
529 | def __init__(self, model_manager):
530 | self.model_manager = model_manager
531 | self.setup_page()
532 |
533 | def setup_page(self):
534 | """Setup main page layout"""
535 | st.set_page_config(page_title="MultiFunctional Chatbot with Bedrock KB and Deepseek Models", layout="wide")
536 | st.title("MultiFunctional Chatbot with Bedrock KB and Deepseek Models")
537 |
538 | # Main content
539 | self.selected_model = st.selectbox(
540 | "Select Generation Model",
541 | options=list(self.model_manager.MODELS.keys()),
542 | index=0
543 | )
544 | self.MODEL_CONFIG = self.model_manager.MODELS[self.selected_model]
545 |
546 | # Sidebar settings
547 | with st.sidebar:
548 | self.setup_sidebar()
549 |
550 | self.setup_main_content()
551 |
552 | def setup_sidebar(self):
553 | """Setup sidebar components"""
554 | st.markdown("### Image Compression Settings")
555 | self.max_size_kb = st.number_input(
556 | "Max Image Size (KB)",
557 | min_value=30,
558 | max_value=2048,
559 | value=DEFAULT_MAX_SIZE_KB
560 | )
561 | self.max_dimension = st.number_input(
562 | "Max Image Dimension",
563 | min_value=200,
564 | max_value=1600,
565 | value=DEFAULT_MAX_DIMENSION
566 | )
567 |
568 | # Add model parameter controls with both slider and number input
569 | st.markdown("### Model Parameters")
570 | self.model_params = {}
571 |
572 | # Max New Tokens
573 | self.model_params["max_new_tokens"] = st.slider(
574 | "Max New Tokens ",
575 | min_value=1,
576 | max_value=4096,
577 | value=1000
578 | )
579 | self.model_params["max_new_tokens"] = st.number_input(
580 | "Max New Tokens ",
581 | min_value=1,
582 | max_value=4096,
583 | value=self.model_params["max_new_tokens"]
584 | )
585 |
586 | # Temperature
587 | self.model_params["temperature"] = st.slider(
588 | "Temperature ",
589 | min_value=0.0,
590 | max_value=1.0,
591 | value=0.7,
592 | step=0.1
593 | )
594 | self.model_params["temperature"] = st.number_input(
595 | "Temperature ",
596 | min_value=0.0,
597 | max_value=1.0,
598 | value=self.model_params["temperature"],
599 | format="%.1f"
600 | )
601 |
602 | # Top K
603 | self.model_params["top_k"] = st.slider(
604 | "Top K ",
605 | min_value=1,
606 | max_value=500,
607 | value=50
608 | )
609 | self.model_params["top_k"] = st.number_input(
610 | "Top K ",
611 | min_value=1,
612 | max_value=500,
613 | value=self.model_params["top_k"]
614 | )
615 |
616 | # Top P
617 | self.model_params["top_p"] = st.slider(
618 | "Top P ",
619 | min_value=0.0,
620 | max_value=1.0,
621 | value=0.9,
622 | step=0.1
623 | )
624 | self.model_params["top_p"] = st.number_input(
625 | "Top P ",
626 | min_value=0.0,
627 | max_value=1.0,
628 | value=self.model_params["top_p"],
629 | format="%.1f"
630 | )
631 |
632 | def setup_main_content(self):
633 | """Setup main content area"""
634 | st.markdown("---")
635 |
636 | # File upload section
637 | st.subheader("File Upload")
638 | uploaded_files = st.file_uploader(
639 | "Upload files (images and/or documents)",
640 | type=FileProcessor.SUPPORTED_IMAGE_TYPES + FileProcessor.SUPPORTED_TEXT_TYPES,
641 | accept_multiple_files=True
642 | )
643 |
644 | self.processed_files = []
645 | if uploaded_files:
646 | for file in uploaded_files:
647 | try:
648 | file_type = file.name.split('.')[-1].lower()
649 | processed_file = FileProcessor.process_file(file, file_type)
650 | self.processed_files.append(processed_file)
651 |
652 | # Display preview based on file type
653 | if processed_file['type'] == 'image':
654 | st.image(file, caption=f"Original Image: {file.name}", use_container_width=True)
655 | compressed_size_kb = len(processed_file['content'])/1024
656 | st.info(f"Image compressed to {compressed_size_kb:.1f} KB")
657 | else:
658 | with st.expander(f"Preview: {file.name}"):
659 | st.text(processed_file['content'][:1000] + "..." if len(processed_file['content']) > 1000 else processed_file['content'])
660 |
661 | except Exception as e:
662 | st.error(f"Error processing file {file.name}: {str(e)}")
663 |
664 | self.query = st.text_area(
665 | "Enter your query:",
666 | height=150,
667 | placeholder="Enter your question here..."
668 | )
669 |
670 | st.markdown("---")
671 |
672 | self.use_kb = st.checkbox("Use Knowledge Base", value=True) if not self.processed_files else False
673 |
674 | # Create placeholders for button and status
675 | button_placeholder = st.empty()
676 | status_placeholder = st.empty()
677 |
678 | if button_placeholder.button("Search", type="primary", use_container_width=True, key="search_button"):
679 | # Clear the button and show status
680 | button_placeholder.empty()
681 | status_placeholder.info("Searching...")
682 |
683 | try:
684 | self.handle_search()
685 | finally:
686 | # Clear the status message
687 | status_placeholder.empty()
688 | # Show the button again
689 | button_placeholder.button(
690 | "Search",
691 | type="primary",
692 | use_container_width=True,
693 | key="search_button_after"
694 | )
695 |
696 | def handle_image_upload(self):
697 | """Handle image upload and processing"""
698 | try:
699 | if not self.MODEL_CONFIG["supports_image"]:
700 | st.error(f"Selected model ({self.selected_model}) does not support image analysis")
701 | st.stop()
702 |
703 | st.image(self.uploaded_file, caption="Original Image", use_container_width=True)
704 | self.image_bytes = ImageProcessor.compress_image(
705 | self.uploaded_file.getvalue(),
706 | max_size_kb=self.max_size_kb,
707 | max_dimension=self.max_dimension
708 | )
709 |
710 | compressed_size_kb = len(self.image_bytes)/1024
711 | if compressed_size_kb > self.max_size_kb:
712 | st.error(f"Image is still too large ({compressed_size_kb:.1f}KB)")
713 | st.stop()
714 |
715 | self.image_base64 = base64.b64encode(self.image_bytes).decode('utf-8')
716 | st.info(f"Image compressed to {compressed_size_kb:.1f} KB")
717 |
718 | except Exception as e:
719 | st.error(f"Error processing image: {str(e)}")
720 | st.stop()
721 |
722 | def handle_search(self):
723 | """Handle search button click"""
724 | if not self.query.strip():
725 | st.warning("Please enter a query")
726 | return
727 |
728 | try:
729 | generated_text, citations = self.model_manager.process_query(
730 | self.MODEL_CONFIG,
731 | self.query,
732 | getattr(self, 'processed_files', None),
733 | self.use_kb if not getattr(self, 'processed_files', None) else False,
734 | self.model_params
735 | )
736 |
737 | st.subheader("Generated Answer:")
738 | st.write(generated_text)
739 | self.display_citations(citations)
740 |
741 | except Exception as e:
742 | st.error(f"Error: {str(e)}")
743 | st.write("Full error details:")
744 | st.write(e)
745 |
746 | @staticmethod
747 | def display_citations(citations):
748 | """Display citation information"""
749 | if citations:
750 | st.subheader("Retrieved References:")
751 | for i, citation in enumerate(citations, 1):
752 | with st.expander(f"Reference {i}", expanded=False):
753 | if 'retrievedReferences' in citation:
754 | for ref in citation['retrievedReferences']:
755 | st.write(ref['content'])
756 | st.write("---")
757 |
758 | def main():
759 | """Main application entry point"""
760 | model_manager = ModelManager()
761 | ui = UI(model_manager)
762 |
763 | if __name__ == "__main__":
764 | main()
--------------------------------------------------------------------------------
/lab/lab-code.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "c3f87dc2-e36d-4d71-856e-d1e74034bbb8",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stdout",
11 | "output_type": "stream",
12 | "text": [
13 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
14 | "autogluon-multimodal 1.2 requires nvidia-ml-py3==7.352.0, which is not installed.\n",
15 | "dash 2.18.1 requires dash-core-components==2.0.0, which is not installed.\n",
16 | "dash 2.18.1 requires dash-html-components==2.0.0, which is not installed.\n",
17 | "dash 2.18.1 requires dash-table==5.0.0, which is not installed.\n",
18 | "jupyter-ai 2.29.0 requires faiss-cpu!=1.8.0.post0,<2.0.0,>=1.8.0, which is not installed.\n",
19 | "aiobotocore 2.19.0 requires botocore<1.36.4,>=1.36.0, but you have botocore 1.36.22 which is incompatible.\n",
20 | "amazon-sagemaker-sql-magic 0.1.3 requires sqlparse==0.5.0, but you have sqlparse 0.5.3 which is incompatible.\n",
21 | "autogluon-common 1.2 requires psutil<7.0.0,>=5.7.3, but you have psutil 7.0.0 which is incompatible.\n",
22 | "autogluon-multimodal 1.2 requires jsonschema<4.22,>=4.18, but you have jsonschema 4.23.0 which is incompatible.\n",
23 | "autogluon-multimodal 1.2 requires nltk<3.9,>=3.4.5, but you have nltk 3.9.1 which is incompatible.\n",
24 | "blis 1.0.1 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.\n",
25 | "dash 2.18.1 requires Flask<3.1,>=1.0.4, but you have flask 3.1.0 which is incompatible.\n",
26 | "dash 2.18.1 requires Werkzeug<3.1, but you have werkzeug 3.1.3 which is incompatible.\n",
27 | "jupyter-scheduler 2.10.0 requires fsspec<=2024.10.0,>=2023.6.0, but you have fsspec 2024.12.0 which is incompatible.\n",
28 | "jupyter-scheduler 2.10.0 requires psutil~=5.9, but you have psutil 7.0.0 which is incompatible.\n",
29 | "jupyter-scheduler 2.10.0 requires pytz<=2024.2,>=2023.3, but you have pytz 2025.1 which is incompatible.\n",
30 | "mlflow 2.20.0 requires pyarrow<19,>=4.0.0, but you have pyarrow 19.0.0 which is incompatible.\n",
31 | "sparkmagic 0.21.0 requires pandas<2.0.0,>=0.17.1, but you have pandas 2.2.3 which is incompatible.\n",
32 | "tensorflow 2.17.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 5.29.3 which is incompatible.\u001b[0m\u001b[31m\n",
33 | "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "%pip install --force-reinstall -q -r ./utils/requirements.txt"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 2,
44 | "id": "08447914-9c03-483f-83bc-223173f77db9",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "text/html": [
50 | ""
51 | ],
52 | "text/plain": [
53 | ""
54 | ]
55 | },
56 | "execution_count": 2,
57 | "metadata": {},
58 | "output_type": "execute_result"
59 | }
60 | ],
61 | "source": [
62 | "# restart kernel\n",
63 | "from IPython.core.display import HTML\n",
64 | "HTML(\"\")"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "4460e029-dc0b-4733-bdac-fedabf04c348",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "import os\n",
75 | "import sys\n",
76 | "import time\n",
77 | "import boto3\n",
78 | "import logging\n",
79 | "import requests\n",
80 | "import pprint\n",
81 | "import json\n",
82 | "import time\n",
83 | "import warnings\n",
84 | "warnings.filterwarnings('ignore')\n",
85 | "\n",
86 | "# Set the path to import module\n",
87 | "from pathlib import Path\n",
88 | "current_path = Path().resolve()\n",
89 | "current_path = current_path.parent\n",
90 | "if str(current_path) not in sys.path:\n",
91 | " sys.path.append(str(current_path))\n",
92 | "# Print sys.path to verify\n",
93 | "# print(sys.path)"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 4,
99 | "id": "53122abc-a8d2-4578-a467-c4840ac73236",
100 | "metadata": {},
101 | "outputs": [
102 | {
103 | "data": {
104 | "text/plain": [
105 | "('us-west-2', '010117700078')"
106 | ]
107 | },
108 | "execution_count": 4,
109 | "metadata": {},
110 | "output_type": "execute_result"
111 | }
112 | ],
113 | "source": [
114 | "#Clients\n",
115 | "s3_client = boto3.client('s3')\n",
116 | "sts_client = boto3.client('sts')\n",
117 | "session = boto3.session.Session()\n",
118 | "region = session.region_name\n",
119 | "account_id = sts_client.get_caller_identity()[\"Account\"]\n",
120 | "bedrock_agent_client = boto3.client('bedrock-agent')\n",
121 | "bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime') \n",
122 | "logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)\n",
123 | "logger = logging.getLogger(__name__)\n",
124 | "region, account_id"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 5,
130 | "id": "837c1486-d4b7-4322-9d48-48796371813d",
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "# Get the current timestamp\n",
135 | "current_time = time.time()\n",
136 | "\n",
137 | "# Format the timestamp as a string\n",
138 | "timestamp_str = time.strftime(\"%Y%m%d%H%M%S\", time.localtime(current_time))[-7:]\n",
139 | "# Create the suffix using the timestamp\n",
140 | "suffix = f\"{timestamp_str}\"\n",
141 | "\n",
142 | "knowledge_base_name = f\"bedrock-multifunctional-chatbot-kb-{suffix}\"\n",
143 | "knowledge_base_description = \"Multifunctional Chatbot Knowledge Base.\"\n",
144 | "\n",
145 | "bucket_name = f'{knowledge_base_name}-{account_id}'\n",
146 | "intermediate_bucket_name = f'{knowledge_base_name}-intermediate-{account_id}'\n"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 6,
152 | "id": "94c288ae-e2d0-4b5a-8ae8-5588f754e56e",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "data_bucket_name = f'bedrock-kb-{suffix}-1' # replace it with your first bucket name.\n",
157 | "\n",
158 | "data_sources=[{\"type\": \"S3\", \"bucket_name\": data_bucket_name}]"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 7,
164 | "id": "946d915f-a2e3-4583-b58d-208cb26d7503",
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "7220345\n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "import importlib\n",
177 | "import utils.knowledge_base\n",
178 | "importlib.reload(utils.knowledge_base)\n",
179 | "from utils.knowledge_base import BedrockKnowledgeBase\n",
180 | "\n",
181 | "print(suffix)"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 8,
187 | "id": "286d63bf-6f0d-4df5-8df6-54eb41ca050e",
188 | "metadata": {
189 | "scrolled": true
190 | },
191 | "outputs": [
192 | {
193 | "name": "stdout",
194 | "output_type": "stream",
195 | "text": [
196 | "========================================================================================\n",
197 | "Step 1 - Creating or retrieving S3 bucket(s) for Knowledge Base documents\n",
198 | "Creating bucket bedrock-kb-7220345-1\n",
199 | "========================================================================================\n",
200 | "Step 2 - Creating Knowledge Base Execution Role and Policies\n",
201 | "========================================================================================\n",
202 | "Step 3 - Creating OSS encryption, network and data access policies\n",
203 | "========================================================================================\n",
204 | "Step 4 - Creating OSS Collection (this step takes a couple of minutes to complete)\n",
205 | "Creating collection...\n",
206 | "Creating collection...........\n",
207 | "Creating collection...........\n",
208 | "Creating collection...........\n",
209 | "Creating collection...........\n",
210 | "Creating collection...........\n",
211 | "Creating collection...........\n",
212 | "Creating collection...........\n",
213 | "Creating collection...........\n",
214 | "Sleeping for a minute to ensure data access rules have been enforced\n",
215 | "========================================================================================\n",
216 | "Step 5 - Creating OSS Vector Index\n"
217 | ]
218 | },
219 | {
220 | "name": "stderr",
221 | "output_type": "stream",
222 | "text": [
223 | "[2025-02-17 22:09:20,975] p135 {base.py:258} INFO - PUT https://3m5htfwhxx9i1hcl2f3j.us-west-2.aoss.amazonaws.com:443/bedrock-sample-rag-index-7220345-f [status:200 request:1.245s]\n"
224 | ]
225 | },
226 | {
227 | "name": "stdout",
228 | "output_type": "stream",
229 | "text": [
230 | "\n",
231 | "Creating index:\n",
232 | "{ 'acknowledged': True,\n",
233 | " 'index': 'bedrock-sample-rag-index-7220345-f',\n",
234 | " 'shards_acknowledged': True}\n",
235 | "========================================================================================\n",
236 | "Step 6 - Creating Knowledge Base\n",
237 | "{ 'createdAt': datetime.datetime(2025, 2, 17, 22, 10, 21, 105567, tzinfo=tzlocal()),\n",
238 | " 'description': 'Multifunctional Chatbot Knowledge Base.',\n",
239 | " 'knowledgeBaseArn': 'arn:aws:bedrock:us-west-2:010117700078:knowledge-base/EHSH1Q38GZ',\n",
240 | " 'knowledgeBaseConfiguration': { 'type': 'VECTOR',\n",
241 | " 'vectorKnowledgeBaseConfiguration': { 'embeddingModelArn': 'arn:aws:bedrock:us-west-2::foundation-model/amazon.titan-embed-text-v2:0'}},\n",
242 | " 'knowledgeBaseId': 'EHSH1Q38GZ',\n",
243 | " 'name': 'bedrock-multifunctional-chatbot-kb-7220345',\n",
244 | " 'roleArn': 'arn:aws:iam::010117700078:role/BedrockExecutionRoleForKnowledgeBase_7220345-f',\n",
245 | " 'status': 'CREATING',\n",
246 | " 'storageConfiguration': { 'opensearchServerlessConfiguration': { 'collectionArn': 'arn:aws:aoss:us-west-2:010117700078:collection/3m5htfwhxx9i1hcl2f3j',\n",
247 | " 'fieldMapping': { 'metadataField': 'text-metadata',\n",
248 | " 'textField': 'text',\n",
249 | " 'vectorField': 'vector'},\n",
250 | " 'vectorIndexName': 'bedrock-sample-rag-index-7220345-f'},\n",
251 | " 'type': 'OPENSEARCH_SERVERLESS'},\n",
252 | " 'updatedAt': datetime.datetime(2025, 2, 17, 22, 10, 21, 105567, tzinfo=tzlocal())}\n",
253 | "Creating Data Sources\n",
254 | "{ 'createdAt': datetime.datetime(2025, 2, 17, 22, 10, 21, 791415, tzinfo=tzlocal()),\n",
255 | " 'dataDeletionPolicy': 'DELETE',\n",
256 | " 'dataSourceConfiguration': { 's3Configuration': { 'bucketArn': 'arn:aws:s3:::bedrock-kb-7220345-1'},\n",
257 | " 'type': 'S3'},\n",
258 | " 'dataSourceId': 'LSDYTJ4PPP',\n",
259 | " 'description': 'Multifunctional Chatbot Knowledge Base.',\n",
260 | " 'knowledgeBaseId': 'EHSH1Q38GZ',\n",
261 | " 'name': 'EHSH1Q38GZ-s3',\n",
262 | " 'status': 'AVAILABLE',\n",
263 | " 'updatedAt': datetime.datetime(2025, 2, 17, 22, 10, 21, 791415, tzinfo=tzlocal()),\n",
264 | " 'vectorIngestionConfiguration': { 'chunkingConfiguration': { 'chunkingStrategy': 'FIXED_SIZE',\n",
265 | " 'fixedSizeChunkingConfiguration': { 'maxTokens': 300,\n",
266 | " 'overlapPercentage': 20}}}}\n",
267 | "========================================================================================\n"
268 | ]
269 | }
270 | ],
271 | "source": [
272 | "# Creating Knowledge Base, may take a few mins.\n",
273 | "\n",
274 | "knowledge_base = BedrockKnowledgeBase(\n",
275 | " kb_name=f'{knowledge_base_name}',\n",
276 | " kb_description=knowledge_base_description,\n",
277 | " data_sources=data_sources,\n",
278 | " chunking_strategy = \"FIXED_SIZE\", \n",
279 | " suffix = f'{suffix}-f'\n",
280 | ")"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 9,
286 | "id": "e3f2f1bf-ee90-4262-85a1-08b8746cebe5",
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "def download_and_upload_squad_sample(bucket_name):\n",
291 | " # Download the partial SQuAD dataset\n",
292 | " url = \"https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v2.0.json\"\n",
293 | " response = requests.get(url)\n",
294 | " data = response.json()\n",
295 | " \n",
296 | " # Select 100 as sample\n",
297 | " sample_data = {\n",
298 | " \"data\": data[\"data\"][:100]\n",
299 | " }\n",
300 | " \n",
301 | " # creating the temp files locally\n",
302 | " with open(\"/tmp/squad_sample.json\", \"w\") as f:\n",
303 | " json.dump(sample_data, f)\n",
304 | " \n",
305 | " # uplaoding to s3\n",
306 | " s3_client = boto3.client('s3')\n",
307 | " s3_client.upload_file(\"/tmp/squad_sample.json\", bucket_name, \"squad_sample.json\")\n",
308 | "\n",
309 | "download_and_upload_squad_sample(data_bucket_name)"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 11,
315 | "id": "52a0780a-2843-4f81-8c6a-fd92baca165f",
316 | "metadata": {
317 | "scrolled": true
318 | },
319 | "outputs": [
320 | {
321 | "name": "stdout",
322 | "output_type": "stream",
323 | "text": [
324 | "job 1 started successfully\n",
325 | "\n",
326 | "{ 'dataSourceId': 'LSDYTJ4PPP',\n",
327 | " 'ingestionJobId': '9SUPBWDEI9',\n",
328 | " 'knowledgeBaseId': 'EHSH1Q38GZ',\n",
329 | " 'startedAt': datetime.datetime(2025, 2, 17, 22, 10, 25, 848380, tzinfo=tzlocal()),\n",
330 | " 'statistics': { 'numberOfDocumentsDeleted': 0,\n",
331 | " 'numberOfDocumentsFailed': 0,\n",
332 | " 'numberOfDocumentsScanned': 1,\n",
333 | " 'numberOfMetadataDocumentsModified': 0,\n",
334 | " 'numberOfMetadataDocumentsScanned': 0,\n",
335 | " 'numberOfModifiedDocumentsIndexed': 0,\n",
336 | " 'numberOfNewDocumentsIndexed': 1},\n",
337 | " 'status': 'COMPLETE',\n",
338 | " 'updatedAt': datetime.datetime(2025, 2, 17, 22, 12, 23, 768670, tzinfo=tzlocal())}\n",
339 | "........................................\r"
340 | ]
341 | }
342 | ],
343 | "source": [
344 | "## Start the ingestion job, embedding the data sources of s3 to opensearch database.\n",
345 | "knowledge_base.start_ingestion_job()"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 12,
351 | "id": "8168569f-da53-4909-bc68-22310ca05adc",
352 | "metadata": {
353 | "scrolled": true
354 | },
355 | "outputs": [
356 | {
357 | "name": "stdout",
358 | "output_type": "stream",
359 | "text": [
360 | "\n",
361 | "Query: What is the major context of the SQuAD dataset?\n",
362 | "\n",
363 | "Retrieved results:\n",
364 | "\n",
365 | "Result 1:\n",
366 | "Score: 0.37877417\n",
367 | "Content: {'text': ', \"id\": \"57284b904b864d19001648e2\", \"answers\": [{\"text\": \"the Main Quadrangles\", \"answer_start\": 92}, {\"text\": \"Main Quadrangles\", \"answer_start\": 96}, {\"text\": \"the Main Quadrangles\", \"answer_start\": 92}, {\"text\": \"the Main Quadrangles\", \"answer_start\": 92}], \"is_impossible\": false}, {\"question\": \"How many quadrangles does the Main Quadrangles have?\", \"id\": \"57284b904b864d19001648e3\", \"answers\": [{\"text\": \"six\", \"answer_start\": 273}, {\"text\": \"six quadrangles\", \"answer_start\": 273}, {\"text\": \"six\", \"answer_start\": 273}, {\"text\": \"six\", \"answer_start\": 273}], \"is_impossible\": false}, {\"question\": \"Who helped designed the Main Quadrangles?\"', 'type': 'TEXT'}\n",
368 | "\n",
369 | "Result 2:\n",
370 | "Score: 0.3762661\n",
371 | "Content: {'text': '\": \"explore computer networking\", \"answer_start\": 190}], \"is_impossible\": false}, {\"question\": \"What completed the triad \", \"id\": \"5726414e271a42140099d7e6\", \"answers\": [{\"text\": \"an interactive host to host connection was made between the IBM mainframe computer systems at the University of Michigan in Ann Arbor and Wayne State\", \"answer_start\": 499}, {\"text\": \"the CDC mainframe at Michigan State University in East Lansing\", \"answer_start\": 703}, {\"text\": \"1972 connections\", \"answer_start\": 683}], \"is_impossible\": false}, {\"question\": \"What set the stage for Merits role in NSFNET\", \"id\": \"5726414e271a42140099d7e7\", \"answers\": [{\"text\": \"Ethernet attached hosts, and eventually TCP/IP and additional public universities in Michigan join the network\", \"answer_start\": 1166}, {\"text\": \"the network was enhanced\", \"answer_start\": 867}, {\"text\": \"TCP/IP\", \"answer_start\": 1206}], \"is_impossible\": false}, {\"plausible_answers\": [{\"text\": \"Merit Network, Inc\", \"answer_start\": 0}], \"question\": \"State educational and economic development where helped by what?\"', 'type': 'TEXT'}\n",
372 | "\n",
373 | "Result 3:\n",
374 | "Score: 0.37598163\n",
375 | "Content: {'text': 'Although this use of the name was incorrect all these services were managed by the same people within one department of KPN contributed to the confusion.\"}, {\"qas\": [{\"question\": \"What is CSNET\", \"id\": \"5726462b708984140094c117\", \"answers\": [{\"text\": \"The Computer Science Network\", \"answer_start\": 0}, {\"text\": \"a computer network funded by the U.S.', 'type': 'TEXT'}\n",
376 | "\n",
377 | "Result 4:\n",
378 | "Score: 0.3758176\n",
379 | "Content: {'text': ', \"id\": \"5a5929d33e1742001a15cfc6\", \"answers\": [], \"is_impossible\": true}], \"context\": \"In the laboratory, stratigraphers analyze samples of stratigraphic sections that can be returned from the field, such as those from drill cores. Stratigraphers also analyze data from geophysical surveys that show the locations of stratigraphic units in the subsurface. Geophysical data and well logs can be combined to produce a better view of the subsurface, and stratigraphers often use computer programs to do this in three dimensions. Stratigraphers can then use these data to reconstruct ancient processes occurring on the surface of the Earth, interpret past environments, and locate areas for water, coal, and hydrocarbon extraction.\"}, {\"qas\": [{\"question\": \"Who analyzes rock samples from drill cores in the lab? \", \"id\": \"57268220f1498d1400e8e216\", \"answers\": [{\"text\": \"biostratigraphers\", \"answer_start\": 19}, {\"text\": \"biostratigraphers\", \"answer_start\": 19}, {\"text\": \"biostratigraphers\", \"answer_start\": 19}], \"is_impossible\": false}, {\"question\": \"Who dates rocks, precisely, within the stratigraphic section?\"', 'type': 'TEXT'}\n",
380 | "\n",
381 | "Result 5:\n",
382 | "Score: 0.37568116\n",
383 | "Content: {'text': ', \"id\": \"5a581597770dc0001aeeffe3\", \"answers\": [], \"is_impossible\": true}, {\"plausible_answers\": [{\"text\": \"the locations of stratigraphic units\", \"answer_start\": 213}], \"question\": \"What do drill cores show about water location?\", \"id\": \"5a581597770dc0001aeeffe4\", \"answers\": [], \"is_impossible\": true}, {\"plausible_answers\": [{\"text\": \"a better view of the subsurface\", \"answer_start\": 327}], \"question\": \"What can drill cores and ancient processes be combined to show?\", \"id\": \"5a581597770dc0001aeeffe5\", \"answers\": [], \"is_impossible\": true}, {\"plausible_answers\": [{\"text\": \"ancient processes occurring on the surface of the Earth\", \"answer_start\": 493}], \"question\": \"What do computers use coal to reconstruct?\", \"id\": \"5a581597770dc0001aeeffe6\", \"answers\": [], \"is_impossible\": true}, {\"plausible_answers\": [{\"text\": \"drill cores\", \"answer_start\": 132}], \"question\": \"What are taken from the laboratory into the field?\"', 'type': 'TEXT'}\n"
384 | ]
385 | }
386 | ],
387 | "source": [
388 | "## Testing the Knowledge Base:\n",
389 | "\n",
390 | "bedrock_agent = boto3.client('bedrock-agent-runtime')\n",
391 | "kb_id = knowledge_base.knowledge_base['knowledgeBaseId']\n",
392 | "\n",
393 | "def simple_kb_test(kb_id, query_text):\n",
394 | " try:\n",
395 | " query = {\n",
396 | " \"text\": query_text \n",
397 | " }\n",
398 | "\n",
399 | " response = bedrock_agent.retrieve(\n",
400 | " knowledgeBaseId=kb_id,\n",
401 | " retrievalQuery=query, # 传入查询字典\n",
402 | " retrievalConfiguration={\n",
403 | " \"vectorSearchConfiguration\": {\n",
404 | " \"numberOfResults\": 5,\n",
405 | " } \n",
406 | " }\n",
407 | " )\n",
408 | " \n",
409 | " print(f\"\\nQuery: {query_text}\")\n",
410 | " print(\"\\nRetrieved results:\")\n",
411 | " for i, result in enumerate(response['retrievalResults'], 1):\n",
412 | " print(f\"\\nResult {i}:\")\n",
413 | " print(f\"Score: {result['score']}\")\n",
414 | " print(f\"Content: {result['content']}\")\n",
415 | " \n",
416 | " except Exception as e:\n",
417 | " print(f\"Error: {e}\")\n",
418 | "\n",
419 | "test_query = \"What is the major context of the SQuAD dataset?\"\n",
420 | "\n",
421 | "simple_kb_test(kb_id, test_query)"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": 13,
427 | "id": "ac456d68-df14-4519-8599-cf1febc77b04",
428 | "metadata": {},
429 | "outputs": [
430 | {
431 | "name": "stdout",
432 | "output_type": "stream",
433 | "text": [
434 | "Current status: ACTIVE\n",
435 | "Inference profile created successfully: arn:aws:bedrock:us-west-2:010117700078:application-inference-profile/732bbsotu6s5\n"
436 | ]
437 | }
438 | ],
439 | "source": [
440 | "## Creating Inference Profile for Amazon Nova Pro model.\n",
441 | "\n",
442 | "nova_pro_profile_name = f'bedrock-kb-nova-pro-profile-{suffix}' \n",
443 | "profile_arn = knowledge_base.create_nova_inference_profile(nova_pro_profile_name, throughput=1)"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": 14,
449 | "id": "9d4978f2-dbf8-4df7-8cd9-226d840a03e4",
450 | "metadata": {},
451 | "outputs": [
452 | {
453 | "name": "stdout",
454 | "output_type": "stream",
455 | "text": [
456 | "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n",
457 | "sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml\n"
458 | ]
459 | },
460 | {
461 | "name": "stderr",
462 | "output_type": "stream",
463 | "text": [
464 | "Using model 'deepseek-llm-r1-distill-qwen-1-5b' with wildcard version identifier '*'. You can pin to version '1.0.0' for more stable results. Note that models may have different input/output signatures after a major version upgrade.\n",
465 | "[2025-02-17 22:13:08,069] p135 {cache.py:625} WARNING - Using model 'deepseek-llm-r1-distill-qwen-1-5b' with wildcard version identifier '*'. You can pin to version '1.0.0' for more stable results. Note that models may have different input/output signatures after a major version upgrade.\n",
466 | "[2025-02-17 22:13:08,118] p135 {session.py:4094} INFO - Creating model with name: deepseek-llm-r1-distill-qwen-1-5b-2025-02-17-22-13-08-116\n",
467 | "[2025-02-17 22:13:08,911] p135 {session.py:5889} INFO - Creating endpoint-config with name deepseek-llm-r1-distill-qwen-1-5b-2025-02-17-22-13-08-117\n",
468 | "[2025-02-17 22:13:09,247] p135 {session.py:4711} INFO - Creating endpoint with name deepseek-llm-r1-distill-qwen-1-5b-2025-02-17-22-13-08-117\n"
469 | ]
470 | },
471 | {
472 | "name": "stdout",
473 | "output_type": "stream",
474 | "text": [
475 | "----------!"
476 | ]
477 | }
478 | ],
479 | "source": [
480 | "## Creating Deepseek model with Sagemaker JumpStart\n",
481 | "\n",
482 | "from sagemaker.jumpstart.model import JumpStartModel\n",
483 | "\n",
484 | "model_id = \"deepseek-llm-r1-distill-qwen-1-5b\"\n",
485 | "my_model = JumpStartModel(model_id=model_id, instance_type='ml.g5.2xlarge')\n",
486 | "\n",
487 | "predictor = my_model.deploy()\n",
488 | "\n",
489 | "deepseek_sagemaker_endpoint = predictor.endpoint_name"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 16,
495 | "id": "02d04dfb-bde7-4e46-accd-12d6131fc6ba",
496 | "metadata": {},
497 | "outputs": [],
498 | "source": [
499 | "## Writing configurations into a config file, then access by the Streamlit APP\n",
500 | "\n",
501 | "config = {\n",
502 | " \"kb_id\": kb_id,\n",
503 | " \"nova_pro_profile_arn\": profile_arn,\n",
504 | " \"nova_pro_model_id\": \"amazon.nova-pro-v1:0\",\n",
505 | " \"sagemaker_endpoint\": deepseek_sagemaker_endpoint,\n",
506 | " \"sagemaker_ep_arn\" : f\"arn:aws:sagemaker:{region}:{account_id}:endpoint/{deepseek_sagemaker_endpoint}\",\n",
507 | " \"region_name\": region\n",
508 | "}\n",
509 | "\n",
510 | "with open('utils/tmp_config.json', 'w') as f:\n",
511 | " json.dump(config, f, indent=4)\n",
512 | "\n",
513 | "\n"
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "execution_count": null,
519 | "id": "3dfea279",
520 | "metadata": {},
521 | "outputs": [],
522 | "source": [
523 | "print(\"=============================== All resources have been completed, please starting your demo ==============================\\n\")"
524 | ]
525 | },
526 | {
527 | "cell_type": "markdown",
528 | "id": "8ec68277-54b8-41a6-8690-af030a71dda0",
529 | "metadata": {},
530 | "source": [
531 | "#### Start a Terminal Session on Jupyterlab, then execute the below command:\n",
532 | "\n",
533 | "```bash\n",
534 | "pip install streamlit\n",
535 | "\n",
536 | "streamlit run demo-dev/app.py \n",
537 | "```"
538 | ]
539 | },
540 | {
541 | "cell_type": "markdown",
542 | "id": "35dc4ae8-2a85-4e8d-8453-9a5ed922b90e",
543 | "metadata": {},
544 | "source": [
545 | "#### To access the Streamlit Web Application via:\n",
546 | "\n",
547 | "1. Copy & paste the URL of the Sagemaker Studio Jupyterlab web URL, eg:\n",
548 | "\n",
549 | "https://xxxxxxxxxxxxx.studio.us-west-2.sagemaker.aws/jupyterlab/default/lab/.../lab-code.ipynb\n",
550 | "\n",
551 | "\n",
552 | "2. Update the url as below format, and access the url via a new browser tab:\n",
553 | "\n",
554 | "https://xxxxxxxxxxxxx.studio.us-west-2.sagemaker.aws/jupyterlab/default/proxy/8501/\n"
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": 26,
560 | "id": "8c01f6e3-748e-4e28-9a5c-c3403f69bf41",
561 | "metadata": {},
562 | "outputs": [
563 | {
564 | "name": "stdout",
565 | "output_type": "stream",
566 | "text": [
567 | "===============================Knowledge base with fixed chunking==============================\n",
568 | "\n",
569 | "File utils/tmp_config.json has been deleted successfully.\n"
570 | ]
571 | }
572 | ],
573 | "source": [
574 | "## Clean up\n",
575 | "\n",
576 | "\n",
577 | "# print(\"===============================Starting Clean up==============================\\n\")\n",
578 | "# predictor.delete_predictor()\n",
579 | "# knowledge_base.delete_kb(delete_s3_bucket=True)\n",
580 | "\n",
581 | "# file_path = 'utils/tmp_config.json'\n",
582 | "# if os.path.exists(file_path):\n",
583 | "# os.remove(file_path)\n",
584 | "# print(f\"File {file_path} has been deleted successfully.\")\n",
585 | "# else:\n",
586 | "# print(f\"File {file_path} does not exist.\")"
587 | ]
588 | },
589 | {
590 | "cell_type": "code",
591 | "execution_count": null,
592 | "id": "735201cf-054f-4d0d-8d3e-7b6b743feb62",
593 | "metadata": {},
594 | "outputs": [],
595 | "source": []
596 | }
597 | ],
598 | "metadata": {
599 | "kernelspec": {
600 | "display_name": "Python 3 (ipykernel)",
601 | "language": "python",
602 | "name": "python3"
603 | },
604 | "language_info": {
605 | "codemirror_mode": {
606 | "name": "ipython",
607 | "version": 3
608 | },
609 | "file_extension": ".py",
610 | "mimetype": "text/x-python",
611 | "name": "python",
612 | "nbconvert_exporter": "python",
613 | "pygments_lexer": "ipython3",
614 | "version": "3.11.11"
615 | }
616 | },
617 | "nbformat": 4,
618 | "nbformat_minor": 5
619 | }
620 |
--------------------------------------------------------------------------------
/lab/utils/knowledge_base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import boto3
3 | import time
4 | from botocore.exceptions import ClientError
5 | from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, RequestError
6 | import pprint
7 | from retrying import retry
8 | import traceback
9 |
10 | import warnings
11 | warnings.filterwarnings('ignore')
12 |
13 | valid_generation_models = ["amazon.nova-pro-v1:0", "amazon.nova-lite-v1:0", "amazon.nova-micro-v1:0"]
14 | valid_reranking_models = ["cohere.rerank-v3-5:0"]
15 | valid_embedding_models = ["amazon.titan-embed-text-v2:0", "amazon.titan-embed-image-v1:0"]
16 |
17 | embedding_context_dimensions = {
18 | "amazon.titan-embed-text-v2:0": 1024
19 | }
20 |
21 | pp = pprint.PrettyPrinter(indent=2)
22 |
23 | def interactive_sleep(seconds: int):
24 | dots = ''
25 | for i in range(seconds):
26 | dots += '.'
27 | print(dots, end='\r')
28 | time.sleep(1)
29 |
30 | class BedrockKnowledgeBase:
31 | def __init__(
32 | self,
33 | kb_name=None,
34 | kb_description=None,
35 | data_sources=None,
36 | embedding_model="amazon.titan-embed-text-v2:0",
37 | generation_model="amazon.nova-pro-v1:0",
38 | reranking_model="cohere.rerank-v3-5:0",
39 | chunking_strategy="FIXED_SIZE",
40 | suffix=None,
41 | ):
42 | boto3_session = boto3.session.Session()
43 | self.region_name = boto3_session.region_name
44 | self.iam_client = boto3_session.client('iam')
45 | self.account_number = boto3.client('sts').get_caller_identity().get('Account')
46 | self.suffix = suffix or f'{self.region_name}-{self.account_number}'
47 | self.identity = boto3.client('sts').get_caller_identity()['Arn']
48 | self.aoss_client = boto3_session.client('opensearchserverless')
49 | self.s3_client = boto3.client('s3')
50 | self.bedrock_agent_client = boto3.client('bedrock-agent')
51 | credentials = boto3.Session().get_credentials()
52 | self.awsauth = AWSV4SignerAuth(credentials, self.region_name, 'aoss')
53 |
54 | self.kb_name = kb_name or f"default-knowledge-base-{self.suffix}"
55 | self.kb_description = kb_description or "Default Knowledge Base"
56 | self.data_sources = data_sources
57 | self.bucket_names = [d["bucket_name"] for d in self.data_sources if d['type']== 'S3']
58 | self.chunking_strategy = chunking_strategy
59 |
60 | self.embedding_model = embedding_model
61 | self.generation_model = generation_model
62 | self.reranking_model = reranking_model
63 |
64 | self._validate_models()
65 |
66 | # Set policy names
67 | self.encryption_policy_name = f"bedrock-sample-rag-sp-{self.suffix}"
68 | self.network_policy_name = f"bedrock-sample-rag-np-{self.suffix}"
69 | self.access_policy_name = f'bedrock-sample-rag-ap-{self.suffix}'
70 | self.kb_execution_role_name = f'BedrockExecutionRoleForKnowledgeBase_{self.suffix}'
71 | self.fm_policy_name = f'BedrockFoundationModelPolicyForKnowledgeBase_{self.suffix}'
72 | self.s3_policy_name = f'BedrockS3PolicyForKnowledgeBase_{self.suffix}'
73 | self.oss_policy_name = f'BedrockOSSPolicyForKnowledgeBase_{self.suffix}'
74 | self.bda_policy_name = f'BedrockBDAPolicyForKnowledgeBase_{self.suffix}'
75 |
76 | self.vector_store_name = f'bedrock-sample-rag-{self.suffix}'
77 | self.index_name = f"bedrock-sample-rag-index-{self.suffix}"
78 |
79 | self._setup_resources()
80 |
81 | def _validate_models(self):
82 | if self.embedding_model not in valid_embedding_models:
83 | raise ValueError(f"Invalid embedding model. Your embedding model should be one of {valid_embedding_models}")
84 | if self.generation_model not in valid_generation_models:
85 | raise ValueError(f"Invalid Generation model. Your generation model should be one of {valid_generation_models}")
86 | if self.reranking_model not in valid_reranking_models:
87 | raise ValueError(f"Invalid Reranking model. Your reranking model should be one of {valid_reranking_models}")
88 |
89 | def _setup_resources(self):
90 | print("========================================================================================")
91 | print(f"Step 1 - Creating or retrieving S3 bucket(s) for Knowledge Base documents")
92 | self.create_s3_bucket()
93 |
94 | print("========================================================================================")
95 | print(f"Step 2 - Creating Knowledge Base Execution Role and Policies")
96 | self.bedrock_kb_execution_role = self.create_bedrock_execution_role()
97 |
98 | print("========================================================================================")
99 | print(f"Step 3 - Creating OSS encryption, network and data access policies")
100 | self.encryption_policy, self.network_policy, self.access_policy = self.create_policies_in_oss()
101 |
102 | print("========================================================================================")
103 | print(f"Step 4 - Creating OSS Collection (this step takes a couple of minutes to complete)")
104 | self.host, self.collection, self.collection_id, self.collection_arn = self.create_oss()
105 | self.oss_client = OpenSearch(
106 | hosts=[{'host': self.host, 'port': 443}],
107 | http_auth=self.awsauth,
108 | use_ssl=True,
109 | verify_certs=True,
110 | connection_class=RequestsHttpConnection,
111 | timeout=300
112 | )
113 |
114 | print("========================================================================================")
115 | print(f"Step 5 - Creating OSS Vector Index")
116 | self.create_vector_index()
117 |
118 | print("========================================================================================")
119 | print(f"Step 6 - Creating Knowledge Base")
120 | self.knowledge_base, self.data_source = self.create_knowledge_base()
121 | print("========================================================================================")
122 |
123 | def create_s3_bucket(self):
124 | for bucket_name in self.bucket_names:
125 | try:
126 | self.s3_client.head_bucket(Bucket=bucket_name)
127 | print(f'Bucket {bucket_name} already exists - retrieving it!')
128 | except ClientError:
129 | print(f'Creating bucket {bucket_name}')
130 | if self.region_name == "us-east-1":
131 | self.s3_client.create_bucket(Bucket=bucket_name)
132 | else:
133 | self.s3_client.create_bucket(
134 | Bucket=bucket_name,
135 | CreateBucketConfiguration={'LocationConstraint': self.region_name}
136 | )
137 |
138 | def create_bedrock_execution_role(self):
139 | # Create foundation model policy
140 | foundation_model_policy_document = {
141 | "Version": "2012-10-17",
142 | "Statement": [
143 | {
144 | "Effect": "Allow",
145 | "Action": ["bedrock:InvokeModel"],
146 | "Resource": [
147 | f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.embedding_model}",
148 | f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.generation_model}",
149 | f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.reranking_model}"
150 | ]
151 | }
152 | ]
153 | }
154 |
155 | # Create S3 policy
156 | s3_policy_document = {
157 | "Version": "2012-10-17",
158 | "Statement": [
159 | {
160 | "Effect": "Allow",
161 | "Action": [
162 | "s3:GetObject",
163 | "s3:ListBucket"
164 | ],
165 | "Resource": [item for sublist in [[f'arn:aws:s3:::{bucket}', f'arn:aws:s3:::{bucket}/*']
166 | for bucket in self.bucket_names] for item in sublist],
167 | "Condition": {
168 | "StringEquals": {
169 | "aws:ResourceAccount": f"{self.account_number}"
170 | }
171 | }
172 | }
173 | ]
174 | }
175 |
176 | # Create BDA policy
177 | bda_policy_document = {
178 | "Version": "2012-10-17",
179 | "Statement": [
180 | {
181 | "Effect": "Allow",
182 | "Action": [
183 | "bedrock:GetDataAutomationStatus",
184 | "bedrock:InvokeDataAutomationAsync"
185 | ],
186 | "Resource": [
187 | f"arn:aws:bedrock:{self.region_name}:{self.account_number}:data-automation-invocation/*",
188 | f"arn:aws:bedrock:{self.region_name}:aws:data-automation-project/public-rag-default"
189 | ]
190 | }
191 | ]
192 | }
193 |
194 | # Create role
195 | assume_role_policy_document = {
196 | "Version": "2012-10-17",
197 | "Statement": [
198 | {
199 | "Effect": "Allow",
200 | "Principal": {
201 | "Service": "bedrock.amazonaws.com"
202 | },
203 | "Action": "sts:AssumeRole"
204 | }
205 | ]
206 | }
207 |
208 | try:
209 | bedrock_kb_execution_role = self.iam_client.create_role(
210 | RoleName=self.kb_execution_role_name,
211 | AssumeRolePolicyDocument=json.dumps(assume_role_policy_document)
212 | )
213 | except self.iam_client.exceptions.EntityAlreadyExistsException:
214 | bedrock_kb_execution_role = self.iam_client.get_role(RoleName=self.kb_execution_role_name)
215 |
216 | # Create and attach policies
217 | policies = [
218 | (self.fm_policy_name, foundation_model_policy_document),
219 | (self.s3_policy_name, s3_policy_document),
220 | (self.bda_policy_name, bda_policy_document)
221 | ]
222 |
223 | for policy_name, policy_document in policies:
224 | try:
225 | policy = self.iam_client.create_policy(
226 | PolicyName=policy_name,
227 | PolicyDocument=json.dumps(policy_document)
228 | )
229 | self.iam_client.attach_role_policy(
230 | RoleName=self.kb_execution_role_name,
231 | PolicyArn=policy["Policy"]["Arn"]
232 | )
233 | except self.iam_client.exceptions.EntityAlreadyExistsException:
234 | policy_arn = f"arn:aws:iam::{self.account_number}:policy/{policy_name}"
235 | self.iam_client.attach_role_policy(
236 | RoleName=self.kb_execution_role_name,
237 | PolicyArn=policy_arn
238 | )
239 |
240 | return bedrock_kb_execution_role
241 |
242 | def create_policies_in_oss(self):
243 | try:
244 | encryption_policy = self.aoss_client.create_security_policy(
245 | name=self.encryption_policy_name,
246 | policy=json.dumps(
247 | {
248 | 'Rules': [{'Resource': ['collection/' + self.vector_store_name],
249 | 'ResourceType': 'collection'}],
250 | 'AWSOwnedKey': True
251 | }),
252 | type='encryption'
253 | )
254 | except self.aoss_client.exceptions.ConflictException:
255 | encryption_policy = self.aoss_client.get_security_policy(
256 | name=self.encryption_policy_name,
257 | type='encryption'
258 | )
259 |
260 | try:
261 | network_policy = self.aoss_client.create_security_policy(
262 | name=self.network_policy_name,
263 | policy=json.dumps(
264 | [
265 | {'Rules': [{'Resource': ['collection/' + self.vector_store_name],
266 | 'ResourceType': 'collection'}],
267 | 'AllowFromPublic': True}
268 | ]),
269 | type='network'
270 | )
271 | except self.aoss_client.exceptions.ConflictException:
272 | network_policy = self.aoss_client.get_security_policy(
273 | name=self.network_policy_name,
274 | type='network'
275 | )
276 |
277 | try:
278 | access_policy = self.aoss_client.create_access_policy(
279 | name=self.access_policy_name,
280 | policy=json.dumps(
281 | [
282 | {
283 | 'Rules': [
284 | {
285 | 'Resource': ['collection/' + self.vector_store_name],
286 | 'Permission': [
287 | 'aoss:CreateCollectionItems',
288 | 'aoss:DeleteCollectionItems',
289 | 'aoss:UpdateCollectionItems',
290 | 'aoss:DescribeCollectionItems'],
291 | 'ResourceType': 'collection'
292 | },
293 | {
294 | 'Resource': ['index/' + self.vector_store_name + '/*'],
295 | 'Permission': [
296 | 'aoss:CreateIndex',
297 | 'aoss:DeleteIndex',
298 | 'aoss:UpdateIndex',
299 | 'aoss:DescribeIndex',
300 | 'aoss:ReadDocument',
301 | 'aoss:WriteDocument'],
302 | 'ResourceType': 'index'
303 | }],
304 | 'Principal': [self.identity, self.bedrock_kb_execution_role['Role']['Arn']],
305 | 'Description': 'Data access policy'
306 | }
307 | ]),
308 | type='data'
309 | )
310 | except self.aoss_client.exceptions.ConflictException:
311 | access_policy = self.aoss_client.get_access_policy(
312 | name=self.access_policy_name,
313 | type='data'
314 | )
315 |
316 | return encryption_policy, network_policy, access_policy
317 |
318 | def create_oss(self):
319 | try:
320 | collection = self.aoss_client.create_collection(
321 | name=self.vector_store_name,
322 | type='VECTORSEARCH'
323 | )
324 | collection_id = collection['createCollectionDetail']['id']
325 | collection_arn = collection['createCollectionDetail']['arn']
326 | except self.aoss_client.exceptions.ConflictException:
327 | collection = self.aoss_client.batch_get_collection(
328 | names=[self.vector_store_name]
329 | )['collectionDetails'][0]
330 | collection_id = collection['id']
331 | collection_arn = collection['arn']
332 |
333 | host = collection_id + '.' + self.region_name + '.aoss.amazonaws.com'
334 |
335 | response = self.aoss_client.batch_get_collection(names=[self.vector_store_name])
336 | while (response['collectionDetails'][0]['status']) == 'CREATING':
337 | print('Creating collection...')
338 | interactive_sleep(30)
339 | response = self.aoss_client.batch_get_collection(names=[self.vector_store_name])
340 |
341 | try:
342 | self.create_oss_policy(collection_id)
343 | print("Sleeping for a minute to ensure data access rules have been enforced")
344 | interactive_sleep(60)
345 | except Exception as e:
346 | print("Policy already exists")
347 | pp.pprint(e)
348 |
349 | return host, collection, collection_id, collection_arn
350 |
351 | def create_oss_policy(self, collection_id):
352 | oss_policy_document = {
353 | "Version": "2012-10-17",
354 | "Statement": [
355 | {
356 | "Effect": "Allow",
357 | "Action": ["aoss:APIAccessAll"],
358 | "Resource": [f"arn:aws:aoss:{self.region_name}:{self.account_number}:collection/{collection_id}"]
359 | }
360 | ]
361 | }
362 | try:
363 | oss_policy = self.iam_client.create_policy(
364 | PolicyName=self.oss_policy_name,
365 | PolicyDocument=json.dumps(oss_policy_document),
366 | Description='Policy for accessing opensearch serverless',
367 | )
368 | oss_policy_arn = oss_policy["Policy"]["Arn"]
369 | except self.iam_client.exceptions.EntityAlreadyExistsException:
370 | oss_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{self.oss_policy_name}"
371 |
372 | self.iam_client.attach_role_policy(
373 | RoleName=self.bedrock_kb_execution_role["Role"]["RoleName"],
374 | PolicyArn=oss_policy_arn
375 | )
376 |
377 | def create_vector_index(self):
378 | body_json = {
379 | "settings": {
380 | "index.knn": "true",
381 | "number_of_shards": 1,
382 | "knn.algo_param.ef_search": 512,
383 | "number_of_replicas": 0,
384 | },
385 | "mappings": {
386 | "properties": {
387 | "vector": {
388 | "type": "knn_vector",
389 | "dimension": embedding_context_dimensions[self.embedding_model],
390 | "method": {
391 | "name": "hnsw",
392 | "engine": "faiss",
393 | "space_type": "l2"
394 | },
395 | },
396 | "text": {
397 | "type": "text"
398 | },
399 | "text-metadata": {
400 | "type": "text"}
401 | }
402 | }
403 | }
404 |
405 | try:
406 | response = self.oss_client.indices.create(index=self.index_name, body=json.dumps(body_json))
407 | print('\nCreating index:')
408 | pp.pprint(response)
409 | interactive_sleep(60)
410 | except RequestError as e:
411 | print(f'Error while trying to create the index, with error {e.error}')
412 |
413 | def create_chunking_strategy_config(self, strategy):
414 | configs = {
415 | "NONE": {
416 | "chunkingConfiguration": {"chunkingStrategy": "NONE"}
417 | },
418 | "FIXED_SIZE": {
419 | "chunkingConfiguration": {
420 | "chunkingStrategy": "FIXED_SIZE",
421 | "fixedSizeChunkingConfiguration": {
422 | "maxTokens": 300,
423 | "overlapPercentage": 20
424 | }
425 | }
426 | }
427 | }
428 | return configs.get(strategy, configs["NONE"])
429 |
430 | @retry(wait_random_min=1000, wait_random_max=2000, stop_max_attempt_number=7)
431 | def create_knowledge_base(self):
432 | opensearch_serverless_configuration = {
433 | "collectionArn": self.collection_arn,
434 | "vectorIndexName": self.index_name,
435 | "fieldMapping": {
436 | "vectorField": "vector",
437 | "textField": "text",
438 | "metadataField": "text-metadata"
439 | }
440 | }
441 |
442 | embedding_model_arn = f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.embedding_model}"
443 | knowledgebase_configuration = {
444 | "type": "VECTOR",
445 | "vectorKnowledgeBaseConfiguration": {
446 | "embeddingModelArn": embedding_model_arn
447 | }
448 | }
449 |
450 | try:
451 | create_kb_response = self.bedrock_agent_client.create_knowledge_base(
452 | name=self.kb_name,
453 | description=self.kb_description,
454 | roleArn=self.bedrock_kb_execution_role['Role']['Arn'],
455 | knowledgeBaseConfiguration=knowledgebase_configuration,
456 | storageConfiguration={
457 | "type": "OPENSEARCH_SERVERLESS",
458 | "opensearchServerlessConfiguration": opensearch_serverless_configuration
459 | }
460 | )
461 | kb = create_kb_response["knowledgeBase"]
462 | pp.pprint(kb)
463 | except self.bedrock_agent_client.exceptions.ConflictException:
464 | kbs = self.bedrock_agent_client.list_knowledge_bases(maxResults=100)
465 | kb_id = next((kb['knowledgeBaseId'] for kb in kbs['knowledgeBaseSummaries'] if kb['name'] == self.kb_name), None)
466 | response = self.bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id)
467 | kb = response['knowledgeBase']
468 | pp.pprint(kb)
469 |
470 | # Create Data Sources
471 | print("Creating Data Sources")
472 | ds_list = []
473 | chunking_strategy_configuration = self.create_chunking_strategy_config(self.chunking_strategy)
474 |
475 | for idx, ds in enumerate(self.data_sources):
476 | if ds['type'] == "S3":
477 | ds_name = f'{kb["knowledgeBaseId"]}-s3'
478 | s3_data_source_configuration = {
479 | "type": "S3",
480 | "s3Configuration":{
481 | "bucketArn": f'arn:aws:s3:::{ds["bucket_name"]}'
482 | }
483 | }
484 |
485 | vector_ingestion_configuration = {
486 | "chunkingConfiguration": chunking_strategy_configuration['chunkingConfiguration']
487 | }
488 |
489 | create_ds_response = self.bedrock_agent_client.create_data_source(
490 | name = ds_name,
491 | description = self.kb_description,
492 | knowledgeBaseId = kb['knowledgeBaseId'],
493 | dataSourceConfiguration = s3_data_source_configuration,
494 | vectorIngestionConfiguration = vector_ingestion_configuration
495 | )
496 | ds = create_ds_response["dataSource"]
497 | pp.pprint(ds)
498 | ds_list.append(ds)
499 |
500 | return kb, ds_list
501 |
502 | def start_ingestion_job(self):
503 | for idx, ds in enumerate(self.data_source):
504 | try:
505 | start_job_response = self.bedrock_agent_client.start_ingestion_job(
506 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'],
507 | dataSourceId=ds["dataSourceId"]
508 | )
509 | job = start_job_response["ingestionJob"]
510 | print(f"job {idx+1} started successfully\n")
511 |
512 | while job['status'] not in ["COMPLETE", "FAILED", "STOPPED"]:
513 | get_job_response = self.bedrock_agent_client.get_ingestion_job(
514 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'],
515 | dataSourceId=ds["dataSourceId"],
516 | ingestionJobId=job["ingestionJobId"]
517 | )
518 | job = get_job_response["ingestionJob"]
519 | pp.pprint(job)
520 | interactive_sleep(40)
521 |
522 | except Exception as e:
523 | print(f"Couldn't start {idx} job.\n")
524 | print(e)
525 |
526 | def delete_kb(self, delete_s3_bucket=False):
527 | with warnings.catch_warnings():
528 | warnings.filterwarnings("ignore")
529 |
530 | # Delete data sources
531 | ds_id_list = self.bedrock_agent_client.list_data_sources(
532 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'],
533 | maxResults=100
534 | )['dataSourceSummaries']
535 |
536 | for idx, ds in enumerate(ds_id_list):
537 | try:
538 | self.bedrock_agent_client.delete_data_source(
539 | dataSourceId=ds_id_list[idx]["dataSourceId"],
540 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId']
541 | )
542 | print("======== Data source deleted =========")
543 | except Exception as e:
544 | print(e)
545 |
546 | # Delete KB
547 | try:
548 | self.bedrock_agent_client.delete_knowledge_base(
549 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId']
550 | )
551 | print("======== Knowledge base deleted =========")
552 | except Exception as e:
553 | print(e)
554 |
555 | time.sleep(20)
556 |
557 | # Delete OSS collection and policies
558 | try:
559 | self.aoss_client.delete_collection(id=self.collection_id)
560 | self.aoss_client.delete_access_policy(type="data", name=self.access_policy_name)
561 | self.aoss_client.delete_security_policy(type="network", name=self.network_policy_name)
562 | self.aoss_client.delete_security_policy(type="encryption", name=self.encryption_policy_name)
563 | print("======== Vector Index, collection and associated policies deleted =========")
564 | except Exception as e:
565 | print(e)
566 |
567 | # Delete role and policies
568 | self.delete_iam_role_and_policies()
569 |
570 | # Delete S3 bucket if requested
571 | if delete_s3_bucket:
572 | for bucket_name in self.bucket_names:
573 | try:
574 | bucket = boto3.resource('s3').Bucket(bucket_name)
575 | bucket.objects.all().delete()
576 | bucket.delete()
577 | print(f"Deleted bucket {bucket_name}")
578 | except Exception as e:
579 | print(f"Error deleting bucket {bucket_name}: {e}")
580 |
581 | def delete_iam_role_and_policies(self):
582 | # Fetch attached policies
583 | response = self.iam_client.list_attached_role_policies(
584 | RoleName=self.kb_execution_role_name
585 | )
586 | policies_to_detach = response['AttachedPolicies']
587 |
588 | for policy in policies_to_detach:
589 | policy_arn = policy['PolicyArn']
590 | try:
591 | self.iam_client.detach_role_policy(
592 | RoleName=self.kb_execution_role_name,
593 | PolicyArn=policy_arn
594 | )
595 | self.iam_client.delete_policy(PolicyArn=policy_arn)
596 | except Exception as e:
597 | print(f"Error detaching/deleting policy {policy_arn}: {e}")
598 |
599 | try:
600 | self.iam_client.delete_role(RoleName=self.kb_execution_role_name)
601 | print("======== All IAM roles and policies deleted =========")
602 | except Exception as e:
603 | print(f"Error deleting role {self.kb_execution_role_name}: {e}")
604 |
605 |
606 | def create_nova_inference_profile(self, profile_name, throughput=1):
607 | try:
608 | bedrock_client = boto3.client('bedrock')
609 |
610 | request_params = {
611 | "inferenceProfileName": profile_name,
612 | "modelSource": {
613 | "copyFrom": f"arn:aws:bedrock:{self.region_name}:{self.account_number}:inference-profile/us.{self.generation_model}"
614 | }
615 | }
616 |
617 | try:
618 | existing_profiles = bedrock_client.list_inference_profiles()
619 | for profile in existing_profiles.get('inferenceProfiles', []):
620 | if profile['inferenceProfileName'] == profile_name:
621 | print(f"Profile {profile_name} already exists")
622 | return profile['inferenceProfileArn']
623 | except Exception as e:
624 | print(f"Error checking existing profiles: {str(e)}")
625 |
626 | response = bedrock_client.create_inference_profile(**request_params)
627 |
628 | profile_arn = response['inferenceProfileArn']
629 |
630 | max_attempts = 30
631 | attempt = 0
632 | while attempt < max_attempts:
633 | try:
634 | status_response = bedrock_client.get_inference_profile(
635 | inferenceProfileIdentifier=profile_arn
636 | )
637 | status = status_response['status']
638 | print(f"Current status: {status}")
639 |
640 | if status == 'ACTIVE':
641 | print(f"Inference profile created successfully: {profile_arn}")
642 | return profile_arn
643 | elif status in ['FAILED', 'DELETED']:
644 | raise Exception(f"Profile creation failed with status: {status}")
645 |
646 | print("Waiting for profile to be ready...")
647 | time.sleep(10)
648 | attempt += 1
649 | except Exception as e:
650 | print(f"Error checking status: {str(e)}")
651 | raise
652 |
653 | raise Exception("Profile creation timed out")
654 |
655 | except Exception as e:
656 | print(f"Error creating inference profile: {str(e)}")
657 | traceback.print_exc()
658 | raise
659 |
660 | def delete_nova_inference_profile(self, profile_name):
661 | try:
662 | self.bedrock_agent_client.delete_inference_profile(
663 | inferenceProfileIdentifier=profile_name
664 | )
665 | print(f"Profile {profile_name} deleted successfully")
666 |
667 | except Exception as e:
668 | print(f"Error deleting profile: {str(e)}")
669 |
--------------------------------------------------------------------------------
/lab/utils/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3
2 | opensearch-py
3 | botocore
4 | awscli
5 | retrying
6 | ragas==0.1.9
7 | ipywidgets
8 | iprogress
9 | langchain
10 | langchain_aws
11 | langchain_community
12 | s3fs
13 | requests
14 | pypdf
15 | tqdm
16 | pandas
17 | PyPDF2
18 | python-docx
19 | pillow
20 | sagemaker==2.237.1
--------------------------------------------------------------------------------