├── Readme.md ├── app.py ├── db_schema_logic.py ├── new.db └── requirements.txt /Readme.md: -------------------------------------------------------------------------------- 1 | # LangChain Database Schema Interaction 2 | 3 | ## Overview 4 | 5 | This project demonstrates a secure method of generating SQL queries using LangChain and OpenAI's language models, based on a database schema rather than direct database access. This approach enhances security by working with the structure of the database without exposing actual data to the language model. 6 | 7 | ## Features 8 | 9 | - Extracts database schema without accessing data 10 | - Generates SQL queries using natural language processing 11 | - Utilizes OpenAI's language models via LangChain 12 | - Provides a layer of abstraction between the query generator and the actual database 13 | 14 | ## Requirements 15 | 16 | - Python 3.7+ 17 | - OpenAI API key 18 | - LangChain 19 | - SQLAlchemy 20 | - A SQL database (SQLite used in this example) 21 | 22 | ## Installation 23 | 24 | 1. Clone this repository: 25 | ``` 26 | git clone https://github.com/paras55/advanced-chat-with-db.git 27 | cd advanced-chat-with-db 28 | ``` 29 | 30 | 2. Install the required packages: 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Usage 36 | 37 | 1. Run the main script: 38 | ``` 39 | streamlit run app.py 40 | ``` 41 | 42 | 2. The script will start the streamlit applicaltion on your local server 43 | 44 | 3. To generate queries for your own questions, upload the Database or use sample one and then write your query: 45 | 46 | 47 | ## How It Works 48 | 49 | 1. **Schema Extraction**: The script connects to the database and extracts its schema (table names and column details) without accessing the actual data. 50 | 51 | 2. **Query Generation**: Using LangChain and OpenAI's language model, the script generates SQL queries based on the extracted schema and natural language questions. 52 | 53 | 3. **Security**: By working with the schema instead of the actual database, this approach adds a layer of security, preventing direct data access during query generation. 54 | 55 | ## Customization 56 | 57 | - Modify the `prompt_template` to adjust how the AI generates SQL queries. 58 | - Extend the `extract_schema` function to include more detailed schema information if needed. 59 | 60 | ## Important Notes 61 | 62 | - This script generates SQL queries but does not execute them. Implement proper security measures before executing generated queries on a real database. 63 | - Always review and validate generated SQL queries before execution to ensure they meet your security and performance requirements. 64 | - Keep your OpenAI API key secure and do not share it in your code repository. 65 | 66 | ## Contributing 67 | 68 | Contributions to improve the functionality, security, or efficiency of this project are welcome. Please submit a pull request or open an issue to discuss proposed changes. 69 | 70 | ## License 71 | 72 | [MIT License](LICENSE) 73 | 74 | ## Disclaimer 75 | 76 | This project is a demonstration and should be carefully reviewed and adapted before use in any production environment. Always prioritize security and data privacy when working with databases and language models. 77 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # File: app.py 2 | 3 | import streamlit as st 4 | import os 5 | import tempfile 6 | from db_schema_logic import extract_schema, setup_llm_chain, generate_sql_query 7 | 8 | # Set the page title 9 | st.set_page_config(page_title="SQL Query Generator", page_icon=":mag:") 10 | 11 | st.markdown(""" 12 | 78 | """, unsafe_allow_html=True) 79 | 80 | def main(): 81 | st.title("Advanced Text to SQL") 82 | st.write("Generate SQL queries from natural language") 83 | 84 | # Sidebar for configuration 85 | st.sidebar.header("Configuration") 86 | openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password") 87 | 88 | # Default database URL 89 | default_db_url = "sqlite:///new.db" 90 | 91 | # File uploader for database 92 | uploaded_file = st.sidebar.file_uploader("Upload a SQLite database", type=["db", "sqlite"]) 93 | 94 | # Button to process the uploaded database 95 | process_db = st.sidebar.button("Process Uploaded Database") 96 | 97 | # Initialize session state 98 | if 'db_processed' not in st.session_state: 99 | st.session_state.db_processed = False 100 | if 'db_url' not in st.session_state: 101 | st.session_state.db_url = default_db_url 102 | if 'schema' not in st.session_state: 103 | st.session_state.schema = None 104 | if 'temp_db_file' not in st.session_state: 105 | st.session_state.temp_db_file = None 106 | 107 | if uploaded_file is not None and process_db: 108 | # Create a temporary file to store the uploaded database 109 | if st.session_state.temp_db_file: 110 | st.session_state.temp_db_file.close() 111 | 112 | st.session_state.temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix='.db') 113 | st.session_state.temp_db_file.write(uploaded_file.getvalue()) 114 | st.session_state.temp_db_file.flush() 115 | 116 | st.session_state.db_url = f"sqlite:///{st.session_state.temp_db_file.name}" 117 | st.session_state.db_processed = True 118 | with st.spinner("Extracting database schema..."): 119 | st.session_state.schema = extract_schema(st.session_state.db_url) 120 | st.sidebar.success("Database processed successfully!") 121 | elif process_db and uploaded_file is None: 122 | st.sidebar.warning("Please upload a database file first.") 123 | 124 | st.sidebar.markdown("[Download sample database](https://drive.google.com/file/d/1RyTj-yCPlAwtQKfTThahZkMtDEKxi2Lr/view?usp=sharing)") # Replace with actual download link 125 | 126 | if not openai_api_key: 127 | st.warning("Please enter your OpenAI API key in the sidebar.") 128 | return 129 | 130 | if st.session_state.schema: 131 | st.subheader("Database Schema") 132 | st.code(st.session_state.schema) 133 | 134 | # Set up LLM chain 135 | try: 136 | chain = setup_llm_chain(openai_api_key) 137 | except ValueError as e: 138 | st.error(f"Error setting up the language model: {str(e)}") 139 | return 140 | 141 | # User input 142 | user_question = st.text_input("Enter your question:") 143 | 144 | if user_question: 145 | with st.spinner("Generating SQL query..."): 146 | sql_query = generate_sql_query(chain, st.session_state.schema, user_question) 147 | 148 | st.subheader("Generated SQL Query") 149 | st.code(sql_query, language="sql") 150 | 151 | # Option to copy the query 152 | if st.button("Copy Query"): 153 | st.write("Query copied to clipboard!", icon="✅") 154 | st.balloons() 155 | st.experimental_set_query_params(clipboard=sql_query) 156 | 157 | else: 158 | st.info("Please upload a database and click 'Process Uploaded Database' to start.") 159 | 160 | st.sidebar.markdown("---") 161 | st.sidebar.write("Note: This app generates SQL queries based on the schema but does not execute them.") 162 | 163 | if __name__ == "__main__": 164 | main() 165 | 166 | # Cleanup function to be called when the Streamlit app is closed or rerun 167 | def cleanup(): 168 | if st.session_state.temp_db_file: 169 | st.session_state.temp_db_file.close() 170 | os.unlink(st.session_state.temp_db_file.name) 171 | st.session_state.temp_db_file = None 172 | 173 | # Register the cleanup function 174 | import atexit 175 | atexit.register(cleanup) 176 | 177 | st.markdown(""" 178 | 187 | """, unsafe_allow_html=True) 188 | -------------------------------------------------------------------------------- /db_schema_logic.py: -------------------------------------------------------------------------------- 1 | # File: db_schema_logic.py 2 | 3 | from langchain import OpenAI, LLMChain 4 | from langchain.prompts import PromptTemplate 5 | from sqlalchemy import create_engine, inspect 6 | 7 | def extract_schema(db_url): 8 | """Extract schema from the database without accessing data.""" 9 | engine = create_engine(db_url) 10 | inspector = inspect(engine) 11 | 12 | schema_info = [] 13 | for table_name in inspector.get_table_names(): 14 | columns = inspector.get_columns(table_name) 15 | schema_info.append(f"Table: {table_name}") 16 | for column in columns: 17 | schema_info.append(f" - {column['name']} ({column['type']})") 18 | 19 | return "\n".join(schema_info) 20 | 21 | def setup_llm_chain(openai_api_key): 22 | """Set up the LangChain components.""" 23 | llm = OpenAI(temperature=0, openai_api_key=openai_api_key) 24 | 25 | prompt_template = """ 26 | You are an AI assistant that generates SQL queries based on user requests. 27 | You have access to the following database schema: 28 | 29 | {schema} 30 | 31 | Based ONLY on this schema, generate a SQL query to answer the following question: 32 | 33 | {question} 34 | 35 | If the question cannot be answered using ONLY the provided schema, respond with "I cannot answer this question based on the given schema." 36 | 37 | SQL Query: 38 | """ 39 | 40 | prompt = PromptTemplate( 41 | input_variables=["schema", "question"], 42 | template=prompt_template, 43 | ) 44 | 45 | return LLMChain(llm=llm, prompt=prompt) 46 | 47 | def generate_sql_query(chain, schema, question): 48 | """Generate SQL query based on the schema and question.""" 49 | return chain.run(schema=schema, question=question) -------------------------------------------------------------------------------- /new.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paras55/advanced-chat-with-db/3ece7345742ca458e21b7bc9a7e90191358c8a0e/new.db -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain 2 | openai 3 | sqlalchemy 4 | langchain-experimental 5 | --------------------------------------------------------------------------------