├── .gitignore ├── config.ini.default ├── terminal.py ├── requirements.txt ├── README.md ├── google_sql_connector.py ├── controller.py └── chatgpt.py /.gitignore: -------------------------------------------------------------------------------- 1 | debug.log 2 | config.ini 3 | credentials.json 4 | token.json 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /config.ini.default: -------------------------------------------------------------------------------- 1 | [database] 2 | driver = ODBC Driver 18 for SQL Server 3 | server = 127.0.0.1 4 | database = xxxxxx 5 | user = xxxxx 6 | password = xxxxxx 7 | encrypt = yes 8 | 9 | [openai] 10 | api_key = xxxxxx 11 | org = xxxx 12 | -------------------------------------------------------------------------------- /terminal.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from controller import Controller 3 | 4 | # Configure the logging settings 5 | logging.basicConfig(filename='debug.log', filemode='a', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') 6 | 7 | def main(): 8 | print("Ask any question about the data. Enter 'q' to quit. Enter 'r' to reset ChatGPT.") 9 | controller = Controller() 10 | while True: 11 | user_input = input("Question: ") 12 | if user_input.lower() == 'q': 13 | break 14 | if user_input == "r": 15 | controller.chatModel.reset() 16 | continue 17 | try: 18 | result = controller.run(message=user_input, sender="USER") 19 | print(f"ChatGPT: {result}") 20 | except ValueError: 21 | print("Invalid input. Please enter a number or 'q' to quit.") 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 2 | aiosignal==1.3.1 3 | async-timeout==4.0.2 4 | attrs==22.2.0 5 | build==0.10.0 6 | CacheControl==0.12.11 7 | cachetools==5.3.0 8 | certifi==2022.12.7 9 | charset-normalizer==3.1.0 10 | cleo==2.0.1 11 | colorama==0.4.6 12 | crashtest==0.4.1 13 | dataclasses-json==0.5.7 14 | distlib==0.3.6 15 | dulwich==0.21.3 16 | exceptiongroup==1.1.1 17 | filelock==3.11.0 18 | frozenlist==1.3.3 19 | google-api-core==2.11.0 20 | google-api-python-client==2.84.0 21 | google-auth==2.17.2 22 | google-auth-httplib2==0.1.0 23 | google-auth-oauthlib==1.0.0 24 | googleapis-common-protos==1.59.0 25 | greenlet==2.0.2 26 | html5lib==1.1 27 | httplib2==0.22.0 28 | idna==3.4 29 | importlib-metadata==6.3.0 30 | iniconfig==2.0.0 31 | installer==0.7.0 32 | jaraco.classes==3.2.3 33 | jsonschema==4.17.3 34 | keyring==23.13.1 35 | langchain==0.0.135 36 | lockfile==0.12.2 37 | marshmallow==3.19.0 38 | marshmallow-enum==1.5.1 39 | more-itertools==9.1.0 40 | msgpack==1.0.5 41 | multidict==6.0.4 42 | mypy-extensions==1.0.0 43 | numpy==1.24.2 44 | oauthlib==3.2.2 45 | openai==0.27.2 46 | openapi-schema-pydantic==1.2.4 47 | packaging==23.0 48 | pandas==2.0.0 49 | pexpect==4.8.0 50 | pkginfo==1.9.6 51 | platformdirs==2.6.2 52 | pluggy==1.0.0 53 | poetry==1.4.2 54 | poetry-core==1.5.2 55 | poetry-plugin-export==1.3.0 56 | protobuf==4.22.1 57 | ptyprocess==0.7.0 58 | pyasn1==0.4.8 59 | pyasn1-modules==0.2.8 60 | pydantic==1.10.7 61 | pyodbc==4.0.35 62 | pyparsing==3.0.9 63 | pyproject_hooks==1.0.0 64 | pyrsistent==0.19.3 65 | pytest==7.3.0 66 | python-dateutil==2.8.2 67 | pytz==2023.3 68 | pywin32-ctypes==0.2.0 69 | PyYAML==6.0 70 | rapidfuzz==2.15.1 71 | requests==2.28.2 72 | requests-oauthlib==1.3.1 73 | requests-toolbelt==0.10.1 74 | rsa==4.9 75 | shellingham==1.5.0.post1 76 | six==1.16.0 77 | SQLAlchemy==1.4.47 78 | tabulate==0.9.0 79 | tenacity==8.2.2 80 | tomli==2.0.1 81 | tomlkit==0.11.7 82 | tqdm==4.65.0 83 | trove-classifiers==2023.3.9 84 | typing-inspect==0.8.0 85 | typing_extensions==4.5.0 86 | tzdata==2023.3 87 | uritemplate==4.1.1 88 | urllib3==1.26.15 89 | virtualenv==20.21.0 90 | webencodings==0.5.1 91 | yarl==1.8.2 92 | zipp==3.15.0 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EDIT July 27, 2023: 2 | 3 | Significant progress has been made on open source frameworks that allow LLM's to interact with other software (such as Databases). For any production use cases, I recommend using such a framework (e.g. [LangChain](https://python.langchain.com/docs/get_started/introduction.html)). This Repository can still serve as a straight forward (learning) example of how LLM tool use can be implemented from the ground up. 4 | 5 | # ChatGPT SQL 6 | 7 | ![Demo](https://s2.gifyu.com/images/chatgpt-sql-demo.gif) 8 | 9 | Connecting ChatGPT to an SQL Server so that you can ask questions about data in natural language, and you also get an answer in natural language. It works by creating a layer between the user and ChatGPT that routes messages between the user and ChatGPT on one side, and ChatGPT and the SQL server on the other. ChatGPT is indicating whether the message it meant for the user or for the server. What you see in the below image is the following: 10 | 11 | 1. The user asks a question ("What is the top selling product by revenue in 2013?") 12 | 2. The script forwards the question to ChatGPT. 13 | 3. ChatGPT responds with a request for schema information. 14 | 4. The script queries the database schema in SQL server and sends the result back to ChatGPT. 15 | 5. ChatGPT Formulates an SQL query based on the schema information 16 | 6. The script executes the query against the database and sends the result back to ChatGPT. 17 | 7. ChatGPT interprets the result and formulates an answer (indicating it should be shared with the user) 18 | 8. The script shares the answer with the user. ("The top selling product by revenue in 2013 is the Mountain-200 Black, 38."). 19 | 20 | The app is currently configured to work with the [AdventureWorks database](https://learn.microsoft.com/en-us/sql/samples/adventureworks-install-configure?view=sql-server-ver16&tabs=ssms), but it should be possible to make it work with any other database as well. 21 | 22 | Some caveats: 23 | * The database I use is the AdventureWorks database, which ChatGPT also knows from it's training data, which means it's probably performing better than it would with other databases. 24 | * It sometimes breaks, due to ChatGPT not following protocol. 25 | 26 | Still, I think it's a pretty cool POC :slightly_smiling_face: 27 | 28 | -------------------------------------------------------------------------------- /google_sql_connector.py: -------------------------------------------------------------------------------- 1 | import pyodbc 2 | import csv 3 | import logging 4 | 5 | from io import StringIO 6 | 7 | class GoogleCloudSQL: 8 | 9 | def __init__(self, driver, server, database, user, password, encrypt="yes"): 10 | self.driver = driver 11 | self.server = server 12 | self.database = database 13 | self.user = user 14 | self.password = password 15 | self.encrypt = encrypt 16 | 17 | def connect(self): 18 | try: 19 | self.conn = pyodbc.connect(f'DRIVER={{{self.driver}}};SERVER={self.server};DATABASE={self.database};UID={self.user};PWD={self.password};ENCRYPT={self.encrypt}') 20 | return True 21 | except Exception as e: 22 | return str(e) 23 | 24 | def close(self): 25 | self.conn.close() 26 | 27 | def execute_query(self, query): 28 | print(f'\033[94mExecuting Query:{query}\033[0m') 29 | try: 30 | cursor = self.conn.cursor() 31 | cursor.execute(query) 32 | result = cursor.fetchall() 33 | if len(result) == 0: 34 | result = "0 rows returned" 35 | logging.debug(result) 36 | print(f'\033[96m{result}\033[0m') 37 | return result 38 | 39 | headers = [column[0] for column in cursor.description] 40 | output = StringIO() 41 | csv_writer = csv.writer(output) 42 | csv_writer.writerow(headers) 43 | csv_writer.writerows(result) 44 | result = output.getvalue() 45 | logging.debug(result) 46 | print(f'\033[96m{result}\033[0m') 47 | return result 48 | except Exception as e: 49 | return str(e) 50 | 51 | def process_table_string(self, input_str): 52 | items = input_str.split(',') 53 | items = [item.split('.')[-1] for item in items] 54 | formatted_str = "', '".join(items) 55 | result = f"'{formatted_str}'" 56 | return result 57 | 58 | def execute_schema(self, table_list): 59 | queryPart = self.process_table_string(table_list) 60 | return f"SELECT CONCAT(TABLE_SCHEMA, '.', TABLE_NAME, ', ', COLUMN_NAME, ', ', DATA_TYPE) AS 'Table, Column, DataType' FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME IN ({queryPart})" 61 | -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | import json 2 | from chatgpt import ChatGPT 3 | from google_sql_connector import GoogleCloudSQL 4 | import configparser 5 | 6 | # Read the config file 7 | config = configparser.ConfigParser() 8 | config.read("config.ini") 9 | 10 | # Access the config values 11 | driver = config.get("database", "driver") 12 | server = config.get("database", "server") 13 | database = config.get("database", "database") 14 | user = config.get("database", "user") 15 | password = config.get("database", "password") 16 | encrypt = config.get("database", "encrypt") 17 | openai_api_key = config.get("openai", "api_key") 18 | openai_org = config.get("openai", "org") 19 | openai_model = config.get("openai", "model") 20 | 21 | 22 | class Controller: 23 | 24 | def __init__(self): 25 | # initialise all the things 26 | self.google_sql = GoogleCloudSQL(driver, server, database, user, password, encrypt) 27 | self.google_sql.connect() 28 | self.chatModel = ChatGPT(openai_api_key, openai_org, openai_model) 29 | 30 | def run(self, message, sender, counter=0): 31 | if (counter > 4): 32 | return 'error: too many requests' 33 | responseString = self.chatModel.message(message, sender) 34 | try: 35 | response = json.loads(responseString[:-1] if responseString.endswith('.') else responseString) 36 | except ValueError: 37 | return self.run("Please repeat that answer but use valid JSON only.", "SYSTEM", counter + 1) 38 | match response["recipient"]: 39 | case "USER": 40 | return response["message"] 41 | case "SERVER": 42 | match response["action"]: 43 | case "QUERY": 44 | result = self.google_sql.execute_query(response["message"]) 45 | return self.run(result, None, counter + 1) 46 | case "SCHEMA": 47 | result = self.google_sql.execute_schema(response["message"]) 48 | return self.run(result, None, counter + 1) 49 | case _: 50 | print('error invalid action') 51 | print(response) 52 | case _: 53 | print('error, invalid recipient') 54 | print(response) 55 | 56 | 57 | def reset(self): 58 | self.chatModel.reset() 59 | -------------------------------------------------------------------------------- /chatgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import logging 4 | import json 5 | 6 | class ChatGPT: 7 | 8 | startMessageStack = [ 9 | {"role": "system", "content": "You act as the middleman between USER and a DATABASE. Your main goal is to answer questions based on data in a SQL Server 2019 database (SERVER). You do this by executing valid queries against the database and interpreting the results to anser the questions from the USER."}, 10 | {"role": "user", "content": "From now you will only ever respond with JSON. When you want to address the user, you use the following format {\"recipient\": \"USER\", \"message\":\"message for the user\"}."}, 11 | {"role": "assistant", "content": "{\"recipient\": \"USER\", \"message\":\"I understand.\"}."}, 12 | {"role": "user", "content": "You can address the SQL Server by using the SERVER recipient. When calling the server, you must also specify an action. The action can be QUERY when you want to QUERY the database, or SCHEMA when you need SCHEMA information for a comma separated list of tables. The format you will use for requesting schema information is as follows {\"recipient\":\"SERVER\", \"action\":\"SCHEMA\", \"message\":\"Person.Person, Person.Address\"}. The format you will use for executing a query is as follows: {\"recipient\":\"SERVER\", \"action\":\"QUERY\", \"message\":\"SELECT SUM(OrderQty) FROM Sales.SalesOrderDetail;\"}"}, 13 | # At some point the list of tables should become dynamic. Todo: Figure out how the flows can be dynamic, perhaps some sort of config. 14 | {"role": "user", "content": "THe following tables are available in the database: Sales.SalesOrderDetail, Sales.SalesOrderHeader, Person.Address, Person.Person, Production.Product, Production.ProductCategory, Production.ProductSubcategory. You will always first request the SCHEMA for a table before using the table in a QUERY."}, 15 | {"role": "user", "content": "Let's start! How many Orders did we ship to Bellflower?"}, 16 | {"role": "assistant", "content": "{\"recipient\":\"SERVER\", \"action\":\"SCHEMA\", \"message\":\"Person.Address, Sales.SalesOrderHeader\"}"}, 17 | {"role": "user", "content": "Table, Column, DataType\nPerson.Address, AddressID, int\nPerson.Address, AddressLine1, nvarchar\nPerson.Address, AddressLine2, nvarchar\nPerson.Address, City, nvarchar\nPerson.Address, StateProvinceID, int\nPerson.Address, PostalCode, nvarchar\nPerson.Address, SpatialLocation, geography\nPerson.Address, rowguid, uniqueidentifier\nPerson.Address, ModifiedDate, datetime\nSales.SalesOrderHeader, SalesOrderID, int\nSales.SalesOrderHeader, RevisionNumber, tinyint\nSales.SalesOrderHeader, OrderDate, datetime\nSales.SalesOrderHeader, DueDate, datetime\nSales.SalesOrderHeader, ShipDate, datetime\nSales.SalesOrderHeader, Status, tinyint\nSales.SalesOrderHeader, OnlineOrderFlag, bit\nSales.SalesOrderHeader, SalesOrderNumber, nvarchar\nSales.SalesOrderHeader, PurchaseOrderNumber, nvarchar\nSales.SalesOrderHeader, AccountNumber, nvarchar\nSales.SalesOrderHeader, CustomerID, int\nSales.SalesOrderHeader, SalesPersonID, int\nSales.SalesOrderHeader, TerritoryID, int\nSales.SalesOrderHeader, BillToAddressID, int\nSales.SalesOrderHeader, ShipToAddressID, int\nSales.SalesOrderHeader, ShipMethodID, int\nSales.SalesOrderHeader, CreditCardID, int\nSales.SalesOrderHeader, CreditCardApprovalCode, varchar\nSales.SalesOrderHeader, CurrencyRateID, int\nSales.SalesOrderHeader, SubTotal, money\nSales.SalesOrderHeader, TaxAmt, money\nSales.SalesOrderHeader, Freight, money\nSales.SalesOrderHeader, TotalDue, money\nSales.SalesOrderHeader, Comment, nvarchar\nSales.SalesOrderHeader, rowguid, uniqueidentifier\nSales.SalesOrderHeader, ModifiedDate, datetime"}, 18 | {"role": "assistant", "content": "{\"recipient\":\"SERVER\", \"action\":\"QUERY\", \"message\":\"SELECT COUNT(*) AS COUNT FROM Sales.SalesOrderHeader sh JOIN Person.Address a ON sh.ShipToAddressID = a.AddressID WHERE a.City = 'Bellflower'\"}"}, 19 | {"role": "user", "content": "COUNT\n243"}, 20 | {"role": "assistant", "content": "{\"recipient\": \"USER\", \"message\":\"243 orders were shipped to Bellflower.\"}."}, 21 | {"role": "user", "content": "excellent, thank you! Now the next question. How many orders did we ship to Milwaukee?"}, 22 | {"role": "assistant", "content": "{\"recipient\":\"SERVER\", \"action\":\"QUERY\", \"message\":\"SELECT COUNT(*) AS COUNT FROM Sales.SalesOrderHeader sh JOIN Person.Address a ON sh.ShipToAddressID = a.AddressID WHERE a.City = 'Milwaukee'\"}"}, 23 | {"role": "user", "content": "COUNT\n4"}, 24 | {"role": "assistant", "content": "{\"recipient\": \"USER\", \"message\":\"4 orders were shipped to Milwaukee.\"}."}, 25 | {"role": "user", "content": "great, thank you! Now the next question. How many orders did we ship to Chicago?"}, 26 | # The below messages will be generated by ChatGPT based on the above. Still keeping them in to reinforce the idea, not sure if useful(?). 27 | {"role": "assistant", "content": "{\"recipient\":\"SERVER\", \"action\":\"QUERY\", \"message\":\"SELECT COUNT(*) AS COUNT FROM Sales.SalesOrderHeader sh JOIN Person.Address a ON sh.ShipToAddressID = a.AddressID WHERE a.City = 'Chicago'\"}", "role": "assistant"}, 28 | {"role": "user", "content": "COUNT\n30"}, 29 | {"role": "assistant", "content": "{\"recipient\": \"USER\", \"message\":\"30 orders were shipped to Chicago.\"}."}, 30 | {"role": "user", "content": "What was the top selling product in 2014 by quantity?"}, 31 | {"role": "assistant", "content": "{\"recipient\":\"SERVER\", \"action\":\"SCHEMA\", \"message\":\"Production.Product, Sales.SalesOrderDetail\"}"}, 32 | {"role": "user", "content": "Production.Product, ProductID, int\nProduction.Product, Name, nvarchar\nProduction.Product, ProductNumber, nvarchar\nProduction.Product, MakeFlag, bit\nProduction.Product, FinishedGoodsFlag, bit\nProduction.Product, Color, nvarchar\nProduction.Product, SafetyStockLevel, smallint\nProduction.Product, ReorderPoint, smallint\nProduction.Product, StandardCost, money\nProduction.Product, ListPrice, money\nProduction.Product, Size, nvarchar\nProduction.Product, SizeUnitMeasureCode, nchar\nProduction.Product, WeightUnitMeasureCode, nchar\nProduction.Product, Weight, decimal\nProduction.Product, DaysToManufacture, int\nProduction.Product, ProductLine, nchar\nProduction.Product, Class, nchar\nProduction.Product, Style, nchar\nProduction.Product, ProductSubcategoryID, int\nProduction.Product, ProductModelID, int\nProduction.Product, SellStartDate, datetime\nProduction.Product, SellEndDate, datetime\nProduction.Product, DiscontinuedDate, datetime\nProduction.Product, rowguid, uniqueidentifier\nProduction.Product, ModifiedDate, datetime\nSales.SalesOrderDetail, SalesOrderID, int\nSales.SalesOrderDetail, SalesOrderDetailID, int\nSales.SalesOrderDetail, CarrierTrackingNumber, nvarchar\nSales.SalesOrderDetail, OrderQty, smallint\nSales.SalesOrderDetail, ProductID, int\nSales.SalesOrderDetail, SpecialOfferID, int\nSales.SalesOrderDetail, UnitPrice, money\nSales.SalesOrderDetail, UnitPriceDiscount, money\nSales.SalesOrderDetail, LineTotal, numeric\nSales.SalesOrderDetail, rowguid, uniqueidentifier\nSales.SalesOrderDetail, ModifiedDate, datetime"}, 33 | {"role": "assistant", "content": "{\"recipient\":\"SERVER\", \"action\":\"QUERY\", \"message\":\"SELECT TOP 1 Production.Product.Name FROM Production.Product JOIN Sales.SalesOrderDetail ON Product.ProductID = SalesOrderDetail.ProductID JOIN Sales.SalesOrderHeader ON SalesOrderDetail.SalesOrderID = SalesOrderHeader.SalesOrderID WHERE YEAR(SalesOrderHeader.OrderDate) = 2014 GROUP BY Product.Name ORDER BY SUM(OrderQty) DESC;"}, 34 | {"role": "user", "content": "Name\nWater Bottle - 30 oz."}, 35 | {"role": "assistant", "content": "{\"recipient\": \"USER\", \"message\":\"The top selling product by quantity in 2014 is the Water Bottle - 30 oz.\"}."} 36 | ] 37 | 38 | def __init__(self, api_key, api_org = "", model = "gpt-3.5-turbo"): 39 | if api_org: 40 | openai.api_key 41 | openai.api_key = api_key 42 | self.model = model 43 | self.messages = self.startMessageStack.copy() 44 | 45 | def message(self, message, sender): 46 | logging.debug(message) 47 | if (sender): 48 | message = json.dumps({'message':message, 'sender':sender}) 49 | self.messages.append({"role": "user", "content": message}) 50 | completion = openai.ChatCompletion.create( 51 | model=self.model, 52 | messages=self.messages 53 | ) 54 | response = completion.choices[0].message.content 55 | logging.debug(response) 56 | self.messages.append({"role": "assistant", "content": response}) 57 | return response 58 | 59 | def reset(self): 60 | self.messages = self.startMessageStack.copy() 61 | print('model was reset to intial state') 62 | 63 | --------------------------------------------------------------------------------