├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app_chatbot
├── chatbot.py
└── requirements.txt
├── data_preparation.ipynb
├── data_preparation
├── preprocessing.py
├── requirements.txt
└── split_paragraph.py
├── images
├── architecture_UI.png
└── offline_architecture.png
├── test_raw_data
├── demo-video-sagemaker-doc.mp4
└── test.webm
└── video_question_answering_langchain.ipynb
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Implement a RAG solution for Video/Audio data
2 |
3 | The existing LLM solution such as RAG or ChatBot only support text data sources. However, video/audio data is also one of the most important knowledge base for organizations holding massive media data. In addition, compared to text data such as documents or books, it's harder to look up information from video/audio data. People may have to go through all video/audio file to localize the information they need.
4 |
5 | In this project, we provide video and audio processing solution for adopting generative AI on video and audio data. There are two main scenarios, 1) Enterprise can enrich their knowledge base with the existing video/audio data, which can make RAG more possible to get relevant information from knowledge base. 2) Individual users can efficiently get the informaiton they are interested in and reach to the most relevant localtion in the video/audio file, which can save much time to look up the information.
6 |
7 | We demonstrate how to use our outputs in RAG solution with [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/). If you want to try RAG with OpenSearch, you can refer to [Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain](https://aws.amazon.com/blogs/machine-learning/build-a-powerful-question-answering-bot-with-amazon-sagemaker-amazon-opensearch-service-streamlit-and-langchain/) for modification.
8 |
9 | The solution architecture is as below:
10 |
11 |
12 |
Solution Architecture
13 |
14 |
15 | The workflow mainly consists of the following stages:
16 |
17 | * Convert video to text with Speech-to-text model and sentence embedding model
18 | * Intelligent video search using Retrieval Augmented Generation (RAG) based approach and LangChain
19 | * Example - Build a multi-functional chatbot with Amazon SageMaker
20 |
21 |
22 |
23 | ## Convert video to text with Speech-to-text model and sentence embedding model
24 | We use whisper to transcribe video and audio data, and use sentence embedding approach to chunk sentences. You can run [Convert video to text with Speech-to-text model and sentence embedding model](data_preparation.ipynb) for this task.
25 |
26 |
27 | ## Intelligent video search using Retrieval Augmented Generation (RAG) based approach and LangChain
28 | We use our data transcriped from video/audio files to build a RAG solution with LangChain by following the blog Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart and modifying the source code. You can run [Question Answering based on Custom Video/Audio Dataset with Open-sourced LangChain Library](video_question_answering_langchai.ipynb) for this task.
29 |
30 |
31 | ## Example - Build a multi-functional chatbot with Amazon SageMaker
32 |
33 |
Solution Architecture
34 |
35 | We demonstrate how to use Streamlit, LangChain and SageMaker to build a multi-functional chatbot to provide an interactive experience for the users. To run the streamlit application, you need to firstly update the endpoint names in the environment variables based on the endpoints deployed in your account accordingly. Open a terminal in SageMaker Studio and navigate to the cloned github repository folder. Then run below commands:
36 |
37 | ```python
38 | export falcon_ep_name=
39 | export wp_ep_name=
40 | export embed_ep_name=
41 | streamlit run app_chatbot/chatbot.py --server.port 6006 --server.maxUploadSize 6
42 | ```
43 |
44 | To access the Streamlit UI, copy your SageMaker Studio url and replace `lab?` with `proxy/[PORT NUMBER]/`. Because we specified the server port to 6006, so the url should look like:
45 | `https://.studio..sagemaker.aws/jupyter/default/proxy/6006/`
46 | Replace the domain ID and region with the correct value in your account to access the UI as below
47 |
48 |
49 | ## Application
50 | We demonstrate a [chatbot](./app_chatbot) application for video and audio data with steamlit.
51 |
52 | ## Test data
53 | We record [video data](test_raw_data/demo-video-sagemaker-doc.mp4) and [audio data](test_raw_data/test.webm) so that you can use these data to test this solution.
54 |
55 | ## Security
56 |
57 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
58 |
59 | ## License
60 |
61 | This library is licensed under the MIT-0 License. See the LICENSE file.
--------------------------------------------------------------------------------
/app_chatbot/chatbot.py:
--------------------------------------------------------------------------------
1 | ## streamlit run chatbot.py --server.port 6006 --server.maxUploadSize 6
2 |
3 | import streamlit as st
4 | from streamlit_chat import message
5 | from langchain.chains import ConversationChain
6 | from langchain.memory import ConversationBufferMemory
7 | from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
8 | from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
9 | from typing import Any, Dict, List, Optional
10 | import json
11 | from io import StringIO, BytesIO
12 | from random import randint
13 | from transformers import AutoTokenizer
14 | from PIL import Image
15 | import boto3
16 | import numpy as np
17 | import pandas as pd
18 | import json
19 | import os
20 | import base64
21 | from langchain.embeddings import SagemakerEndpointEmbeddings
22 | from langchain.indexes import VectorstoreIndexCreator
23 | from langchain.vectorstores import FAISS
24 | from langchain.text_splitter import CharacterTextSplitter
25 | from langchain.document_loaders import DirectoryLoader
26 | from langchain import PromptTemplate
27 | from langchain.chains.question_answering import load_qa_chain
28 |
29 | client = boto3.client('runtime.sagemaker')
30 | aws_region = boto3.Session().region_name
31 | source = []
32 |
33 | def query_endpoint_with_json_payload(encoded_json, endpoint_name):
34 | response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/json', Body=encoded_json)
35 | return response
36 |
37 | def parse_response(query_response):
38 | response_dict = json.loads(query_response['Body'].read())
39 | return response_dict['generated_images'], response_dict['prompt']
40 |
41 | st.set_page_config(page_title="Document Analysis", page_icon=":robot:")
42 |
43 |
44 | Falcon_endpoint_name = os.getenv("falcon_ep_name", default="falcon-40b-instruct-12xl")
45 | whisper_endpoint_name = os.getenv('wp_ep_name', default="whisper-large-v2")
46 | embedding_endpoint_name = os.getenv('embed_ep_name', default="huggingface-textembedding-gpt-j-6b")
47 |
48 | endpoint_names = {
49 | "NLP":Falcon_endpoint_name,
50 | "Audio":whisper_endpoint_name
51 | }
52 |
53 | ################# Prepare for RAG solution #######################
54 | class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
55 | def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
56 | """Compute doc embeddings using a SageMaker Inference Endpoint.
57 |
58 | Args:
59 | texts: The list of texts to embed.
60 | chunk_size: The chunk size defines how many input texts will
61 | be grouped together as request. If None, will use the
62 | chunk size specified by the class.
63 |
64 | Returns:
65 | List of embeddings, one for each text.
66 | """
67 | results = []
68 | _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
69 |
70 | for i in range(0, len(texts), _chunk_size):
71 | response = self._embedding_func(texts[i : i + _chunk_size])
72 | print
73 | results.extend(response)
74 | return results
75 |
76 |
77 | class ContentHandlerEmbed(EmbeddingsContentHandler):
78 | content_type = "application/json"
79 | accepts = "application/json"
80 |
81 | def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
82 | input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
83 | return input_str.encode("utf-8")
84 |
85 | def transform_output(self, output: bytes) -> str:
86 | response_json = json.loads(output.read().decode("utf-8"))
87 | embeddings = response_json["embedding"]
88 | return embeddings
89 |
90 |
91 | content_handler_embed = ContentHandlerEmbed()
92 |
93 | embeddings = SagemakerEndpointEmbeddingsJumpStart(
94 | endpoint_name=embedding_endpoint_name,
95 | region_name=aws_region,
96 | content_handler=content_handler_embed,
97 | )
98 |
99 | @st.cache_resource
100 | def generate_index():
101 | loader = DirectoryLoader("./data/demo-video-sagemaker-doc/", glob="**/*.txt")
102 | documents = loader.load()
103 | docsearch = FAISS.from_documents(documents, embeddings)
104 | return docsearch
105 |
106 | docsearch = generate_index()
107 |
108 |
109 | ################# Prepare for chatbot with memory #######################
110 |
111 | class ContentHandler(LLMContentHandler):
112 | content_type = "application/json"
113 | accepts = "application/json"
114 | len_prompt = 0
115 |
116 | def transform_input(self, prompt: str, model_kwargs: Dict={}) -> bytes:
117 | self.len_prompt = len(prompt)
118 | input_str = json.dumps({"inputs": prompt, "parameters":{"max_new_tokens": st.session_state.max_token, "temperature":st.session_state.temperature, "seed":st.session_state.seed, "stop": ["Human:"], "num_beams":1, "return_full_text": False}})
119 | print(input_str)
120 | return input_str.encode('utf-8')
121 |
122 | def transform_output(self, output: bytes) -> str:
123 | response_json = output.read()
124 | res = json.loads(response_json)
125 | # print(res)
126 | ans = res[0]['generated_text']#[self.len_prompt:]
127 | ans = ans[:ans.rfind("Human")].strip()
128 |
129 | return ans
130 |
131 |
132 |
133 | content_handler = ContentHandler()
134 |
135 | llm = SagemakerEndpoint(
136 | endpoint_name=Falcon_endpoint_name,
137 | region_name="us-east-1",
138 | content_handler=content_handler,
139 | )
140 |
141 | @st.cache_resource
142 | def load_chain(endpoint_name: str=Falcon_endpoint_name):
143 |
144 | memory = ConversationBufferMemory(return_messages=True)
145 | chain = ConversationChain(llm=llm, memory=memory)
146 | return chain
147 |
148 | chatchain = load_chain()
149 |
150 |
151 | # initialise session variables
152 | if 'generated' not in st.session_state:
153 | st.session_state['generated'] = []
154 | if 'past' not in st.session_state:
155 | st.session_state['past'] = []
156 | chatchain.memory.clear()
157 |
158 | if 'widget_key' not in st.session_state:
159 | st.session_state['widget_key'] = str(randint(1000, 100000000))
160 | if 'max_token' not in st.session_state:
161 | st.session_state.max_token = 200
162 | if 'temperature' not in st.session_state:
163 | st.session_state.temperature = 0.1
164 | if 'seed' not in st.session_state:
165 | st.session_state.seed = 0
166 | if 'extract_audio' not in st.session_state:
167 | st.session_state.extract_audio = False
168 | if 'option' not in st.session_state:
169 | st.session_state.option = "NLP"
170 |
171 | def clear_button_fn():
172 | st.session_state['generated'] = []
173 | st.session_state['past'] = []
174 | st.session_state['widget_key'] = str(randint(1000, 100000000))
175 | st.widget_key = str(randint(1000, 100000000))
176 | st.session_state.extract_audio = False
177 | chatchain = load_chain(endpoint_name=endpoint_names['NLP'])
178 | chatchain.memory.clear()
179 |
180 |
181 | def on_file_upload():
182 | st.session_state.extract_audio = True
183 | st.session_state['generated'] = []
184 | st.session_state['past'] = []
185 | # st.session_state['widget_key'] = str(randint(1000, 100000000))
186 | chatchain.memory.clear()
187 |
188 |
189 |
190 | with st.sidebar:
191 | # Sidebar - the clear button is will flush the memory of the conversation
192 | st.sidebar.title("Conversation setup")
193 | clear_button = st.sidebar.button("Clear Conversation", key="clear", on_click=clear_button_fn)
194 |
195 | # upload file button
196 | uploaded_file = st.sidebar.file_uploader("Upload a file (text or audio)",
197 | key=st.session_state['widget_key'],
198 | on_change=on_file_upload,
199 | )
200 | if uploaded_file:
201 | filename = uploaded_file.name
202 | print(filename)
203 | if filename.lower().endswith(('.flac', '.wav', '.webm', 'mp3')):
204 | st.session_state.option = "Audio"
205 | byteio = BytesIO(uploaded_file.getvalue())
206 | data = byteio.read()
207 | st.audio(data, format='audio/webm')
208 | else:
209 | st.session_state.option = "NLP"
210 |
211 | rag = st.checkbox('Use knowledge base')
212 |
213 |
214 |
215 |
216 | left_column, _, right_column = st.columns([50, 2, 20])
217 |
218 | with left_column:
219 | st.header("Building a multifunctional chatbot with Amazon SageMaker")
220 | # this is the container that displays the past conversation
221 | response_container = st.container()
222 | # this is the container with the input text box
223 | container = st.container()
224 |
225 | with container:
226 | # define the input text box
227 | with st.form(key='my_form', clear_on_submit=True):
228 | user_input = st.text_area("Input text:", key='input', height=100)
229 | submit_button = st.form_submit_button(label='Send')
230 |
231 |
232 | # when the submit button is pressed we send the user query to the chatchain object and save the chat history
233 | if submit_button and user_input:
234 | st.session_state.option = "NLP"
235 | if rag:
236 | # output = index.query(question=user_input, llm=llm)
237 | docs = docsearch.similarity_search_with_score(user_input)
238 | contexts = []
239 |
240 | for doc, score in docs:
241 | print(f"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}")
242 | if score <= 0.9:
243 | contexts.append(doc)
244 | source.append(doc.metadata['source'].split('/')[-1])
245 | print(f"\n INPUT CONTEXT:{contexts}")
246 | prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.:\n\n{context}\n\nQuestion: {question}\nHelpful Answer:"""
247 |
248 | PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
249 | chain = load_qa_chain(llm=llm, prompt=PROMPT)
250 | result = chain({"input_documents": contexts, "question": user_input},
251 | return_only_outputs=True)["output_text"]
252 | output = result
253 | else:
254 | output = chatchain(user_input)["response"]
255 | print(output)
256 | st.session_state['past'].append(user_input)
257 | st.session_state['generated'].append(output)
258 | # when a file is uploaded we also send the content to the chatchain object and ask for confirmation
259 | elif uploaded_file is not None:
260 | if st.session_state.option == "Audio" and st.session_state.extract_audio:
261 | byteio = BytesIO(uploaded_file.getvalue())
262 | data = byteio.read()
263 | response = client.invoke_endpoint(EndpointName=whisper_endpoint_name, ContentType='audio/x-audio', Body=data)
264 | output = json.loads(response['Body'].read())["text"]
265 | st.session_state['past'].append("I have uploaded an audio file. Plese extract the text from this audio file")
266 | st.session_state['generated'].append(output)
267 | content = "=== BEGIN AUDIO FILE ===\n"
268 | content += output
269 | content += "\n=== END AUDIO FILE ===\nPlease remember the audio file by saying 'Yes, I remembered the audio file'"
270 | output = chatchain(content)["response"]
271 | print(output)
272 | st.session_state.extract_audio = False
273 | elif st.session_state.option == "NLP":
274 | stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
275 | content = "=== BEGIN FILE ===\n"
276 | content += stringio.read().strip()
277 | content += "\n=== END FILE ===\nPlease confirm that you have read that file by saying 'Yes, I have read the file'"
278 | output = chatchain(content)["response"]
279 | st.session_state['past'].append("I have uploaded a file. Please confirm that you have read that file.")
280 | st.session_state['generated'].append(output)
281 |
282 | if len(source) != 0:
283 | df = pd.DataFrame(source, columns=['knowledge source'])
284 | st.data_editor(df)
285 | source = []
286 |
287 | st.write(f"Currently using a {st.session_state.option} model")
288 |
289 |
290 | # this loop is responsible for displaying the chat history
291 | if st.session_state['generated']:
292 | with response_container:
293 | for i in range(len(st.session_state['generated'])):
294 | message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
295 | message(st.session_state["generated"][i], key=str(i))
296 |
297 |
298 | with right_column:
299 |
300 | max_tokens= st.slider(
301 | min_value=8,
302 | max_value=1024,
303 | step=1,
304 | # value=200,
305 | label="Number of tokens to generate",
306 | key="max_token"
307 | )
308 | temperature = st.slider(
309 | min_value=0.1,
310 | max_value=2.5,
311 | step=0.1,
312 | # value=0.4,
313 | label="Temperature",
314 | key="temperature"
315 | )
316 | seed = st.slider(
317 | min_value=0,
318 | max_value=1000,
319 | # value=0,
320 | step=1,
321 | label="Random seed to use for the generation",
322 | key="seed"
323 | )
324 |
325 |
326 |
--------------------------------------------------------------------------------
/app_chatbot/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.4
2 | aiosignal==1.3.1
3 | altair==5.0.1
4 | async-timeout==4.0.2
5 | attrs==23.1.0
6 | blinker==1.6.2
7 | boto3==1.26.150
8 | botocore==1.29.150
9 | cachetools==5.3.1
10 | certifi==2023.5.7
11 | charset-normalizer==3.1.0
12 | click==8.1.3
13 | dataclasses-json==0.5.7
14 | decorator==5.1.1
15 | frozenlist==1.3.3
16 | gitdb==4.0.10
17 | GitPython==3.1.31
18 | idna==3.4
19 | importlib-metadata==6.6.0
20 | Jinja2==3.1.2
21 | jmespath==1.0.1
22 | jsonschema==4.17.3
23 | langchain==0.0.195
24 | langchainplus-sdk==0.0.8
25 | markdown-it-py==2.2.0
26 | MarkupSafe==2.1.3
27 | marshmallow==3.19.0
28 | marshmallow-enum==1.5.1
29 | mdurl==0.1.2
30 | multidict==6.0.4
31 | mypy-extensions==1.0.0
32 | numexpr==2.8.4
33 | numpy==1.24.3
34 | openapi-schema-pydantic==1.2.4
35 | packaging==23.1
36 | pandas==2.0.2
37 | Pillow==9.5.0
38 | protobuf==4.23.2
39 | pyarrow==12.0.0
40 | pydantic==1.10.9
41 | pydeck==0.8.1b0
42 | Pygments==2.15.1
43 | Pympler==1.0.1
44 | pyrsistent==0.19.3
45 | python-dateutil==2.8.2
46 | pytz==2023.3
47 | pytz-deprecation-shim==0.1.0.post0
48 | PyYAML==6.0
49 | requests==2.31.0
50 | rich==13.4.1
51 | s3transfer==0.6.1
52 | six==1.16.0
53 | smmap==5.0.0
54 | SQLAlchemy==2.0.15
55 | streamlit==1.23.1
56 | streamlit-chat==0.0.2.2
57 | tenacity==8.2.2
58 | toml==0.10.2
59 | toolz==0.12.0
60 | tornado==6.3.2
61 | typing-inspect==0.9.0
62 | typing_extensions==4.6.3
63 | tzdata==2023.3
64 | tzlocal==4.3
65 | urllib3==1.26.16
66 | validators==0.20.0
67 | yarl==1.9.2
68 | zipp==3.15.0
69 | unstructured==0.8.1
70 | transformers~=4.30.2
71 | faiss-cpu==1.7.4
--------------------------------------------------------------------------------
/data_preparation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "cfd8d0c0-2c2d-4751-8f5d-999d79ae6108",
6 | "metadata": {},
7 | "source": [
8 | "## Convert video to text with Speech-to-text model and sentence embedding model\n",
9 | "\n",
10 | "In this notebook, we will extract information from video/audio files with [Whipser model](https://github.com/openai/whisper). Be leveraging multilingual support, we can extract tanscripts from videos files mixed different languages, even for one video file with different languanges. We provide the following options for whisper inference:\n",
11 | "- Batch inference with SageMaker Processing job, we can process massive data and store them into vector database for RAG solution.\n",
12 | "- Real-time inference with SageMaker Endpoint, we can leverage it to do summarizaton or QA with a short video/audio file (less than 6MB)."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "94a02764-403f-4c35-9754-e99e2a8d5b58",
19 | "metadata": {
20 | "tags": []
21 | },
22 | "outputs": [],
23 | "source": [
24 | "!pip install -U sagemaker -q"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "38922488-e64f-45e2-983a-5c176f4e13ab",
30 | "metadata": {},
31 | "source": [
32 | "## Set up"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "id": "3c209ddd-7a9a-4c7c-b582-6729ce88a4d9",
39 | "metadata": {
40 | "tags": []
41 | },
42 | "outputs": [],
43 | "source": [
44 | "from sagemaker.huggingface import HuggingFaceProcessor\n",
45 | "from sagemaker import get_execution_role\n",
46 | "from sagemaker.processing import ProcessingInput, ProcessingOutput\n",
47 | "from sagemaker.huggingface import HuggingFaceModel\n",
48 | "import sagemaker\n",
49 | "import boto3\n",
50 | "import json\n",
51 | "\n",
52 | "try:\n",
53 | " role = sagemaker.get_execution_role()\n",
54 | "except ValueError:\n",
55 | " iam = boto3.client('iam')\n",
56 | " role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
57 | "\n",
58 | "sess = sagemaker.session.Session()\n",
59 | "bucket = sess.default_bucket()\n",
60 | "prefix = \"sagemaker/rag_video\"\n",
61 | "folder_name = \"genai_workshop\"\n",
62 | "s3_input = f\"s3://{bucket}/{prefix}/raw_data/{folder_name}\" # Directory for video files\n",
63 | "s3_output_clips = f\"s3://{bucket}/{prefix}/clips\" # Directory for video clips\n",
64 | "s3_output_transcript = f\"s3://{bucket}/{prefix}/transcript\" # Directory for transcripts"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "id": "7600cddd-9a23-4ab1-abf6-6cedd8b7fa55",
71 | "metadata": {
72 | "tags": []
73 | },
74 | "outputs": [],
75 | "source": [
76 | "%store s3_output_transcript"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "id": "5892e7ff-edfe-4112-9ccd-bc34b2fc1bde",
82 | "metadata": {},
83 | "source": [
84 | "## Upload test data to S3 bucket\n",
85 | "\n",
86 | "Download data from YouTube."
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "id": "388a7d5d-eeaa-48bd-9b5b-9a352489cc89",
93 | "metadata": {
94 | "tags": []
95 | },
96 | "outputs": [],
97 | "source": [
98 | "# Download data from YouTube\n",
99 | "!pip install pytube"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "id": "d5eee50a-9c66-456b-879f-98a09d85d87e",
106 | "metadata": {
107 | "tags": []
108 | },
109 | "outputs": [],
110 | "source": [
111 | "from pytube import YouTube\n",
112 | "\n",
113 | "VIDEO_SAVE_DIRECTORY = \"./videos\"\n",
114 | "AUDIO_SAVE_DIRECTORY = \"./audio\"\n",
115 | "video_name = \"genai_interview.mp4\"\n",
116 | "def download(video_url):\n",
117 | " video = YouTube(video_url)\n",
118 | " video = video.streams.get_highest_resolution()\n",
119 | "\n",
120 | " try:\n",
121 | " video.download(VIDEO_SAVE_DIRECTORY, filename=video_name)\n",
122 | " except:\n",
123 | " print(\"Failed to download video\")\n",
124 | "\n",
125 | " print(\"video was downloaded successfully\")\n",
126 | " \n",
127 | "def download_audio(video_url):\n",
128 | " video = YouTube(video_url)\n",
129 | " audio = video.streams.filter(only_audio = True).first()\n",
130 | "\n",
131 | " try:\n",
132 | " audio.download(AUDIO_SAVE_DIRECTORY)\n",
133 | " except:\n",
134 | " print(\"Failed to download audio\")\n",
135 | "\n",
136 | " print(\"audio was downloaded successfully\")"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "4888c5f4-52a0-4b3d-95e5-0bd0cca84af5",
143 | "metadata": {
144 | "tags": []
145 | },
146 | "outputs": [],
147 | "source": [
148 | "# JAWS-UG AI/ML (Japanese) #16 Generative AI: https://www.youtube.com/watch?v=PkZenNAXtYs\n",
149 | "# New York Summit 2023 AIML: https://www.youtube.com/watch?v=1PkABWCJINM Totally 36mins"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "id": "b8e6fe10-3eed-4578-b4ad-7b4e044f84a7",
156 | "metadata": {
157 | "tags": []
158 | },
159 | "outputs": [],
160 | "source": [
161 | "download(\"https://www.youtube.com/watch?v=dBzCGcwYCJo\")"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "id": "02daea10-0cbf-4f5d-b294-023f6d56b12f",
168 | "metadata": {
169 | "tags": []
170 | },
171 | "outputs": [],
172 | "source": [
173 | "!aws s3 cp videos/{video_name} {s3_input}/"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "id": "7f00c2e3-c897-4a2c-a12e-b07c75ca5986",
179 | "metadata": {
180 | "tags": []
181 | },
182 | "source": [
183 | "## Batch inference with SageMaker Processing"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "id": "2e002118-b7dc-4915-9fe0-f80e3bbfe847",
190 | "metadata": {
191 | "tags": []
192 | },
193 | "outputs": [],
194 | "source": [
195 | "hfp = HuggingFaceProcessor(\n",
196 | " role=get_execution_role(), \n",
197 | " instance_count=1,\n",
198 | " instance_type='ml.p3.2xlarge',\n",
199 | " transformers_version='4.28.1',\n",
200 | " pytorch_version='2.0.0', \n",
201 | " base_job_name='frameworkprocessor-hf',\n",
202 | " py_version=\"py310\"\n",
203 | ")"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "id": "90383250-179b-4499-8583-fdca5320ee75",
210 | "metadata": {
211 | "scrolled": true,
212 | "tags": []
213 | },
214 | "outputs": [],
215 | "source": [
216 | "hfp.run(\n",
217 | " code='preprocessing.py',\n",
218 | " source_dir=\"data_preparation\",\n",
219 | " inputs=[\n",
220 | " ProcessingInput(source=s3_input, destination=\"/opt/ml/processing/input\")\n",
221 | " ], \n",
222 | " outputs=[\n",
223 | " ProcessingOutput(source='/opt/ml/processing/output_clips', destination=s3_output_clips),\n",
224 | " ProcessingOutput(source='/opt/ml/processing/transcripts', destination=s3_output_transcript),\n",
225 | " ],\n",
226 | " arguments=[\n",
227 | " \"--whisper-model\", \"whisper-large-v2\",\n",
228 | " \"--target-language\", \"en\",\n",
229 | " \"--sentence-embedding-model\", \"all-mpnet-base-v2\",\n",
230 | " \"--order\", \"5\"\n",
231 | " ]\n",
232 | ")"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "1c4073c6-19a6-4c5c-93eb-b66eb99eb676",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "!mkdir -p video-scripts\n",
243 | "!aws s3 sync $s3_output_transcript/genai_interview/ video-scripts"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "id": "1efea032-76d4-478c-943a-a4be59d47ea7",
249 | "metadata": {},
250 | "source": [
251 | "## Deploy Whipser model to SageMaker for real-time inference"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "id": "0ed73b43-b5c7-4e1a-bd15-8e020cab8f51",
258 | "metadata": {
259 | "tags": []
260 | },
261 | "outputs": [],
262 | "source": [
263 | "endpoint_name=\"whisper-large-v2\"\n",
264 | "# Hub Model configuration. https://huggingface.co/models\n",
265 | "hub = {\n",
266 | " 'HF_MODEL_ID':'openai/whisper-large-v2',\n",
267 | " 'HF_TASK':'automatic-speech-recognition',\n",
268 | "}\n",
269 | "\n",
270 | "# create Hugging Face Model Class\n",
271 | "huggingface_model = HuggingFaceModel(\n",
272 | " transformers_version='4.26.0',\n",
273 | " pytorch_version='1.13.1',\n",
274 | " py_version='py39',\n",
275 | " \n",
276 | " env=hub,\n",
277 | " role=role\n",
278 | ")"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": null,
284 | "id": "3d00b59f-c17c-4446-a900-55e2835c5625",
285 | "metadata": {
286 | "tags": []
287 | },
288 | "outputs": [],
289 | "source": [
290 | "# deploy model to SageMaker Inference\n",
291 | "predictor = huggingface_model.deploy(\n",
292 | " endpoint_name=endpoint_name,\n",
293 | " initial_instance_count=1, # number of instances\n",
294 | " instance_type='ml.g5.xlarge' # ec2 instance type\n",
295 | ")"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": null,
301 | "id": "9cf3d59a-0ccf-4baf-95d9-a292c43872dc",
302 | "metadata": {
303 | "tags": []
304 | },
305 | "outputs": [],
306 | "source": [
307 | "client = boto3.client('runtime.sagemaker')\n",
308 | "file = \"test_raw_data/test.webm\"\n",
309 | "with open(file, \"rb\") as f:\n",
310 | " data = f.read()"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "id": "e8e8d3ec-15a3-4e5f-b85b-19dd5ce98dd1",
317 | "metadata": {
318 | "tags": []
319 | },
320 | "outputs": [],
321 | "source": [
322 | "response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='audio/x-audio', Body=data)\n",
323 | "output = json.loads(response['Body'].read())\n",
324 | "print(f\"Extracted text from the audio file:\\n {output['text']}\")"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "id": "959e9aac-7690-478e-a703-9e9d433f9fb0",
330 | "metadata": {
331 | "tags": []
332 | },
333 | "source": [
334 | "You can follow section for `Example - Build a multi-functional chatbot with Amazon SageMaker` in [REAMDE](./README.md) to build a multi-functional chatbot with whipser endpoint.\n",
335 | "Please delete endpoint once you don't it."
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "f23df075-6f8b-42f7-994c-b8510d87c3dd",
342 | "metadata": {
343 | "tags": []
344 | },
345 | "outputs": [],
346 | "source": [
347 | "predictor.delete_endpoint()"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "id": "1a8c26ef-a1e2-47af-a621-309a594a80fd",
354 | "metadata": {},
355 | "outputs": [],
356 | "source": []
357 | }
358 | ],
359 | "metadata": {
360 | "availableInstances": [
361 | {
362 | "_defaultOrder": 0,
363 | "_isFastLaunch": true,
364 | "category": "General purpose",
365 | "gpuNum": 0,
366 | "hideHardwareSpecs": false,
367 | "memoryGiB": 4,
368 | "name": "ml.t3.medium",
369 | "vcpuNum": 2
370 | },
371 | {
372 | "_defaultOrder": 1,
373 | "_isFastLaunch": false,
374 | "category": "General purpose",
375 | "gpuNum": 0,
376 | "hideHardwareSpecs": false,
377 | "memoryGiB": 8,
378 | "name": "ml.t3.large",
379 | "vcpuNum": 2
380 | },
381 | {
382 | "_defaultOrder": 2,
383 | "_isFastLaunch": false,
384 | "category": "General purpose",
385 | "gpuNum": 0,
386 | "hideHardwareSpecs": false,
387 | "memoryGiB": 16,
388 | "name": "ml.t3.xlarge",
389 | "vcpuNum": 4
390 | },
391 | {
392 | "_defaultOrder": 3,
393 | "_isFastLaunch": false,
394 | "category": "General purpose",
395 | "gpuNum": 0,
396 | "hideHardwareSpecs": false,
397 | "memoryGiB": 32,
398 | "name": "ml.t3.2xlarge",
399 | "vcpuNum": 8
400 | },
401 | {
402 | "_defaultOrder": 4,
403 | "_isFastLaunch": true,
404 | "category": "General purpose",
405 | "gpuNum": 0,
406 | "hideHardwareSpecs": false,
407 | "memoryGiB": 8,
408 | "name": "ml.m5.large",
409 | "vcpuNum": 2
410 | },
411 | {
412 | "_defaultOrder": 5,
413 | "_isFastLaunch": false,
414 | "category": "General purpose",
415 | "gpuNum": 0,
416 | "hideHardwareSpecs": false,
417 | "memoryGiB": 16,
418 | "name": "ml.m5.xlarge",
419 | "vcpuNum": 4
420 | },
421 | {
422 | "_defaultOrder": 6,
423 | "_isFastLaunch": false,
424 | "category": "General purpose",
425 | "gpuNum": 0,
426 | "hideHardwareSpecs": false,
427 | "memoryGiB": 32,
428 | "name": "ml.m5.2xlarge",
429 | "vcpuNum": 8
430 | },
431 | {
432 | "_defaultOrder": 7,
433 | "_isFastLaunch": false,
434 | "category": "General purpose",
435 | "gpuNum": 0,
436 | "hideHardwareSpecs": false,
437 | "memoryGiB": 64,
438 | "name": "ml.m5.4xlarge",
439 | "vcpuNum": 16
440 | },
441 | {
442 | "_defaultOrder": 8,
443 | "_isFastLaunch": false,
444 | "category": "General purpose",
445 | "gpuNum": 0,
446 | "hideHardwareSpecs": false,
447 | "memoryGiB": 128,
448 | "name": "ml.m5.8xlarge",
449 | "vcpuNum": 32
450 | },
451 | {
452 | "_defaultOrder": 9,
453 | "_isFastLaunch": false,
454 | "category": "General purpose",
455 | "gpuNum": 0,
456 | "hideHardwareSpecs": false,
457 | "memoryGiB": 192,
458 | "name": "ml.m5.12xlarge",
459 | "vcpuNum": 48
460 | },
461 | {
462 | "_defaultOrder": 10,
463 | "_isFastLaunch": false,
464 | "category": "General purpose",
465 | "gpuNum": 0,
466 | "hideHardwareSpecs": false,
467 | "memoryGiB": 256,
468 | "name": "ml.m5.16xlarge",
469 | "vcpuNum": 64
470 | },
471 | {
472 | "_defaultOrder": 11,
473 | "_isFastLaunch": false,
474 | "category": "General purpose",
475 | "gpuNum": 0,
476 | "hideHardwareSpecs": false,
477 | "memoryGiB": 384,
478 | "name": "ml.m5.24xlarge",
479 | "vcpuNum": 96
480 | },
481 | {
482 | "_defaultOrder": 12,
483 | "_isFastLaunch": false,
484 | "category": "General purpose",
485 | "gpuNum": 0,
486 | "hideHardwareSpecs": false,
487 | "memoryGiB": 8,
488 | "name": "ml.m5d.large",
489 | "vcpuNum": 2
490 | },
491 | {
492 | "_defaultOrder": 13,
493 | "_isFastLaunch": false,
494 | "category": "General purpose",
495 | "gpuNum": 0,
496 | "hideHardwareSpecs": false,
497 | "memoryGiB": 16,
498 | "name": "ml.m5d.xlarge",
499 | "vcpuNum": 4
500 | },
501 | {
502 | "_defaultOrder": 14,
503 | "_isFastLaunch": false,
504 | "category": "General purpose",
505 | "gpuNum": 0,
506 | "hideHardwareSpecs": false,
507 | "memoryGiB": 32,
508 | "name": "ml.m5d.2xlarge",
509 | "vcpuNum": 8
510 | },
511 | {
512 | "_defaultOrder": 15,
513 | "_isFastLaunch": false,
514 | "category": "General purpose",
515 | "gpuNum": 0,
516 | "hideHardwareSpecs": false,
517 | "memoryGiB": 64,
518 | "name": "ml.m5d.4xlarge",
519 | "vcpuNum": 16
520 | },
521 | {
522 | "_defaultOrder": 16,
523 | "_isFastLaunch": false,
524 | "category": "General purpose",
525 | "gpuNum": 0,
526 | "hideHardwareSpecs": false,
527 | "memoryGiB": 128,
528 | "name": "ml.m5d.8xlarge",
529 | "vcpuNum": 32
530 | },
531 | {
532 | "_defaultOrder": 17,
533 | "_isFastLaunch": false,
534 | "category": "General purpose",
535 | "gpuNum": 0,
536 | "hideHardwareSpecs": false,
537 | "memoryGiB": 192,
538 | "name": "ml.m5d.12xlarge",
539 | "vcpuNum": 48
540 | },
541 | {
542 | "_defaultOrder": 18,
543 | "_isFastLaunch": false,
544 | "category": "General purpose",
545 | "gpuNum": 0,
546 | "hideHardwareSpecs": false,
547 | "memoryGiB": 256,
548 | "name": "ml.m5d.16xlarge",
549 | "vcpuNum": 64
550 | },
551 | {
552 | "_defaultOrder": 19,
553 | "_isFastLaunch": false,
554 | "category": "General purpose",
555 | "gpuNum": 0,
556 | "hideHardwareSpecs": false,
557 | "memoryGiB": 384,
558 | "name": "ml.m5d.24xlarge",
559 | "vcpuNum": 96
560 | },
561 | {
562 | "_defaultOrder": 20,
563 | "_isFastLaunch": false,
564 | "category": "General purpose",
565 | "gpuNum": 0,
566 | "hideHardwareSpecs": true,
567 | "memoryGiB": 0,
568 | "name": "ml.geospatial.interactive",
569 | "supportedImageNames": [
570 | "sagemaker-geospatial-v1-0"
571 | ],
572 | "vcpuNum": 0
573 | },
574 | {
575 | "_defaultOrder": 21,
576 | "_isFastLaunch": true,
577 | "category": "Compute optimized",
578 | "gpuNum": 0,
579 | "hideHardwareSpecs": false,
580 | "memoryGiB": 4,
581 | "name": "ml.c5.large",
582 | "vcpuNum": 2
583 | },
584 | {
585 | "_defaultOrder": 22,
586 | "_isFastLaunch": false,
587 | "category": "Compute optimized",
588 | "gpuNum": 0,
589 | "hideHardwareSpecs": false,
590 | "memoryGiB": 8,
591 | "name": "ml.c5.xlarge",
592 | "vcpuNum": 4
593 | },
594 | {
595 | "_defaultOrder": 23,
596 | "_isFastLaunch": false,
597 | "category": "Compute optimized",
598 | "gpuNum": 0,
599 | "hideHardwareSpecs": false,
600 | "memoryGiB": 16,
601 | "name": "ml.c5.2xlarge",
602 | "vcpuNum": 8
603 | },
604 | {
605 | "_defaultOrder": 24,
606 | "_isFastLaunch": false,
607 | "category": "Compute optimized",
608 | "gpuNum": 0,
609 | "hideHardwareSpecs": false,
610 | "memoryGiB": 32,
611 | "name": "ml.c5.4xlarge",
612 | "vcpuNum": 16
613 | },
614 | {
615 | "_defaultOrder": 25,
616 | "_isFastLaunch": false,
617 | "category": "Compute optimized",
618 | "gpuNum": 0,
619 | "hideHardwareSpecs": false,
620 | "memoryGiB": 72,
621 | "name": "ml.c5.9xlarge",
622 | "vcpuNum": 36
623 | },
624 | {
625 | "_defaultOrder": 26,
626 | "_isFastLaunch": false,
627 | "category": "Compute optimized",
628 | "gpuNum": 0,
629 | "hideHardwareSpecs": false,
630 | "memoryGiB": 96,
631 | "name": "ml.c5.12xlarge",
632 | "vcpuNum": 48
633 | },
634 | {
635 | "_defaultOrder": 27,
636 | "_isFastLaunch": false,
637 | "category": "Compute optimized",
638 | "gpuNum": 0,
639 | "hideHardwareSpecs": false,
640 | "memoryGiB": 144,
641 | "name": "ml.c5.18xlarge",
642 | "vcpuNum": 72
643 | },
644 | {
645 | "_defaultOrder": 28,
646 | "_isFastLaunch": false,
647 | "category": "Compute optimized",
648 | "gpuNum": 0,
649 | "hideHardwareSpecs": false,
650 | "memoryGiB": 192,
651 | "name": "ml.c5.24xlarge",
652 | "vcpuNum": 96
653 | },
654 | {
655 | "_defaultOrder": 29,
656 | "_isFastLaunch": true,
657 | "category": "Accelerated computing",
658 | "gpuNum": 1,
659 | "hideHardwareSpecs": false,
660 | "memoryGiB": 16,
661 | "name": "ml.g4dn.xlarge",
662 | "vcpuNum": 4
663 | },
664 | {
665 | "_defaultOrder": 30,
666 | "_isFastLaunch": false,
667 | "category": "Accelerated computing",
668 | "gpuNum": 1,
669 | "hideHardwareSpecs": false,
670 | "memoryGiB": 32,
671 | "name": "ml.g4dn.2xlarge",
672 | "vcpuNum": 8
673 | },
674 | {
675 | "_defaultOrder": 31,
676 | "_isFastLaunch": false,
677 | "category": "Accelerated computing",
678 | "gpuNum": 1,
679 | "hideHardwareSpecs": false,
680 | "memoryGiB": 64,
681 | "name": "ml.g4dn.4xlarge",
682 | "vcpuNum": 16
683 | },
684 | {
685 | "_defaultOrder": 32,
686 | "_isFastLaunch": false,
687 | "category": "Accelerated computing",
688 | "gpuNum": 1,
689 | "hideHardwareSpecs": false,
690 | "memoryGiB": 128,
691 | "name": "ml.g4dn.8xlarge",
692 | "vcpuNum": 32
693 | },
694 | {
695 | "_defaultOrder": 33,
696 | "_isFastLaunch": false,
697 | "category": "Accelerated computing",
698 | "gpuNum": 4,
699 | "hideHardwareSpecs": false,
700 | "memoryGiB": 192,
701 | "name": "ml.g4dn.12xlarge",
702 | "vcpuNum": 48
703 | },
704 | {
705 | "_defaultOrder": 34,
706 | "_isFastLaunch": false,
707 | "category": "Accelerated computing",
708 | "gpuNum": 1,
709 | "hideHardwareSpecs": false,
710 | "memoryGiB": 256,
711 | "name": "ml.g4dn.16xlarge",
712 | "vcpuNum": 64
713 | },
714 | {
715 | "_defaultOrder": 35,
716 | "_isFastLaunch": false,
717 | "category": "Accelerated computing",
718 | "gpuNum": 1,
719 | "hideHardwareSpecs": false,
720 | "memoryGiB": 61,
721 | "name": "ml.p3.2xlarge",
722 | "vcpuNum": 8
723 | },
724 | {
725 | "_defaultOrder": 36,
726 | "_isFastLaunch": false,
727 | "category": "Accelerated computing",
728 | "gpuNum": 4,
729 | "hideHardwareSpecs": false,
730 | "memoryGiB": 244,
731 | "name": "ml.p3.8xlarge",
732 | "vcpuNum": 32
733 | },
734 | {
735 | "_defaultOrder": 37,
736 | "_isFastLaunch": false,
737 | "category": "Accelerated computing",
738 | "gpuNum": 8,
739 | "hideHardwareSpecs": false,
740 | "memoryGiB": 488,
741 | "name": "ml.p3.16xlarge",
742 | "vcpuNum": 64
743 | },
744 | {
745 | "_defaultOrder": 38,
746 | "_isFastLaunch": false,
747 | "category": "Accelerated computing",
748 | "gpuNum": 8,
749 | "hideHardwareSpecs": false,
750 | "memoryGiB": 768,
751 | "name": "ml.p3dn.24xlarge",
752 | "vcpuNum": 96
753 | },
754 | {
755 | "_defaultOrder": 39,
756 | "_isFastLaunch": false,
757 | "category": "Memory Optimized",
758 | "gpuNum": 0,
759 | "hideHardwareSpecs": false,
760 | "memoryGiB": 16,
761 | "name": "ml.r5.large",
762 | "vcpuNum": 2
763 | },
764 | {
765 | "_defaultOrder": 40,
766 | "_isFastLaunch": false,
767 | "category": "Memory Optimized",
768 | "gpuNum": 0,
769 | "hideHardwareSpecs": false,
770 | "memoryGiB": 32,
771 | "name": "ml.r5.xlarge",
772 | "vcpuNum": 4
773 | },
774 | {
775 | "_defaultOrder": 41,
776 | "_isFastLaunch": false,
777 | "category": "Memory Optimized",
778 | "gpuNum": 0,
779 | "hideHardwareSpecs": false,
780 | "memoryGiB": 64,
781 | "name": "ml.r5.2xlarge",
782 | "vcpuNum": 8
783 | },
784 | {
785 | "_defaultOrder": 42,
786 | "_isFastLaunch": false,
787 | "category": "Memory Optimized",
788 | "gpuNum": 0,
789 | "hideHardwareSpecs": false,
790 | "memoryGiB": 128,
791 | "name": "ml.r5.4xlarge",
792 | "vcpuNum": 16
793 | },
794 | {
795 | "_defaultOrder": 43,
796 | "_isFastLaunch": false,
797 | "category": "Memory Optimized",
798 | "gpuNum": 0,
799 | "hideHardwareSpecs": false,
800 | "memoryGiB": 256,
801 | "name": "ml.r5.8xlarge",
802 | "vcpuNum": 32
803 | },
804 | {
805 | "_defaultOrder": 44,
806 | "_isFastLaunch": false,
807 | "category": "Memory Optimized",
808 | "gpuNum": 0,
809 | "hideHardwareSpecs": false,
810 | "memoryGiB": 384,
811 | "name": "ml.r5.12xlarge",
812 | "vcpuNum": 48
813 | },
814 | {
815 | "_defaultOrder": 45,
816 | "_isFastLaunch": false,
817 | "category": "Memory Optimized",
818 | "gpuNum": 0,
819 | "hideHardwareSpecs": false,
820 | "memoryGiB": 512,
821 | "name": "ml.r5.16xlarge",
822 | "vcpuNum": 64
823 | },
824 | {
825 | "_defaultOrder": 46,
826 | "_isFastLaunch": false,
827 | "category": "Memory Optimized",
828 | "gpuNum": 0,
829 | "hideHardwareSpecs": false,
830 | "memoryGiB": 768,
831 | "name": "ml.r5.24xlarge",
832 | "vcpuNum": 96
833 | },
834 | {
835 | "_defaultOrder": 47,
836 | "_isFastLaunch": false,
837 | "category": "Accelerated computing",
838 | "gpuNum": 1,
839 | "hideHardwareSpecs": false,
840 | "memoryGiB": 16,
841 | "name": "ml.g5.xlarge",
842 | "vcpuNum": 4
843 | },
844 | {
845 | "_defaultOrder": 48,
846 | "_isFastLaunch": false,
847 | "category": "Accelerated computing",
848 | "gpuNum": 1,
849 | "hideHardwareSpecs": false,
850 | "memoryGiB": 32,
851 | "name": "ml.g5.2xlarge",
852 | "vcpuNum": 8
853 | },
854 | {
855 | "_defaultOrder": 49,
856 | "_isFastLaunch": false,
857 | "category": "Accelerated computing",
858 | "gpuNum": 1,
859 | "hideHardwareSpecs": false,
860 | "memoryGiB": 64,
861 | "name": "ml.g5.4xlarge",
862 | "vcpuNum": 16
863 | },
864 | {
865 | "_defaultOrder": 50,
866 | "_isFastLaunch": false,
867 | "category": "Accelerated computing",
868 | "gpuNum": 1,
869 | "hideHardwareSpecs": false,
870 | "memoryGiB": 128,
871 | "name": "ml.g5.8xlarge",
872 | "vcpuNum": 32
873 | },
874 | {
875 | "_defaultOrder": 51,
876 | "_isFastLaunch": false,
877 | "category": "Accelerated computing",
878 | "gpuNum": 1,
879 | "hideHardwareSpecs": false,
880 | "memoryGiB": 256,
881 | "name": "ml.g5.16xlarge",
882 | "vcpuNum": 64
883 | },
884 | {
885 | "_defaultOrder": 52,
886 | "_isFastLaunch": false,
887 | "category": "Accelerated computing",
888 | "gpuNum": 4,
889 | "hideHardwareSpecs": false,
890 | "memoryGiB": 192,
891 | "name": "ml.g5.12xlarge",
892 | "vcpuNum": 48
893 | },
894 | {
895 | "_defaultOrder": 53,
896 | "_isFastLaunch": false,
897 | "category": "Accelerated computing",
898 | "gpuNum": 4,
899 | "hideHardwareSpecs": false,
900 | "memoryGiB": 384,
901 | "name": "ml.g5.24xlarge",
902 | "vcpuNum": 96
903 | },
904 | {
905 | "_defaultOrder": 54,
906 | "_isFastLaunch": false,
907 | "category": "Accelerated computing",
908 | "gpuNum": 8,
909 | "hideHardwareSpecs": false,
910 | "memoryGiB": 768,
911 | "name": "ml.g5.48xlarge",
912 | "vcpuNum": 192
913 | },
914 | {
915 | "_defaultOrder": 55,
916 | "_isFastLaunch": false,
917 | "category": "Accelerated computing",
918 | "gpuNum": 8,
919 | "hideHardwareSpecs": false,
920 | "memoryGiB": 1152,
921 | "name": "ml.p4d.24xlarge",
922 | "vcpuNum": 96
923 | },
924 | {
925 | "_defaultOrder": 56,
926 | "_isFastLaunch": false,
927 | "category": "Accelerated computing",
928 | "gpuNum": 8,
929 | "hideHardwareSpecs": false,
930 | "memoryGiB": 1152,
931 | "name": "ml.p4de.24xlarge",
932 | "vcpuNum": 96
933 | }
934 | ],
935 | "instance_type": "ml.t3.medium",
936 | "kernelspec": {
937 | "display_name": "Python 3 (ipykernel)",
938 | "language": "python",
939 | "name": "python3"
940 | },
941 | "language_info": {
942 | "codemirror_mode": {
943 | "name": "ipython",
944 | "version": 3
945 | },
946 | "file_extension": ".py",
947 | "mimetype": "text/x-python",
948 | "name": "python",
949 | "nbconvert_exporter": "python",
950 | "pygments_lexer": "ipython3",
951 | "version": "3.10.13"
952 | }
953 | },
954 | "nbformat": 4,
955 | "nbformat_minor": 5
956 | }
957 |
--------------------------------------------------------------------------------
/data_preparation/preprocessing.py:
--------------------------------------------------------------------------------
1 | from moviepy.editor import *
2 | import whisper
3 | from transformers import pipeline
4 |
5 | import glob
6 | import torch
7 | import argparse
8 | import logging
9 | import os
10 |
11 | from split_paragraph import gen_parag
12 |
13 | def extract_transcript(file_path, pipe, save_dir, chunk_length_s=20, sentence_embedding_model="all-minilm-l6-v2", p_size=10, order=10, target_language='en'):
14 | # load audio and pad/trim it to fit 30 seconds
15 | audio = whisper.load_audio(file_path)
16 | audio = whisper.pad_or_trim(audio)
17 |
18 | file_name = file_path.split('/')[-1].replace('.mp3', '')
19 | # decode the audio
20 |
21 | generate_kwargs = {"task":"transcribe", "language":f"<|{target_language}|>"}
22 | prediction = pipe(
23 | file_path,
24 | return_timestamps=True,
25 | chunk_length_s=chunk_length_s,
26 | stride_length_s=(5),
27 | generate_kwargs=generate_kwargs
28 | )
29 |
30 | para_chunks, para_timestamp = gen_parag(
31 | prediction['chunks'],
32 | model_name=sentence_embedding_model,
33 | p_size=p_size,
34 | order=order
35 | )
36 |
37 | for chunk, timestamp in zip(para_chunks, para_timestamp):
38 | trans_path = f"{save_dir}/{file_name}_{timestamp[0]}_{timestamp[1]}.txt"
39 | with open(trans_path, 'w', encoding='utf-8') as f:
40 | f.write(chunk)
41 |
42 | if __name__ == "__main__":
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--whisper-model", type=str, default="whisper-large-v2")
45 | parser.add_argument("--clip-duration", type=int, default=120)
46 | parser.add_argument("--target-language", type=str, default="en")
47 | parser.add_argument("--sentence-embedding-model", type=str, default="all-mpnet-base-v2")
48 | parser.add_argument("--chunk-length", type=int, default=20)
49 | parser.add_argument("--p-size", type=int, default=10)
50 | parser.add_argument("--order", type=int, default=2)
51 | args, _ = parser.parse_known_args()
52 |
53 | input_dir = "/opt/ml/processing/input"
54 | transcript_dir = "/opt/ml/processing/transcripts"
55 |
56 | clip_duration = args.clip_duration #second
57 |
58 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
59 | pipe = pipeline(
60 | "automatic-speech-recognition",
61 | model=f"openai/{args.whisper_model}",
62 | device=device
63 | )
64 |
65 | mp4_list = glob.glob(input_dir + "/*.mp4")
66 | mp3_list = glob.glob(input_dir + "/*.mp3")
67 | file_list = mp4_list + mp3_list
68 | for file_path in file_list:
69 | print(file_path)
70 | if file_path.endswith('.mp4'):
71 | file_name = file_path.split('/')[-1].replace('.mp4', '')
72 |
73 | video = VideoFileClip(file_path)
74 | file_path = file_path.replace('.mp4', '.mp3')
75 | video.audio.write_audiofile(file_path)
76 | elif file_path.endswith('.mp3'):
77 | file_name = file_path.split('/')[-1].replace('.mp3', '')
78 |
79 | trans_dir = f"{transcript_dir}/{file_name}"
80 | if not os.path.exists(trans_dir):
81 | os.makedirs(trans_dir)
82 |
83 | extract_transcript(
84 | file_path,
85 | pipe,
86 | trans_dir,
87 | chunk_length_s=args.chunk_length,
88 | p_size=args.p_size,
89 | order=args.order,
90 | sentence_embedding_model=args.sentence_embedding_model,
91 | target_language=args.target_language
92 | )
93 |
--------------------------------------------------------------------------------
/data_preparation/requirements.txt:
--------------------------------------------------------------------------------
1 | tiktoken
2 | moviepy
3 | git+https://github.com/openai/whisper.git
4 | sagemaker
5 | pydub
6 | sentence-transformers
--------------------------------------------------------------------------------
/data_preparation/split_paragraph.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sentence_transformers import SentenceTransformer
3 | from sklearn.metrics.pairwise import cosine_similarity
4 | from scipy.signal import argrelextrema
5 | import math
6 |
7 | # Transcript is one line, so we select it and change question mark for dots so that we split it correctly.
8 | def split_sentence(text):
9 | text = text.replace("?", ".")
10 | sentences = text.split('. ')
11 | sentences[-1] = sentences[-1].replace('.', '')
12 | return sentences
13 |
14 | def unify_sentence(sentences):
15 | # Get the length of each sentence
16 | sentece_length = [len(each) for each in sentences]
17 | # Determine longest outlier
18 | long = np.mean(sentece_length) + np.std(sentece_length) *2
19 | # Determine shortest outlier
20 | short = np.mean(sentece_length) - np.std(sentece_length) *2
21 | # Shorten long sentences
22 | text = ''
23 | prev_each = ''
24 |
25 | for i, each in enumerate(sentences):
26 | if each == prev_each or len(each.strip()) == 0:
27 | continue
28 | if len(each) > long:
29 | # let's replace all the commas with dots
30 | comma_splitted = each.replace(',', '.')
31 | text+= f'{comma_splitted}. '
32 | else:
33 | text+= f'{each}. '
34 |
35 | prev_each = each
36 |
37 | sentences = text.split('. ')
38 | sentences[-1] = sentences[-1].replace('.', '')
39 | # Now let's concatenate short ones
40 | text = ''
41 | for each in sentences:
42 | if len(each) == 0:
43 | continue
44 | if len(each) < short:
45 | text+= f'{each} '
46 | else:
47 | text+= f'{each}. '
48 |
49 | return text
50 |
51 | def rev_sigmoid(x:float)->float:
52 | return (1 / (1 + math.exp(0.5*x)))
53 |
54 | def activate_similarities(similarities:np.array, p_size=10, order=5)->np.array:
55 | """ Function returns list of weighted sums of activated sentence similarities
56 | Args:
57 | similarities (numpy array): it should square matrix where each sentence corresponds to another with cosine similarity
58 | p_size (int): number of sentences are used to calculate weighted sum
59 | Returns:
60 | list: list of weighted sums
61 | """
62 | if similarities.shape[0] < p_size:
63 | p_size = similarities.shape[0]
64 | x = np.linspace(-10,10,p_size)
65 | # Then we need to apply activation function to the created space
66 | y = np.vectorize(rev_sigmoid)
67 | # Because we only apply activation to p_size number of sentences we have to add zeros to neglect the effect of every additional sentence and to match the length ofvector we will multiply
68 | activation_weights = np.pad(y(x),(0,similarities.shape[0]-p_size))
69 | ### 1. Take each diagonal to the right of the main diagonal
70 | diagonals = [similarities.diagonal(each) for each in range(0,similarities.shape[0])]
71 | ### 2. Pad each diagonal by zeros at the end. Because each diagonal is different length we should pad it with zeros at the end
72 | diagonals = [np.pad(each, (0,similarities.shape[0]-len(each))) for each in diagonals]
73 | ### 3. Stack those diagonals into new matrix
74 | diagonals = np.stack(diagonals)
75 | ### 4. Apply activation weights to each row. Multiply similarities with our activation.
76 | diagonals = diagonals * activation_weights.reshape(-1,1)
77 | ### 5. Calculate the weighted sum of activated similarities
78 | activated_similarities = np.sum(diagonals, axis=0)
79 | ### 6. Find relative minima of our vector. For all local minimas and save them to variable with argrelextrema function
80 | minmimas = argrelextrema(activated_similarities, np.less, order=order) #order parameter controls how frequent should be splits. I would not reccomend changing this parameter.
81 |
82 | return minmimas
83 |
84 | def correct_chunks(chunks):
85 | prev_chunk = None
86 | new_chunks = []
87 | for chunk in chunks:
88 | if prev_chunk:
89 | chunk['text'] = prev_chunk['text'] + chunk['text']
90 | chunk['timestamp'] = (prev_chunk['timestamp'][0], chunk['timestamp'][1])
91 |
92 | if not chunk['text'].endswith('.'):
93 | prev_chunk = chunk
94 | else:
95 | new_chunks.append(chunk)
96 | prev_chunk = None
97 | return new_chunks
98 |
99 | def gen_parag(input_chunks, model_name='all-minilm-l6-v2', p_size=10, order=5):
100 | sentences_all = []
101 | timestamps_all = []
102 |
103 | corrected_chunks = correct_chunks(input_chunks)
104 |
105 | for chunk in corrected_chunks:
106 | sentences = split_sentence(chunk['text'])
107 | text = unify_sentence(sentences)
108 | text = text.strip()
109 | sentences = text.split('. ')
110 | sentences[-1] = sentences[-1].replace('.', '')
111 | timestamps = [chunk['timestamp']]*len(sentences)
112 |
113 | sentences_all += sentences
114 | timestamps_all += timestamps
115 |
116 | # Embed sentences
117 | model = SentenceTransformer(model_name)
118 | embeddings = model.encode(sentences_all)
119 | # Create similarities matrix
120 | similarities = cosine_similarity(embeddings)
121 |
122 | # Let's apply our function. For long sentences i reccomend to use 10 or more sentences
123 | minmimas = activate_similarities(similarities, p_size=p_size, order=order)
124 |
125 | # Create empty string
126 | split_points = [each for each in minmimas[0]]
127 | text = ''
128 |
129 | para_chunks = []
130 | para_timestamp = []
131 | start_timestamp = 0
132 |
133 | for num, each in enumerate(sentences_all):
134 | current_timestamp = timestamps_all[num]
135 |
136 | if text == '' and (start_timestamp == current_timestamp[1]):
137 | start_timestamp = current_timestamp[0]
138 |
139 | if num in split_points:
140 | #text+=f'{each}. '
141 | para_chunks.append(text)
142 | para_timestamp.append([start_timestamp, current_timestamp[1]])
143 | text = f'{each}. '
144 | start_timestamp = current_timestamp[1]
145 | else:
146 | text+=f'{each}. '
147 |
148 | if len(text):
149 | para_chunks.append(text)
150 | para_timestamp.append([start_timestamp, timestamps_all[-1][1]])
151 |
152 | return para_chunks, para_timestamp
153 |
--------------------------------------------------------------------------------
/images/architecture_UI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-llm-on-video-audio/dbb56783227d34df8708a18699d15300b7970231/images/architecture_UI.png
--------------------------------------------------------------------------------
/images/offline_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-llm-on-video-audio/dbb56783227d34df8708a18699d15300b7970231/images/offline_architecture.png
--------------------------------------------------------------------------------
/test_raw_data/demo-video-sagemaker-doc.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-llm-on-video-audio/dbb56783227d34df8708a18699d15300b7970231/test_raw_data/demo-video-sagemaker-doc.mp4
--------------------------------------------------------------------------------
/test_raw_data/test.webm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-llm-on-video-audio/dbb56783227d34df8708a18699d15300b7970231/test_raw_data/test.webm
--------------------------------------------------------------------------------
/video_question_answering_langchain.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Retrieval-Augmented Generation: Question Answering based on Custom Video/Audio Dataset with Open-sourced [LangChain](https://python.langchain.com/en/latest/index.html) Library\n"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "We use our data transcriped from video/audio files to build a RAG solution with LangChain by following the blog [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/) and modifying the [source code](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/question_answering_retrieval_augmented_generation/question_answering_langchain_jumpstart.ipynb)."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "## Step 1. Deploy large language model (LLM) and embedding model in SageMaker JumpStart\n",
22 | "\n",
23 | "To better illustrate the idea, let's first deploy all the models that are required to perform the demo. You can choose either deploying all three Flan T5 XL, BloomZ 7B1, and Flan UL2 models as the large language model (LLM) to compare their model performances, or select **subset** of the models based on your preference. To do that, you need modify the `_MODEL_CONFIG_` python dictionary defined as below."
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {
30 | "collapsed": false,
31 | "jupyter": {
32 | "outputs_hidden": false
33 | },
34 | "pycharm": {
35 | "name": "#%%\n"
36 | },
37 | "tags": []
38 | },
39 | "outputs": [],
40 | "source": [
41 | "!pip install --upgrade sagemaker --quiet\n",
42 | "!pip install ipywidgets==7.0.0 --quiet\n",
43 | "!pip install langchain==0.0.148 --quiet\n",
44 | "!pip install faiss-cpu --quiet\n",
45 | "!pip install unstructured --quiet"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {
52 | "tags": []
53 | },
54 | "outputs": [],
55 | "source": [
56 | "import time\n",
57 | "import sagemaker, boto3, json\n",
58 | "from sagemaker.session import Session\n",
59 | "from sagemaker.model import Model\n",
60 | "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
61 | "from sagemaker.predictor import Predictor\n",
62 | "from sagemaker.utils import name_from_base\n",
63 | "from typing import Any, Dict, List, Optional\n",
64 | "from langchain.embeddings import SagemakerEndpointEmbeddings\n",
65 | "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
66 | "\n",
67 | "sagemaker_session = Session()\n",
68 | "role = sagemaker_session.get_caller_identity_arn()\n",
69 | "aws_region = boto3.Session().region_name\n",
70 | "sess = sagemaker.Session()\n",
71 | "model_version = \"*\""
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {
78 | "tags": []
79 | },
80 | "outputs": [],
81 | "source": [
82 | "def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n",
83 | " client = boto3.client(\"runtime.sagemaker\")\n",
84 | " response = client.invoke_endpoint(\n",
85 | " EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n",
86 | " )\n",
87 | " return response\n",
88 | "\n",
89 | "\n",
90 | "def parse_response_model_falcon(query_response):\n",
91 | " model_predictions = json.loads(query_response[\"Body\"].read())\n",
92 | " generated_text = model_predictions[0]['generated_text']\n",
93 | " return generated_text\n"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "Deploy SageMaker endpoint(s) for large language models and GPT-J 6B embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance."
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {
107 | "tags": []
108 | },
109 | "outputs": [],
110 | "source": [
111 | "hf_model_id = \"tiiuae/falcon-40b-instruct\" # model id from huggingface.co/models\n",
112 | "number_of_gpu = 4 # number of gpus to use for inference and tensor parallelism\n",
113 | "health_check_timeout = 1200 # Increase the timeout for the health check to 5 minutes for downloading the model\n",
114 | "instance_type = \"ml.g5.12xlarge\" # instance type to use for deployment\n",
115 | "_MODEL_CONFIG_ = {\n",
116 | "\n",
117 | " \"huggingface-textembedding-gpt-j-6b\": {\n",
118 | " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
119 | " },\n",
120 | " \"huggingface-falcon-40b\": {\n",
121 | " \"hf_model_id\":hf_model_id,\n",
122 | " \"env\": {\n",
123 | " 'HF_MODEL_ID': hf_model_id,\n",
124 | " 'HF_MODEL_QUANTIZE': \"bitsandbytes\", # comment in to quantize\n",
125 | " 'SM_NUM_GPUS': json.dumps(number_of_gpu),\n",
126 | " 'MAX_INPUT_LENGTH': json.dumps(1200), # Max length of input text\n",
127 | " 'MAX_TOTAL_TOKENS': json.dumps(2048), \n",
128 | " },\n",
129 | " \"parse_function\": parse_response_model_falcon,\n",
130 | " \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
131 | " },\n",
132 | "}"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {
139 | "tags": []
140 | },
141 | "outputs": [],
142 | "source": [
143 | "from sagemaker.huggingface import get_huggingface_llm_image_uri\n",
144 | "from sagemaker.huggingface import HuggingFaceModel\n",
145 | "# retrieve the llm image uri\n",
146 | "llm_image = get_huggingface_llm_image_uri(\n",
147 | " \"huggingface\",\n",
148 | " version=\"0.8.2\"\n",
149 | ")\n",
150 | "\n",
151 | "# print ecr image uri\n",
152 | "print(f\"llm image uri: {llm_image}\")"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "metadata": {},
158 | "source": [
159 | "### Deploy the LLM"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {
166 | "tags": []
167 | },
168 | "outputs": [],
169 | "source": [
170 | "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n",
171 | "# deploy falcon 40b model from hugging face\n",
172 | "model_id = \"huggingface-falcon-40b\"\n",
173 | "_MODEL_CONFIG_[model_id][\"endpoint_name\"] = name_from_base(f\"{model_id}\")\n",
174 | "llm_model = HuggingFaceModel(\n",
175 | " role=role,\n",
176 | " image_uri=llm_image,\n",
177 | " env={\n",
178 | " 'HF_MODEL_ID': _MODEL_CONFIG_[model_id]['hf_model_id'],\n",
179 | " # 'HF_MODEL_QUANTIZE': \"bitsandbytes\", # comment in to quantize\n",
180 | " 'SM_NUM_GPUS': json.dumps(number_of_gpu),\n",
181 | " 'MAX_INPUT_LENGTH': json.dumps(1200), # Max length of input text\n",
182 | " 'MAX_TOTAL_TOKENS': json.dumps(2048), # Max length of the generation (including input text)\n",
183 | " }\n",
184 | " )\n",
185 | "llm = llm_model.deploy(\n",
186 | " initial_instance_count=1,\n",
187 | " instance_type=instance_type,\n",
188 | " container_startup_health_check_timeout=health_check_timeout,\n",
189 | " endpoint_name=_MODEL_CONFIG_[model_id][\"endpoint_name\"],\n",
190 | ")\n",
191 | "print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "### Deploy the embedding model"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {
205 | "tags": []
206 | },
207 | "outputs": [],
208 | "source": [
209 | "model_id = \"huggingface-textembedding-gpt-j-6b\"\n",
210 | "# Retrieve the model uri.\n",
211 | "model_uri = model_uris.retrieve(\n",
212 | " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n",
213 | ")\n",
214 | "_MODEL_CONFIG_[model_id][\"endpoint_name\"] = name_from_base(f\"{model_id}\")\n",
215 | "\n",
216 | "# Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n",
217 | "deploy_image_uri = image_uris.retrieve(\n",
218 | " region=None,\n",
219 | " framework=None, # automatically inferred from model_id\n",
220 | " image_scope=\"inference\",\n",
221 | " model_id=model_id,\n",
222 | " model_version=model_version,\n",
223 | " instance_type=instance_type,\n",
224 | ")\n",
225 | "model_inference = Model(\n",
226 | " image_uri=deploy_image_uri,\n",
227 | " model_data=model_uri,\n",
228 | " role=role,\n",
229 | " predictor_cls=Predictor,\n",
230 | " name=_MODEL_CONFIG_[model_id][\"endpoint_name\"],\n",
231 | " env=_MODEL_CONFIG_[model_id][\"env\"],\n",
232 | ")\n",
233 | "model_predictor_inference = model_inference.deploy(\n",
234 | " initial_instance_count=1,\n",
235 | " instance_type=instance_type,\n",
236 | " predictor_cls=Predictor,\n",
237 | " endpoint_name=_MODEL_CONFIG_[model_id][\"endpoint_name\"],\n",
238 | ")\n",
239 | "print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "## Step 2. Use RAG based approach with [LangChain](https://python.langchain.com/en/latest/index.html) and SageMaker endpoints to build a simplified question and answering application.\n",
247 | "\n",
248 | "\n",
249 | "We plan to use document embeddings to fetch the most relevant documents in our document knowledge library and combine them with the prompt that we provide to LLM.\n",
250 | "\n",
251 | "To achieve that, we will do following.\n",
252 | "\n",
253 | "1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**\n",
254 | "2. **Identify top K most relevant documents based on user query.**\n",
255 | " - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**\n",
256 | " - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**\n",
257 | " - 2.3 **Use the indexes to retrieve the corresponded documents.**\n",
258 | "3. **Combine the retrieved documents with prompt and question and send them into SageMaker LLM.**\n",
259 | "\n",
260 | "\n",
261 | "\n",
262 | "Note: The retrieved document/text should be large enough to contain enough information to answer a question; but small enough to fit into the LLM prompt -- maximum sequence length of 1024 tokens. \n",
263 | "\n",
264 | "---\n",
265 | "To build a simiplied QA application with LangChain, we need: \n",
266 | "1. Wrap up our SageMaker endpoints for embedding model and LLM into `langchain.embeddings.SagemakerEndpointEmbeddings` and `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. That requires a small overwritten of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding mdoel.\n",
267 | "2. Prepare the dataset to build the knowledge data base. \n",
268 | "\n",
269 | "---"
270 | ]
271 | },
272 | {
273 | "cell_type": "markdown",
274 | "metadata": {},
275 | "source": [
276 | "Wrap up our SageMaker endpoints for embedding model into `langchain.embeddings.SagemakerEndpointEmbeddings`. That requires a small overwritten of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding mdoel."
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {
283 | "tags": []
284 | },
285 | "outputs": [],
286 | "source": [
287 | "from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
288 | "\n",
289 | "\n",
290 | "class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):\n",
291 | " def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:\n",
292 | " \"\"\"Compute doc embeddings using a SageMaker Inference Endpoint.\n",
293 | "\n",
294 | " Args:\n",
295 | " texts: The list of texts to embed.\n",
296 | " chunk_size: The chunk size defines how many input texts will\n",
297 | " be grouped together as request. If None, will use the\n",
298 | " chunk size specified by the class.\n",
299 | "\n",
300 | " Returns:\n",
301 | " List of embeddings, one for each text.\n",
302 | " \"\"\"\n",
303 | " results = []\n",
304 | " _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size\n",
305 | "\n",
306 | " for i in range(0, len(texts), _chunk_size):\n",
307 | " response = self._embedding_func(texts[i : i + _chunk_size])\n",
308 | " print\n",
309 | " results.extend(response)\n",
310 | " return results\n",
311 | "\n",
312 | "\n",
313 | "class ContentHandler(EmbeddingsContentHandler):\n",
314 | " content_type = \"application/json\"\n",
315 | " accepts = \"application/json\"\n",
316 | "\n",
317 | " def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
318 | " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n",
319 | " return input_str.encode(\"utf-8\")\n",
320 | "\n",
321 | " def transform_output(self, output: bytes) -> str:\n",
322 | " response_json = json.loads(output.read().decode(\"utf-8\"))\n",
323 | " embeddings = response_json[\"embedding\"]\n",
324 | " return embeddings\n",
325 | "\n",
326 | "\n",
327 | "content_handler = ContentHandler()\n",
328 | "\n",
329 | "embeddings = SagemakerEndpointEmbeddingsJumpStart(\n",
330 | " endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n",
331 | " region_name=aws_region,\n",
332 | " content_handler=content_handler,\n",
333 | ")"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "metadata": {},
339 | "source": [
340 | "Next, we wrap up our SageMaker endpoints for LLM into `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. "
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "metadata": {
347 | "tags": []
348 | },
349 | "outputs": [],
350 | "source": [
351 | "from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n",
352 | "\n",
353 | "parameters = {\n",
354 | " \"max_new_tokens\": 1000,\n",
355 | " \"return_full_text\": False\n",
356 | "}\n",
357 | "\n",
358 | "\n",
359 | "class ContentHandler(LLMContentHandler):\n",
360 | " content_type = \"application/json\"\n",
361 | " accepts = \"application/json\"\n",
362 | "\n",
363 | " def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
364 | " self.len_prompt = len(prompt)\n",
365 | " input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
366 | " return input_str.encode(\"utf-8\")\n",
367 | "\n",
368 | " def transform_output(self, output: bytes) -> str:\n",
369 | " response_json = output.read()\n",
370 | " res = json.loads(response_json)\n",
371 | " print(res)\n",
372 | " ans = res[0]['generated_text'][self.len_prompt:]\n",
373 | " return ans \n",
374 | "\n",
375 | "\n",
376 | "content_handler = ContentHandler()\n",
377 | "\n",
378 | "sm_llm = SagemakerEndpoint(\n",
379 | " endpoint_name=_MODEL_CONFIG_[\"huggingface-falcon-40b\"][\"endpoint_name\"],\n",
380 | " region_name=aws_region,\n",
381 | " model_kwargs=parameters,\n",
382 | " content_handler=content_handler,\n",
383 | ")"
384 | ]
385 | },
386 | {
387 | "cell_type": "markdown",
388 | "metadata": {},
389 | "source": [
390 | "Now, let's download the example data and prepare it for demonstration. We will use [Amazon SageMaker FAQs](https://aws.amazon.com/sagemaker/faqs/) as knowledge library. The data are formatted in a CSV file with two columns Question and Answer. We use the Answer column as the documents of knowledge library, from which relevant documents are retrieved based on a query. \n",
391 | "\n",
392 | "**For your purpose, you can replace the example dataset of your own to build a custom question and answering application.**"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": null,
398 | "metadata": {
399 | "tags": []
400 | },
401 | "outputs": [],
402 | "source": [
403 | "from langchain.chains import RetrievalQA\n",
404 | "from langchain.llms import OpenAI\n",
405 | "from langchain.document_loaders import TextLoader\n",
406 | "from langchain.indexes import VectorstoreIndexCreator\n",
407 | "from langchain.vectorstores import Chroma, AtlasDB, FAISS\n",
408 | "from langchain.text_splitter import CharacterTextSplitter\n",
409 | "from langchain import PromptTemplate\n",
410 | "from langchain.chains.question_answering import load_qa_chain\n",
411 | "from langchain.document_loaders.csv_loader import CSVLoader\n",
412 | "from langchain.document_loaders import DirectoryLoader"
413 | ]
414 | },
415 | {
416 | "cell_type": "markdown",
417 | "metadata": {},
418 | "source": [
419 | "Use langchain to read the `csv` data. There are multiple built-in functions in LangChain to read different format of files such as `txt`, `html`, and `pdf`. For details, see [LangChain document loaders](https://python.langchain.com/en/latest/modules/indexes/document_loaders.html)."
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": null,
425 | "metadata": {
426 | "scrolled": true,
427 | "tags": []
428 | },
429 | "outputs": [],
430 | "source": [
431 | "%store -r s3_output_transcript\n",
432 | "!aws s3 cp --recursive {s3_output_transcript} ./data"
433 | ]
434 | },
435 | {
436 | "cell_type": "code",
437 | "execution_count": null,
438 | "metadata": {
439 | "tags": []
440 | },
441 | "outputs": [],
442 | "source": [
443 | "loader = DirectoryLoader(\"./data/genai_interview\", glob=\"**/*.txt\")"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": null,
449 | "metadata": {
450 | "tags": []
451 | },
452 | "outputs": [],
453 | "source": [
454 | "documents = loader.load()\n",
455 | "# text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)\n",
456 | "# texts = text_splitter.split_documents(documents) ### if you use langchain.document_loaders.TextLoader to load text file. You can uncomment the code\n",
457 | "## to split the text."
458 | ]
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {},
463 | "source": [
464 | "**Now, we can build an QA application. LangChain makes it extremly simple with following few lines of code.**"
465 | ]
466 | },
467 | {
468 | "cell_type": "markdown",
469 | "metadata": {},
470 | "source": [
471 | "Based on the question below, we can achieven the points in Step 4 with just a few lines of code as shown below."
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "execution_count": null,
477 | "metadata": {},
478 | "outputs": [],
479 | "source": [
480 | "index_creator = VectorstoreIndexCreator(\n",
481 | " vectorstore_cls=FAISS,\n",
482 | " embedding=embeddings,\n",
483 | " text_splitter=CharacterTextSplitter(chunk_size=600, chunk_overlap=0),\n",
484 | ")"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": null,
490 | "metadata": {
491 | "scrolled": true,
492 | "tags": []
493 | },
494 | "outputs": [],
495 | "source": [
496 | "index = index_creator.from_loaders([loader])"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": null,
502 | "metadata": {
503 | "tags": []
504 | },
505 | "outputs": [],
506 | "source": [
507 | "question = \"how does bedrock help customers in generative AI?\""
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": null,
513 | "metadata": {},
514 | "outputs": [],
515 | "source": [
516 | "index.query(question=question, llm=sm_llm)"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "metadata": {},
522 | "source": [
523 | "## Step 3. Customize the QA application above with different prompt.\n",
524 | "\n",
525 | "Now, we see how simple it is to use LangChain to achieve question and answering application with just few lines of code. Let's break down the above `VectorstoreIndexCreator` and see what's happening under the hood. Furthermore, we will see how to incorporate a customize prompt rather than using a default prompt with `VectorstoreIndexCreator`."
526 | ]
527 | },
528 | {
529 | "cell_type": "markdown",
530 | "metadata": {},
531 | "source": [
532 | "Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": null,
538 | "metadata": {},
539 | "outputs": [],
540 | "source": [
541 | "docsearch = FAISS.from_documents(documents, embeddings)"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": null,
547 | "metadata": {},
548 | "outputs": [],
549 | "source": [
550 | "question"
551 | ]
552 | },
553 | {
554 | "cell_type": "markdown",
555 | "metadata": {},
556 | "source": [
557 | "Based on the question above, we then **identify top K most relevant documents based on user query, where K = 3 in this setup**."
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": null,
563 | "metadata": {},
564 | "outputs": [],
565 | "source": [
566 | "docs = docsearch.similarity_search_with_score(question)"
567 | ]
568 | },
569 | {
570 | "cell_type": "markdown",
571 | "metadata": {},
572 | "source": [
573 | "Print out the top 3 most relevant docuemnts as below."
574 | ]
575 | },
576 | {
577 | "cell_type": "code",
578 | "execution_count": null,
579 | "metadata": {},
580 | "outputs": [],
581 | "source": [
582 | "docs"
583 | ]
584 | },
585 | {
586 | "cell_type": "code",
587 | "execution_count": null,
588 | "metadata": {
589 | "tags": []
590 | },
591 | "outputs": [],
592 | "source": [
593 | "source = []\n",
594 | "context = []\n",
595 | "for doc, score in docs[:1]:\n",
596 | " context.append(doc)\n",
597 | " source.append(doc.metadata['source'].split('/')[-1])"
598 | ]
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "metadata": {},
603 | "source": [
604 | "Finally, we **combine the retrieved documents with prompt and question and send them into SageMaker LLM.** \n",
605 | "\n",
606 | "We define a customized prompt as below."
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "metadata": {},
613 | "outputs": [],
614 | "source": [
615 | "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.:\\n\\n{context}\\n\\nQuestion: {question}\\nHelpful Answer:\"\"\"\n",
616 | "\n",
617 | "PROMPT = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])"
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "execution_count": null,
623 | "metadata": {},
624 | "outputs": [],
625 | "source": [
626 | "chain = load_qa_chain(llm=sm_llm, prompt=PROMPT)"
627 | ]
628 | },
629 | {
630 | "cell_type": "markdown",
631 | "metadata": {},
632 | "source": [
633 | "Send the top 3 most relevant docuemnts and question into LLM to get a answer."
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": [
642 | "result = chain({\"input_documents\": context, \"question\": question}, return_only_outputs=True)[\n",
643 | " \"output_text\"\n",
644 | "]"
645 | ]
646 | },
647 | {
648 | "cell_type": "markdown",
649 | "metadata": {},
650 | "source": [
651 | "Print the final answer from LLM as below, which is accurate."
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": null,
657 | "metadata": {},
658 | "outputs": [],
659 | "source": [
660 | "result"
661 | ]
662 | },
663 | {
664 | "cell_type": "markdown",
665 | "metadata": {},
666 | "source": [
667 | "## Clean up"
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "execution_count": null,
673 | "metadata": {
674 | "tags": []
675 | },
676 | "outputs": [],
677 | "source": [
678 | "llm.delete_endpoint()\n",
679 | "model_predictor_inference.delete_endpoint()"
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": null,
685 | "metadata": {},
686 | "outputs": [],
687 | "source": []
688 | }
689 | ],
690 | "metadata": {
691 | "availableInstances": [
692 | {
693 | "_defaultOrder": 0,
694 | "_isFastLaunch": true,
695 | "category": "General purpose",
696 | "gpuNum": 0,
697 | "hideHardwareSpecs": false,
698 | "memoryGiB": 4,
699 | "name": "ml.t3.medium",
700 | "vcpuNum": 2
701 | },
702 | {
703 | "_defaultOrder": 1,
704 | "_isFastLaunch": false,
705 | "category": "General purpose",
706 | "gpuNum": 0,
707 | "hideHardwareSpecs": false,
708 | "memoryGiB": 8,
709 | "name": "ml.t3.large",
710 | "vcpuNum": 2
711 | },
712 | {
713 | "_defaultOrder": 2,
714 | "_isFastLaunch": false,
715 | "category": "General purpose",
716 | "gpuNum": 0,
717 | "hideHardwareSpecs": false,
718 | "memoryGiB": 16,
719 | "name": "ml.t3.xlarge",
720 | "vcpuNum": 4
721 | },
722 | {
723 | "_defaultOrder": 3,
724 | "_isFastLaunch": false,
725 | "category": "General purpose",
726 | "gpuNum": 0,
727 | "hideHardwareSpecs": false,
728 | "memoryGiB": 32,
729 | "name": "ml.t3.2xlarge",
730 | "vcpuNum": 8
731 | },
732 | {
733 | "_defaultOrder": 4,
734 | "_isFastLaunch": true,
735 | "category": "General purpose",
736 | "gpuNum": 0,
737 | "hideHardwareSpecs": false,
738 | "memoryGiB": 8,
739 | "name": "ml.m5.large",
740 | "vcpuNum": 2
741 | },
742 | {
743 | "_defaultOrder": 5,
744 | "_isFastLaunch": false,
745 | "category": "General purpose",
746 | "gpuNum": 0,
747 | "hideHardwareSpecs": false,
748 | "memoryGiB": 16,
749 | "name": "ml.m5.xlarge",
750 | "vcpuNum": 4
751 | },
752 | {
753 | "_defaultOrder": 6,
754 | "_isFastLaunch": false,
755 | "category": "General purpose",
756 | "gpuNum": 0,
757 | "hideHardwareSpecs": false,
758 | "memoryGiB": 32,
759 | "name": "ml.m5.2xlarge",
760 | "vcpuNum": 8
761 | },
762 | {
763 | "_defaultOrder": 7,
764 | "_isFastLaunch": false,
765 | "category": "General purpose",
766 | "gpuNum": 0,
767 | "hideHardwareSpecs": false,
768 | "memoryGiB": 64,
769 | "name": "ml.m5.4xlarge",
770 | "vcpuNum": 16
771 | },
772 | {
773 | "_defaultOrder": 8,
774 | "_isFastLaunch": false,
775 | "category": "General purpose",
776 | "gpuNum": 0,
777 | "hideHardwareSpecs": false,
778 | "memoryGiB": 128,
779 | "name": "ml.m5.8xlarge",
780 | "vcpuNum": 32
781 | },
782 | {
783 | "_defaultOrder": 9,
784 | "_isFastLaunch": false,
785 | "category": "General purpose",
786 | "gpuNum": 0,
787 | "hideHardwareSpecs": false,
788 | "memoryGiB": 192,
789 | "name": "ml.m5.12xlarge",
790 | "vcpuNum": 48
791 | },
792 | {
793 | "_defaultOrder": 10,
794 | "_isFastLaunch": false,
795 | "category": "General purpose",
796 | "gpuNum": 0,
797 | "hideHardwareSpecs": false,
798 | "memoryGiB": 256,
799 | "name": "ml.m5.16xlarge",
800 | "vcpuNum": 64
801 | },
802 | {
803 | "_defaultOrder": 11,
804 | "_isFastLaunch": false,
805 | "category": "General purpose",
806 | "gpuNum": 0,
807 | "hideHardwareSpecs": false,
808 | "memoryGiB": 384,
809 | "name": "ml.m5.24xlarge",
810 | "vcpuNum": 96
811 | },
812 | {
813 | "_defaultOrder": 12,
814 | "_isFastLaunch": false,
815 | "category": "General purpose",
816 | "gpuNum": 0,
817 | "hideHardwareSpecs": false,
818 | "memoryGiB": 8,
819 | "name": "ml.m5d.large",
820 | "vcpuNum": 2
821 | },
822 | {
823 | "_defaultOrder": 13,
824 | "_isFastLaunch": false,
825 | "category": "General purpose",
826 | "gpuNum": 0,
827 | "hideHardwareSpecs": false,
828 | "memoryGiB": 16,
829 | "name": "ml.m5d.xlarge",
830 | "vcpuNum": 4
831 | },
832 | {
833 | "_defaultOrder": 14,
834 | "_isFastLaunch": false,
835 | "category": "General purpose",
836 | "gpuNum": 0,
837 | "hideHardwareSpecs": false,
838 | "memoryGiB": 32,
839 | "name": "ml.m5d.2xlarge",
840 | "vcpuNum": 8
841 | },
842 | {
843 | "_defaultOrder": 15,
844 | "_isFastLaunch": false,
845 | "category": "General purpose",
846 | "gpuNum": 0,
847 | "hideHardwareSpecs": false,
848 | "memoryGiB": 64,
849 | "name": "ml.m5d.4xlarge",
850 | "vcpuNum": 16
851 | },
852 | {
853 | "_defaultOrder": 16,
854 | "_isFastLaunch": false,
855 | "category": "General purpose",
856 | "gpuNum": 0,
857 | "hideHardwareSpecs": false,
858 | "memoryGiB": 128,
859 | "name": "ml.m5d.8xlarge",
860 | "vcpuNum": 32
861 | },
862 | {
863 | "_defaultOrder": 17,
864 | "_isFastLaunch": false,
865 | "category": "General purpose",
866 | "gpuNum": 0,
867 | "hideHardwareSpecs": false,
868 | "memoryGiB": 192,
869 | "name": "ml.m5d.12xlarge",
870 | "vcpuNum": 48
871 | },
872 | {
873 | "_defaultOrder": 18,
874 | "_isFastLaunch": false,
875 | "category": "General purpose",
876 | "gpuNum": 0,
877 | "hideHardwareSpecs": false,
878 | "memoryGiB": 256,
879 | "name": "ml.m5d.16xlarge",
880 | "vcpuNum": 64
881 | },
882 | {
883 | "_defaultOrder": 19,
884 | "_isFastLaunch": false,
885 | "category": "General purpose",
886 | "gpuNum": 0,
887 | "hideHardwareSpecs": false,
888 | "memoryGiB": 384,
889 | "name": "ml.m5d.24xlarge",
890 | "vcpuNum": 96
891 | },
892 | {
893 | "_defaultOrder": 20,
894 | "_isFastLaunch": false,
895 | "category": "General purpose",
896 | "gpuNum": 0,
897 | "hideHardwareSpecs": true,
898 | "memoryGiB": 0,
899 | "name": "ml.geospatial.interactive",
900 | "supportedImageNames": [
901 | "sagemaker-geospatial-v1-0"
902 | ],
903 | "vcpuNum": 0
904 | },
905 | {
906 | "_defaultOrder": 21,
907 | "_isFastLaunch": true,
908 | "category": "Compute optimized",
909 | "gpuNum": 0,
910 | "hideHardwareSpecs": false,
911 | "memoryGiB": 4,
912 | "name": "ml.c5.large",
913 | "vcpuNum": 2
914 | },
915 | {
916 | "_defaultOrder": 22,
917 | "_isFastLaunch": false,
918 | "category": "Compute optimized",
919 | "gpuNum": 0,
920 | "hideHardwareSpecs": false,
921 | "memoryGiB": 8,
922 | "name": "ml.c5.xlarge",
923 | "vcpuNum": 4
924 | },
925 | {
926 | "_defaultOrder": 23,
927 | "_isFastLaunch": false,
928 | "category": "Compute optimized",
929 | "gpuNum": 0,
930 | "hideHardwareSpecs": false,
931 | "memoryGiB": 16,
932 | "name": "ml.c5.2xlarge",
933 | "vcpuNum": 8
934 | },
935 | {
936 | "_defaultOrder": 24,
937 | "_isFastLaunch": false,
938 | "category": "Compute optimized",
939 | "gpuNum": 0,
940 | "hideHardwareSpecs": false,
941 | "memoryGiB": 32,
942 | "name": "ml.c5.4xlarge",
943 | "vcpuNum": 16
944 | },
945 | {
946 | "_defaultOrder": 25,
947 | "_isFastLaunch": false,
948 | "category": "Compute optimized",
949 | "gpuNum": 0,
950 | "hideHardwareSpecs": false,
951 | "memoryGiB": 72,
952 | "name": "ml.c5.9xlarge",
953 | "vcpuNum": 36
954 | },
955 | {
956 | "_defaultOrder": 26,
957 | "_isFastLaunch": false,
958 | "category": "Compute optimized",
959 | "gpuNum": 0,
960 | "hideHardwareSpecs": false,
961 | "memoryGiB": 96,
962 | "name": "ml.c5.12xlarge",
963 | "vcpuNum": 48
964 | },
965 | {
966 | "_defaultOrder": 27,
967 | "_isFastLaunch": false,
968 | "category": "Compute optimized",
969 | "gpuNum": 0,
970 | "hideHardwareSpecs": false,
971 | "memoryGiB": 144,
972 | "name": "ml.c5.18xlarge",
973 | "vcpuNum": 72
974 | },
975 | {
976 | "_defaultOrder": 28,
977 | "_isFastLaunch": false,
978 | "category": "Compute optimized",
979 | "gpuNum": 0,
980 | "hideHardwareSpecs": false,
981 | "memoryGiB": 192,
982 | "name": "ml.c5.24xlarge",
983 | "vcpuNum": 96
984 | },
985 | {
986 | "_defaultOrder": 29,
987 | "_isFastLaunch": true,
988 | "category": "Accelerated computing",
989 | "gpuNum": 1,
990 | "hideHardwareSpecs": false,
991 | "memoryGiB": 16,
992 | "name": "ml.g4dn.xlarge",
993 | "vcpuNum": 4
994 | },
995 | {
996 | "_defaultOrder": 30,
997 | "_isFastLaunch": false,
998 | "category": "Accelerated computing",
999 | "gpuNum": 1,
1000 | "hideHardwareSpecs": false,
1001 | "memoryGiB": 32,
1002 | "name": "ml.g4dn.2xlarge",
1003 | "vcpuNum": 8
1004 | },
1005 | {
1006 | "_defaultOrder": 31,
1007 | "_isFastLaunch": false,
1008 | "category": "Accelerated computing",
1009 | "gpuNum": 1,
1010 | "hideHardwareSpecs": false,
1011 | "memoryGiB": 64,
1012 | "name": "ml.g4dn.4xlarge",
1013 | "vcpuNum": 16
1014 | },
1015 | {
1016 | "_defaultOrder": 32,
1017 | "_isFastLaunch": false,
1018 | "category": "Accelerated computing",
1019 | "gpuNum": 1,
1020 | "hideHardwareSpecs": false,
1021 | "memoryGiB": 128,
1022 | "name": "ml.g4dn.8xlarge",
1023 | "vcpuNum": 32
1024 | },
1025 | {
1026 | "_defaultOrder": 33,
1027 | "_isFastLaunch": false,
1028 | "category": "Accelerated computing",
1029 | "gpuNum": 4,
1030 | "hideHardwareSpecs": false,
1031 | "memoryGiB": 192,
1032 | "name": "ml.g4dn.12xlarge",
1033 | "vcpuNum": 48
1034 | },
1035 | {
1036 | "_defaultOrder": 34,
1037 | "_isFastLaunch": false,
1038 | "category": "Accelerated computing",
1039 | "gpuNum": 1,
1040 | "hideHardwareSpecs": false,
1041 | "memoryGiB": 256,
1042 | "name": "ml.g4dn.16xlarge",
1043 | "vcpuNum": 64
1044 | },
1045 | {
1046 | "_defaultOrder": 35,
1047 | "_isFastLaunch": false,
1048 | "category": "Accelerated computing",
1049 | "gpuNum": 1,
1050 | "hideHardwareSpecs": false,
1051 | "memoryGiB": 61,
1052 | "name": "ml.p3.2xlarge",
1053 | "vcpuNum": 8
1054 | },
1055 | {
1056 | "_defaultOrder": 36,
1057 | "_isFastLaunch": false,
1058 | "category": "Accelerated computing",
1059 | "gpuNum": 4,
1060 | "hideHardwareSpecs": false,
1061 | "memoryGiB": 244,
1062 | "name": "ml.p3.8xlarge",
1063 | "vcpuNum": 32
1064 | },
1065 | {
1066 | "_defaultOrder": 37,
1067 | "_isFastLaunch": false,
1068 | "category": "Accelerated computing",
1069 | "gpuNum": 8,
1070 | "hideHardwareSpecs": false,
1071 | "memoryGiB": 488,
1072 | "name": "ml.p3.16xlarge",
1073 | "vcpuNum": 64
1074 | },
1075 | {
1076 | "_defaultOrder": 38,
1077 | "_isFastLaunch": false,
1078 | "category": "Accelerated computing",
1079 | "gpuNum": 8,
1080 | "hideHardwareSpecs": false,
1081 | "memoryGiB": 768,
1082 | "name": "ml.p3dn.24xlarge",
1083 | "vcpuNum": 96
1084 | },
1085 | {
1086 | "_defaultOrder": 39,
1087 | "_isFastLaunch": false,
1088 | "category": "Memory Optimized",
1089 | "gpuNum": 0,
1090 | "hideHardwareSpecs": false,
1091 | "memoryGiB": 16,
1092 | "name": "ml.r5.large",
1093 | "vcpuNum": 2
1094 | },
1095 | {
1096 | "_defaultOrder": 40,
1097 | "_isFastLaunch": false,
1098 | "category": "Memory Optimized",
1099 | "gpuNum": 0,
1100 | "hideHardwareSpecs": false,
1101 | "memoryGiB": 32,
1102 | "name": "ml.r5.xlarge",
1103 | "vcpuNum": 4
1104 | },
1105 | {
1106 | "_defaultOrder": 41,
1107 | "_isFastLaunch": false,
1108 | "category": "Memory Optimized",
1109 | "gpuNum": 0,
1110 | "hideHardwareSpecs": false,
1111 | "memoryGiB": 64,
1112 | "name": "ml.r5.2xlarge",
1113 | "vcpuNum": 8
1114 | },
1115 | {
1116 | "_defaultOrder": 42,
1117 | "_isFastLaunch": false,
1118 | "category": "Memory Optimized",
1119 | "gpuNum": 0,
1120 | "hideHardwareSpecs": false,
1121 | "memoryGiB": 128,
1122 | "name": "ml.r5.4xlarge",
1123 | "vcpuNum": 16
1124 | },
1125 | {
1126 | "_defaultOrder": 43,
1127 | "_isFastLaunch": false,
1128 | "category": "Memory Optimized",
1129 | "gpuNum": 0,
1130 | "hideHardwareSpecs": false,
1131 | "memoryGiB": 256,
1132 | "name": "ml.r5.8xlarge",
1133 | "vcpuNum": 32
1134 | },
1135 | {
1136 | "_defaultOrder": 44,
1137 | "_isFastLaunch": false,
1138 | "category": "Memory Optimized",
1139 | "gpuNum": 0,
1140 | "hideHardwareSpecs": false,
1141 | "memoryGiB": 384,
1142 | "name": "ml.r5.12xlarge",
1143 | "vcpuNum": 48
1144 | },
1145 | {
1146 | "_defaultOrder": 45,
1147 | "_isFastLaunch": false,
1148 | "category": "Memory Optimized",
1149 | "gpuNum": 0,
1150 | "hideHardwareSpecs": false,
1151 | "memoryGiB": 512,
1152 | "name": "ml.r5.16xlarge",
1153 | "vcpuNum": 64
1154 | },
1155 | {
1156 | "_defaultOrder": 46,
1157 | "_isFastLaunch": false,
1158 | "category": "Memory Optimized",
1159 | "gpuNum": 0,
1160 | "hideHardwareSpecs": false,
1161 | "memoryGiB": 768,
1162 | "name": "ml.r5.24xlarge",
1163 | "vcpuNum": 96
1164 | },
1165 | {
1166 | "_defaultOrder": 47,
1167 | "_isFastLaunch": false,
1168 | "category": "Accelerated computing",
1169 | "gpuNum": 1,
1170 | "hideHardwareSpecs": false,
1171 | "memoryGiB": 16,
1172 | "name": "ml.g5.xlarge",
1173 | "vcpuNum": 4
1174 | },
1175 | {
1176 | "_defaultOrder": 48,
1177 | "_isFastLaunch": false,
1178 | "category": "Accelerated computing",
1179 | "gpuNum": 1,
1180 | "hideHardwareSpecs": false,
1181 | "memoryGiB": 32,
1182 | "name": "ml.g5.2xlarge",
1183 | "vcpuNum": 8
1184 | },
1185 | {
1186 | "_defaultOrder": 49,
1187 | "_isFastLaunch": false,
1188 | "category": "Accelerated computing",
1189 | "gpuNum": 1,
1190 | "hideHardwareSpecs": false,
1191 | "memoryGiB": 64,
1192 | "name": "ml.g5.4xlarge",
1193 | "vcpuNum": 16
1194 | },
1195 | {
1196 | "_defaultOrder": 50,
1197 | "_isFastLaunch": false,
1198 | "category": "Accelerated computing",
1199 | "gpuNum": 1,
1200 | "hideHardwareSpecs": false,
1201 | "memoryGiB": 128,
1202 | "name": "ml.g5.8xlarge",
1203 | "vcpuNum": 32
1204 | },
1205 | {
1206 | "_defaultOrder": 51,
1207 | "_isFastLaunch": false,
1208 | "category": "Accelerated computing",
1209 | "gpuNum": 1,
1210 | "hideHardwareSpecs": false,
1211 | "memoryGiB": 256,
1212 | "name": "ml.g5.16xlarge",
1213 | "vcpuNum": 64
1214 | },
1215 | {
1216 | "_defaultOrder": 52,
1217 | "_isFastLaunch": false,
1218 | "category": "Accelerated computing",
1219 | "gpuNum": 4,
1220 | "hideHardwareSpecs": false,
1221 | "memoryGiB": 192,
1222 | "name": "ml.g5.12xlarge",
1223 | "vcpuNum": 48
1224 | },
1225 | {
1226 | "_defaultOrder": 53,
1227 | "_isFastLaunch": false,
1228 | "category": "Accelerated computing",
1229 | "gpuNum": 4,
1230 | "hideHardwareSpecs": false,
1231 | "memoryGiB": 384,
1232 | "name": "ml.g5.24xlarge",
1233 | "vcpuNum": 96
1234 | },
1235 | {
1236 | "_defaultOrder": 54,
1237 | "_isFastLaunch": false,
1238 | "category": "Accelerated computing",
1239 | "gpuNum": 8,
1240 | "hideHardwareSpecs": false,
1241 | "memoryGiB": 768,
1242 | "name": "ml.g5.48xlarge",
1243 | "vcpuNum": 192
1244 | },
1245 | {
1246 | "_defaultOrder": 55,
1247 | "_isFastLaunch": false,
1248 | "category": "Accelerated computing",
1249 | "gpuNum": 8,
1250 | "hideHardwareSpecs": false,
1251 | "memoryGiB": 1152,
1252 | "name": "ml.p4d.24xlarge",
1253 | "vcpuNum": 96
1254 | },
1255 | {
1256 | "_defaultOrder": 56,
1257 | "_isFastLaunch": false,
1258 | "category": "Accelerated computing",
1259 | "gpuNum": 8,
1260 | "hideHardwareSpecs": false,
1261 | "memoryGiB": 1152,
1262 | "name": "ml.p4de.24xlarge",
1263 | "vcpuNum": 96
1264 | }
1265 | ],
1266 | "instance_type": "ml.t3.medium",
1267 | "kernelspec": {
1268 | "display_name": "conda_pytorch_p310",
1269 | "language": "python",
1270 | "name": "conda_pytorch_p310"
1271 | },
1272 | "language_info": {
1273 | "codemirror_mode": {
1274 | "name": "ipython",
1275 | "version": 3
1276 | },
1277 | "file_extension": ".py",
1278 | "mimetype": "text/x-python",
1279 | "name": "python",
1280 | "nbconvert_exporter": "python",
1281 | "pygments_lexer": "ipython3",
1282 | "version": "3.10.10"
1283 | }
1284 | },
1285 | "nbformat": 4,
1286 | "nbformat_minor": 4
1287 | }
1288 |
--------------------------------------------------------------------------------