├── .gitignore
├── .streamlit
└── config.toml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app.py
├── architecture.png
├── config.py
├── init.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | faiss_index
2 | __pycache__
3 |
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | primaryColor="#9328C2"
3 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AWS Glue Data Catalog Text-to-SQL 👾
2 |
3 | **AWS Glue Data Catalog Text-to-SQL** is a plug-and-play Generative AI application that that integrates with your Glue Data Catalog to enhance table search as well as SQL Query generation. It enables data analysts, data scientists, and data users to enhance their productivity and to find the right datasets for their use cases and data products.
4 |
5 | 
6 |
7 | ## Prerequisites
8 | ### Region
9 | This code should run in the **region** where
10 | - Your Glue Data Catalog is hosted
11 | - Amazon Bedrock is generally available
12 |
13 | ### Python version
14 | This code has been tested with python3.8.
15 |
16 | If you are on Amazon Linux and do not have python>=3.8, install it:
17 | ```
18 | sudo yum remove python3.7
19 | sudo yum install -y amazon-linux-extras
20 | sudo amazon-linux-extras enable python3.8
21 | sudo yum install python3.8
22 | ```
23 | You might need to fix pip before the Installation step:
24 | ```
25 | curl -O https://bootstrap.pypa.io/get-pip.py
26 | python3.8 get-pip.py --user
27 | ```
28 |
29 | ## Installation
30 | Install dependencies needed
31 | ```bash
32 | pip install -r requirements.txt
33 | ```
34 |
35 | Troubleshooting: if this pip install fails, please check that the versions of the downloaded libraries matches the versions in the pip install. If it does not match, please fix the versions in the pip install.
36 |
37 | ## Configuration
38 |
39 | Using **Opensearch** as a Vector Store is entirely **optional**. You can use a local FAISS implementation.
40 |
41 | If you are using **Opensearch** as a Vector Store, please configure the domain's name and its corresponding endpoint in the ```config.py``` file
42 |
43 | > ⚠️ **Warning:** At this time, the sample supports Amazon Opensearch service's provisioned version
44 |
45 | ## Initialization
46 |
47 | Initialize your Vector Database with the existing AWS Glue Data Catalog assets.
48 |
49 | **FAISS (LOCAL)**
50 |
51 | ``` bash
52 | python3.8 init.py faiss
53 | ```
54 |
55 | **FAISS (Opensearch)**
56 |
57 | ``` bash
58 | python3.8 init.py opensearch
59 | ```
60 |
61 | ## Usage
62 |
63 | Run the streamlit app
64 |
65 | ``` bash
66 | streamlit run app.py --server.port
67 | ```
68 |
69 | ## Security
70 |
71 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
72 |
73 | ## License
74 |
75 | This library is licensed under the MIT-0 License. See the LICENSE file.
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import boto3
3 | import time
4 | import config
5 | import langchain
6 | from requests_aws4auth import AWS4Auth
7 | from opensearchpy import OpenSearch, RequestsHttpConnection
8 | from langchain.embeddings import BedrockEmbeddings
9 | from langchain.llms.bedrock import Bedrock
10 | from langchain.vectorstores import FAISS
11 | from langchain.vectorstores import OpenSearchVectorSearch
12 | from langchain.docstore.document import Document
13 | from langchain.prompts import PromptTemplate
14 | from langchain.chains import RetrievalQA
15 |
16 | if __name__ == "__main__":
17 |
18 | # Page configuration
19 |
20 | st.set_page_config(
21 | page_title='AWS Glue Data Catalog Text-to-SQL',
22 | page_icon=':space_invader:',
23 | initial_sidebar_state='collapsed')
24 | st.title(':violet[AWS Glue] Data Catalog Text-to-SQL :space_invader:')
25 | st.caption('Supercharge your Glue Data Catalog :rocket:')
26 |
27 | # Variables
28 |
29 | langchain.verbose = True
30 | session = boto3.session.Session()
31 | region = config._global['region']
32 | credentials = session.get_credentials()
33 | service = 'es'
34 | http_auth = AWS4Auth(
35 | credentials.access_key,
36 | credentials.secret_key,
37 | region,
38 | service,
39 | session_token=credentials.token)
40 | opensearch_cluster_domain_endpoint = config.opensearch['domain_endpoint']
41 | domain_name = config.opensearch['domain_name']
42 | index_name = "index-superglue"
43 |
44 | # Create AWS Glue client
45 |
46 | glue_client = boto3.client('glue', region_name=region)
47 |
48 | # Function to get all tables from Glue Data Catalog
49 |
50 |
51 | def get_tables(glue_client):
52 | # get all AWS Glue databases
53 | databases = glue_client.get_databases()
54 |
55 | tables = []
56 |
57 | num_db = len(databases['DatabaseList'])
58 |
59 | for db in databases['DatabaseList']:
60 | tables = tables + \
61 | glue_client.get_tables(DatabaseName=db['Name'])["TableList"]
62 |
63 | num_tables = len(tables)
64 |
65 | return tables, num_db, num_tables
66 |
67 | # Function to flatten JSON representations of Glue tables
68 |
69 |
70 | def dict_to_multiline_string(d):
71 |
72 | lines = []
73 | db_name = d['DatabaseName']
74 | table_name = d['Name']
75 | columns = [c['Name'] for c in d['StorageDescriptor']['Columns']]
76 |
77 | line = f"{db_name}.{table_name} ({', '.join(columns)})"
78 | lines.append(line)
79 |
80 | return "\n".join(lines)
81 |
82 | # Function to render user input elements
83 |
84 |
85 | def render_form(catalog):
86 | if (num_tables or num_db):
87 | st.write(
88 | "A total of ",
89 | num_tables,
90 | "tables and ",
91 | num_db,
92 | "databases were indexed")
93 |
94 | k = st.selectbox(
95 | 'How many tables do you want to include in table search result?',
96 | (1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
97 | index=2)
98 |
99 | query = st.text_area(
100 | 'Prompt',
101 | "What is the total inventory per warehouse?")
102 |
103 | with st.sidebar:
104 | st.subheader(":violet[Data Catalog] :point_down:")
105 | st.write(catalog)
106 |
107 | return k, query
108 |
109 | # Function to perform a similarity search
110 |
111 |
112 | def search_tables(vectorstore, k, query):
113 | relevant_documents = vectorstore.similarity_search_with_score(query, k=k)
114 | for rel_doc in relevant_documents:
115 | st.write(rel_doc[0].page_content.split(" ")[0])
116 | st.write("Score: ", rel_doc[1])
117 | st.divider()
118 |
119 |
120 | # Function to generate LLM response (SQL + Explanation)
121 |
122 | def generate_sql(vectorstore, k, query):
123 | prompt_template = """
124 | \n\nHuman: Between tags, you have a description of tables with their associated columns. Create a SQL query to answer the question between tags only using the tables described between the tags. If you cannot find the solution with the provided tables, say that you are unable to generate the SQL query.
125 |
126 |
127 | {context}
128 |
129 |
130 | Question: {question}
131 |
132 | Provide your answer using the following xml format: SQL queryExplain clearly your approach, what the query does, and its syntax
133 |
134 | Assistant:"""
135 |
136 | PROMPT = PromptTemplate(
137 | template=prompt_template, input_variables=["context", "question"]
138 | )
139 |
140 | qa = RetrievalQA.from_chain_type(
141 | llm=bedrock_llm,
142 | chain_type="stuff",
143 | retriever=vectorstore.as_retriever(
144 | search_type="similarity", search_kwargs={"k": k}
145 | ),
146 | return_source_documents=True,
147 | chain_type_kwargs={"prompt": PROMPT},
148 | verbose=True
149 | )
150 | with st.status("Generating response :thinking_face:"):
151 | answer = qa({"query": query})
152 |
153 | # st.write(answer)
154 |
155 | with st.status("Searching tables :books:"):
156 | time.sleep(1)
157 |
158 | for i, rel_doc in enumerate(answer["source_documents"]):
159 | st.write(rel_doc.page_content.split(" ")[0])
160 |
161 | with st.status("Rendering response :fire:"):
162 | sql_query = answer["result"].split("")[1].split("")[0]
163 | explanation = answer["result"].split("")[
164 | 1].split("")[0]
165 |
166 | st.code(sql_query, language='sql')
167 | st.link_button(
168 | "Athena console :sun_with_face:",
169 | "https://{0}.console.aws.amazon.com/athena/home?region={0}".format(region))
170 |
171 | st.write(explanation)
172 |
173 | # Amazon Bedrock LangChain clients
174 |
175 |
176 | bedrock_llm = Bedrock(
177 | model_id="anthropic.claude-v2",
178 | model_kwargs={
179 | 'max_tokens_to_sample': 3000})
180 | bedrock_embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
181 |
182 | # VectorDB type
183 |
184 | vectorDB = st.selectbox(
185 | "VectorDB",
186 | ("FAISS (local)", "OpenSearch (Persistent)"),
187 | index=0
188 | )
189 |
190 | if vectorDB == "FAISS (local)":
191 |
192 | with st.status("Connecting to Glue Data Catalog :man_dancing:"):
193 |
194 | catalog, num_db, num_tables = get_tables(glue_client)
195 |
196 | # Check if an index copy of FAISS is stored locally
197 |
198 | try:
199 | vectorstore_faiss = FAISS.load_local(
200 | "faiss_index", bedrock_embeddings)
201 | except BaseException:
202 | docs = [
203 | Document(
204 | page_content=dict_to_multiline_string(x),
205 | metadata={
206 | "source": "local"}) for x in catalog]
207 |
208 | vectorstore_faiss = FAISS.from_documents(
209 | docs,
210 | bedrock_embeddings,
211 | )
212 |
213 | vectorstore_faiss.save_local("faiss_index")
214 |
215 | k, query = render_form(catalog)
216 |
217 | if st.button('Search relevant tables :dart:'):
218 |
219 | search_tables(vectorstore=vectorstore_faiss, k=k, query=query)
220 |
221 | if st.button('Generate SQL :crystal_ball:'):
222 |
223 | generate_sql(vectorstore=vectorstore_faiss, k=k, query=query)
224 |
225 | elif vectorDB == "OpenSearch (Persistent)":
226 |
227 | with st.status("Connecting to Glue Data Catalog :man_dancing:"):
228 |
229 | catalog, num_db, num_tables = get_tables(glue_client)
230 |
231 | # Initialize Opensearch Vector Search clients
232 |
233 | vectorstore_opensearch = OpenSearchVectorSearch(
234 | index_name=index_name,
235 | embedding_function=bedrock_embeddings,
236 | opensearch_url=opensearch_cluster_domain_endpoint,
237 | engine="faiss",
238 | timeout=300,
239 | use_ssl=True,
240 | verify_certs=True,
241 | http_auth=http_auth,
242 | connection_class=RequestsHttpConnection
243 | )
244 |
245 | k, query = render_form(catalog)
246 |
247 | if st.button('Search relevant tables :dart:'):
248 | search_tables(vectorstore=vectorstore_opensearch, k=k, query=query)
249 |
250 | if st.button('Generate SQL :crystal_ball:'):
251 |
252 | generate_sql(vectorstore=vectorstore_opensearch, k=k, query=query)
253 |
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/aws-glue-data-catalog-text2sql/23c0f77b62e9e9bfcf689c10b2bef0b89d1dc336/architecture.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | opensearch = dict(
2 | domain_endpoint='',
3 | domain_name='',
4 | )
5 | _global = dict(
6 | region='us-east-1'
7 | )
8 |
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | import boto3
2 | import config
3 | import langchain
4 | import sys
5 | from requests_aws4auth import AWS4Auth
6 | from opensearchpy import OpenSearch, RequestsHttpConnection
7 | from langchain.embeddings import BedrockEmbeddings
8 | from langchain.vectorstores import FAISS
9 | from langchain.vectorstores import OpenSearchVectorSearch
10 | from langchain.docstore.document import Document
11 |
12 | if __name__ == "__main__":
13 |
14 |
15 | session = boto3.session.Session()
16 | region = config._global['region']
17 | credentials = session.get_credentials()
18 | service = 'es'
19 | http_auth = AWS4Auth(
20 | credentials.access_key,
21 | credentials.secret_key,
22 | region,
23 | service,
24 | session_token=credentials.token)
25 | opensearch_cluster_domain_endpoint = config.opensearch['domain_endpoint']
26 | domain_name = config.opensearch['domain_name']
27 | index_name = "index-superglue"
28 |
29 | # Create AWS Glue client
30 |
31 | glue_client = boto3.client('glue', region_name=region)
32 |
33 |
34 | # Create Amazon Opensearch client
35 |
36 | def get_opensearch_cluster_client():
37 | opensearch_client = OpenSearch(
38 | hosts=opensearch_cluster_domain_endpoint,
39 | http_auth=http_auth,
40 | engine="faiss",
41 | index_name=index_name,
42 | use_ssl=True,
43 | verify_certs=True,
44 | connection_class=RequestsHttpConnection,
45 | timeout=300
46 | )
47 | return opensearch_client
48 |
49 | # Function to get all tables from Glue Data Catalog
50 |
51 |
52 | def get_tables(glue_client):
53 | # get all AWS Glue databases
54 | databases = glue_client.get_databases()
55 |
56 | tables = []
57 |
58 | num_db = len(databases['DatabaseList'])
59 |
60 | for db in databases['DatabaseList']:
61 | tables = tables + \
62 | glue_client.get_tables(DatabaseName=db['Name'])["TableList"]
63 |
64 | num_tables = len(tables)
65 |
66 | return tables, num_db, num_tables
67 |
68 | # Function to flatten JSON representations of Glue tables
69 |
70 |
71 | def dict_to_multiline_string(d):
72 |
73 | lines = []
74 | db_name = d['DatabaseName']
75 | table_name = d['Name']
76 | columns = [c['Name'] for c in d['StorageDescriptor']['Columns']]
77 |
78 | line = f"{db_name}.{table_name} ({', '.join(columns)})"
79 | lines.append(line)
80 |
81 | return "\n".join(lines)
82 |
83 | # Amazon Bedrock LangChain clients
84 |
85 |
86 | bedrock_embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
87 |
88 | # VectorDB type
89 |
90 | vectorDB = sys.argv[1]
91 |
92 | if vectorDB == "faiss":
93 |
94 | print("INFO: Indexing FAISS started.")
95 |
96 | catalog, num_db, num_tables = get_tables(glue_client)
97 |
98 | docs = [
99 | Document(
100 | page_content=dict_to_multiline_string(x),
101 | metadata={
102 | "source": "local"}) for x in catalog]
103 |
104 | vectorstore_faiss = FAISS.from_documents(
105 | docs,
106 | bedrock_embeddings,
107 | )
108 |
109 | print("INFO: Loaded Documents in FAISS.")
110 |
111 | vectorstore_faiss.save_local("faiss_index")
112 |
113 | print("COMPLETE: FAISS Index saved.")
114 |
115 | elif vectorDB == "opensearch":
116 |
117 | print("INFO: Opensearch Index saved.")
118 |
119 | catalog, num_db, num_tables = get_tables(glue_client)
120 |
121 | # Initialize Opensearch clients
122 |
123 | opensearch_client = get_opensearch_cluster_client()
124 |
125 | vectorstore_opensearch = OpenSearchVectorSearch(
126 | index_name=index_name,
127 | embedding_function=bedrock_embeddings,
128 | opensearch_url=opensearch_cluster_domain_endpoint,
129 | engine="faiss",
130 | timeout=300,
131 | use_ssl=True,
132 | verify_certs=True,
133 | http_auth=http_auth,
134 | connection_class=RequestsHttpConnection
135 | )
136 |
137 | # Delete index for initial batch embedding
138 |
139 | try:
140 | opensearch_client.indices.delete(index_name)
141 | except BaseException:
142 | print("Index does not exist.")
143 |
144 | # Prepare and add documents
145 |
146 | docs = [
147 | Document(
148 | page_content=dict_to_multiline_string(x),
149 | metadata={
150 | "source": "local"}) for x in catalog]
151 |
152 | vectorstore_opensearch.add_documents(docs)
153 |
154 | print("COMPLETE: Loaded Document Embeddings in Opensearch.")
155 |
156 |
157 | else:
158 | print("ERROR: Invalid vector database type.")
159 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.30.0
2 | langchain==0.0.329
3 | boto3==1.28.57
4 | faiss-cpu==1.7.4
5 | opensearch-py==2.3.1
6 | requests-aws4auth==1.2.3
7 | urllib3==1.26.18
--------------------------------------------------------------------------------