├── .gitattributes ├── .github └── workflows │ └── main.yml ├── .gitignore ├── README.md ├── api ├── .env.example ├── .gitignore ├── README.md ├── app │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ ├── chat_gpt_plugin.py │ │ ├── chronjobs │ │ │ └── test_homepage_queries.py │ │ ├── discoverability_routes.py │ │ ├── routes.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── cached_queries │ │ │ └── featured_queries.py │ │ │ ├── caesar_logging.py │ │ │ ├── classification │ │ │ └── input_classification.py │ │ │ ├── few_shot_examples.py │ │ │ ├── geo_data.py │ │ │ ├── logging │ │ │ └── sentry.py │ │ │ ├── messages.py │ │ │ ├── sql_explanation │ │ │ ├── __init__.py │ │ │ └── sql_explanation.py │ │ │ ├── sql_gen │ │ │ ├── __init__.py │ │ │ ├── prompts.py │ │ │ ├── sql_helper.py │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ └── test_txt_to_sql.py │ │ │ ├── text_to_sql.py │ │ │ └── text_to_sql_chat.py │ │ │ ├── suggestions │ │ │ ├── __init__.py │ │ │ └── suggestions.py │ │ │ └── table_selection │ │ │ ├── __init__.py │ │ │ ├── table_details.py │ │ │ └── table_selection.py │ ├── config.py │ ├── data │ │ ├── city_lat_lon.json │ │ ├── few_shot_examples.json │ │ ├── sf_analysis_neighborhoods.json │ │ ├── sf_neighborhoods.json │ │ ├── sf_tables.json │ │ ├── tables.json │ │ ├── tables_many.json │ │ └── zip_lat_lon.json │ └── extensions.py ├── discordbot │ ├── bot.py │ └── responses.py ├── requirements.txt └── scripts │ ├── dev.sh │ └── setup.sh ├── byod ├── .gitignore ├── README.md ├── api │ ├── .gitignore │ ├── README.md │ ├── app │ │ ├── .env.example │ │ ├── __init__.py │ │ ├── config.py │ │ ├── extensions.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── in_context_examples.py │ │ │ ├── json │ │ │ │ ├── in_context_examples.json │ │ │ │ ├── table_metadata.json │ │ │ │ └── type_metadata.json │ │ │ ├── table_metadata.py │ │ │ └── type_metadata.py │ │ ├── setup │ │ │ ├── __init__.py │ │ │ ├── routes.py │ │ │ └── utils.py │ │ ├── sql_explanation │ │ │ ├── __init__.py │ │ │ ├── routes.py │ │ │ └── utils.py │ │ ├── sql_generation │ │ │ ├── __init__.py │ │ │ ├── routes.py │ │ │ └── utils.py │ │ ├── table_selection │ │ │ ├── __init__.py │ │ │ ├── routes.py │ │ │ └── utils.py │ │ ├── utils.py │ │ └── visualization │ │ │ ├── __init__.py │ │ │ ├── routes.py │ │ │ └── utils.py │ ├── requirements.txt │ └── scripts │ │ ├── dev.sh │ │ └── setup.sh └── client │ ├── .env.example │ ├── .gitignore │ ├── README.md │ ├── app.py │ ├── config.py │ ├── requirements.txt │ └── scripts │ ├── dev.sh │ └── setup.sh ├── client └── censusGPT │ ├── .env.production │ ├── .eslintignore │ ├── .eslintrc.json │ ├── .gitignore │ ├── .prettierignore │ ├── .prettierrc.json │ ├── README.md │ ├── package-lock.json │ ├── package.json │ ├── pnpm-lock.yaml │ ├── postcss.config.js │ ├── public │ ├── favicon.ico │ ├── index.html │ ├── logo192.png │ ├── logo512.png │ ├── manifest.json │ ├── mapbox-sample.png │ ├── official_logo.png │ └── robots.txt │ ├── src │ ├── App.js │ ├── SanFrancisco.js │ ├── components │ │ ├── banner.js │ │ ├── dataPlot.js │ │ ├── disclaimer.js │ │ ├── error.js │ │ ├── exampleCard.js │ │ ├── examples.js │ │ ├── examplesFeed.js │ │ ├── explanationModal.js │ │ ├── header.js │ │ ├── headerButtons.js │ │ ├── loadingSpinner.js │ │ ├── results │ │ │ ├── dataVisualization.js │ │ │ ├── resultsContainer.js │ │ │ └── sqlDisplay.js │ │ ├── searchBar.js │ │ ├── suggestion.js │ │ ├── table.js │ │ ├── toast.js │ │ └── vizSelector.js │ ├── contexts │ │ └── feedContext.js │ ├── css │ │ ├── App.css │ │ ├── index.css │ │ └── mapbox-gl.css │ ├── featureFlags.js │ ├── index.js │ ├── logo.svg │ ├── misc │ │ ├── privacy.js │ │ └── tos.js │ ├── reportWebVitals.js │ ├── setupTests.js │ ├── utils │ │ ├── loggers │ │ │ ├── posthog.js │ │ │ └── sentry.js │ │ ├── mapbox-ui-config.js │ │ ├── plotly-ui-config.js │ │ ├── sf_analysis_neighborhoods.js │ │ ├── user.js │ │ └── utils.js │ └── vitals.js │ ├── tailwind.config.js │ └── yarn.lock ├── data └── README.md └── license.md /.gitattributes: -------------------------------------------------------------------------------- 1 | api/app/data/usa_zip_code_boundaries.json filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: JS Pipeline 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | jobs: 7 | test: 8 | name: Check the source code 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Install packages 13 | run: npm ci 14 | working-directory: ./client/censusGPT 15 | - name: Prettier 16 | run: npm run format:fix 17 | working-directory: ./client/censusGPT 18 | - name: Lint 19 | run: npm run lint:fix 20 | working-directory: ./client/censusGPT 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /client/censusGPT/node_modules 2 | .idea 3 | .DS_Store 4 | client/discord/hoops/.env -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 🚨 Check out the latest project from the creators of textSQL: [Julius.ai](https://julius.ai?utm_source=github&utm_campaign=textSQL) 🚨 2 | 3 | ### 4 | 5 | # Natural Language → SQL 6 | 7 | ### 8 | 9 | :bridge_at_night: Demo on San Francisco City Data: [SanFranciscoGPT.com](http://sanfranciscogpt.com) 10 | 11 | :us: Demo on US Census Data: [CensusGPT.com](https://censusgpt.com) 12 | 13 | 14 |

15 | SanFranciscoGPT • 16 | CensusGPT • 17 | Join the Discord Server 18 |

19 | 20 | 21 |

22 | PRs Welcome 23 | 24 | Github Stars 25 | License 26 | GitHub commit activity 27 |

28 | 29 | Welcome to textSQL, a project which uses LLMs to democratize access to data analysis. Example use cases of textSQL are San Francisco GPT and CensusGPT — natural language interfaces to public data (SF city data and US census data), enabling anyone to analyze and gain insights from the data. 30 | 31 | Screenshot 2023-03-10 at 12 55 44 AM 32 | 33 | ## :thinking: How it works: 34 | With CensusGPT, you can ask any question related to census data in natural language. 35 | 36 | These natural language questions get converted to SQL using GPT-3.5 and are then used to query the census database. 37 | 38 | Here are some examples: 39 | 40 | * [🔍 Five cities with a population over 100,000 and lowest crime](https://censusgpt.com/?s=five%20cities%20with%20a%20population%20over%20100%2C000%20and%20lowest%20crime) 41 | * [🔍 10 highest income areas in california](https://censusgpt.com/?s=10%20highest%20income%20areas%20in%20california) 42 | 43 | Here is a similar example from sfGPT: 44 | 45 | * [🔍 Which four neighborhoods had the most crime in San Francisco in 2021?](https://censusgpt.com/sf?s=Which+four+neighborhoods+had+the+most+crime+in+San+Francisco+in+2021%3F) 46 | 47 | 48 | #### Diagram: 49 | 50 | ![TextSQL diagram](https://raw.githubusercontent.com/zafileo23/textSQL/zafileo23-patch-2/TextSQL.svg) 51 | 52 | 53 | ## :world_map: Roadmap: 54 | 55 | We're splitting the roadmap for this project broadly into two categories: 56 | 57 | 58 | ### 1. Visualizations: 59 | 60 | Currently, textSQL only supports visualizing zip codes and cities on an interactive map and bar chart using [Mapbox](https://www.mapbox.com/) + [Plotly](https://plotly.com/). But data can be visualized in other interesting ways such as Heatmaps and Pie charts. Not every kind of data can be (or should be) visualized on a map. For example, a query like _"What percent of total crime in San Francisco is burglary vs in New York City"_ is perfect for visualizing as a stacked bar chart, but really hard to visualize on map. 61 | 62 | Bar Chart: 63 | 64 | Top 5 richest cities in Washington 65 | 66 | [coming soon] Heatmap: 67 | 68 | Screenshot 2023-03-10 at 12 58 33 AM 69 | 70 | [coming soon] Visualization-GPT: A way to use natural language to create and iterate on data visualizations in natural language through a text-to-vega engine. 71 | 72 | ### 2. 🔌 Text-to-SQL BYOD (Bring Your Own Data) [here](https://github.com/caesarHQ/textSQL/tree/main/byod) 73 | 74 | 75 | You can now connect your own database & datasets to textSQL and self-host the service. Our vision is to continue to modularize and improve this process. 76 | 77 | #### Use cases 78 | 79 | - Public-facing interactive interfaces for data. Democratizing public data 80 | - Empowering researchers. Enabling journalists and other researchers to more easily explore data 81 | - Business intelligence. Reducing the burden on technical employees to build & run queries for non-technical 82 | 83 | 84 | Setup instructions for BYOD are [here](https://github.com/caesarHQ/textSQL/tree/main/byod). 85 | 86 | 87 | ## :pencil: Additional Notes 88 | 89 | #### Datasets: 90 | 91 | A lot of the users of this project have asked for additional data for both CensusGPT and sfGPT — historical census data (trends), weather, health, transportation and real-estate data. Feel free to create a pull request, drop a link to your dataset in our [Discord](https://discord.gg/JZtxhZQQus), or contribute data via our [dedicated submission form](https://airtable.com/shrDKRRGyRCihWEZd). 92 | 93 | More data → Better CensusGPT and sfGPT 94 | 95 | #### Query Building: 96 | 97 | Users build complex queries progressively. They start with a simple query like _"Which neighborhoods in LA have the best schools?"_ and then progressively add details like _"with median income that is under $100,000"_. One of the most powerful aspects of textSQL is enabling iterating on a query as a process of uncovering insights. 98 | 99 | ### 100 | 101 | ## :computer: How to Contribute: 102 | 103 | Join our [discord](https://discord.gg/JZtxhZQQus) 104 | 105 | ReadMe for the backend [here](https://github.com/caesarHQ/textSQL/blob/main/api/README.md) 106 | 107 | ReadMe for the frontend [here](https://github.com/caesarHQ/textSQL/blob/main/client/censusGPT/README.md) 108 | 109 | 110 | 111 | 112 | 113 | ### 114 | 115 | **Note:** Census data, like any other dataset, has its limitations and potential biases. Some data may not be collected or reported uniformly across different regions or time periods, which can affect the comparability of results. Users should keep these limitations in mind when interpreting the results of their queries and exercise caution when making decisions based on census data. 116 | -------------------------------------------------------------------------------- /api/.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_KEY="" 2 | PINECONE_KEY="" 3 | PINECONE_ENV="" 4 | DB_URL="postgresql://census_data_user:3PjePE3hVzm2m2UFPywLTLfIiC6w28HB@dpg-cg73gvhmbg5ab7mrk8qg-b.replica-cyan.oregon-postgres.render.com/census_data_w0ix" -------------------------------------------------------------------------------- /api/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Scratch 132 | scratch/ 133 | 134 | settings.json -------------------------------------------------------------------------------- /api/README.md: -------------------------------------------------------------------------------- 1 | # API for textSQL 2 | 3 | ## Prerequisites 4 | - `python3.10` 5 | 6 | ## Required configuration for development: 7 | - OpenAI Key 8 | - URL to the postgres DB (Read-only URL provided in `.env.example`) 9 | 10 | Make a copy of `.env.example`, rename it to `.env`, and configure the above variables. 11 | 12 | ## Local development 13 | 14 | Initial setup 15 | ```sh 16 | $ ./scripts/setup.sh 17 | ``` 18 | 19 | Activate virtual env 20 | ```sh 21 | $ source ./venv/bin/activate 22 | ``` 23 | 24 | Run local instance 25 | ```sh 26 | $ ./scripts/dev.sh 27 | ``` 28 | 29 | ### Test case queries to test out prompt or other API related changes: 30 | 31 | ``` 32 | - Three highest income zip codes in {City} 33 | - Five zip codes in {State} with the highest income and hispanic population of at least 10,000 34 | - Which 3 zip codes in {City} have the highest female to male ratio? 35 | - Which zip code in {City} has the most racial diversity and what is the racial distribution? 36 | - 10 highest crime cities in {State} 37 | - Which 20 zip codes in {State} have median income that's closes to the national median income? 38 | - Which zip code has a median income that is closest to the national median income? 39 | ``` 40 | -------------------------------------------------------------------------------- /api/app/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, jsonify, make_response 2 | from flask_admin import Admin 3 | from flask_cors import CORS 4 | from flask_migrate import Migrate 5 | 6 | from app.api.routes import bp as api_bp 7 | from app.api.discoverability_routes import discoverability 8 | from app.config import FlaskAppConfig, ENV 9 | from app.extensions import db 10 | 11 | from app.api.chat_gpt_plugin import plugin, plugin_config 12 | 13 | 14 | def create_app(config_object=FlaskAppConfig): 15 | app = Flask(__name__) 16 | app.config.from_object(config_object) 17 | CORS(app) 18 | 19 | # Initialize app with extensions 20 | db.init_app(app) 21 | # migrate = Migrate(app, db) 22 | with app.app_context(): 23 | db.create_all() 24 | admin = Admin(None, name='admin', template_mode='bootstrap3') 25 | admin.init_app(app) 26 | 27 | @app.route("/ping") 28 | def ping(): 29 | return 'pong' 30 | 31 | 32 | app.register_blueprint(api_bp, url_prefix='/api') 33 | app.register_blueprint(plugin, url_prefix='/plugin') 34 | app.register_blueprint(discoverability, url_prefix='/examples') 35 | app.register_blueprint(plugin_config) 36 | 37 | # from app.errors import bp as errors_bp 38 | # app.register_blueprint(errors_bp) 39 | 40 | # from app.main import bp as main_bp 41 | # app.register_blueprint(main_bp) 42 | 43 | return app -------------------------------------------------------------------------------- /api/app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/__init__.py -------------------------------------------------------------------------------- /api/app/api/chat_gpt_plugin.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from flask import Blueprint, jsonify, request 3 | 4 | plugin = Blueprint('plugin', __name__) 5 | 6 | @plugin.route('get_census_data', methods=['GET']) 7 | def get_census_data_get(): 8 | """ 9 | Get census data based on the question 10 | """ 11 | 12 | question = request.args.get('question') 13 | if not question: 14 | return jsonify({ 15 | "error": "question is missing from the request args" 16 | }) 17 | 18 | raw_table_json = requests.post('https://text-sql-be.onrender.com/api/get_tables', json={"natural_language_query": question}) 19 | 20 | print('raw table: ', raw_table_json.json()) 21 | 22 | newJson = { 23 | **raw_table_json.json(), 24 | "natural_language_query": question, 25 | } 26 | 27 | print('new json: ', newJson) 28 | 29 | final_res = requests.post('https://text-sql-be.onrender.com/api/text_to_sql', json=newJson) 30 | 31 | print('final res: ', final_res.json()) 32 | 33 | parsed = final_res.json() 34 | 35 | resultsData = parsed['result']['results'] 36 | sqlData = parsed['sql_query'] 37 | 38 | return jsonify({ 39 | "answer": resultsData, 40 | "sql_query": sqlData, 41 | }) 42 | 43 | 44 | plugin_config = Blueprint('plugin_config', __name__) 45 | @plugin_config.route("/.well-known/ai-plugin.json") 46 | def openapi_json(): 47 | current_domain = request.host_url 48 | print('request to :', current_domain, '\n') 49 | return jsonify({ 50 | "schema_version": "v1", 51 | "name_for_human": "censusGPT", 52 | "name_for_model": "census_data_and_sql_queries", 53 | "description_for_human": "censusGPT", 54 | "description_for_model": "CensusGPT provides information derived from the 2020 US Census. The data is provided as a chart along with the SQL query used to calculate it.", 55 | "auth": { 56 | "type": "none" 57 | }, 58 | "api": { 59 | "type": "openapi", 60 | "url": f"{current_domain}openapi.yaml", 61 | "is_user_authenticated": False 62 | }, 63 | "logo_url":"https://censusgpt.com/logo192.png", 64 | "contact_email": "rahul@caesarhq.com", 65 | "legal_info_url": "https://censusgpt.com/privacy" 66 | }) 67 | 68 | @plugin_config.route("/openapi.yaml") 69 | def openapi_yaml(): 70 | current_domain = request.host_url 71 | return f"""components: 72 | schemas: 73 | CensusResponse: 74 | properties: 75 | answer: 76 | description: the answer to the question 77 | type: string 78 | sql_query: 79 | description: the sql query that generated the answer 80 | type: string 81 | type: object 82 | info: 83 | description: Data from the US Census 84 | title: Census Data 85 | version: '1.0' 86 | openapi: 3.0.2 87 | paths: 88 | /plugin/get_census_data: 89 | get: 90 | consumes: 91 | - application/json 92 | description: Provide census data for questions about populations and other location 93 | information. This provides data from the 2020 census. 94 | operationId: getCensusData 95 | parameters: 96 | - description: the question to ask the census data 97 | in: query 98 | name: question 99 | required: false 100 | schema: 101 | type: string 102 | responses: 103 | default: 104 | description: '' 105 | schema: 106 | $ref: '#/components/schemas/CensusResponse' 107 | tags: 108 | - Census Data 109 | options: 110 | consumes: 111 | - application/json 112 | description: Provide census data for questions about populations and other location 113 | information. This provides data from the 2020 census. 114 | operationId: getCensusData 115 | parameters: 116 | - description: the question to ask the census data 117 | in: query 118 | name: question 119 | required: false 120 | schema: 121 | type: string 122 | responses: 123 | default: 124 | description: '' 125 | schema: 126 | $ref: '#/components/schemas/CensusResponse' 127 | tags: 128 | - Census Data 129 | servers: 130 | - url: {current_domain}""" -------------------------------------------------------------------------------- /api/app/api/chronjobs/test_homepage_queries.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import getenv 3 | from dotenv import load_dotenv 4 | from sqlalchemy import create_engine 5 | import requests 6 | from sqlalchemy import text 7 | 8 | load_dotenv() 9 | 10 | EVENTS_URL = getenv("EVENTS_URL") 11 | 12 | endpoint = 'https://text-sql-be2.onrender.com' 13 | 14 | def queryTextToTables(payload): 15 | headers = {'Content-Type': 'application/json'} 16 | res = requests.post(endpoint + '/api/get_tables', json=payload, headers=headers) 17 | return res.json() 18 | 19 | def queryTextToSQL(payload): 20 | headers = {'Content-Type': 'application/json'} 21 | res = requests.post(endpoint + '/api/text_to_sql', json=payload, headers=headers) 22 | return res.json() 23 | 24 | scope = ['SF'] 25 | 26 | good = [] 27 | bad = [] 28 | attempted = [] 29 | results = [] 30 | 31 | # go thru and test that each of the generate tables -> generate SQL works 32 | def testQueryWorks(query, scope): 33 | global good, bad, attempted, results 34 | print('trying ', query) 35 | attempted.append(query) 36 | try: 37 | payload = { 38 | "natural_language_query": query, 39 | "scope": scope 40 | } 41 | res = queryTextToTables(payload) 42 | print('tables: ', res['table_names']) 43 | payload = { 44 | "table_names": res['table_names'], 45 | "natural_language_query": query, 46 | "scope": scope 47 | } 48 | res2 = queryTextToSQL(payload) 49 | print(len(res2['result']['column_names']), ' columns') 50 | print(len(res2['result']['results']), 'rows') 51 | print('SQL query: \n', res2['sql_query']) 52 | print('\n \n---- \n \n') 53 | good.append(query) 54 | results.append({'q': query, 55 | 'columns': len(res2['result']['column_names']), 56 | 'rows': len(res2['result']['results']), 57 | 'sql': res2['sql_query'] 58 | }) 59 | 60 | 61 | except Exception as e: 62 | print('Failure!', str(e)) 63 | bad.append({'q': query, 'e':str(e)}) 64 | 65 | # test the SF homepage queries 66 | if 'SF' in scope: 67 | queries = [ 68 | 'plz Show me all the needles in SF', 69 | 'plz Show me all the muggings', 70 | 'plz Which two neighborhoods have the most homeless activity?', 71 | 'plz Which five neighborhoods have the most poop on the street?', 72 | 'plz Which four neighborhoods had the most crime incidents involving guns or knives in 2021?', 73 | 'plz 3 neighborhoods with the highest female to male ratio', 74 | 'plz What are the top 5 neighborhoods with the most encampments per capita?', 75 | 'plz What hours of the day do most burglaries occur?', 76 | ] 77 | for q in queries: 78 | testQueryWorks(q, "SF") 79 | 80 | print('good: ', len(good)) 81 | print('bad: ', len(bad)) 82 | print('attempted: ', len(attempted)) 83 | print('results: ', results) 84 | 85 | EVENTS_ENGINE = create_engine(EVENTS_URL) 86 | 87 | params = { 88 | 'app_name': 'sf_prod', 89 | 'passed': len(good), 90 | 'failed': len(bad), 91 | 'attempted': len(attempted), 92 | 'percent_passing': 0 if len(attempted) == 0 else len(good)/len(attempted), 93 | 'result_stats': json.dumps(results) 94 | } 95 | 96 | insert_query = text(""" 97 | INSERT INTO health_checks (app_name, passed, failed, attempted, percent_passing, result_stats) 98 | VALUES (:app_name, :passed, :failed, :attempted, :percent_passing, :result_stats)""") 99 | 100 | with EVENTS_ENGINE.connect() as conn: 101 | conn.execute(insert_query, params) 102 | conn.commit() -------------------------------------------------------------------------------- /api/app/api/discoverability_routes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discoverability routes 3 | 4 | This module contains the routes that are used to provide the feed and other discoverability information. 5 | """ 6 | 7 | from flask import Blueprint, jsonify, make_response 8 | from app.config import EVENTS_ENGINE 9 | from app.api.utils.caesar_logging import get_feed_data 10 | 11 | discoverability = Blueprint('discoverability', __name__) 12 | 13 | # discoverability is a get endpoint, takes a /{app} 14 | @discoverability.route('/', methods=['GET']) 15 | def get_discoverability(app): 16 | """ 17 | Get discoverability information for the app 18 | """ 19 | 20 | if not EVENTS_ENGINE: 21 | return make_response(jsonify({ 22 | "success": False, 23 | "error": "Events engine not configured" 24 | }), 200) 25 | 26 | feed_data = get_feed_data(app) 27 | if feed_data: 28 | return make_response(jsonify({ 29 | "success": True, 30 | "examples": feed_data 31 | }), 200) 32 | 33 | return make_response(jsonify({ 34 | "success": False, 35 | "error": "Not implemented" 36 | }), 200) 37 | 38 | -------------------------------------------------------------------------------- /api/app/api/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/utils/__init__.py -------------------------------------------------------------------------------- /api/app/api/utils/cached_queries/featured_queries.py: -------------------------------------------------------------------------------- 1 | from app.config import EVENTS_ENGINE 2 | from sqlalchemy import text 3 | 4 | def get_featured_table(input_str, scope="USA"): 5 | 6 | if not EVENTS_ENGINE: 7 | return False 8 | 9 | print('CHECKING SCOPE', scope) 10 | 11 | params = { 12 | "input_text": input_str, 13 | "scope": scope 14 | } 15 | query = text(""" 16 | SELECT * FROM featured_queries 17 | WHERE input_text = :input_text 18 | AND app = :scope 19 | """) 20 | try: 21 | with EVENTS_ENGINE.connect() as conn: 22 | result = conn.execute(query, params) 23 | conn.commit() 24 | res = result.fetchall() 25 | 26 | except Exception as e: 27 | return False 28 | 29 | if len(res) == 0: 30 | return False 31 | 32 | related_tables = res[0][1] 33 | 34 | return related_tables 35 | 36 | def get_featured_sql(input_str, scope="USA"): 37 | 38 | if not EVENTS_ENGINE: 39 | return False 40 | 41 | params = { 42 | "input_text": input_str, 43 | "scope": scope 44 | } 45 | query = text(""" 46 | SELECT * FROM featured_queries 47 | WHERE input_text ilike :input_text 48 | AND app = :scope 49 | """) 50 | 51 | try: 52 | with EVENTS_ENGINE.connect() as conn: 53 | result = conn.execute(query, params) 54 | conn.commit() 55 | res = result.fetchall() 56 | except Exception as e: 57 | return False 58 | 59 | if len(res) == 0: 60 | return False 61 | 62 | related_sql = res[0][2] 63 | 64 | return related_sql -------------------------------------------------------------------------------- /api/app/api/utils/classification/input_classification.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from app.config import EVENTS_ENGINE 4 | from app.api.utils.messages import call_chat 5 | from app.api.utils.caesar_logging import log_input_classification 6 | 7 | from app.api.utils.table_selection.table_details import get_minimal_table_schemas 8 | 9 | async def create_labels(user_input, scope="USA", parent_id=None, session_id=None) -> bool: 10 | """ 11 | Create labels for the user input 12 | """ 13 | 14 | if not EVENTS_ENGINE: 15 | return None 16 | 17 | table_prefix = get_minimal_table_schemas(scope) 18 | 19 | user_message = f"""The user asked our database for: 20 | ---- 21 | {user_input} 22 | ---- 23 | 24 | Our schema has the following tables (here's parts of the script to create them): 25 | --- 26 | {table_prefix} 27 | --- 28 | 29 | give me a JSON object for classifying it in our database as well as if we have it. The object needs to consist of 30 | {{ 31 | topics: str[], 32 | categories: str[], 33 | locations: str[], 34 | relevant_tables_from_schema: str[], 35 | has_relevant_table: bool, 36 | }} 37 | Thanks! Provide the JSON and only the JSON. Values should be in all lowercase.""" 38 | 39 | messages = [{"role": "user", "content": user_message}] 40 | 41 | assistant_message = call_chat(messages, model="gpt-3.5-turbo", scope=scope, purpose="input_classification", session_id=session_id) 42 | 43 | try: 44 | parsed = json.loads(assistant_message) 45 | except: 46 | parsed = {} 47 | 48 | generation_id = log_input_classification(scope, user_input, parsed, parent_id, session_id) 49 | 50 | # is_relevant_query = parsed.get("has_relevant_table", False) 51 | 52 | return generation_id -------------------------------------------------------------------------------- /api/app/api/utils/few_shot_examples.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | 5 | few_shot_examples = {} 6 | with open("app/data/few_shot_examples.json", "r") as f: 7 | few_shot_examples = json.load(f) 8 | 9 | 10 | def get_few_shot_example_messages(mode: str = "text_to_sql", scope="USA", n=-1) -> List[dict]: 11 | examples = few_shot_examples.get(scope, {}).get(mode, []) 12 | if n > 0: 13 | examples = examples[:n] 14 | if n == 0: 15 | examples = [] 16 | messages = [] 17 | for example in examples: 18 | messages.append({ 19 | "role": "user", 20 | "content": example["user"], 21 | }) 22 | messages.append({ 23 | "role": "assistant", 24 | "content": example["assistant"], 25 | }) 26 | return messages -------------------------------------------------------------------------------- /api/app/api/utils/geo_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | zip_lat_lon = {} 5 | with open("app/data/zip_lat_lon.json", "r") as f: 6 | zip_lat_lon = json.load(f) 7 | 8 | city_lat_lon = {} 9 | with open("app/data/city_lat_lon.json", "r") as f: 10 | city_lat_lon = json.load(f) 11 | 12 | neighborhood_shapes = {} 13 | # with open("app/data/sf_neighborhoods.json", "r") as f: 14 | with open("app/data/sf_analysis_neighborhoods.json", "r") as f: 15 | neighborhood_shapes = json.load(f) -------------------------------------------------------------------------------- /api/app/api/utils/logging/sentry.py: -------------------------------------------------------------------------------- 1 | from sentry_sdk import capture_exception 2 | 3 | from app.config import SENTRY_URL 4 | 5 | def log_sentry_exception(e): 6 | if SENTRY_URL: 7 | capture_exception(e) 8 | -------------------------------------------------------------------------------- /api/app/api/utils/messages.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import time 4 | from typing import List, Dict 5 | 6 | import openai 7 | import tiktoken 8 | 9 | from app.api.utils.caesar_logging import log_apicall 10 | 11 | 12 | def get_assistant_message_from_openai( 13 | messages: List[Dict[str, str]], 14 | temperature: int = 0, 15 | model: str = "gpt-3.5-turbo", 16 | scope: str = "USA", 17 | purpose: str = "Generic", 18 | session_id: str = None, 19 | test_failure: bool = False, 20 | # model: str = "gpt-4", 21 | ): 22 | # alright, it looks like gpt-3.5-turbo is ignoring the user messages in history 23 | # let's go and re-create the chat in the last message! 24 | final_payload = messages 25 | 26 | start = time.time() 27 | try: 28 | if test_failure: 29 | raise Exception("Test failure") 30 | res = openai.ChatCompletion.create( 31 | model=model, 32 | temperature=0, 33 | messages=final_payload 34 | ) 35 | except Exception as e: 36 | duration = time.time() - start 37 | log_apicall( 38 | duration, 39 | 'openai', 40 | model, 41 | 0, 42 | 0, 43 | scope, 44 | purpose, 45 | session_id = session_id, 46 | success=False, 47 | log_message = str(e), 48 | ) 49 | raise e 50 | duration = time.time() - start 51 | 52 | usage = res['usage'] 53 | input_tokens = usage['prompt_tokens'] 54 | output_tokens = usage['completion_tokens'] 55 | 56 | log_apicall( 57 | duration, 58 | 'openai', 59 | model, 60 | input_tokens, 61 | output_tokens, 62 | scope, 63 | purpose, 64 | session_id = session_id, 65 | ) 66 | 67 | # completion = res['choices'][0]["message"]["content"] 68 | assistant_message = res['choices'][0] 69 | 70 | return assistant_message 71 | 72 | def call_chat( 73 | messages: List[Dict[str, str]], 74 | temperature: int = 0, 75 | model: str = "gpt-3.5-turbo", 76 | scope: str = "USA", 77 | purpose: str = "Generic", 78 | session_id: str = None, 79 | # model: str = "gpt-4", 80 | ): 81 | 82 | start = time.time() 83 | try: 84 | res = openai.ChatCompletion.create( 85 | model=model, 86 | temperature=temperature, 87 | messages=messages 88 | ) 89 | except Exception as e: 90 | duration = time.time() - start 91 | log_apicall( 92 | duration, 93 | 'openai', 94 | model, 95 | 0, 96 | 0, 97 | scope, 98 | purpose, 99 | session_id = session_id, 100 | success=False, 101 | log_message = str(e), 102 | ) 103 | raise e 104 | 105 | duration = time.time() - start 106 | 107 | usage = res['usage'] 108 | input_tokens = usage['prompt_tokens'] 109 | output_tokens = usage['completion_tokens'] 110 | 111 | log_apicall( 112 | duration, 113 | 'openai', 114 | model, 115 | input_tokens, 116 | output_tokens, 117 | scope, 118 | purpose, 119 | session_id = session_id, 120 | ) 121 | 122 | # completion = res['choices'][0]["message"]["content"] 123 | assistant_message = res['choices'][0]['message']['content'] 124 | 125 | return assistant_message 126 | 127 | def clean_sql_message_content(assistant_message_content): 128 | """ 129 | Cleans message content to extract the SQL query 130 | """ 131 | # Ignore text after the last SQL query terminator `;` 132 | parts = assistant_message_content.split(";") 133 | assistant_message_content = ";".join(parts[:-1]) 134 | 135 | # Remove prefix for corrected query assistant message 136 | split_corrected_query_message = assistant_message_content.split(":") 137 | if len(split_corrected_query_message) > 1: 138 | sql_query = split_corrected_query_message[1].strip() 139 | else: 140 | sql_query = assistant_message_content 141 | 142 | return sql_query 143 | 144 | 145 | def extract_sql_query_from_message(assistant_message_content): 146 | try: 147 | data = json.loads(assistant_message_content) 148 | except Exception as e: 149 | print('e: ', e) 150 | raise e 151 | 152 | if data.get('MissingData'): 153 | return data 154 | 155 | sql = data['SQL'] 156 | 157 | return {"SQL": sql} 158 | 159 | 160 | def extract_sql_from_markdown(assistant_message_content): 161 | regex = r"```([\s\S]+?)```" 162 | matches = re.findall(regex, assistant_message_content) 163 | 164 | if matches: 165 | code_str = matches[0] 166 | match = re.search(r"(?i)sql\s+(.*)", code_str, re.DOTALL) 167 | if match: 168 | code_str = match.group(1) 169 | else: 170 | code_str = assistant_message_content 171 | 172 | return code_str -------------------------------------------------------------------------------- /api/app/api/utils/sql_explanation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/utils/sql_explanation/__init__.py -------------------------------------------------------------------------------- /api/app/api/utils/sql_explanation/sql_explanation.py: -------------------------------------------------------------------------------- 1 | from ..few_shot_examples import get_few_shot_example_messages 2 | from ..messages import get_assistant_message_from_openai 3 | from app.config import DIALECT 4 | 5 | 6 | def get_message_with_descriptions(): 7 | message = ( 8 | f"Provide a concise explanation for the following {DIALECT}" 9 | " query: ```{sql}```" 10 | ) 11 | return message 12 | 13 | 14 | def get_default_messages(): 15 | default_messages = [{ 16 | "role": "system", 17 | "content": ( 18 | f"You are a helpful assistant for providing an explanation for a {DIALECT} query." 19 | ) 20 | }] 21 | default_messages.extend(get_few_shot_example_messages(mode="sql_explanation")) 22 | return default_messages 23 | 24 | 25 | def get_sql_explanation(sql) -> str: 26 | """ 27 | Use language model to generate explanation of SQL query 28 | """ 29 | content = get_message_with_descriptions().format(sql=sql) 30 | messages = get_default_messages().copy() 31 | messages.append({ 32 | "role": "user", 33 | "content": content 34 | }) 35 | 36 | # model = "gpt-4" 37 | model = "gpt-3.5-turbo" 38 | 39 | assistant_message_content = get_assistant_message_from_openai( 40 | messages=messages, 41 | model=model, 42 | purpose="sql_explanation" 43 | )["message"]["content"] 44 | return assistant_message_content -------------------------------------------------------------------------------- /api/app/api/utils/sql_gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/utils/sql_gen/__init__.py -------------------------------------------------------------------------------- /api/app/api/utils/sql_gen/sql_helper.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import OrderedDict 3 | from sqlalchemy import text 4 | 5 | from app.config import ENGINE 6 | from ..geo_data import city_lat_lon, neighborhood_shapes, zip_lat_lon 7 | 8 | 9 | 10 | 11 | class NotReadOnlyException(Exception): 12 | pass 13 | 14 | 15 | class CityOrCountyWithoutStateException(Exception): 16 | pass 17 | 18 | 19 | class NullValueException(Exception): 20 | pass 21 | 22 | 23 | def execute_sql(sql_query: str): 24 | if not is_read_only_query(sql_query): 25 | raise NotReadOnlyException("Only read-only queries are allowed.") 26 | 27 | with ENGINE.connect() as connection: 28 | connection = connection.execution_options( 29 | postgresql_readonly=True 30 | ) 31 | with connection.begin(): 32 | sql_text = text(sql_query) 33 | result = connection.execute(sql_text) 34 | 35 | column_names = list(result.keys()) 36 | if 'state' not in column_names and any(c in column_names for c in ['city', 'county']): 37 | raise CityOrCountyWithoutStateException("Include `state` in the result table, too.") 38 | 39 | rows = [list(r) for r in result.all()] 40 | 41 | # Add lat and lon to zip_code 42 | zip_code_idx = None 43 | try: 44 | zip_code_idx = column_names.index("zip_code") 45 | except ValueError: 46 | zip_code_idx = None 47 | 48 | if zip_code_idx is not None: 49 | column_names.append("lat") 50 | column_names.append("long") 51 | for row in rows: 52 | zip_code = row[zip_code_idx] 53 | lat = zip_lat_lon.get(zip_code, {}).get('lat') 54 | lon = zip_lat_lon.get(zip_code, {}).get('lon') 55 | row.append(lat) 56 | row.append(lon) 57 | 58 | # No zip_code lat lon, so try to get city lat lon 59 | else: 60 | # Add lat and lon to city 61 | city_idx = None 62 | state_idx = None 63 | try: 64 | city_idx = column_names.index("city") 65 | state_idx = column_names.index("state") 66 | except ValueError: 67 | city_idx = None 68 | state_idx = None 69 | 70 | if city_idx is not None and state_idx is not None: 71 | column_names.append("lat") 72 | column_names.append("long") 73 | for row in rows: 74 | city = row[city_idx] 75 | state = row[state_idx] 76 | lat = city_lat_lon.get(state, {}).get(city, {}).get('lat') 77 | lon = city_lat_lon.get(state, {}).get(city, {}).get('lon') 78 | 79 | if "St." in city: 80 | new_city = city.replace("St.", "Saint") 81 | lat = city_lat_lon.get(state, {}).get(new_city, {}).get('lat') 82 | lon = city_lat_lon.get(state, {}).get(new_city, {}).get('lon') 83 | 84 | row.append(lat) 85 | row.append(lon) 86 | 87 | results = [] 88 | for row in rows: 89 | result = OrderedDict() 90 | for i, column_name in enumerate(column_names): 91 | result[column_name] = row[i] 92 | results.append(result) 93 | 94 | return { 95 | 'column_names': column_names, 96 | 'results': results, 97 | } 98 | 99 | 100 | def is_read_only_query(sql_query: str): 101 | """ 102 | Checks if the given SQL query string is read-only. 103 | Returns True if the query is read-only, False otherwise. 104 | """ 105 | # List of SQL statements that modify data in the database 106 | modifying_statements = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "GRANT", "TRUNCATE", "LOCK TABLES", "UNLOCK TABLES"] 107 | 108 | # Check if the query contains any modifying statements 109 | for statement in modifying_statements: 110 | if not sql_query or statement in sql_query.upper(): 111 | return False 112 | 113 | # If no modifying statements are found, the query is read-only 114 | return True 115 | -------------------------------------------------------------------------------- /api/app/api/utils/sql_gen/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/utils/sql_gen/tests/__init__.py -------------------------------------------------------------------------------- /api/app/api/utils/sql_gen/tests/test_txt_to_sql.py: -------------------------------------------------------------------------------- 1 | from typing import Union, OrderedDict, Any, List, Tuple, Callable, Dict 2 | 3 | import pytest 4 | 5 | from ...table_selection.table_selection import get_relevant_tables 6 | 7 | Res = Dict[str, Union[List[OrderedDict[str, Any]], List[str]]] 8 | 9 | inputs: List[Tuple[str, Callable[[str, Res], str], int]] = [ 10 | ( 11 | "What are the three highest income zip codes in San Jose", 12 | [lambda result: any(sub_result.get('zip_code', None) == '95113' for sub_result in result),lambda result: all(sub_result.get('zip_code', None) != '94105' for sub_result in result)], 13 | ), 14 | ( 15 | "10 highest crime cities in California", 16 | [lambda result: any(sub_result.get('city', '') == 'Los Angeles' for sub_result in result), 17 | lambda result: any(sub_result.get('city', '') != 'Los Gatos' for sub_result in result)] 18 | ) 19 | ] 20 | -------------------------------------------------------------------------------- /api/app/api/utils/sql_gen/text_to_sql.py: -------------------------------------------------------------------------------- 1 | from app.config import DIALECT 2 | import tiktoken 3 | from ..few_shot_examples import get_few_shot_example_messages 4 | from ..messages import extract_sql_query_from_message, get_assistant_message_from_openai 5 | from ..table_selection.table_details import get_table_schemas, get_table_and_enums 6 | from .prompts import get_retry_prompt 7 | from ..caesar_logging import log_sql_failure 8 | from .sql_helper import execute_sql 9 | 10 | MSG_WITH_ERROR_TRY_AGAIN = ( 11 | "Try again. " 12 | f"Only respond with valid {DIALECT}. Write your answer in JSON. " 13 | f"The {DIALECT} query you just generated resulted in the following error message:\n" 14 | "{error_message}" 15 | "Check the table schema and ensure that the columns for the table exist and will provide the expected results." 16 | ) 17 | 18 | def make_default_messages(schemas: str, scope="USA", n=-1): 19 | default_messages = [] 20 | 21 | default_messages.extend(get_few_shot_example_messages(mode="text_to_sql", scope=scope, n=n)) 22 | return default_messages 23 | 24 | 25 | def make_rephrase_msg_with_schema_and_warnings(): 26 | return ( 27 | "Let's start by rephrasing the query to be more analytical. Use the schema context to rephrase the user question in a way that leads to optimal query results: {natural_language_query}" 28 | "The following are schemas of tables you can query:\n" 29 | "---------------------\n" 30 | "{schemas}" 31 | "\n" 32 | "---------------------\n" 33 | "Do not include any of the table names in the query." 34 | " Ask the natural language query the way a data analyst, with knowledge of these tables, would." 35 | ) 36 | 37 | def text_to_sql_with_retry(natural_language_query, table_names, k=3, messages=None, scope="USA", session_id=None): 38 | """ 39 | Tries to take a natural language query and generate valid SQL to answer it K times 40 | """ 41 | if scope == "SF": 42 | model = "gpt-3.5-turbo" 43 | else: 44 | model = "gpt-3.5-turbo" 45 | 46 | example_messages = [] 47 | enums_message = [{'role': 'user', 'content': ''}] 48 | schema_message = [{'role': 'user', 'content': ''}] 49 | message_history = [] 50 | 51 | if not messages: 52 | table_text, enum_text = get_table_and_enums(table_names, scope) 53 | 54 | schema_message[0]['content'] = table_text 55 | enums_message[0]['content'] = enum_text 56 | 57 | enc = tiktoken.encoding_for_model(model) 58 | 59 | instruction_length = len(enc.encode(table_text + '\n\n' + enum_text)) 60 | 61 | content = get_retry_prompt(DIALECT, natural_language_query, scope) 62 | 63 | max_messages = -1 64 | if instruction_length > 1000: 65 | max_messages = 3 66 | elif instruction_length > 1500: 67 | max_messages = 2 68 | elif instruction_length > 2000: 69 | max_messages = 1 70 | 71 | example_messages = make_default_messages('', scope, n=max_messages) 72 | message_history.append({ 73 | "role": "user", 74 | "content": content 75 | }) 76 | 77 | assistant_message = None 78 | sql_query = "" 79 | for attempt_number in range(k): 80 | sql_query_data = {} 81 | try: 82 | purpose = "text_to_sql" if attempt_number == 0 else "text_to_sql_retry" 83 | try: 84 | payload = schema_message + message_history 85 | if (attempt_number == 0): 86 | payload = example_messages + enums_message + payload 87 | assistant_message = get_assistant_message_from_openai(payload, model=model, scope=scope, purpose=purpose, session_id=session_id) 88 | except: 89 | continue 90 | 91 | sql_query_data = extract_sql_query_from_message(assistant_message["message"]["content"]) 92 | 93 | if sql_query_data.get('MissingData'): 94 | return {"MissingData": sql_query_data['MissingData']}, "" 95 | 96 | sql_query = sql_query_data["SQL"] 97 | 98 | response = execute_sql(sql_query) 99 | # Generated SQL query did not produce exception. Return result 100 | return response, sql_query 101 | 102 | except Exception as e: 103 | 104 | log_sql_failure(natural_language_query, sql_query_data.get('SQL', ""), str(e), attempt_number, scope, session_id=session_id) 105 | 106 | message_history.append({ 107 | "role": "assistant", 108 | "content": assistant_message["message"]["content"] 109 | }) 110 | message_history.append({ 111 | "role": "user", 112 | "content": MSG_WITH_ERROR_TRY_AGAIN.format(error_message=str(e)) 113 | }) 114 | 115 | print(f"Could not generate {DIALECT} query after {k} tries.") 116 | return None, sql_query 117 | 118 | 119 | def use_cached_sql(sql): 120 | return execute_sql(sql) -------------------------------------------------------------------------------- /api/app/api/utils/sql_gen/text_to_sql_chat.py: -------------------------------------------------------------------------------- 1 | from app.config import DIALECT 2 | 3 | 4 | from ..few_shot_examples import get_few_shot_example_messages 5 | from ..messages import get_assistant_message_from_openai 6 | from ..table_selection.table_details import get_table_schemas 7 | from .prompts import get_retry_prompt 8 | from .text_to_sql import text_to_sql_with_retry 9 | 10 | MSG_WITH_ERROR_TRY_AGAIN = ( 11 | "Try again. " 12 | f"Only respond with valid {DIALECT}. Write your answer in markdown format. " 13 | f"The {DIALECT} query you just generated resulted in the following error message:\n" 14 | "{error_message}" 15 | ) 16 | 17 | def make_default_messages(schemas: str, scope="USA"): 18 | default_messages = [] 19 | 20 | default_messages.extend(get_few_shot_example_messages(mode="text_to_sql", scope=scope)) 21 | return default_messages 22 | 23 | 24 | def make_rephrase_msg_with_schema_and_warnings(): 25 | return ( 26 | "Let's start by rephrasing the query to be more analytical. Use the schema context to rephrase the user question in a way that leads to optimal query results: {natural_language_query}" 27 | "The following are schemas of tables you can query:\n" 28 | "---------------------\n" 29 | "{schemas}" 30 | "\n" 31 | "---------------------\n" 32 | "Do not include any of the table names in the query." 33 | " Ask the natural language query the way a data analyst, with knowledge of these tables, would." 34 | ) 35 | 36 | class NoMessagesException(Exception): 37 | pass 38 | 39 | class LastMessageNotUserException(Exception): 40 | pass 41 | 42 | 43 | def text_to_sql_chat_with_retry(messages, table_names=None, scope="USA"): 44 | """ 45 | Takes a series of messages and tries to respond to a natural language query with valid SQL 46 | """ 47 | if not messages: 48 | raise NoMessagesException("No messages provided.") 49 | if messages[-1]["role"] != "user": 50 | raise LastMessageNotUserException("Last message is not a user message.") 51 | 52 | # First question, prime with table schemas and rephrasing 53 | natural_language_query = messages[-1]["content"] 54 | # Ask the assistant to rephrase before generating the query 55 | schemas = get_table_schemas(table_names, scope) 56 | rephrase = [{ 57 | "role": "user", 58 | "content": make_rephrase_msg_with_schema_and_warnings().format( 59 | natural_language_query=natural_language_query, 60 | schemas=schemas 61 | ) 62 | }] 63 | rephrased_query = get_assistant_message_from_openai(rephrase)["message"]["content"] 64 | 65 | content = get_retry_prompt(DIALECT, rephrased_query, schemas, scope) 66 | # Don't return messages_copy to the front-end. It contains extra information for prompting 67 | messages_copy = make_default_messages(schemas) 68 | messages_copy.extend(messages) 69 | messages_copy[-1] = { 70 | "role": "user", 71 | "content": content 72 | } 73 | 74 | # Send all messages 75 | response, sql_query = text_to_sql_with_retry(natural_language_query, table_names, k=3, messages=messages_copy, scope=scope) 76 | 77 | if response is None and sql_query is None: 78 | messages.append({ 79 | "role": "assistant", 80 | "content": "Sorry, I wasn't able to answer that. Try rephrasing your question to make it more specific and easier to understand." 81 | }) 82 | 83 | else: 84 | messages.append({ 85 | "role": "assistant", 86 | "content": sql_query 87 | }) 88 | 89 | return response, sql_query, messages -------------------------------------------------------------------------------- /api/app/api/utils/suggestions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/utils/suggestions/__init__.py -------------------------------------------------------------------------------- /api/app/api/utils/suggestions/suggestions.py: -------------------------------------------------------------------------------- 1 | from ..messages import get_assistant_message_from_openai 2 | from ..few_shot_examples import get_few_shot_example_messages 3 | from ..table_selection.table_details import get_table_schemas 4 | from ..caesar_logging import log_suggested_query 5 | 6 | 7 | def _get_failed_query_suggestion_message(scope="USA", natural_language_query=""): 8 | return f""" 9 | The following is a natural language query that cannot be answered with available data: 10 | --------------------- 11 | {natural_language_query} 12 | --------------------- 13 | Suggest a different natural language query (as similar as possible to the one given) that can be answered with a SQL query and the available data. 14 | Only return the suggested natural language query. 15 | Avoid using table names and column names in the suggested query. 16 | The following are descriptions of available tables and enums: 17 | --------------------- 18 | {get_table_schemas(scope=scope)} 19 | --------------------- 20 | """ 21 | 22 | 23 | def _get_query_suggestion_message(scope="USA", natural_language_query=""): 24 | return f""" 25 | The following is a natural language query: 26 | --------------------- 27 | {natural_language_query} 28 | --------------------- 29 | Suggest a different natural language query, similar to the one given, that can be answered with a SQL query and the available data. 30 | Only return the suggested natural language query. 31 | If possible, build on top of the given query to generate deeper insights into the data available. 32 | The following are descriptions of available tables and enums: 33 | --------------------- 34 | {get_table_schemas(scope=scope)} 35 | --------------------- 36 | """ 37 | 38 | 39 | def _get_failed_query_suggestion_messages(scope="USA"): 40 | # default_messages = [{ 41 | # "role": "system", 42 | # "content": ( 43 | # f""" 44 | # Users come to you with a natural language query that cannot be answered with available data. 45 | # You are a helpful assistant for suggesting a different natural language query (as similar as possible to the one given) that can be answered with a SQL query and the available data. 46 | # Only return the suggested natural language query. 47 | # Avoid using table names and column names in the suggested query. 48 | # The following are descriptions of available tables and enums: 49 | # --------------------- 50 | # {get_table_schemas(scope=scope)} 51 | # --------------------- 52 | # """ 53 | # ) 54 | # }] 55 | default_messages = [] 56 | default_messages.extend(get_few_shot_example_messages(mode="failed_query_suggestion", scope=scope)) 57 | return default_messages 58 | 59 | 60 | def _get_query_suggestion_messages(scope="USA"): 61 | # default_messages = [{ 62 | # "role": "system", 63 | # "content": ( 64 | # """ 65 | # Users come to you with a natural language query that has been answered from available data. 66 | # You are a helpful assistant for suggesting a different query, similar to the one given, that can be answered with a SQL query and the available data. 67 | # Only return the suggested natural language query. 68 | # If possible, build on top of the given query to generate deeper insights into the data available. 69 | # The following are descriptions of available tables and enums: 70 | # --------------------- 71 | # {get_table_schemas(scope=scope)} 72 | # --------------------- 73 | # """ 74 | # ) 75 | # }] 76 | default_messages = [] 77 | default_messages.extend(get_few_shot_example_messages(mode="query_suggestion", scope=scope)) 78 | return default_messages 79 | 80 | 81 | def generate_suggestion_failed_query(scope, failed_query, parent_id=None, session_id=None): 82 | """ 83 | Get suggested query based on failed query 84 | """ 85 | messages = _get_failed_query_suggestion_messages(scope) 86 | 87 | prompt = _get_failed_query_suggestion_message(scope, failed_query) 88 | 89 | messages.append({ 90 | "role": "user", 91 | "content": prompt 92 | }) 93 | 94 | model = "gpt-3.5-turbo" 95 | 96 | response = get_assistant_message_from_openai( 97 | messages=messages, 98 | model=model, 99 | scope="USA", 100 | purpose="failed_query_suggestion", 101 | session_id=session_id 102 | )["message"]["content"] 103 | suggested_query = response 104 | 105 | suggestion_id = log_suggested_query( 106 | input_text=failed_query, 107 | reason="failed_query_suggestion", 108 | parent_id=parent_id, 109 | suggested_query=suggested_query, 110 | app_name=scope, 111 | prompt=prompt, 112 | model=model, 113 | session_id=session_id 114 | ) 115 | 116 | return suggested_query, str(suggestion_id) 117 | 118 | 119 | def generate_suggestion(scope, failed_query, parent_id=None, session_id=None): 120 | """ 121 | Get suggested query to build on top of a given query or as a similar query 122 | """ 123 | messages = _get_query_suggestion_messages(scope) 124 | 125 | prompt = _get_query_suggestion_message(scope, failed_query) 126 | model = "gpt-3.5-turbo" 127 | 128 | messages.append({ 129 | "role": "user", 130 | "content": prompt 131 | }) 132 | response = get_assistant_message_from_openai( 133 | messages=messages, 134 | model=model, 135 | scope="USA", 136 | purpose="query_suggestion" 137 | )["message"]["content"] 138 | suggested_query = response 139 | 140 | suggestion_id = log_suggested_query( 141 | input_text=failed_query, 142 | reason="successful_query_suggestion", 143 | parent_id=parent_id, 144 | suggested_query=suggested_query, 145 | app_name=scope, 146 | prompt=prompt, 147 | model=model, 148 | session_id=session_id 149 | ) 150 | 151 | return suggested_query, str(suggestion_id) -------------------------------------------------------------------------------- /api/app/api/utils/table_selection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/api/app/api/utils/table_selection/__init__.py -------------------------------------------------------------------------------- /api/app/api/utils/table_selection/table_details.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import re 4 | 5 | 6 | table_details = {} 7 | with open("app/data/tables_many.json", "r") as f: 8 | table_details = json.load(f) 9 | 10 | sf_table_details = {} 11 | with open("app/data/sf_tables.json", "r") as f: 12 | sf_table_details = json.load(f) 13 | 14 | 15 | def extract_text_from_markdown(text): 16 | regex = r"`([\s\S]+?)`" 17 | matches = re.findall(regex, text) 18 | 19 | if matches: 20 | extracted_text = matches[0] 21 | else: 22 | extracted_text = text 23 | 24 | return extracted_text 25 | 26 | def get_all_table_names(scope="USA") -> List[str]: 27 | if scope == "USA": 28 | return [table["name"] for table in table_details["tables"]] 29 | elif scope == "SF": 30 | return [table["name"] for table in sf_table_details["tables"]] 31 | return [] 32 | 33 | 34 | def get_table_schemas(table_names: List[str] = None, scope="USA") -> str: 35 | enums_list = [] 36 | tables_list = [] 37 | 38 | if scope == "USA": 39 | enums_list = table_details.get("enums", []) 40 | if table_names: 41 | for table in table_details['tables']: 42 | if table['name'] in table_names: 43 | tables_list.append(table) 44 | else: 45 | tables_list = table_details["tables"] 46 | elif scope == "SF": 47 | enums_list = sf_table_details["enums"] 48 | if table_names: 49 | for table in sf_table_details['tables']: 50 | if table['name'] in table_names: 51 | tables_list.append(table) 52 | else: 53 | tables_list = sf_table_details["tables"] 54 | 55 | enums_str_set = set() 56 | tables_str_list = [] 57 | for table in tables_list: 58 | if scope == "SF": 59 | 60 | tables_str = table['table_creation_query'] 61 | 62 | # get all the vars in backticks using regex from tables_str 63 | regex = r"`([\s\S]+?)`" 64 | matches = re.findall(regex, tables_str) 65 | if matches: 66 | # add each to enums_str_set 67 | for match in matches: 68 | enums_str_set.add(match) 69 | 70 | else: 71 | tables_str = f"table name: {table['name']}\n" 72 | tables_str += f"table description: {table['description']}\n" 73 | columns_str_list = [] 74 | for column in table['columns']: 75 | if column.get('description'): 76 | columns_str_list.append(f"{column['name']} [{column['type']}] ({column['description']})") 77 | if 'custom type' in column['description']: 78 | enums_str_set.add(extract_text_from_markdown(column['description'])) 79 | else: 80 | columns_str_list.append(f"{column['name']} [{column['type']}]") 81 | tables_str += f"table columns: {', '.join(columns_str_list)}\n" 82 | tables_str_list.append(tables_str) 83 | tables_description = "\n\n".join(tables_str_list) 84 | 85 | enums_str_list = [] 86 | for custom_type_str in enums_str_set: 87 | custom_type = next((t for t in enums_list if t["type"] == custom_type_str), None) 88 | if custom_type: 89 | enums_str = f"custom type: {custom_type['type']}\n" 90 | enums_str += f"valid values: {', '.join(custom_type['valid_values'])}\n" 91 | enums_str_list.append(enums_str) 92 | enums_description = "\n\n".join(enums_str_list) 93 | 94 | # return tables_description 95 | return enums_description + "\n\n" + tables_description 96 | 97 | def get_table_and_enums(table_names: List[str] = None, scope="USA") -> tuple[str, str]: 98 | enums_list = [] 99 | tables_list = [] 100 | 101 | if scope == "USA": 102 | enums_list = table_details.get("enums", []) 103 | if table_names: 104 | for table in table_details['tables']: 105 | if table['name'] in table_names: 106 | tables_list.append(table) 107 | else: 108 | tables_list = table_details["tables"] 109 | elif scope == "SF": 110 | enums_list = sf_table_details["enums"] 111 | if table_names: 112 | for table in sf_table_details['tables']: 113 | if table['name'] in table_names: 114 | tables_list.append(table) 115 | else: 116 | tables_list = sf_table_details["tables"] 117 | 118 | enums_str_set = set() 119 | tables_str_list = [] 120 | for table in tables_list: 121 | if scope == "SF": 122 | 123 | tables_str = table['table_creation_query'] 124 | 125 | # get all the vars in backticks using regex from tables_str 126 | regex = r"`([\s\S]+?)`" 127 | matches = re.findall(regex, tables_str) 128 | if matches: 129 | # add each to enums_str_set 130 | for match in matches: 131 | enums_str_set.add(match) 132 | 133 | else: 134 | tables_str = f"table name: {table['name']}\n" 135 | tables_str += f"table description: {table['description']}\n" 136 | columns_str_list = [] 137 | for column in table['columns']: 138 | if column.get('description'): 139 | columns_str_list.append(f"{column['name']} [{column['type']}] ({column['description']})") 140 | if 'custom type' in column['description']: 141 | enums_str_set.add(extract_text_from_markdown(column['description'])) 142 | else: 143 | columns_str_list.append(f"{column['name']} [{column['type']}]") 144 | tables_str += f"table columns: {', '.join(columns_str_list)}\n" 145 | tables_str_list.append(tables_str) 146 | tables_description = "\n\n".join(tables_str_list) 147 | 148 | enums_str_list = [] 149 | for custom_type_str in enums_str_set: 150 | custom_type = next((t for t in enums_list if t["type"] == custom_type_str), None) 151 | if custom_type: 152 | enums_str = f"custom type: {custom_type['type']}\n" 153 | enums_str += f"valid values: {', '.join(custom_type['valid_values'])}\n" 154 | enums_str_list.append(enums_str) 155 | enums_description = "\n\n".join(enums_str_list) 156 | 157 | # return tables_description 158 | return tables_description, enums_description 159 | 160 | def get_minimal_table_schemas(scope="USA") -> str: 161 | 162 | tables_list = [] 163 | 164 | if scope == "USA": 165 | tables_list = table_details["tables"] 166 | elif scope == "SF": 167 | tables_list = sf_table_details["tables"] 168 | 169 | tables_str_list = [] 170 | for table in tables_list: 171 | if scope == "SF": 172 | tables_str = f"table name: {table['name']}\n" 173 | tables_str += f"table description: {table['description']}\n" 174 | else: 175 | tables_str = f"table name: {table['name']}\n" 176 | tables_str += f"table description: {table['description']}\n" 177 | tables_str_list.append(tables_str) 178 | 179 | tables_description = "\n\n".join(tables_str_list) 180 | 181 | # return tables_description 182 | return tables_description 183 | -------------------------------------------------------------------------------- /api/app/api/utils/table_selection/table_selection.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import List 4 | 5 | import pinecone 6 | from openai.embeddings_utils import get_embedding 7 | 8 | from ....config import PINECONE_ENV, PINECONE_KEY 9 | from ..few_shot_examples import get_few_shot_example_messages 10 | from ..messages import get_assistant_message_from_openai 11 | from .table_details import get_table_schemas, get_all_table_names 12 | 13 | 14 | def _extract_text_from_markdown(text): 15 | matches = re.findall(r"```([\s\S]+?)```", text) 16 | if matches: 17 | return matches[0] 18 | return text 19 | 20 | def _get_table_selection_message_with_descriptions(scope="USA"): 21 | message = ( 22 | """ 23 | You are an expert data scientist. 24 | Return a JSON object with relevant SQL tables for answering the following natural language query: 25 | --------------- 26 | {natural_language_query} 27 | --------------- 28 | Respond in JSON format with your answer in a field named \"tables\" which is a list of strings. 29 | Respond with an empty list if you cannot identify any relevant tables. 30 | Write your answer in markdown format. 31 | """ 32 | ) 33 | return ( 34 | message + 35 | f""" 36 | The following are the scripts that created the tables and the definition of their enums: 37 | --------------------- 38 | {get_table_schemas(scope=scope)} 39 | --------------------- 40 | 41 | in your answer, provide the following information: 42 | 43 | - 44 | - 45 | - 46 | - the markdown formatted like this: 47 | ``` 48 | 49 | ``` 50 | 51 | Provide only the list of related tables and nothing else after. 52 | """ 53 | ) 54 | 55 | 56 | def _get_table_selection_messages(scope="USA"): 57 | # default_messages = [{ 58 | # "role": "system", 59 | # "content": ( 60 | # f""" 61 | # You are a helpful assistant for identifying relevant SQL tables to use for answering a natural language query. 62 | # You respond in JSON format with your answer in a field named \"tables\" which is a list of strings. 63 | # Respond with an empty list if you cannot identify any relevant tables. 64 | # Write your answer in markdown format. 65 | # The following are descriptions of available tables and enums: 66 | # --------------------- 67 | # {get_table_schemas(scope=scope)} 68 | # --------------------- 69 | # """ 70 | # ) 71 | # }] 72 | default_messages = [] 73 | default_messages.extend(get_few_shot_example_messages(mode="table_selection", scope=scope)) 74 | return default_messages 75 | 76 | 77 | def get_relevant_tables_from_pinecone(natural_language_query, scope="USA") -> List[str]: 78 | vector = get_embedding(natural_language_query, "text-embedding-ada-002") 79 | 80 | if scope == "SF": 81 | index_name = "sf-gpt" 82 | elif scope == "USA": 83 | index_name = "usa-gpt" 84 | 85 | results = pinecone.Index(index_name).query( 86 | vector=vector, 87 | top_k=5, 88 | include_metadata=True, 89 | ) 90 | 91 | tables_set = set() 92 | for result in results["matches"]: 93 | for table_name in result.metadata["table_names"]: 94 | tables_set.add(table_name) 95 | 96 | if scope == "USA": 97 | if len(tables_set) == 1 and "crime_by_city" in tables_set: 98 | pass 99 | else: 100 | tables_set.add("location_data") 101 | 102 | return list(tables_set) 103 | 104 | def get_relevant_tables_from_lm(natural_language_query, scope="USA", model="gpt-3.5-turbo", session_id=None): 105 | """ 106 | Identify relevant tables for answering a natural language query via LM 107 | """ 108 | content = _get_table_selection_message_with_descriptions(scope).format( 109 | natural_language_query=natural_language_query, 110 | ) 111 | 112 | messages = _get_table_selection_messages(scope).copy() 113 | messages.append({ 114 | "role": "user", 115 | "content": content 116 | }) 117 | 118 | try: 119 | response = get_assistant_message_from_openai( 120 | messages=messages, 121 | model=model, 122 | scope=scope, 123 | purpose="table_selection", 124 | session_id=session_id, 125 | )["message"]["content"] 126 | tables_json_str = _extract_text_from_markdown(response) 127 | 128 | tables = json.loads(tables_json_str).get("tables") 129 | except: 130 | tables = [] 131 | 132 | possible_tables = get_all_table_names(scope=scope) 133 | 134 | tables = [table for table in tables if table in possible_tables] 135 | 136 | # only get the first 7 tables 137 | tables = tables[:7] 138 | 139 | return tables 140 | 141 | 142 | def get_relevant_tables(natural_language_query, scope="USA") -> List[str]: 143 | """ 144 | Identify relevant tables for answering a natural language query 145 | """ 146 | 147 | # temporary hack to always use LM for SF 148 | if scope == "SF": 149 | # model = "gpt-4" 150 | model = "gpt-3.5-turbo" 151 | return get_relevant_tables_from_lm(natural_language_query, scope, model) 152 | 153 | if PINECONE_KEY and PINECONE_ENV: 154 | return get_relevant_tables_from_pinecone(natural_language_query, scope=scope) 155 | 156 | if scope == "SF": 157 | # model = "gpt-4" 158 | model = "gpt-3.5-turbo" 159 | else: 160 | model = "gpt-3.5-turbo" 161 | 162 | return get_relevant_tables_from_lm(natural_language_query, scope, model) 163 | 164 | 165 | async def get_relevant_tables_async(natural_language_query, scope="USA", session_id=None) -> List[str]: 166 | """ 167 | Identify relevant tables for answering a natural language query 168 | """ 169 | 170 | # temporary hack to always use LM for SF 171 | if scope == "SF": 172 | # model = "gpt-4" 173 | model = "gpt-3.5-turbo" 174 | return get_relevant_tables_from_lm(natural_language_query, scope, model, session_id=session_id) 175 | 176 | if PINECONE_KEY and PINECONE_ENV: 177 | return get_relevant_tables_from_pinecone(natural_language_query, scope=scope) 178 | 179 | if scope == "SF": 180 | # model = "gpt-4" 181 | model = "gpt-3.5-turbo" 182 | else: 183 | model = "gpt-3.5-turbo" 184 | 185 | return get_relevant_tables_from_lm(natural_language_query, scope, model) 186 | -------------------------------------------------------------------------------- /api/app/config.py: -------------------------------------------------------------------------------- 1 | from os import getenv 2 | 3 | import openai 4 | import pinecone 5 | from dotenv import load_dotenv 6 | from sqlalchemy import create_engine 7 | 8 | import sentry_sdk 9 | from sentry_sdk.integrations.flask import FlaskIntegration 10 | 11 | load_dotenv() 12 | 13 | ENV = getenv("ENVIRONMENT") or "unknown" 14 | DB_URL = getenv("DB_URL") 15 | OPENAI_KEY = getenv("OPENAI_KEY") 16 | PINECONE_KEY = getenv("PINECONE_KEY") 17 | PINECONE_ENV = getenv("PINECONE_ENV") 18 | EVENTS_URL = getenv("EVENTS_URL") 19 | SENTRY_URL = getenv("SENTRY_URL") 20 | 21 | if SENTRY_URL: 22 | sentry_sdk.init( 23 | dsn=SENTRY_URL, 24 | environment=ENV, 25 | integrations=[ 26 | FlaskIntegration(), 27 | ], 28 | traces_sample_rate=1.0 29 | ) 30 | 31 | 32 | openai.api_key = OPENAI_KEY 33 | 34 | class FlaskAppConfig: 35 | CORS_HEADERS = "Content-Type" 36 | SQLALCHEMY_DATABASE_URI = DB_URL 37 | SQLALCHEMY_ENGINE_OPTIONS = {"pool_pre_ping": True} 38 | TIMEOUT = 60 39 | 40 | if DB_URL: 41 | ENGINE = create_engine(DB_URL) 42 | dialect_mapping = { 43 | "postgresql": "PostgreSQL 15.2", 44 | "mysql": "MySQL", 45 | } 46 | DIALECT = dialect_mapping.get(ENGINE.dialect.name) 47 | else: 48 | print('DB_URL not found, please check your environment') 49 | exit(1) 50 | 51 | if EVENTS_URL: 52 | EVENTS_ENGINE = create_engine(EVENTS_URL) 53 | else: 54 | EVENTS_ENGINE = None 55 | 56 | if PINECONE_KEY and PINECONE_ENV: 57 | pinecone.init( 58 | api_key=PINECONE_KEY, 59 | environment=PINECONE_ENV, 60 | ) -------------------------------------------------------------------------------- /api/app/extensions.py: -------------------------------------------------------------------------------- 1 | from flask_sqlalchemy import SQLAlchemy 2 | 3 | 4 | db = SQLAlchemy() -------------------------------------------------------------------------------- /api/discordbot/bot.py: -------------------------------------------------------------------------------- 1 | import discord 2 | import responses 3 | from discord import app_commands 4 | from discord import Interaction 5 | 6 | async def send_message(message, user_message, is_private): 7 | try: 8 | response = responses.get_response(user_message) 9 | await message.author.send(response) if is_private else await message.channel.send(response) 10 | 11 | except Exception as e: 12 | print(e) 13 | 14 | 15 | 16 | 17 | def run_discord_bot(): 18 | TOKEN = 'DISCORD_BOT_TOKEN' 19 | intents = discord.Intents.default() 20 | intents.message_content = True 21 | client = discord.Client(intents=intents) 22 | tree = app_commands.CommandTree(client) 23 | 24 | @client.event 25 | async def on_ready(): 26 | print(f'{client.user} is now running!') 27 | 28 | @tree.command(name="query", description="Query census gpt") 29 | async def query_command(interaction: Interaction): 30 | await interaction.response.defer() 31 | await interaction.response.send_message("Hello!") 32 | 33 | @client.event 34 | async def on_message(message): 35 | if message.author == client.user: 36 | return 37 | 38 | username = str(message.author) 39 | user_message = str(message.content) 40 | channel = str(message.channel) 41 | 42 | print(f'{username} said: "{user_message}" ({channel})') 43 | 44 | if user_message[0] == '?': 45 | user_message = user_message[1:] 46 | await send_message(message, user_message, is_private=True) 47 | else: 48 | await send_message(message, user_message, is_private=False) 49 | 50 | client.run(TOKEN) -------------------------------------------------------------------------------- /api/discordbot/responses.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from tabulate import tabulate 3 | 4 | def get_response(message: str) -> str: 5 | p_message = message.lower() 6 | 7 | if p_message.startswith('!query'): 8 | natural_language_query = p_message.split('!query ')[-1] 9 | url = "https://text-sql-be.onrender.com/api/text_to_sql" 10 | 11 | payload = {"natural_language_query": natural_language_query} 12 | headers = {"Content-Type": "application/json"} 13 | 14 | response = requests.post(url, json=payload, headers=headers) 15 | if response.json()["result"] is None: 16 | return "Sorry, I couldn't find any results for that query" 17 | data = response.json()["result"]["results"] 18 | headers = response.json()["result"]["column_names"] 19 | table_data = [[d.get(header, "") for header in headers] for d in data] 20 | table = tabulate(table_data, headers=headers) 21 | return "```\n" + table + "\n```" 22 | -------------------------------------------------------------------------------- /api/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 2 | aiosignal==1.3.1 3 | alembic==1.10.3 4 | anyio==3.6.2 5 | appnope==0.1.3 6 | argon2-cffi==21.3.0 7 | argon2-cffi-bindings==21.2.0 8 | arrow==1.2.3 9 | asttokens==2.2.1 10 | async-timeout==4.0.2 11 | attrs==22.2.0 12 | autopep8==2.0.2 13 | backcall==0.2.0 14 | beautifulsoup4==4.12.2 15 | bleach==6.0.0 16 | blinker==1.5 17 | certifi==2022.12.7 18 | cffi==1.15.1 19 | charset-normalizer==2.1.1 20 | click==8.1.3 21 | comm==0.1.3 22 | contourpy==1.0.7 23 | cycler==0.11.0 24 | debugpy==1.6.7 25 | decorator==5.1.1 26 | defusedxml==0.7.1 27 | discord==2.2.2 28 | discord.py==2.2.2 29 | dnspython==2.3.0 30 | et-xmlfile==1.1.0 31 | exceptiongroup==1.1.1 32 | executing==1.2.0 33 | fastjsonschema==2.16.3 34 | Flask==2.2.2 35 | Flask-Admin==1.6.1 36 | Flask-Cors==3.0.10 37 | Flask-Migrate==4.0.4 38 | Flask-SQLAlchemy==3.0.2 39 | fonttools==4.39.3 40 | fqdn==1.5.1 41 | frozenlist==1.3.3 42 | gunicorn==20.1.0 43 | idna==3.4 44 | iniconfig==2.0.0 45 | ipykernel==6.22.0 46 | ipython==8.12.0 47 | ipython-genutils==0.2.0 48 | ipywidgets==8.0.6 49 | isoduration==20.11.0 50 | itsdangerous==2.1.2 51 | jedi==0.18.2 52 | Jinja2==3.1.2 53 | joblib==1.2.0 54 | jsonpointer==2.3 55 | jsonschema==4.17.3 56 | jupyter==1.0.0 57 | jupyter-console==6.6.3 58 | jupyter-events==0.6.3 59 | jupyter_client==8.1.0 60 | jupyter_core==5.3.0 61 | jupyter_server==2.5.0 62 | jupyter_server_terminals==0.4.4 63 | jupyterlab-pygments==0.2.2 64 | jupyterlab-widgets==3.0.7 65 | kiwisolver==1.4.4 66 | loguru==0.6.0 67 | Mako==1.2.4 68 | MarkupSafe==2.1.2 69 | matplotlib==3.7.1 70 | matplotlib-inline==0.1.6 71 | mistune==2.0.5 72 | multidict==6.0.4 73 | nbclassic==0.5.5 74 | nbclient==0.7.3 75 | nbconvert==7.3.0 76 | nbformat==5.8.0 77 | nest-asyncio==1.5.6 78 | newrelic==8.7.0 79 | notebook==6.5.4 80 | notebook_shim==0.2.2 81 | numpy==1.24.2 82 | openai==0.27.4 83 | openpyxl==3.1.2 84 | packaging==23.0 85 | pandas==2.0.0 86 | pandas-stubs==1.5.3.230321 87 | pandocfilters==1.5.0 88 | parso==0.8.3 89 | pexpect==4.8.0 90 | pickleshare==0.7.5 91 | Pillow==9.5.0 92 | pinecone-client==2.2.1 93 | platformdirs==3.2.0 94 | plotly==5.14.1 95 | pluggy==1.0.0 96 | prometheus-client==0.16.0 97 | prompt-toolkit==3.0.38 98 | psutil==5.9.4 99 | psycopg2-binary==2.9.5 100 | ptyprocess==0.7.0 101 | pure-eval==0.2.2 102 | pycodestyle==2.10.0 103 | pycparser==2.21 104 | Pygments==2.14.0 105 | pyparsing==3.0.9 106 | pyrsistent==0.19.3 107 | pytest==7.2.2 108 | python-dateutil==2.8.2 109 | python-dotenv==1.0.0 110 | python-json-logger==2.0.7 111 | pytz==2023.3 112 | PyYAML==6.0 113 | pyzmq==25.0.2 114 | qtconsole==5.4.2 115 | QtPy==2.3.1 116 | regex==2023.3.23 117 | requests==2.28.1 118 | rfc3339-validator==0.1.4 119 | rfc3986-validator==0.1.1 120 | scikit-learn==1.2.2 121 | scipy==1.10.1 122 | Send2Trash==1.8.0 123 | sentry-sdk==1.16.0 124 | six==1.16.0 125 | sniffio==1.3.0 126 | soupsieve==2.4 127 | SQLAlchemy==2.0.9 128 | stack-data==0.6.2 129 | tabulate==0.9.0 130 | tenacity==8.2.2 131 | terminado==0.17.1 132 | threadpoolctl==3.1.0 133 | tiktoken==0.3.2 134 | tinycss2==1.2.1 135 | tomli==2.0.1 136 | tornado==6.2 137 | tqdm==4.65.0 138 | traitlets==5.9.0 139 | types-pytz==2023.3.0.0 140 | typing_extensions==4.5.0 141 | tzdata==2023.3 142 | uri-template==1.2.0 143 | urllib3==1.26.15 144 | wcwidth==0.2.6 145 | webcolors==1.13 146 | webencodings==0.5.1 147 | websocket-client==1.5.1 148 | Werkzeug==2.2.3 149 | widgetsnbextension==4.0.7 150 | WTForms==3.0.1 151 | yarl==1.8.2 152 | -------------------------------------------------------------------------------- /api/scripts/dev.sh: -------------------------------------------------------------------------------- 1 | FLASK_APP=app.py FLASK_DEBUG=true flask run -p 9000 -------------------------------------------------------------------------------- /api/scripts/setup.sh: -------------------------------------------------------------------------------- 1 | # create python venv named venv 2 | python3.10 -m venv ./venv 3 | 4 | # activate venv 5 | source venv/bin/activate 6 | 7 | # upgrade pip 8 | pip install --upgrade pip 9 | 10 | # install project dependencies 11 | cat requirements.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 python -m pip install 12 | -------------------------------------------------------------------------------- /byod/.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | DS_Store 3 | venv/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /byod/README.md: -------------------------------------------------------------------------------- 1 | # 🔌 Text-to-SQL BYOD (Bring Your Own Data) 2 | 3 | 4 | You can now connect your own database & datasets to textSQL and self-host the service. Our vision is to continue to modularize and improve this process. 5 | 6 | ### Use cases 7 | 8 | - Public-facing interactive interfaces for data. Democratizing public data 9 | - Empowering researchers. Enabling journalists and other researchers to more easily explore data 10 | - Business intelligence. Reducing the burden on technical employees to build & run queries for non-technical 11 | 12 | ### Setup instructions 13 | 14 | These instructions will walk you through running your own API and client. You can run this all on localhost and then deploy it wherever you would like. 15 | 16 | ## API 17 | 18 | #### Prerequisites 19 | - `python3.10` 20 | 21 | #### Required configuration for development 22 | 23 | - OpenAI Key 24 | - URL to the postgres DB 25 | 26 | Configure the above in `.env` in the following path `/byod/api/app/` 27 | 28 | Here's an example of `.env` file that points to the CensusGPT Postgres database 29 | 30 | ``` 31 | OPENAI_KEY="YOUR_OPENAI_KEY" 32 | DB_URL="postgresql://census_data_user:3PjePE3hVzm2m2UFPywLTLfIiC6w28HB@dpg-cg73gvhmbg5ab7mrk8qg-b.replica-cyan.oregon-postgres.render.com/census_data_w0ix" 33 | ``` 34 | 35 | #### Local development 36 | 37 | Initial setup 38 | ```sh 39 | $ ./scripts/setup.sh 40 | ``` 41 | 42 | Activate virtual env 43 | ```sh 44 | $ source ./venv/bin/activate 45 | ``` 46 | 47 | Run local instance 48 | ```sh 49 | $ ./scripts/dev.sh 50 | ``` 51 | 52 | ## Client 53 | 54 | A front-end streamlit application for Text-to-SQL (alternatively you can use your own frontend) 55 | 56 | Screenshot 2023-04-13 at 8 48 24 PM 57 | 58 | #### Prerequisites 59 | `python3.10` 60 | 61 | #### Required configuration for development: 62 | - base URL for TextSQL API 63 | 64 | Configure the above in `.env` 65 | 66 | Example of `.env` file that should go in the following path `/byod/client` 67 | ``` 68 | API_BASE="http://localhost:9000" 69 | ``` 70 | 71 | When everything on localhost, this will point to the BYOD API on port 9000. 72 | 73 | #### Local development 74 | 75 | Initial setup 76 | ``` 77 | $ ./scripts/setup.sh 78 | ``` 79 | 80 | Activate virtual env 81 | ``` 82 | $ source ./venv/bin/activate 83 | ``` 84 | 85 | Run local instance 86 | ``` 87 | $ ./scripts/dev.sh 88 | ``` 89 | 90 | ## Facing issues? Got questions? 91 | 92 | Reach out in the discord for support: https://discord.com/invite/JZtxhZQQus 93 | -------------------------------------------------------------------------------- /byod/api/.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | DS_Store 3 | venv/ 4 | __pycache__/ 5 | scratch/ -------------------------------------------------------------------------------- /byod/api/README.md: -------------------------------------------------------------------------------- 1 | # TextSQL API 2 | 3 | A Flask API for Text-to-SQL 4 | 5 | ## Prerequisites 6 | - `python3.10` 7 | 8 | ## Required configuration for development: 9 | - OpenAI Key 10 | - URL to the postgres DB 11 | 12 | Configure the above in `.env`. Refer to `.env.example`. 13 | 14 | ## Local development 15 | 16 | Initial setup 17 | ```sh 18 | $ ./scripts/setup.sh 19 | ``` 20 | 21 | Activate virtual env 22 | ```sh 23 | $ source ./venv/bin/activate 24 | ``` 25 | 26 | Run local instance 27 | ```sh 28 | $ ./scripts/dev.sh 29 | ``` 30 | -------------------------------------------------------------------------------- /byod/api/app/.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_KEY="" 2 | PINECONE_KEY="" 3 | PINECONE_ENV="" 4 | DB_URL="" 5 | DB_MANAGED_METADATA="false" -------------------------------------------------------------------------------- /byod/api/app/__init__.py: -------------------------------------------------------------------------------- 1 | from app.config import FlaskAppConfig, DB_MANAGED_METADATA 2 | from app.extensions import db 3 | # import models to create tables if they don't exist 4 | from app.models import in_context_examples, table_metadata, type_metadata 5 | from app.setup.routes import bp as setup_bp 6 | from app.sql_explanation.routes import bp as sql_explanation_bp 7 | from app.sql_generation.routes import bp as sql_gen_bp 8 | from app.table_selection.routes import bp as table_selection_bp 9 | from app.visualization.routes import bp as visualization_bp 10 | from app.table_selection.utils import load_tables_and_types_metadata 11 | from app.utils import load_in_context_examples 12 | from flask import Flask 13 | from flask_admin import Admin 14 | from flask_cors import CORS 15 | from flask_migrate import Migrate 16 | 17 | 18 | def create_app(config_object=FlaskAppConfig): 19 | app = Flask(__name__) 20 | app.config.from_object(config_object) 21 | CORS(app) 22 | 23 | 24 | # Initialize app with extensions 25 | db.init_app(app) 26 | migrate = Migrate(app, db) 27 | with app.app_context(): 28 | if DB_MANAGED_METADATA: 29 | db.create_all() 30 | load_tables_and_types_metadata() 31 | load_in_context_examples() 32 | admin = Admin(None, name='admin', template_mode='bootstrap3') 33 | admin.init_app(app) 34 | 35 | 36 | @app.route("/ping") 37 | def ping(): 38 | return 'pong' 39 | 40 | app.register_blueprint(setup_bp) 41 | app.register_blueprint(sql_explanation_bp) 42 | app.register_blueprint(sql_gen_bp) 43 | app.register_blueprint(table_selection_bp) 44 | app.register_blueprint(visualization_bp) 45 | 46 | # from app.errors import bp as errors_bp 47 | # app.register_blueprint(errors_bp) 48 | 49 | # from app.main import bp as main_bp 50 | # app.register_blueprint(main_bp) 51 | 52 | @app.teardown_request 53 | def session_clear(exception=None): 54 | db.session.remove() 55 | if exception: 56 | if db.session.is_active: 57 | db.session.rollback() 58 | 59 | return app -------------------------------------------------------------------------------- /byod/api/app/config.py: -------------------------------------------------------------------------------- 1 | from os import getenv 2 | 3 | import openai 4 | import pinecone 5 | from dotenv import load_dotenv 6 | from sqlalchemy import create_engine 7 | 8 | load_dotenv() 9 | 10 | DB_URL = getenv("DB_URL") 11 | OPENAI_KEY = getenv("OPENAI_KEY") 12 | PINECONE_KEY = getenv("PINECONE_KEY") 13 | PINECONE_ENV = getenv("PINECONE_ENV") 14 | DB_MANAGED_METADATA = getenv("DB_MANAGED_METADATA") 15 | DB_MANAGED_METADATA= False if DB_MANAGED_METADATA is None else DB_MANAGED_METADATA.lower() == 'true' 16 | 17 | openai.api_key = OPENAI_KEY 18 | 19 | class FlaskAppConfig: 20 | CORS_HEADERS = "Content-Type" 21 | SQLALCHEMY_DATABASE_URI = DB_URL 22 | SQLALCHEMY_ENGINE_OPTIONS = {"pool_pre_ping": True} 23 | 24 | if DB_URL: 25 | ENGINE = create_engine(DB_URL) 26 | else: 27 | print('DB_URL not found, please check your environment') 28 | exit(1) 29 | 30 | if PINECONE_KEY and PINECONE_ENV: 31 | pinecone.init( 32 | api_key=PINECONE_KEY, 33 | environment=PINECONE_ENV 34 | ) -------------------------------------------------------------------------------- /byod/api/app/extensions.py: -------------------------------------------------------------------------------- 1 | from flask_sqlalchemy import SQLAlchemy 2 | 3 | 4 | db = SQLAlchemy() -------------------------------------------------------------------------------- /byod/api/app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/byod/api/app/models/__init__.py -------------------------------------------------------------------------------- /byod/api/app/models/in_context_examples.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List 3 | 4 | from app.extensions import db 5 | 6 | @dataclass 7 | class InContextExamples(db.Model): 8 | __tablename__ = "ai_sql_in_context_examples" 9 | mode = db.Column(db.String, primary_key=True) 10 | examples: List[Dict[str, str]] = db.Column(db.JSON) -------------------------------------------------------------------------------- /byod/api/app/models/json/type_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "sex": { 3 | "type": "sex", 4 | "valid_values": [ 5 | "Both Sexes", 6 | "Female", 7 | "Male" 8 | ] 9 | }, 10 | "race": { 11 | "type": "race", 12 | "valid_values": [ 13 | "African American/Black", 14 | "All Races", 15 | "Asian", 16 | "Hispanic/Latino", 17 | "Multirace", 18 | "Native American", 19 | "Pacific Islander", 20 | "Some other race", 21 | "White" 22 | ] 23 | }, 24 | "mortality_reasons": { 25 | "type": "mortality_reasons", 26 | "valid_values": [ 27 | "All Cause Mortality", 28 | "All other and unspecified accidents and adverse effects", 29 | "All other diseases (residual)", 30 | "All other external causes", 31 | "Alzheimers disease", 32 | "Assault (homicide)", 33 | "Assault (homicide) by discharge of firearms", 34 | "Assault (homicide) by other and unspecified means", 35 | "Atherosclerosis", 36 | "Cerebrovascular diseases", 37 | "Certain conditions originating in the perinatal period", 38 | "Chronic liver disease and cirrhosis", 39 | "Chronic lower respiratory diseases", 40 | "Complications of medical and surgical care", 41 | "Congenital malformations", 42 | "Diabetes mellitus", 43 | "Essential hypertension and hypertensive renal disease", 44 | "Human immunodeficiency virus (HIV) disease", 45 | "Hypertensive heart disease with or without renal disease", 46 | "Influenza and pneumonia", 47 | "Intentional self-harm (suicide)", 48 | "Intentional self-harm (suicide) by discharge of firearms", 49 | "Intentional self-harm (suicide) by other and unspecified means", 50 | "Ischemic heart diseases", 51 | "Leukemia", 52 | "Malignant neoplasm of breast", 53 | "Malignant neoplasm of pancreas", 54 | "Malignant neoplasm of prostate", 55 | "Malignant neoplasm of stomach", 56 | "Malignant neoplasms of cervix uteri, corpus uteri and ovary", 57 | "Malignant neoplasms of colon, rectum and anus", 58 | "Malignant neoplasms of trachea, bronchus and lung", 59 | "Malignant neoplasms of urinary tract", 60 | "Motor vehicle accidents", 61 | "Nephritis, nephrotic syndrome and nephrosis", 62 | "Non-Hodgkins lymphoma", 63 | "Other diseases of circulatory system", 64 | "Other diseases of heart", 65 | "Other malignant neoplasms", 66 | "Peptic ulcer", 67 | "Pregnancy, childbirth and the puerperium", 68 | "Sudden infant death syndrome", 69 | "Other neonatal (excluding SIDS)", 70 | "Tuberculosis" 71 | ] 72 | }, 73 | "Year": { 74 | "type": "Year", 75 | "valid_values": [ 76 | "2012", 77 | "2013", 78 | "2014", 79 | "2015", 80 | "2016", 81 | "2017", 82 | "2018", 83 | "2019", 84 | "2020" 85 | ] 86 | } 87 | } -------------------------------------------------------------------------------- /byod/api/app/models/table_metadata.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List 3 | 4 | from app.extensions import db 5 | 6 | @dataclass 7 | class TableMetadata(db.Model): 8 | __tablename__ = "ai_sql_table_metadata" 9 | table_name = db.Column(db.String, primary_key=True) 10 | table_metadata: Dict[str, List[object]] = db.Column(db.JSON) -------------------------------------------------------------------------------- /byod/api/app/models/type_metadata.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List 3 | 4 | from app.extensions import db 5 | 6 | @dataclass 7 | class TypeMetadata(db.Model): 8 | __tablename__ = "ai_sql_type_metadata" 9 | type_name = db.Column(db.String, primary_key=True) 10 | type_metadata: Dict[str, List[object]] = db.Column(db.JSON) -------------------------------------------------------------------------------- /byod/api/app/setup/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/byod/api/app/setup/__init__.py -------------------------------------------------------------------------------- /byod/api/app/setup/routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, jsonify, make_response, request 2 | 3 | from ..config import ENGINE 4 | from .utils import (ENUMS_METADATA_DICT, TABLES_METADATA_DICT, 5 | generate_few_shot_queries, generate_table_metadata, 6 | generate_type_metadata, get_table_names, get_type_names, 7 | save_table_metadata, save_type_metadata) 8 | 9 | bp = Blueprint('setup_bp', __name__) 10 | 11 | @bp.route('/setup', methods=['POST']) 12 | def setup_db(): 13 | """ 14 | Set up database for text to SQL 15 | """ 16 | request_body = request.get_json() 17 | db_credentials = {} 18 | db_credentials["address"] = request_body.get("address") 19 | db_credentials["database"] = request_body.get("database") 20 | db_credentials["username"] = request_body.get("username") 21 | db_credentials["password"] = request_body.get("password") 22 | db_credentials["port"] = request_body.get("port", 5432) 23 | 24 | for key, value in db_credentials.items(): 25 | if not value: 26 | error_msg = f"`{key}` is missing from request body" 27 | return make_response(jsonify({"error": error_msg}), 400) 28 | 29 | 30 | @bp.route('/tables', methods=['GET']) 31 | def get_tables(): 32 | """ 33 | Get table names from database 34 | """ 35 | table_names = get_table_names() 36 | return make_response(jsonify({"table_names": table_names}), 200) 37 | 38 | 39 | @bp.route('/get_tables_metadata', methods=['POST']) 40 | def get_tables_metadata(): 41 | """ 42 | Get tables metadata 43 | """ 44 | request_body = request.get_json() 45 | table_names = request_body.get('table_names') 46 | 47 | tables_metadata = {} 48 | for t in table_names: 49 | metadata = generate_table_metadata(t) 50 | tables_metadata[t] = metadata 51 | 52 | return make_response(jsonify({"tables_metadata": tables_metadata}), 200) 53 | 54 | 55 | @bp.route('/types', methods=['GET']) 56 | def get_types(): 57 | """ 58 | Get type names from database 59 | """ 60 | type_names = get_type_names() 61 | return make_response(jsonify({"type_names": type_names}), 200) 62 | 63 | 64 | @bp.route('/get_types_metadata', methods=['POST']) 65 | def get_types_metadata(): 66 | """ 67 | Get types metadata 68 | """ 69 | request_body = request.get_json() 70 | type_names = request_body.get('type_names') 71 | 72 | types_metadata = {} 73 | for t in type_names: 74 | metadata = generate_type_metadata(t) 75 | types_metadata[t] = metadata 76 | 77 | return make_response(jsonify({"types_metadata": types_metadata}), 200) 78 | 79 | 80 | @bp.route('/save_metadata', methods=['POST']) 81 | def save_metadata(): 82 | request_body = request.get_json() 83 | tables_metadata_dict = request_body.get("tables_metadata_dict", {}) 84 | types_metadata_dict = request_body.get("types_metadata_dict", {}) 85 | 86 | for name, metadata in tables_metadata_dict.items(): 87 | save_table_metadata(name, metadata) 88 | 89 | for name, metadata in types_metadata_dict.items(): 90 | save_type_metadata(name, metadata) 91 | 92 | return "Success" 93 | 94 | 95 | # TODO: delete metadata 96 | 97 | 98 | # DEPRECATED 99 | @bp.route('/setup_metadata', methods=['POST']) 100 | def setup_metadata(): 101 | 102 | # overwrite existing tables and enums metadata 103 | TABLES_METADATA_DICT = {} 104 | ENUMS_METADATA_DICT = {} 105 | 106 | for table_name in get_table_names(): 107 | save_table_metadata(table_name, generate_table_metadata(table_name)) 108 | for type_name in get_type_names(): 109 | save_type_metadata(type_name, generate_type_metadata(type_name)) 110 | 111 | return "Success" -------------------------------------------------------------------------------- /byod/api/app/sql_explanation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/byod/api/app/sql_explanation/__init__.py -------------------------------------------------------------------------------- /byod/api/app/sql_explanation/routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, jsonify, make_response, request 2 | 3 | from .utils import get_sql_explanation 4 | 5 | bp = Blueprint('sql_explanation_bp', __name__) 6 | 7 | @bp.route('/explain_sql', methods=['POST']) 8 | def get_tables(): 9 | """ 10 | Explains SQL in natural language 11 | """ 12 | request_body = request.get_json() 13 | sql = request_body.get('sql') 14 | 15 | if not sql: 16 | error_msg = '`sql` is missing from request body' 17 | return make_response(jsonify({"error": error_msg}), 400) 18 | 19 | explanation = get_sql_explanation(sql) 20 | return make_response(jsonify({'explanation': explanation}), 200) -------------------------------------------------------------------------------- /byod/api/app/sql_explanation/utils.py: -------------------------------------------------------------------------------- 1 | from ..utils import get_assistant_message, get_few_shot_messages 2 | 3 | 4 | def get_message_with_descriptions(): 5 | message = ( 6 | "Provide a concise explanation for the following SQL query: ```{sql}```" 7 | ) 8 | return message 9 | 10 | 11 | def get_default_messages(): 12 | default_messages = [{ 13 | "role": "system", 14 | "content": "You are a helpful assistant for providing an explanation for a SQL query." 15 | }] 16 | default_messages.extend(get_few_shot_messages(mode="sql_explanation")) 17 | return default_messages 18 | 19 | 20 | def get_sql_explanation(sql) -> str: 21 | """ 22 | Use language model to generate explanation of SQL query 23 | """ 24 | content = get_message_with_descriptions().format(sql=sql) 25 | messages = get_default_messages().copy() 26 | messages.append({ 27 | "role": "user", 28 | "content": content 29 | }) 30 | 31 | model = "gpt-3.5-turbo" 32 | 33 | assistant_message_content = get_assistant_message(messages=messages, model=model)["message"]["content"] 34 | return assistant_message_content -------------------------------------------------------------------------------- /byod/api/app/sql_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/byod/api/app/sql_generation/__init__.py -------------------------------------------------------------------------------- /byod/api/app/sql_generation/routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, jsonify, make_response, request 2 | 3 | from ..config import PINECONE_ENV, PINECONE_KEY 4 | from ..table_selection.utils import (get_relevant_tables_from_lm, 5 | get_relevant_tables_from_pinecone) 6 | from .utils import text_to_sql_with_retry 7 | 8 | bp = Blueprint('sql_generation_bp', __name__) 9 | 10 | 11 | @bp.route('/text_to_sql', methods=['POST']) 12 | def text_to_sql(): 13 | """ 14 | Convert natural language query to SQL 15 | """ 16 | request_body = request.get_json() 17 | natural_language_query = request_body.get("natural_language_query") 18 | table_names = request_body.get("table_names") 19 | 20 | if not natural_language_query: 21 | error_msg = "`natural_language_query` is missing from request body" 22 | return make_response(jsonify({"error": error_msg}), 400) 23 | 24 | try: 25 | if not table_names: 26 | if PINECONE_ENV and PINECONE_KEY: 27 | table_names = get_relevant_tables_from_pinecone(natural_language_query) 28 | else: 29 | table_names = get_relevant_tables_from_lm(natural_language_query) 30 | result, sql_query = text_to_sql_with_retry(natural_language_query, table_names) 31 | except Exception as e: 32 | error_msg = f"Error processing request: {str(e)}" 33 | return make_response(jsonify({"error": error_msg}), 500) 34 | 35 | return make_response(jsonify({"result": result, "sql_query": sql_query}), 200) 36 | -------------------------------------------------------------------------------- /byod/api/app/sql_generation/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict, List 3 | 4 | from app.config import ENGINE 5 | from sqlalchemy import text 6 | 7 | from ..table_selection.utils import get_table_schemas_str 8 | from ..utils import (extract_sql_query_from_message, get_assistant_message, 9 | get_few_shot_messages) 10 | 11 | MSG_WITH_ERROR_TRY_AGAIN = ( 12 | """ 13 | Try again. 14 | Only respond with valid SQL. Make sure to write your answer in markdown format. 15 | The SQL query you just generated resulted in the following error message: 16 | --------------------- 17 | {error_message} 18 | --------------------- 19 | """ 20 | ) 21 | 22 | 23 | def make_default_messages(schemas_str: str) -> List[Dict[str, str]]: 24 | # default_messages = [{ 25 | # "role": "system", 26 | # "content": ( 27 | # f""" 28 | # You are a helpful assistant for generating syntactically correct read-only SQL to answer a given question or command. 29 | # The following are tables you can query: 30 | # --------------------- 31 | # {schemas_str} 32 | # --------------------- 33 | # Make sure to write your answer in markdown format. 34 | # """ 35 | # # TODO: place warnings here 36 | # # i.e. "Make sure each value in the result table is not null." 37 | # ) 38 | # }] 39 | default_messages = [] 40 | default_messages.extend(get_few_shot_messages(mode="text_to_sql")) 41 | return default_messages 42 | 43 | 44 | def make_rephrase_msg_with_schema_and_warnings(): 45 | return ( 46 | """ 47 | Let's start by fixing and rephrasing the query to be more analytical. Use the schema context to rephrase the user question in a way that leads to optimal query results: {natural_language_query} 48 | The following are schemas of tables you can query: 49 | --------------------- 50 | {schemas_str} 51 | --------------------- 52 | Do not include any of the table names in the query. 53 | Ask the natural language query the way a data analyst, with knowledge of these tables, would. 54 | """ 55 | ) 56 | 57 | def make_msg_with_schema_and_warnings(): 58 | return ( 59 | """ 60 | Generate syntactically correct read-only SQL to answer the following question/command: {natural_language_query} 61 | The following are schemas of tables you can query: 62 | --------------------- 63 | {schemas_str} 64 | --------------------- 65 | Make sure to write your answer in markdown format. 66 | """ 67 | # TODO: place warnings here 68 | # i.e. "Make sure each value in the result table is not null."" 69 | ) 70 | 71 | def is_read_only_query(sql_query: str) -> bool: 72 | """ 73 | Checks if the given SQL query string is read-only. 74 | Returns True if the query is read-only, False otherwise. 75 | """ 76 | # List of SQL statements that modify data in the database 77 | modifying_statements = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "GRANT", "TRUNCATE", "LOCK TABLES", "UNLOCK TABLES"] 78 | 79 | # Check if the query contains any modifying statements 80 | for statement in modifying_statements: 81 | if statement in sql_query.upper(): 82 | return False 83 | 84 | # If no modifying statements are found, the query is read-only 85 | return True 86 | 87 | 88 | class NotReadOnlyException(Exception): 89 | pass 90 | 91 | 92 | class NullValueException(Exception): 93 | pass 94 | 95 | 96 | def execute_sql(sql_query: str): 97 | if not is_read_only_query(sql_query): 98 | raise NotReadOnlyException("Only read-only queries are allowed.") 99 | 100 | with ENGINE.connect() as connection: 101 | connection = connection.execution_options(postgresql_readonly=True) 102 | with connection.begin(): 103 | result = connection.execute(text(sql_query)) 104 | 105 | column_names = list(result.keys()) 106 | 107 | rows = [list(r) for r in result.all()] 108 | 109 | # Check for null values 110 | # for row in rows: 111 | # for value in row: 112 | # if value is None: 113 | # raise NullValueException("Make sure each value in the result table is not null.") 114 | 115 | 116 | results = [] 117 | for row in rows: 118 | result = OrderedDict() 119 | for i, column_name in enumerate(column_names): 120 | result[column_name] = row[i] 121 | results.append(result) 122 | 123 | result_dict = { 124 | "column_names": column_names, 125 | "results": results, 126 | } 127 | if results: 128 | result_dict["column_types"] = [type(r).__name__ for r in results[0]] 129 | 130 | return result_dict 131 | 132 | 133 | def text_to_sql_with_retry(natural_language_query, table_names, k=3, messages=None): 134 | """ 135 | Tries to take a natural language query and generate valid SQL to answer it K times 136 | """ 137 | if not messages: 138 | # ask the assistant to rephrase before generating the query 139 | schemas_str = get_table_schemas_str(table_names) 140 | # rephrase = [{ 141 | # "role": "user", 142 | # "content": make_rephrase_msg_with_schema_and_warnings().format( 143 | # natural_language_query=natural_language_query, 144 | # schemas_str=schemas_str 145 | # ) 146 | # }] 147 | # rephrased_query = get_assistant_message(rephrase)["message"]["content"] 148 | # natural_language_query = rephrased_query 149 | 150 | content = make_msg_with_schema_and_warnings().format( 151 | natural_language_query=natural_language_query, 152 | schemas_str=schemas_str 153 | ) 154 | 155 | messages = make_default_messages(schemas_str) 156 | messages.append({ 157 | "role": "user", 158 | "content": content 159 | }) 160 | 161 | assistant_message = None 162 | 163 | for _ in range(k): 164 | try: 165 | # model = "gpt-4" 166 | # model = "gpt-3.5-turbo" 167 | model = "gpt-3.5-turbo-0301" 168 | assistant_message = get_assistant_message(messages, model=model) 169 | sql_query = extract_sql_query_from_message(assistant_message["message"]["content"]) 170 | 171 | response = execute_sql(sql_query) 172 | # Generated SQL query did not produce exception. Return result 173 | return response, sql_query 174 | 175 | except Exception as e: 176 | messages.append({ 177 | "role": "assistant", 178 | "content": assistant_message["message"]["content"] 179 | }) 180 | messages.append({ 181 | "role": "user", 182 | "content": MSG_WITH_ERROR_TRY_AGAIN.format(error_message=str(e)) 183 | }) 184 | 185 | print("Could not generate SQL query after {k} tries.".format(k=k)) 186 | return None, None 187 | -------------------------------------------------------------------------------- /byod/api/app/table_selection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/byod/api/app/table_selection/__init__.py -------------------------------------------------------------------------------- /byod/api/app/table_selection/routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, jsonify, make_response, request 2 | 3 | from .utils import get_relevant_tables_from_pinecone 4 | 5 | bp = Blueprint('table_selection_bp', __name__) 6 | 7 | @bp.route('/get_tables', methods=['POST']) 8 | def get_tables(): 9 | """ 10 | Select relevant tables given a natural language query 11 | """ 12 | request_body = request.get_json() 13 | natural_language_query = request_body.get("natural_language_query") 14 | 15 | if not natural_language_query: 16 | error_msg = '`natural_language_query` is missing from request body' 17 | return make_response(jsonify({"error": error_msg}), 400) 18 | 19 | table_names = get_relevant_tables_from_pinecone(natural_language_query) 20 | return make_response(jsonify({"table_names": table_names}), 200) -------------------------------------------------------------------------------- /byod/api/app/table_selection/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import List 4 | 5 | import pinecone 6 | from app.config import DB_MANAGED_METADATA 7 | from app.extensions import db 8 | from app.models.table_metadata import TableMetadata 9 | from app.models.type_metadata import TypeMetadata 10 | from openai.embeddings_utils import get_embedding 11 | 12 | from ..utils import get_assistant_message, get_few_shot_messages 13 | 14 | ENUMS_METADATA_DICT = {} 15 | TABLES_METADATA_DICT = {} 16 | def load_tables_and_types_metadata(): 17 | """ 18 | Setup metadata dicts for tables and enums 19 | """ 20 | global ENUMS_METADATA_DICT 21 | global TABLES_METADATA_DICT 22 | 23 | if not DB_MANAGED_METADATA: 24 | with open("app/models/json/table_metadata.json", "r") as f: 25 | TABLES_METADATA_DICT = json.load(f) 26 | with open("app/models/json/type_metadata.json", "r") as f: 27 | ENUMS_METADATA_DICT = json.load(f) 28 | return 29 | 30 | try: 31 | enums_metadata = TypeMetadata.query.all() 32 | except Exception as e: 33 | print(e) 34 | enums_metadata = [] 35 | for enum_metadata in enums_metadata: 36 | # ENUMS_METADATA_DICT[enum_metadata.type_name] = enum_metadata 37 | ENUMS_METADATA_DICT[enum_metadata.type_name] = enum_metadata.type_metadata 38 | 39 | try: 40 | tables_metadata = TableMetadata.query.all() 41 | except Exception as e: 42 | print(e) 43 | tables_metadata = [] 44 | for table_metadata in tables_metadata: 45 | # TABLES_METADATA_DICT[table_metadata.table_name] = table_metadata 46 | TABLES_METADATA_DICT[table_metadata.table_name] = table_metadata.table_metadata 47 | 48 | 49 | 50 | def save_tables_metadata_to_json(): 51 | with open("app/models/json/table_metadata.json", "w") as f: 52 | json.dump(TABLES_METADATA_DICT, f, indent=4) 53 | 54 | 55 | def save_enums_metadata_to_json(): 56 | with open("app/models/json/type_metadata.json", "w") as f: 57 | json.dump(ENUMS_METADATA_DICT, f, indent=4) 58 | 59 | # # TODO: load few shot from json 60 | # def save_few_shots_to_json(): 61 | # with open("app/models/json/in_context_examples.json", "w") as f: 62 | # json.dump(IN_CONTEXT_EXAMPLES_DICT, f, indent=4) 63 | 64 | 65 | # TODO: refac this to access JSON fields instead of tables 66 | def get_table_schemas_str(table_names: List[str] = []) -> str: 67 | """ 68 | Format table and types metadata into string to be used in prompt 69 | """ 70 | global ENUMS_METADATA_DICT 71 | global TABLES_METADATA_DICT 72 | 73 | tables_to_use = [] 74 | if table_names: 75 | tables_to_use = [TABLES_METADATA_DICT[t_name] for t_name in table_names] 76 | else: 77 | tables_to_use = [t for t in TABLES_METADATA_DICT.values()] 78 | 79 | enums_to_use = set() 80 | tables_str_list = [] 81 | for table in tables_to_use: 82 | tables_str = f"table name: {table['name']}\n" 83 | if table.get("description"): 84 | tables_str += f"table description: {table.get('description')}\n" 85 | columns_str_list = [] 86 | for column in table.get("columns", []): 87 | columns_str_list.append(f"{column['name']} [{column['type']}]") 88 | if column.get("type") in ENUMS_METADATA_DICT.keys(): 89 | enums_to_use.add(column.get("type")) 90 | tables_str += f"table columns: {', '.join(columns_str_list)}\n" 91 | tables_str_list.append(tables_str) 92 | tables_details = "\n\n".join(tables_str_list) 93 | 94 | enums_str_list = [] 95 | for custom_type_str in enums_to_use: 96 | custom_type = ENUMS_METADATA_DICT.get(custom_type_str) 97 | if custom_type: 98 | enums_str = f"enum: {custom_type['type']}\n" 99 | enums_str += f"valid values: {', '.join(custom_type.get('valid_values'))}\n" 100 | enums_str_list.append(enums_str) 101 | enums_details = "\n\n".join(enums_str_list) 102 | 103 | return enums_details + "\n\n" + tables_details 104 | 105 | 106 | def get_relevant_tables_from_pinecone(natural_language_query, index_name="text_to_sql") -> List[str]: 107 | """ 108 | Identify relevant tables for answering a natural language query via vector store 109 | """ 110 | vector = get_embedding(natural_language_query, "text-embedding-ada-002") 111 | 112 | results = pinecone.Index(index_name).query( 113 | vector=vector, 114 | top_k=5, 115 | include_metadata=True, 116 | ) 117 | 118 | table_names = set() 119 | for result in results["matches"]: 120 | for table_name in result.metadata["table_names"]: 121 | table_names.add(table_name) 122 | 123 | print(results["matches"]) 124 | 125 | return list(table_names) 126 | 127 | 128 | def _get_table_selection_message_with_descriptions(natural_language_query): 129 | return f""" 130 | Return a JSON object with relevant SQL tables for answering the following natural language query: 131 | --------------------- 132 | {natural_language_query} 133 | --------------------- 134 | Respond in JSON format with your answer in a field named \"tables\" which is a list of strings. 135 | Respond with an empty list if you cannot identify any relevant tables. 136 | Make sure to write your answer in markdown format. 137 | The following are descriptions of available tables and enums: 138 | --------------------- 139 | {get_table_schemas_str()} 140 | --------------------- 141 | """ 142 | 143 | 144 | def _get_table_selection_messages(): 145 | # default_messages = [{ 146 | # "role": "system", 147 | # "content": ( 148 | # f""" 149 | # You are a helpful assistant for identifying relevant SQL tables to use for answering a natural language query. 150 | # You respond in JSON format with your answer in a field named \"tables\" which is a list of strings. 151 | # Respond with an empty list if you cannot identify any relevant tables. 152 | # Make sure to write your answer in markdown format. 153 | # The following are descriptions of available tables and enums: 154 | # --------------------- 155 | # {get_table_schemas_str()} 156 | # --------------------- 157 | # """ 158 | # ) 159 | # }] 160 | default_messages = [] 161 | default_messages.extend(get_few_shot_messages(mode="table_selection")) 162 | return default_messages 163 | 164 | 165 | def _extract_text_from_markdown(text): 166 | matches = re.findall(r"```([\s\S]+?)```", text) 167 | if matches: 168 | return matches[0] 169 | return text 170 | 171 | 172 | def get_relevant_tables_from_lm(natural_language_query): 173 | """ 174 | Identify relevant tables for answering a natural language query via LM 175 | """ 176 | content = _get_table_selection_message_with_descriptions(natural_language_query) 177 | messages = _get_table_selection_messages().copy() 178 | messages.append({ 179 | "role": "user", 180 | "content": content 181 | }) 182 | 183 | tables_json_str = _extract_text_from_markdown( 184 | get_assistant_message( 185 | messages=messages, 186 | model="gpt-3.5-turbo-0301", 187 | )["message"]["content"] 188 | ) 189 | tables = json.loads(tables_json_str).get("tables") 190 | return tables -------------------------------------------------------------------------------- /byod/api/app/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Dict, List 4 | 5 | import openai 6 | from app.config import DB_MANAGED_METADATA 7 | from app.extensions import db 8 | from app.models.in_context_examples import InContextExamples 9 | 10 | IN_CONTEXT_EXAMPLES_DICT = {} 11 | def load_in_context_examples(): 12 | """ 13 | Setup in context examples dict 14 | """ 15 | global IN_CONTEXT_EXAMPLES_DICT 16 | 17 | if not DB_MANAGED_METADATA: 18 | with open("app/models/json/in_context_examples.json", "r") as f: 19 | IN_CONTEXT_EXAMPLES_DICT = json.load(f) 20 | return 21 | 22 | try: 23 | in_context_examples = InContextExamples.query.all() 24 | except Exception as e: 25 | print(e) 26 | in_context_examples = [] 27 | for in_context_example in in_context_examples: 28 | IN_CONTEXT_EXAMPLES_DICT[in_context_example.mode] = in_context_example.examples 29 | 30 | 31 | def get_few_shot_messages(mode: str = "text_to_sql") -> List[Dict]: 32 | global IN_CONTEXT_EXAMPLES_DICT 33 | 34 | examples = IN_CONTEXT_EXAMPLES_DICT.get(mode, []) 35 | messages = [] 36 | for example in examples: 37 | messages.append({ 38 | "role": "user", 39 | "content": example["user"], 40 | }) 41 | messages.append({ 42 | "role": "assistant", 43 | "content": example["assistant"], 44 | }) 45 | return messages 46 | 47 | 48 | def get_assistant_message( 49 | messages: List[Dict[str, str]], 50 | temperature: int = 0, 51 | model: str = "gpt-3.5-0301", 52 | # model: str = "gpt-3.5-turbo", 53 | # model: str = "gpt-4", 54 | ): 55 | res = openai.ChatCompletion.create( 56 | model=model, 57 | temperature=temperature, 58 | messages=messages 59 | ) 60 | # completion = res['choices'][0]["message"]["content"] 61 | assistant_message = res['choices'][0] 62 | return assistant_message 63 | 64 | 65 | def clean_message_content(assistant_message_content): 66 | """ 67 | Cleans message content to extract the SQL query 68 | """ 69 | # Ignore text after the SQL query terminator `;` 70 | assistant_message_content = assistant_message_content.split(";")[0] 71 | 72 | # Remove prefix for corrected query assistant message 73 | split_corrected_query_message = assistant_message_content.split(":") 74 | if len(split_corrected_query_message) > 1: 75 | sql_query = split_corrected_query_message[1].strip() 76 | else: 77 | sql_query = assistant_message_content 78 | return sql_query 79 | 80 | 81 | def extract_sql_query_from_message(assistant_message_content): 82 | print(assistant_message_content) 83 | content = extract_code_from_markdown(assistant_message_content) 84 | return clean_message_content(content) 85 | 86 | 87 | def extract_code_from_markdown(assistant_message_content): 88 | matches = re.findall(r"```([\s\S]+?)```", assistant_message_content) 89 | 90 | if matches: 91 | code_str = matches[0] 92 | match = re.search(r"(?i)sql\s+(.*)", code_str, re.DOTALL) 93 | if match: 94 | code_str = match.group(1) 95 | else: 96 | code_str = assistant_message_content 97 | 98 | return code_str -------------------------------------------------------------------------------- /byod/api/app/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/byod/api/app/visualization/__init__.py -------------------------------------------------------------------------------- /byod/api/app/visualization/routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, jsonify, make_response, request 2 | 3 | from .utils import get_changed_vega, get_vega_lite_spec 4 | 5 | bp = Blueprint('visualization_bp', __name__) 6 | 7 | 8 | @bp.route("/viz", methods=["POST"]) 9 | def get_visualization(): 10 | """ 11 | Get Vega-Lite spec from data 12 | """ 13 | request_body = request.get_json() 14 | data = request_body.get("data") 15 | 16 | if not data: 17 | return make_response(jsonify({"error": "`data` is missing from request body"}), 400) 18 | 19 | vega_lite_spec = get_vega_lite_spec(data) 20 | return make_response(jsonify({"vega_lite_spec": vega_lite_spec}), 200) 21 | 22 | 23 | @bp.route('/text_to_viz', methods=['POST']) 24 | def modify_visualization(): 25 | """ 26 | Change Vega-Lite spec based on a command 27 | """ 28 | request_body = request.get_json() 29 | natural_language_command = request_body.get('natural_language_command') 30 | vega_lite_spec = request_body.get('vega_lite_spec') 31 | 32 | if not natural_language_command: 33 | return make_response(jsonify({"error": "`natural_language_command` is missing from request body"}), 400) 34 | 35 | if not vega_lite_spec: 36 | return make_response(jsonify({"error": "`vega_lite_spec` is missing from request body"}), 400) 37 | 38 | changed_vega = get_changed_vega(natural_language_command, vega_lite_spec) 39 | return make_response(jsonify({"changed_vega": changed_vega}), 200) -------------------------------------------------------------------------------- /byod/api/app/visualization/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Dict 4 | 5 | from ..utils import get_assistant_message, get_few_shot_messages 6 | 7 | 8 | def make_default_visualize_data_messages(): 9 | default_messages = [{ 10 | "role": "system", 11 | "content": ( 12 | "You are a helpful assistant for generating syntactically correct Vega-Lite specs that are best for visualizing given data." 13 | " Write responses in markdown format." 14 | " You will be given a JSON object in the following format." 15 | "\n\n" 16 | """ 17 | { 18 | "fields": [ 19 | { 20 | "name": "field_name", // name of the field 21 | "type": "nominal" // type of the field (quantitative, nominal, ordinal, temporal) 22 | "sample_value": "sample_value" // example value for the field 23 | } 24 | ], 25 | "total_rows": 100 // total number of rows in the result 26 | } 27 | """ 28 | ) 29 | }] 30 | default_messages.extend(get_few_shot_messages(mode="visualization")) 31 | return default_messages 32 | 33 | 34 | def make_visualize_data_message(): 35 | return ( 36 | "Generate a syntactically correct Vega-Lite spec to best visualize the given data." 37 | "\n\n" 38 | "{data}" 39 | ) 40 | 41 | 42 | def make_default_visualization_change_messages(): 43 | default_messages = [{ 44 | "role": "system", 45 | "content": ( 46 | "You are a helpful assistant for making changes to a Vega-Lite spec." 47 | " You generate a syntactically correct Vega-Lite spec." 48 | " You will be given a Vega-Lite spec and a command." 49 | " Write responses in markdown format." 50 | ) 51 | }] 52 | default_messages.extend(get_few_shot_messages(mode="visualization_edits")) 53 | return default_messages 54 | 55 | 56 | def make_visualization_change_message(): 57 | return ( 58 | "Make the following changes to the given Vega-Lite spec to best visualize the data." 59 | "\n\n" 60 | "changes: {command}" 61 | "\n\n" 62 | "Vega-Lite spec: {vega_lite_spec}" 63 | ) 64 | 65 | 66 | def get_vega_lite_spec(data) -> Dict: 67 | messages = make_default_visualize_data_messages() 68 | messages.append({ 69 | "role": "user", 70 | "content": make_visualize_data_message().format( 71 | data=json.dumps(data, indent=2) 72 | ) 73 | }) 74 | vega = extract_json_str_from_markdown( 75 | get_assistant_message(messages)["message"]["content"] 76 | ) 77 | return json.loads(vega) 78 | 79 | 80 | def get_changed_vega(command, vega_lite_spec) -> Dict: 81 | messages = make_default_visualization_change_messages() 82 | messages.append({ 83 | "role": "user", 84 | "content": make_visualization_change_message().format( 85 | command=command, 86 | vega_lite_spec=vega_lite_spec 87 | ) 88 | }) 89 | vega = extract_json_str_from_markdown( 90 | get_assistant_message(messages)["message"]["content"] 91 | ) 92 | return json.loads(vega) 93 | 94 | 95 | def extract_json_str_from_markdown(assistant_message_content) -> str: 96 | matches = re.findall(r"```([\s\S]+?)```", assistant_message_content) 97 | 98 | if matches: 99 | code_str = matches[0] 100 | match = re.search(r"(?i)bash\s+(.*)", code_str, re.DOTALL) 101 | if match: 102 | code_str = match.group(1) 103 | else: 104 | code_str = assistant_message_content 105 | 106 | return code_str -------------------------------------------------------------------------------- /byod/api/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask===2.2.2 2 | flask-admin==1.6.1 3 | flask-sqlalchemy==3.0.2 4 | flask-cors==3.0.10 5 | Flask-Migrate==4.0.4 6 | psycopg2-binary==2.9.5 7 | gunicorn==20.1.0 8 | openai[embeddings]==0.27.2 9 | openai==0.27.2 10 | python-dotenv==1.0.0 11 | blinker==1.5 12 | joblib==1.2.0 13 | requests==2.28.1 14 | pytest==7.2.2 15 | pinecone-client==2.2.1 16 | snowflake-connector-python==3.0.2 -------------------------------------------------------------------------------- /byod/api/scripts/dev.sh: -------------------------------------------------------------------------------- 1 | FLASK_APP=app.py FLASK_DEBUG=true flask run -p 9000 -------------------------------------------------------------------------------- /byod/api/scripts/setup.sh: -------------------------------------------------------------------------------- 1 | # create python venv named venv 2 | python3.10 -m venv ./venv 3 | 4 | # activate venv 5 | source venv/bin/activate 6 | 7 | # upgrade pip 8 | pip install --upgrade pip 9 | 10 | # install project dependencies 11 | cat requirements.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 python -m pip install 12 | -------------------------------------------------------------------------------- /byod/client/.env.example: -------------------------------------------------------------------------------- 1 | API_BASE="http://localhost:9000" -------------------------------------------------------------------------------- /byod/client/.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | DS_Store 3 | venv/ 4 | __pycache__/ 5 | scratch/ -------------------------------------------------------------------------------- /byod/client/README.md: -------------------------------------------------------------------------------- 1 | # TextSQL front-end 2 | 3 | A front-end streamlit application for Text-to-SQL 4 | 5 | ## Prerequisites 6 | `python3.10` 7 | 8 | ## Required configuration for development: 9 | - base URL for TextSQL API 10 | 11 | Configure the above in `.env`. Refer to `.env.example`. 12 | 13 | ## Local development 14 | 15 | Initial setup 16 | ``` 17 | $ ./scripts/setup.sh 18 | ``` 19 | 20 | Activate virtual env 21 | ``` 22 | $ source ./venv/bin/activate 23 | ``` 24 | 25 | Run local instance 26 | ``` 27 | $ ./scripts/dev.sh 28 | ``` -------------------------------------------------------------------------------- /byod/client/app.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import requests 4 | import streamlit as st 5 | from config import API_BASE 6 | 7 | VEGA_LITE_TYPES_MAP = { 8 | "int": "quantitative", 9 | "float": "quantitative", 10 | "str": "nominal", 11 | "bool": "nominal", 12 | "date": "temporal", 13 | "time": "temporal", 14 | "datetime": "temporal", 15 | } 16 | 17 | 18 | def create_viz_data_dict(column_names, column_types, results): 19 | data = { 20 | "fields": [], 21 | "total_rows": len(results), 22 | } 23 | for i, column_name in enumerate(column_names): 24 | data["fields"].append({ 25 | "name": column_name, 26 | "type": VEGA_LITE_TYPES_MAP.get(column_types[i], "nominal"), 27 | }) 28 | for i in range(len(results)): 29 | # include 1 sample 30 | if i == 1: 31 | break 32 | r = results[i] 33 | for j, column_name in enumerate(column_names): 34 | data["fields"][j]["sample_value"] = r[column_name] 35 | return data 36 | 37 | 38 | def main(): 39 | st.title("Text-to-SQL") 40 | 41 | natural_language_query = st.text_input(label="Ask anything...", label_visibility="hidden", placeholder="Ask anything...") 42 | 43 | if natural_language_query: 44 | with st.spinner(text="Generating SQL..."): 45 | start_time = time.time() 46 | response = requests.post(f"{API_BASE}/text_to_sql", json={"natural_language_query": natural_language_query}) 47 | end_time = time.time() 48 | time_taken = end_time - start_time 49 | if response.status_code == 200: 50 | st.info(f"SQL generated in {time_taken:.2f} seconds") 51 | SQL = f"""```sql 52 | {response.json().get("sql_query")} 53 | """ 54 | st.markdown(SQL) 55 | 56 | RESULT = response.json().get("result", {}) 57 | st.table(RESULT.get("results", [])) 58 | 59 | with st.spinner(text="Generating visualization..."): 60 | start_time = time.time() 61 | response = requests.post(f"{API_BASE}/viz", 62 | json={ 63 | "data": create_viz_data_dict( 64 | RESULT.get("column_names", []), 65 | RESULT.get("column_types", []), 66 | RESULT.get("results", []), 67 | ) 68 | } 69 | ) 70 | end_time = time.time() 71 | time_taken = end_time - start_time 72 | if response.status_code == 200: 73 | st.info(f"Visualization generated in {time_taken:.2f} seconds") 74 | VEGA_LITE_SPEC = response.json().get("vega_lite_spec") 75 | st.vega_lite_chart(data=RESULT.get("results", []), spec=VEGA_LITE_SPEC) 76 | else: 77 | st.error(f"{response.status_code}: {response.reason}") 78 | st.info("Sorry, I couldn't generate a visualization. Please try again.") 79 | else: 80 | st.error(f"{response.status_code}: {response.reason}") 81 | st.info("Sorry, I couldn't answer your question/command. Please try again.") 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /byod/client/config.py: -------------------------------------------------------------------------------- 1 | from os import getenv 2 | 3 | from dotenv import load_dotenv 4 | 5 | 6 | load_dotenv() 7 | 8 | API_BASE = getenv("API_BASE") -------------------------------------------------------------------------------- /byod/client/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.18.1 2 | requests==2.28.2 3 | python-dotenv==1.0.0 -------------------------------------------------------------------------------- /byod/client/scripts/dev.sh: -------------------------------------------------------------------------------- 1 | streamlit run app.py -------------------------------------------------------------------------------- /byod/client/scripts/setup.sh: -------------------------------------------------------------------------------- 1 | # create python venv named venv 2 | python3.10 -m venv ./venv 3 | 4 | # activate venv 5 | source venv/bin/activate 6 | 7 | # upgrade pip version 8 | pip install --upgrade pip 9 | 10 | # install project dependencies 11 | cat requirements.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 python -m pip install 12 | -------------------------------------------------------------------------------- /client/censusGPT/.env.production: -------------------------------------------------------------------------------- 1 | # `REACT_APP` prefix is required to expose to client-side 2 | REACT_APP_VERCEL_ANALYTICS_ID=$VERCEL_ANALYTICS_ID -------------------------------------------------------------------------------- /client/censusGPT/.eslintignore: -------------------------------------------------------------------------------- 1 | **/node_modules/* 2 | **/out/* 3 | **/.next/* 4 | -------------------------------------------------------------------------------- /client/censusGPT/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "es6": true 5 | }, 6 | "extends": [ 7 | "eslint:recommended", 8 | "plugin:react/recommended" 9 | ], 10 | "parserOptions": { 11 | "ecmaFeatures": { 12 | "jsx": true 13 | }, 14 | "ecmaVersion": 2020, 15 | "sourceType": "module" 16 | }, 17 | "plugins": [ 18 | "react", 19 | "import" 20 | ], 21 | "rules": { 22 | "react/prop-types": 0, 23 | "no-undef": "error", 24 | "no-unused-vars": "warn", 25 | "react/no-unescaped-entities": "off", 26 | "no-irregular-whitespace": "off", 27 | "react/no-unknown-property": "off" 28 | }, 29 | "globals": { 30 | "process": "readonly", 31 | "React": "writable", 32 | "require": "readonly" 33 | } 34 | } -------------------------------------------------------------------------------- /client/censusGPT/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | /dist 14 | 15 | # misc 16 | .DS_Store 17 | .env.local 18 | .env.development.local 19 | .env.test.local 20 | .env.production.local 21 | 22 | npm-debug.log* 23 | yarn-debug.log* 24 | yarn-error.log* 25 | 26 | # Environment Variables 27 | .env 28 | .env.build 29 | .vercel 30 | 31 | .idea/ -------------------------------------------------------------------------------- /client/censusGPT/.prettierignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | # Ignore artifacts: 3 | build 4 | coverag -------------------------------------------------------------------------------- /client/censusGPT/.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "trailingComma": "es5", 3 | "tabWidth": 4, 4 | "semi": false, 5 | "singleQuote": true 6 | } -------------------------------------------------------------------------------- /client/censusGPT/README.md: -------------------------------------------------------------------------------- 1 | # textSQL Frontend 2 | ## Development 3 | 4 | In the project directory, you can run: 5 | 6 | ### `npm run dev` 7 | 8 | Runs the app with **API pointing to Development environment**. Open [http://localhost:3000](http://localhost:3000) to view in browser. 9 | 10 | ### `npm start` 11 | 12 | Runs the app in development mode with **API pointing to production**. Open [http://localhost:3000](http://localhost:3000) to view in browser. 13 | The page will reload when you make changes. You may also see any lint errors in the console. 14 | 15 | ### `npm test` 16 | 17 | Launches the test runner in the interactive watch mode. See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information. 18 | 19 | ### `npm run build` 20 | 21 | Builds the app for production to the `build` folder. 22 | 23 | It correctly bundles React in production mode and optimizes the build for the best performance. The build is minified and the filenames include the hashes. -------------------------------------------------------------------------------- /client/censusGPT/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "private": true, 3 | "dependencies": { 4 | "@headlessui/react": "^1.7.12", 5 | "@heroicons/react": "^2.0.16", 6 | "@sentry/react": "^7.42.0", 7 | "@sentry/tracing": "^7.42.0", 8 | "@tailwindcss/forms": "^0.5.3", 9 | "@testing-library/jest-dom": "^5.16.5", 10 | "@testing-library/react": "^13.4.0", 11 | "@testing-library/user-event": "^14.4.3", 12 | "@turf/turf": "^6.5.0", 13 | "autoprefixer": "^10.4.13", 14 | "deck.gl": "^8.8.26", 15 | "mapbox-gl": "^2.13.0", 16 | "plotly.js": "^2.18.2", 17 | "postcss-cli": "^10.1.0", 18 | "posthog-js": "^1.50.4", 19 | "react": "^18.2.0", 20 | "react-dom": "^18.2.0", 21 | "react-github-btn": "^1.4.0", 22 | "react-hot-toast": "^2.4.0", 23 | "react-icons": "^4.8.0", 24 | "react-map-gl": "^7.0.21", 25 | "react-plotly.js": "^2.6.0", 26 | "react-router-dom": "^6.8.2", 27 | "react-scripts": "5.0.1", 28 | "react-syntax-highlighter": "^15.5.0", 29 | "tailwindcss": "^3.2.7", 30 | "turf": "^3.0.14", 31 | "use-debounce": "^9.0.3", 32 | "web-vitals": "^3.0.4", 33 | "worker-loader": "^3.0.8", 34 | "@turf/bbox": "^6.5.0" 35 | }, 36 | "devDependencies": { 37 | "eslint": "^8.36.0", 38 | "eslint-plugin-only-warn": "^1.1.0", 39 | "eslint-plugin-react": "^7.32.2", 40 | "prettier": "^2.8.4" 41 | }, 42 | "scripts": { 43 | "start": "react-scripts start", 44 | "build": "react-scripts build", 45 | "dev": "REACT_APP_HOST_ENV=dev react-scripts start", 46 | "lint": "eslint ./ --ext js,jsx,ts,tsx", 47 | "test": "react-scripts test", 48 | "eject": "react-scripts eject", 49 | "format": "prettier --check ./src", 50 | "format:fix": "prettier --write ./src", 51 | "lint": "eslint ./src", 52 | "lint:fix": "eslint --fix ./src" 53 | }, 54 | "eslintConfig": { 55 | "extends": [ 56 | "react-app", 57 | "react-app/jest" 58 | ] 59 | }, 60 | "browserslist": { 61 | "production": [ 62 | ">0.2%", 63 | "not dead", 64 | "not op_mini all" 65 | ], 66 | "development": [ 67 | "last 1 chrome version", 68 | "last 1 firefox version", 69 | "last 1 safari version" 70 | ] 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /client/censusGPT/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: [require('tailwindcss'), require('autoprefixer')], 3 | }; 4 | -------------------------------------------------------------------------------- /client/censusGPT/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/client/censusGPT/public/favicon.ico -------------------------------------------------------------------------------- /client/censusGPT/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 17 | 18 | 27 | Census GPT 28 | 29 | 30 | 31 |
32 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /client/censusGPT/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/client/censusGPT/public/logo192.png -------------------------------------------------------------------------------- /client/censusGPT/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/client/censusGPT/public/logo512.png -------------------------------------------------------------------------------- /client/censusGPT/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "Census GPT", 3 | "name": "Census GPT", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /client/censusGPT/public/mapbox-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/client/censusGPT/public/mapbox-sample.png -------------------------------------------------------------------------------- /client/censusGPT/public/official_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caesarHQ/textSQL/7393e19b1458f1004d312713f5fe49968449e33c/client/censusGPT/public/official_logo.png -------------------------------------------------------------------------------- /client/censusGPT/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /client/censusGPT/src/SanFrancisco.js: -------------------------------------------------------------------------------- 1 | import App from "./App" 2 | 3 | function SanFrancisco(props) { 4 | return 5 | } 6 | export default SanFrancisco 7 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/banner.js: -------------------------------------------------------------------------------- 1 | import { XMarkIcon } from '@heroicons/react/20/solid'; 2 | 3 | export default function PromoBanner() { 4 | return ( 5 |
6 |

7 | Hey! The team behind CensusGPT is now working on Julius, your personal AI data analyst  8 |

9 | 13 | Check out Julius 14 | 15 |
16 | 20 |
21 |
22 | ); 23 | } 24 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/dataPlot.js: -------------------------------------------------------------------------------- 1 | import Plot from 'react-plotly.js' 2 | import { getPlotConfig } from '../utils/plotly-ui-config' 3 | 4 | 5 | const DataPlot = (props) => { 6 | let config = getPlotConfig(props.rows, props.cols) 7 | 8 | return ( 9 | 15 | ); 16 | } 17 | 18 | export default DataPlot -------------------------------------------------------------------------------- /client/censusGPT/src/components/disclaimer.js: -------------------------------------------------------------------------------- 1 | const Disclaimer = (props) => { 2 | const SF_disclaimer = ( 3 | <> 4 | Note: SanFranciscoGPT currently only has data for crime, 311 cases, demographics, income, population, food, parks, and housing in SF. But we are working to add more data! 5 |
6 | 311 data and crime data are sourced from the city's website for public datasets and include data from 1/1/21 to 4/7/23. 7 |
8 | This app uses SF Analysis Neighborhoods which have boundaries formed specifically to fit census tracts. 9 | 10 | ); 11 | 12 | const Census_disclaimer = ( 13 | <> 14 | Note: CensusGPT currently only has data for crime, demographics, income, education levels, and population in the USA. But we are working to add more data! 15 |
16 | Census data is sourced from the 2021 ACS (latest). Crime data is sourced from the FBI's 2019 UCR (latest). 17 | 18 | ); 19 | 20 | return ( 21 |
22 |
23 | {props.version === 'San Francisco' ? SF_disclaimer : Census_disclaimer} 24 |
25 |
26 | ) 27 | } 28 | 29 | 30 | export default Disclaimer 31 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/error.js: -------------------------------------------------------------------------------- 1 | import { XCircleIcon } from '@heroicons/react/20/solid' 2 | 3 | /** 4 | * The error message component 5 | * @param {*} props 6 | * @param {string} props.errorMessage - The error message to display 7 | * @returns {JSX.Element} – The error message component 8 | */ 9 | function ErrorMessage(props) { 10 | return ( 11 |
12 |
13 |
14 |
19 |
20 |

21 | There were errors with your submission 22 |

23 |
24 |
    25 |
  • {props.errorMessage.toString()}
  • 26 |
27 |
28 |
29 |
30 |
31 | ) 32 | } 33 | 34 | export default ErrorMessage 35 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/exampleCard.js: -------------------------------------------------------------------------------- 1 | import { capturePosthog } from '../utils/loggers/posthog' 2 | 3 | export const ExampleCard = ({ example, props }) => { 4 | return ( 5 |
{ 13 | capturePosthog('example_clicked', { 14 | natural_language_query: example.input_text, 15 | }) 16 | props.setQuery(example.input_text) 17 | props.handleClick(example.input_text) 18 | }} 19 | > 20 | {example.img && ( 21 | 29 | )} 30 |
31 | {example.emoji} 32 |
33 |

39 | {example.input_text} 40 |

41 |
42 | ) 43 | } 44 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/examples.js: -------------------------------------------------------------------------------- 1 | // Examples 2 | import { capturePosthog } from "../utils/loggers/posthog" 3 | /** 4 | * Examples component 5 | * @param {*} props – The props for the example component used to pass in callback functions 6 | * @param {*} props.posthogInstance - The posthog instance 7 | * @param {*} props.setQuery - Sets the query in the search bar 8 | * @param {*} props.handleClick - Handles the search button click 9 | * @returns {JSX.Element} – The examples component 10 | */ 11 | const Examples = (props) => { 12 | let basic_example_queries = [ 13 | 'Five cities in Florida with the highest crime', 14 | 'Richest neighborhood in Houston, TX', 15 | ] 16 | let advanced_example_queries = [ 17 | '3 neighborhoods in San Francisco that have the highest female to male ratio', 18 | 'Which area in San Francisco has the highest racial diversity and what is the percentage population of each race in that area?', 19 | // "Which 5 areas have the median income closest to the national median income?" 20 | ] 21 | 22 | if (props.version === 'San Francisco') { 23 | basic_example_queries = [ 24 | 'Show me the locations of the 10 highest rated coffee shops with at least 100 ratings.', 25 | 'Which neighborhood has the most parks?', 26 | 'Show me all the needles in SF', 27 | 'Show me all the muggings', 28 | 'Which two neighborhoods have the most homeless activity?', 29 | 'Which five neighborhoods have the most poop on the street?', 30 | ] 31 | advanced_example_queries = [ 32 | 'Which four neighborhoods had the most crime incidents involving guns or knives in 2021?', 33 | '3 neighborhoods with the highest female to male ratio', 34 | 'What are the top 5 neighborhoods with the most encampments per capita?', 35 | 'What hour of the day do most burglaries occur?', 36 | ] 37 | } 38 | 39 | return ( 40 |
41 |

Try these:

42 |
43 |

Basic

44 |
45 | {basic_example_queries.map((q) => ( 46 |
50 |
51 |

{ 54 | capturePosthog( 55 | 'example_clicked', 56 | { natural_language_query: q } 57 | ) 58 | props.setQuery(q) 59 | props.handleClick(q) 60 | }} 61 | > 62 |

{q}

67 |

68 |
69 | 74 | 75 | 76 |
77 | ))} 78 |
79 |
80 |
81 |

Advanced

82 |
83 | {advanced_example_queries.map((q) => ( 84 |
88 |
89 |

{ 92 | capturePosthog( 93 | 'example_clicked', 94 | { natural_language_query: q } 95 | ) 96 | props.setQuery(q) 97 | props.handleClick(q) 98 | }} 99 | > 100 |

{q}

105 |

106 |
107 | 112 | 113 | 114 |
115 | ))} 116 |
117 |
118 |
119 | ) 120 | } 121 | 122 | export default Examples 123 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/examplesFeed.js: -------------------------------------------------------------------------------- 1 | // Examples 2 | import { useContext } from 'react' 3 | import { FeedContext } from '../contexts/feedContext' 4 | import { ExampleCard } from './exampleCard' 5 | /** 6 | * Examples component 7 | * @param {*} props – The props for the example component used to pass in callback functions 8 | * @param {*} props.posthogInstance - The posthog instance 9 | * @param {*} props.setQuery - Sets the query in the search bar 10 | * @param {*} props.handleClick - Handles the search button click 11 | * @returns {JSX.Element} – The examples component 12 | */ 13 | const ExamplesFeed = (props) => { 14 | const { examples } = useContext(FeedContext) 15 | 16 | return ( 17 |
18 |

Try one of these examples:

19 |
20 |
27 | {examples.map((example, idx) => ( 28 | 33 | ))} 34 |
35 |
36 |
37 | ) 38 | } 39 | 40 | export default ExamplesFeed 41 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/explanationModal.js: -------------------------------------------------------------------------------- 1 | import { useEffect, useState } from 'react' 2 | import { XCircleIcon } from '@heroicons/react/20/solid' 3 | 4 | export const ExplanationModal = ({showExplanationModal, setShowExplanationModal, version}) =>{ 5 | 6 | const messageToShow = showExplanationModal == 'no_tables' ? "Sorry, we don't think we're able to help with that query yet ='(" : "Sorry, we tried to answer your question but weren't able to get a working query." 7 | 8 | return ( 9 | 10 |
11 |
12 |
13 |
23 |
24 |

25 |
{messageToShow}
26 | 27 | 28 |

29 |
30 |
31 |
32 | ) 33 | 34 | } 35 | const Disclaimer = (props) => { 36 | const SF_disclaimer = ( 37 | <> 38 | Note: SanFranciscoGPT currently only has data for crime, 311 cases, demographics, income, population, food, parks, housing in SF. But we are working to add more data! 39 |
40 | 311 data and crime data are sourced from the city's website for public datasets and include data from 1/1/21 to 4/7/23. 41 |
42 | This app uses SF Analysis Neighborhoods which have boundaries formed specifically to fit census tracts. 43 | 44 | ); 45 | 46 | const Census_disclaimer = ( 47 | <> 48 | Note: CensusGPT currently only has data for crime, demographics, income, education levels and population in the USA. But we are working to add more data! 49 |
50 | Census data is sourced from the 2021 ACS (latest). Crime data is sourced from the FBI's 2019 UCR (latest). 51 | 52 | ); 53 | 54 | return ( 55 |
56 | {props.version === 'San Francisco' ? SF_disclaimer : Census_disclaimer} 57 |
58 | ) 59 | } 60 | 61 | export default Disclaimer 62 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/header.js: -------------------------------------------------------------------------------- 1 | import { useState } from 'react' 2 | import { Dialog } from '@headlessui/react' 3 | import { Bars3Icon, XMarkIcon } from '@heroicons/react/24/outline' 4 | 5 | const navigation = [ 6 | { name: 'Star on Github', href: 'https://github.com/caesarhq/textSQL' }, 7 | { name: 'Discord', href: 'https://discord.gg/JZtxhZQQus' }, 8 | { name: 'How does it work?', href: '#' }, 9 | ] 10 | 11 | export default function Header(props) { 12 | const [mobileMenuOpen, setMobileMenuOpen] = useState(false) 13 | 14 | return ( 15 |
16 | 57 | 58 | {/* TODO @rahul : This is for Mobile. Get it to the work */} 59 | 60 | {/* 61 |
62 | 63 | 95 |
96 | {navigation.map((item) => ( 97 | 102 | {item.name} 103 | 104 | ))} 105 |
106 |
107 |
*/} 108 |
109 | ) 110 | } 111 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/headerButtons.js: -------------------------------------------------------------------------------- 1 | import { useState } from 'react' 2 | import { 3 | BsArrowUp, 4 | BsDiscord, 5 | BsGithub, 6 | BsMoonFill, 7 | BsSunFill, 8 | BsUpload, 9 | } from 'react-icons/bs' 10 | 11 | const HeaderButton = ({ title, icon, onClick }) => ( 12 | 20 | ) 21 | 22 | export const DiscordButton = () => ( 23 | } 26 | onClick={() => window.open('https://discord.gg/JZtxhZQQus', '_blank')} 27 | /> 28 | ) 29 | 30 | // "https://github.com/caesarhq/textSQL" 31 | export const GithubButton = () => ( 32 | } 35 | onClick={() => 36 | window.open('https://github.com/caesarhq/textSQL', '_blank') 37 | } 38 | /> 39 | ) 40 | 41 | export const ContributeButton = () => ( 42 | } 45 | onClick={() => 46 | window.open('https://airtable.com/shrDKRRGyRCihWEZd', '_blank') 47 | } 48 | /> 49 | ) 50 | 51 | export const DarkModeButton = () => { 52 | const [darkMode, setDarkMode] = useState( 53 | document.documentElement.classList.contains('dark') 54 | ) 55 | return ( 56 | : } 58 | onClick={() => { 59 | if (darkMode) { 60 | document.documentElement.classList.remove('dark') 61 | setDarkMode(false) 62 | } else { 63 | document.documentElement.classList.add('dark') 64 | setDarkMode(true) 65 | } 66 | }} 67 | /> 68 | ) 69 | } 70 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/loadingSpinner.js: -------------------------------------------------------------------------------- 1 | // The loading spinner 2 | 3 | /** 4 | * Generates loading spinner component 5 | * @param {*} props – The loading spinner props 6 | * @param {boolean} props.isLoading – Whether the spinner should be displayed or not 7 | * @returns {JSX.Element} – The loading spinner component 8 | */ 9 | function LoadingSpinner(props) { 10 | return props.isLoading ? ( 11 |
12 | 28 | Loading... 29 |
30 | ) : ( 31 | <> 32 | ) 33 | } 34 | 35 | export default LoadingSpinner 36 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/results/dataVisualization.js: -------------------------------------------------------------------------------- 1 | import mapboxgl from 'mapbox-gl' 2 | import Map, { Layer, Source } from 'react-map-gl' 3 | 4 | import DataPlot from '../dataPlot' 5 | import { VizSelector } from '../vizSelector' 6 | import { FEATURE_FLAGS } from '../../featureFlags' 7 | 8 | // Mapbox UI configuration 9 | import { 10 | zipcodeFeatures, 11 | citiesFeatures, 12 | zipcodeLayerHigh, 13 | zipcodeLayerLow, 14 | citiesLayer, 15 | polygonsLayer, 16 | pointsFeatures, 17 | pointsLayer, 18 | } from '../../utils/mapbox-ui-config' 19 | 20 | // The following is required to stop "npm build" from transpiling mapbox code. 21 | // notice the exclamation point in the import. 22 | // @ts-ignore 23 | // prettier-ignore 24 | // eslint-disable-next-line import/no-webpack-loader-syntax, import/no-unresolved 25 | mapboxgl.workerClass = require('worker-loader!mapbox-gl/dist/mapbox-gl-csp-worker').default; 26 | 27 | mapboxgl.Map.prototype.toImage = function (width, height, callback) { 28 | const originalWidth = this.getCanvas().width 29 | const originalHeight = this.getCanvas().height 30 | 31 | const originalStyleWidth = this.getCanvas().style.width 32 | const originalStyleHeight = this.getCanvas().style.height 33 | 34 | this.getCanvas().width = width 35 | this.getCanvas().height = height 36 | this.getCanvas().style.width = `${width}px` 37 | this.getCanvas().style.height = `${height}px` 38 | 39 | this.once('render', () => { 40 | setTimeout(() => { 41 | const imgData = this.getCanvas().toDataURL('image/png') 42 | this.getCanvas().width = originalWidth 43 | this.getCanvas().height = originalHeight 44 | this.getCanvas().style.width = originalStyleWidth 45 | this.getCanvas().style.height = originalStyleHeight 46 | this.resize() 47 | callback(imgData) 48 | }, 100) 49 | }) 50 | 51 | this.resize() 52 | this._renderTaskQueue.run() 53 | } 54 | 55 | export const DataVisualization = ({ 56 | visualization, 57 | setVisualization, 58 | mobileTableRef, 59 | mobileSqlRef, 60 | mapRef, 61 | initialView, 62 | zipcodes, 63 | zipcodesFormatted, 64 | cities, 65 | polygonsGeoJSON, 66 | tableInfo, 67 | points, 68 | sql, 69 | props, 70 | }) => { 71 | const handleDownloadMap = async () => { 72 | const downloadButton = document.querySelector('#downloadButton') 73 | downloadButton.disabled = true 74 | 75 | const map = mapRef.current.getMap() 76 | map.toImage(250, 250, (imgData) => { 77 | const link = document.createElement('a') 78 | link.href = imgData 79 | link.download = 'map.png' 80 | link.click() 81 | 82 | // Re-enable the download button after the download has finished. 83 | downloadButton.disabled = false 84 | }) 85 | } 86 | 87 | return ( 88 |
89 |
90 | 97 |
98 |
99 | {visualization == 'map' ? ( 100 | <> 101 | 115 | 120 | 123 | 124 | 132 | 133 | 134 | 142 | 143 | 144 | 149 | 150 | 151 | 159 | 160 | 161 | 162 | {FEATURE_FLAGS.downloadButton && ( 163 |
164 | 171 |
172 | )} 173 | 174 | ) : ( 175 | // following
helps plot better scale bar widths for responsiveness 176 |
177 | 181 |
182 | )} 183 |
184 |
185 | ) 186 | } 187 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/results/resultsContainer.js: -------------------------------------------------------------------------------- 1 | import React, { useState } from 'react' 2 | import { DataVisualization } from './dataVisualization' 3 | import { SQLDisplay } from './sqlDisplay' 4 | import Table from '../table' 5 | import Examples from '../examples' 6 | import ExamplesFeed from '../examplesFeed' 7 | 8 | import { BsChevronCompactDown, BsDashLg, BsTable } from 'react-icons/bs' 9 | 10 | export const ResultsContainer = ({ 11 | visualization, 12 | setVisualization, 13 | mobileTableRef, 14 | mobileSqlRef, 15 | mapRef, 16 | initialView, 17 | zipcodes, 18 | zipcodesFormatted, 19 | cities, 20 | polygonsGeoJSON, 21 | tableInfo, 22 | points, 23 | sql, 24 | props, 25 | isStartingState, 26 | isLoading, 27 | isGetTablesLoading, 28 | setQuery, 29 | fetchBackend, 30 | useServerFeed, 31 | tableColumns, 32 | tableRows, 33 | tableNames, 34 | sqlExplanationIsOpen, 35 | setSqlExplanationIsOpen, 36 | isExplainSqlLoading, 37 | sqlExplanation, 38 | explainSql, 39 | executeSql, 40 | setSQL, 41 | title, 42 | }) => { 43 | return ( 44 |
45 | {!isStartingState && ( 46 | 62 | )} 63 | 64 |
65 |
66 | {!isLoading && sql.length !== 0 && ( 67 | <> 68 |
69 | 82 |
83 | 84 | 85 | 86 | )} 87 | {tableNames && ( 88 | 89 | )} 90 | 91 | 92 | 93 | ) 94 | } 95 | 96 | const TableNamesDisplay = ({ tableNames }) => { 97 | const [minimizeTableNames, setMinimizeTableNames] = useState(false) 98 | return ( 99 |
100 |
101 |
102 | 103 | Tables Queried 104 |
105 | 106 | 116 |
117 | 118 | {!minimizeTableNames && ( 119 |
    120 | {tableNames.map((tableName, index) => ( 121 |
  • 132 | {tableName} 133 |
  • 134 | ))} 135 |
136 | )} 137 |
138 | ) 139 | } 140 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/searchBar.js: -------------------------------------------------------------------------------- 1 | import { AiOutlineSearch } from 'react-icons/ai' 2 | import { FaTimes } from 'react-icons/fa' 3 | import Suggestion from './suggestion' 4 | 5 | 6 | const SearchButton = (props) => { 7 | return ( 8 | 15 | ) 16 | } 17 | 18 | const SearchBar = (props) => { 19 | const { 20 | value, 21 | onSearchChange, 22 | onClear, 23 | suggestedQuery, 24 | setTitle, 25 | setQuery, 26 | fetchBackend, 27 | currentSuggestionId, 28 | } = props 29 | 30 | return ( 31 |
32 |
33 |
34 | 48 | 55 |
56 | 57 |
58 |
59 | { 60 | suggestedQuery ? 61 | 68 | : 69 | null 70 | } 71 |
72 |
73 | ) 74 | } 75 | 76 | export default SearchBar 77 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/suggestion.js: -------------------------------------------------------------------------------- 1 | import { useSearchParams } from 'react-router-dom' 2 | 3 | let api_endpoint = process.env.REACT_APP_API_URL || 'https://dev-text-sql-be.onrender.com' 4 | 5 | if (process.env.REACT_APP_HOST_ENV === 'dev') { 6 | api_endpoint = 'http://localhost:9000' 7 | } 8 | 9 | const acceptSuggestion = async (id) => { 10 | const url = `${api_endpoint}/api/accept_suggestion` 11 | const body = { 12 | id, 13 | } 14 | const response = await fetch(url, { 15 | method: 'POST', 16 | headers: { 17 | 'Content-Type': 'application/json', 18 | }, 19 | body: JSON.stringify(body), 20 | }) 21 | const data = await response.json() 22 | console.log(data) 23 | return "success" 24 | } 25 | 26 | 27 | 28 | const Suggestion = (props) => { 29 | const { 30 | suggestedQuery, 31 | setTitle, 32 | setQuery, 33 | fetchBackend, 34 | currentSuggestionId, 35 | } = props 36 | 37 | const [searchParams, setSearchParams] = useSearchParams(); 38 | 39 | const handleClick = () => { 40 | acceptSuggestion(currentSuggestionId) 41 | setSearchParams(new URLSearchParams({ s: props.suggestedQuery })) 42 | setTitle(suggestedQuery) 43 | setQuery(suggestedQuery) 44 | fetchBackend(suggestedQuery, currentSuggestionId) 45 | }; 46 | 47 | const clickableQuery = ( 48 |
49 | Try this: {props.suggestedQuery} 50 |
51 | ); 52 | 53 | return ( 54 |
55 |
56 | {clickableQuery} 57 |
58 |
59 | ) 60 | } 61 | 62 | 63 | 64 | export default Suggestion 65 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/table.js: -------------------------------------------------------------------------------- 1 | import {useEffect, useState, memo} from 'react' 2 | 3 | // Path: client/src/components/table.js 4 | // Custom components for Table 5 | 6 | /** 7 | * Converts the value to title case 8 | * @param {string} value – Value to be converted to title case 9 | * @returns {string} - The converted value 10 | */ 11 | const convertToTitleCase = (value) => { 12 | // Convert the table header values to title case 13 | return value 14 | .split('_') 15 | .map((x) => x.charAt(0).toUpperCase() + x.slice(1)) 16 | .join(' ') 17 | } 18 | 19 | const formatNumber = (value, col) => { 20 | if (value === null) { 21 | return 'Unknown' 22 | } 23 | // Format the number to have commas 24 | if (col == 'zip_code') { 25 | // Don't format the zip code 26 | return value 27 | } 28 | if (col.includes('date')) { 29 | // Don't format the date 30 | return value 31 | } 32 | if (col.includes('time')) { 33 | // Don't format the time 34 | return value 35 | } 36 | if (col.includes('percentage')) { 37 | let newValue = value.toString() 38 | if (newValue.includes('.')) { 39 | // Round to 2 decimal places 40 | newValue = newValue.slice(0, newValue.indexOf('.') + 3) 41 | } else { 42 | // Add commas if no decimal places 43 | newValue.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ',') 44 | } 45 | newValue = newValue + '%' // Add the percentage sign 46 | return newValue 47 | } 48 | 49 | if (!value.toString().includes('.')) { 50 | return value.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ',') 51 | } 52 | 53 | return value 54 | } 55 | 56 | /** 57 | * Generates the table header 58 | * @param {object} props - The table columns data 59 | * @returns {JSX.Element} – The table header component 60 | */ 61 | const TableHeader = (props) => { 62 | return ( 63 |
64 | 65 | {props.columns.map((x, index) => ( 66 | 73 | ))} 74 | 75 | 76 | ) 77 | } 78 | 79 | /** 80 | * Generates the table rows 81 | * @param {object} props - The table rows data 82 | * @returns {JSX.Element} – The table rows component 83 | */ 84 | const TableRows = (props) => { 85 | 86 | return ( 87 | 88 | {props.values.slice(0, 50).map((row, i) => ( 89 | 90 | {row.map((rowValue, columnIndex) => ( 91 | 97 | ))} 98 | 99 | ))} 100 | 101 | ) 102 | } 103 | 104 | /** 105 | * Generates the Table component 106 | * @param {*} props - The table columns and rows data 107 | * @returns {JSX.Element} – The table component 108 | */ 109 | const Table = ({columns, values}) => { 110 | return ( 111 |
112 |
113 |
114 |
115 |
71 | {convertToTitleCase(x)} 72 |
95 | {formatNumber(rowValue, props.columns[columnIndex])} 96 |
116 | 119 | 123 |
124 |
125 |
126 |
127 |
128 | ) 129 | } 130 | 131 | export default memo(Table, (prevProps, nextProps) => { 132 | return prevProps.columns === nextProps.columns && prevProps.values === nextProps.values; 133 | }); 134 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/toast.js: -------------------------------------------------------------------------------- 1 | import toast from 'react-hot-toast' 2 | 3 | export const notify = (err) => { 4 | console.error(JSON.stringify(err)) 5 | toast.error(JSON.stringify(err)) 6 | } 7 | -------------------------------------------------------------------------------- /client/censusGPT/src/components/vizSelector.js: -------------------------------------------------------------------------------- 1 | export const VizSelector = (props) => { 2 | let selected = props.selected 3 | let mapClassName = 4 | 'relative mt-px inline-flex items-center rounded-t-md rounded-tr-none rounded-l-md px-3 py-2 text-sm font-semibold text-gray-900 dark:text-neutral-200 ring-1 ring-inset ring-gray-300 dark:ring-dark-300 hover:bg-gray-100 focus:z-10 ' + 5 | (selected == 'map' 6 | ? 'bg-gray-100 dark:bg-neutral-700' 7 | : 'bg-white dark:bg-neutral-600 hover:bg-gray-100 hover:dark:bg-neutral-700') 8 | let chartClassName = 9 | 'relative mt-px -ml-px inline-flex items-center rounded-r-md px-3 py-2 text-sm font-semibold text-gray-900 dark:text-neutral-200 ring-1 ring-inset ring-gray-300 dark:ring-dark-300 hover:bg-gray-100 focus:z-10 ' + 10 | (selected == 'chart' 11 | ? 'bg-gray-100 dark:bg-neutral-700' 12 | : 'bg-white dark:bg-neutral-600 hover:bg-gray-100 hover:dark:bg-neutral-700') 13 | 14 | return ( 15 | <> 16 |
17 | 18 | 27 | 28 | 37 | 38 |
39 | 40 | ) 41 | } 42 | -------------------------------------------------------------------------------- /client/censusGPT/src/contexts/feedContext.js: -------------------------------------------------------------------------------- 1 | import { createContext, useState, useEffect } from 'react' 2 | import { FEATURE_FLAGS } from '../featureFlags' 3 | export const FeedContext = createContext() 4 | 5 | let api_endpoint = 6 | process.env.REACT_APP_API_URL || 'https://dev-text-sql-be.onrender.com' 7 | 8 | if (process.env.REACT_APP_HOST_ENV === 'dev') { 9 | api_endpoint = 'http://localhost:9000' 10 | } 11 | 12 | const FeedProvider = ({ app, children }) => { 13 | const [examples, setExamples] = useState([]) 14 | const [useServerFeed, setUseServerFeed] = useState( 15 | FEATURE_FLAGS.exampleFeed 16 | ) 17 | 18 | const fetchExamples = async () => { 19 | try { 20 | const response = await fetch(`${api_endpoint}/examples/${app}`) 21 | const data = await response.json() 22 | if (data.success) { 23 | setExamples(data.examples) 24 | } else { 25 | setUseServerFeed(false) 26 | } 27 | } catch (e) { 28 | setUseServerFeed(false) 29 | } 30 | } 31 | 32 | useEffect(() => { 33 | fetchExamples() 34 | }, []) 35 | 36 | return ( 37 | 38 | {children} 39 | 40 | ) 41 | } 42 | 43 | export default FeedProvider 44 | -------------------------------------------------------------------------------- /client/censusGPT/src/css/App.css: -------------------------------------------------------------------------------- 1 | .App { 2 | text-align: center; 3 | } 4 | 5 | .App-logo { 6 | height: 40vmin; 7 | pointer-events: none; 8 | } 9 | 10 | @media (prefers-reduced-motion: no-preference) { 11 | .App-logo { 12 | animation: App-logo-spin infinite 20s linear; 13 | } 14 | } 15 | 16 | .map-container { 17 | height: 500px; 18 | } 19 | 20 | .App-header { 21 | background-color: #282c34; 22 | min-height: 100vh; 23 | display: flex; 24 | flex-direction: column; 25 | align-items: center; 26 | justify-content: center; 27 | font-size: calc(10px + 2vmin); 28 | color: white; 29 | } 30 | 31 | .App-link { 32 | color: #61dafb; 33 | } 34 | 35 | @keyframes App-logo-spin { 36 | from { 37 | transform: rotate(0deg); 38 | } 39 | to { 40 | transform: rotate(360deg); 41 | } 42 | } 43 | 44 | .form { 45 | margin-left: 23rem; 46 | margin-right: 23rem; 47 | } 48 | -------------------------------------------------------------------------------- /client/censusGPT/src/css/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | body { 6 | margin: 0; 7 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 8 | 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 9 | 'Helvetica Neue', sans-serif; 10 | -webkit-font-smoothing: antialiased; 11 | -moz-osx-font-smoothing: grayscale; 12 | } 13 | 14 | code { 15 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 16 | monospace; 17 | } -------------------------------------------------------------------------------- /client/censusGPT/src/css/mapbox-gl.css: -------------------------------------------------------------------------------- 1 | .map-container { 2 | height: 500px; 3 | } 4 | -------------------------------------------------------------------------------- /client/censusGPT/src/featureFlags.js: -------------------------------------------------------------------------------- 1 | export const FEATURE_FLAGS = { 2 | exampleFeed: true, 3 | downloadButton: false, 4 | } 5 | -------------------------------------------------------------------------------- /client/censusGPT/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import ReactDOM from 'react-dom' 3 | import './css/index.css' 4 | import App from './App' 5 | import SanFrancisco from './SanFrancisco' 6 | import reportWebVitals from './reportWebVitals' 7 | import { sendToVercelAnalytics } from './vitals' 8 | import 'mapbox-gl/dist/mapbox-gl.css' 9 | import { createBrowserRouter, RouterProvider } from 'react-router-dom' 10 | import TermsOfService from './misc/tos' 11 | import PrivacyPolicy from './misc/privacy' 12 | import FeedProvider from './contexts/feedContext' 13 | 14 | const router = createBrowserRouter([ 15 | { 16 | path: '/', 17 | element: ( 18 | 19 | 20 | 21 | ), 22 | }, 23 | { 24 | path: '/sf', 25 | element: ( 26 | 27 | 28 | 29 | ), 30 | }, 31 | { 32 | path: '/sanfrancisco', 33 | element: ( 34 | 35 | 36 | 37 | ), 38 | }, 39 | { 40 | path: '/tos', 41 | element: , 42 | }, 43 | { 44 | path: '/privacy', 45 | element: , 46 | }, 47 | ]) 48 | 49 | ReactDOM.render( 50 | 51 | 52 | , 53 | document.getElementById('root') 54 | ) 55 | 56 | reportWebVitals(sendToVercelAnalytics) 57 | -------------------------------------------------------------------------------- /client/censusGPT/src/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /client/censusGPT/src/misc/privacy.js: -------------------------------------------------------------------------------- 1 | const PrivacyPolicy = () => { 2 | return ( 3 |
4 |

Privacy Policy

5 |

CaesarHQ operates the censusGPT.com website, which provides a service for generating SQL queries on census data.

6 |

Information Collection and Use

7 |

We collect information that you provide when using our website, including the SQL queries that you generate and any feedback that you provide. We use this information to improve our website and provide you with a better user experience.

8 |

Log Data

9 |

Like many website operators, we collect information that your browser sends whenever you visit our website. This may include information such as your computer's Internet Protocol (IP) address, browser type, browser version, the pages of our website that you visit, the time and date of your visit, the time spent on those pages and other statistics.

10 |

Cookies

11 |

We use cookies to collect information about your preferences and to personalize your experience on our website. You can instruct your browser to refuse all cookies or to indicate when a cookie is being sent. However, if you do not accept cookies, you may not be able to use some portions of our website.

12 |

Third Party Services

13 |

We may use third party services, such as Google Analytics, to collect, monitor and analyze usage of our website.

14 |

Security

15 |

The security of your personal information is important to us, but please remember that no method of transmission over the Internet, or method of electronic storage, is 100% secure. While we strive to use commercially acceptable means to protect your personal information, we cannot guarantee its absolute security.

16 |

Changes to Privacy Policy

17 |

CaesarHQ reserves the right to modify this privacy policy at any time without notice. Your continued use of our website following any such modification constitutes your agreement to be bound by the modified privacy policy.

18 |

Contact Us

19 |

If you have any questions about this privacy policy, please contact us at team@caesarhq.com.

20 |

This privacy policy is effective as of April 10, 2023.

21 |
22 | ); 23 | }; 24 | 25 | export default PrivacyPolicy; 26 | 27 | 28 | -------------------------------------------------------------------------------- /client/censusGPT/src/misc/tos.js: -------------------------------------------------------------------------------- 1 | 2 | const TermsOfService = () => { 3 | return ( 4 |
5 |

Terms of Service

6 |

Welcome to censusGPT.com, a website owned and operated by CaesarHQ. By accessing or using our website, you agree to be bound by the following terms and conditions:

7 |

Use of Service

8 |

Our website is designed to help you generate SQL queries on census data for informational and research purposes only. You agree to use our website in accordance with all applicable laws and regulations.

9 |

Intellectual Property

10 |

Our website and its entire contents, features, and functionality (including but not limited to all information, software, text, displays, images, video, and audio, and the design, selection, and arrangement thereof), are owned by CaesarHQ or its licensors and are protected by United States and international copyright, trademark, patent, trade secret, and other intellectual property or proprietary rights laws.

11 |

Disclaimer of Warranties

12 |

Our website and its contents are provided on an "as is" and "as available" basis without any warranties of any kind, either express or implied, including but not limited to warranties of merchantability, fitness for a particular purpose, or non-infringement. CaesarHQ does not warrant that our website will be uninterrupted or error-free, that defects will be corrected, or that our website or the server that makes it available are free of viruses or other harmful components.

13 |

Limitation of Liability

14 |

In no event shall CaesarHQ, its affiliates, licensors, or service providers be liable for any direct, indirect, incidental, special, consequential, or punitive damages arising from or related to your use of our website, including but not limited to any errors or omissions in any content, or any loss or damage of any kind incurred as a result of the use of any content (or product) posted, transmitted, or otherwise made available via the website.

15 |

Indemnification

16 |

You agree to indemnify, defend, and hold harmless CaesarHQ, its affiliates, licensors, and service providers, and its and their respective officers, directors, employees, contractors, agents, licensors, suppliers, successors, and assigns from and against any claims, liabilities, damages, judgments, awards, losses, costs, expenses, or fees (including reasonable attorneys' fees) arising out of or relating to your violation of these terms of service or your use of our website.

17 |

Changes to Terms of Service

18 |

CaesarHQ reserves the right to modify these terms of service at any time without notice. Your continued use of our website following any such modification constitutes your agreement to be bound by the modified terms of service.

19 |

Governing Law

20 |

These terms of service and any dispute or claim arising out of or related to your use of our website shall be governed by and construed in accordance with the laws of the State of California, without giving effect to any principles of conflicts of law.

21 |

Contact Us

22 |

If you have any questions about these terms of service, please contact us at team@caesarhq.com.

23 |
24 | ); 25 | }; 26 | 27 | export default TermsOfService; 28 | 29 | 30 | -------------------------------------------------------------------------------- /client/censusGPT/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = (onPerfEntry) => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then( 4 | ({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 5 | getCLS(onPerfEntry) 6 | getFID(onPerfEntry) 7 | getFCP(onPerfEntry) 8 | getLCP(onPerfEntry) 9 | getTTFB(onPerfEntry) 10 | } 11 | ) 12 | } 13 | } 14 | 15 | export default reportWebVitals 16 | -------------------------------------------------------------------------------- /client/censusGPT/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom' 6 | -------------------------------------------------------------------------------- /client/censusGPT/src/utils/loggers/posthog.js: -------------------------------------------------------------------------------- 1 | import posthog from 'posthog-js' 2 | 3 | const POSTHOG_KEY = process.env.REACT_APP_POSTHOG_KEY 4 | 5 | if (POSTHOG_KEY) { 6 | posthog.init(POSTHOG_KEY, { 7 | api_host: 'https://app.posthog.com', 8 | }) 9 | } 10 | 11 | export const capturePosthog = (eventName, properties) => { 12 | if (POSTHOG_KEY) { 13 | posthog.capture(eventName, properties) 14 | } 15 | } -------------------------------------------------------------------------------- /client/censusGPT/src/utils/loggers/sentry.js: -------------------------------------------------------------------------------- 1 | import * as Sentry from '@sentry/react' 2 | import { BrowserTracing } from '@sentry/tracing' 3 | 4 | const SENTRY_ROUTE = process.env.REACT_APP_SENTRY_ROUTE 5 | 6 | if (SENTRY_ROUTE) { 7 | Sentry.init({ 8 | dsn: SENTRY_ROUTE, 9 | integrations: [new BrowserTracing()], 10 | tracesSampleRate: 1.0, 11 | }) 12 | } 13 | 14 | export const logSentryError = (queryContext, err) => { 15 | console.log('LOGGING TO SENTRY') 16 | if (SENTRY_ROUTE) { 17 | Sentry.setContext('queryContext', queryContext) 18 | Sentry.captureException(err) 19 | } 20 | } -------------------------------------------------------------------------------- /client/censusGPT/src/utils/mapbox-ui-config.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains the UI configuration for the mapbox UI. 3 | */ 4 | 5 | export const zipcodeFeatures = (zipcodes) => { 6 | return zipcodes.map((z) => { 7 | return { 8 | type: 'Feature', 9 | geometry: { 10 | type: 'Point', 11 | coordinates: [z.long, z.lat], 12 | }, 13 | } 14 | }) 15 | } 16 | 17 | export const citiesFeatures = (cities) => { 18 | return cities.map((c) => { 19 | return { 20 | type: 'Feature', 21 | geometry: { 22 | type: 'Point', 23 | coordinates: [c.long, c.lat], 24 | }, 25 | } 26 | }) 27 | } 28 | 29 | export const zipcodeLayerLow = (zipcodesFormatted) => { 30 | return { 31 | id: 'zips-kml', 32 | type: 'fill', 33 | source: 'zips-kml', 34 | minzoom: 5, 35 | layout: { 36 | visibility: 'visible', 37 | }, 38 | paint: { 39 | 'fill-outline-color': 'black', 40 | 'fill-opacity': 0.9, 41 | 'fill-color': '#006AF9', 42 | }, 43 | 'source-layer': 'Layer_0', 44 | filter: [ 45 | 'in', 46 | ['get', 'Name'], 47 | ['literal', zipcodesFormatted], // Zip code in the feature is formatted like this: 94105 48 | ], 49 | } 50 | } 51 | 52 | export const zipcodeLayerHigh = { 53 | id: 'Zip', 54 | type: 'circle', 55 | layout: { 56 | visibility: 'visible', 57 | }, 58 | maxzoom: 8, 59 | paint: { 60 | 'circle-radius': 10, 61 | 'circle-color': '#006AF9', 62 | 'circle-opacity': 1, 63 | }, 64 | } 65 | 66 | export const citiesLayer = { 67 | id: 'cities', 68 | type: 'circle', 69 | layout: { 70 | visibility: 'visible', 71 | }, 72 | paint: { 73 | 'circle-radius': 18, 74 | 'circle-color': '#006AF9', 75 | 'circle-opacity': 0.8, 76 | }, 77 | } 78 | 79 | export const polygonsLayer = { 80 | id: 'polygons', 81 | type: 'fill', 82 | source: "polygons", 83 | layout: { 84 | visibility: 'visible', 85 | }, 86 | paint: { 87 | 'fill-outline-color': 'black', 88 | 'fill-color': '#006AF9', 89 | 'fill-opacity': 0.8, 90 | }, 91 | } 92 | 93 | export const pointsFeatures = (points) => { 94 | return points.map((p) => { 95 | return { 96 | type: 'Feature', 97 | geometry: { 98 | type: 'Point', 99 | coordinates: [p.long, p.lat], 100 | }, 101 | } 102 | }) 103 | } 104 | 105 | export const pointsLayer = { 106 | id: 'points', 107 | type: 'circle', 108 | layout: { 109 | visibility: 'visible', 110 | }, 111 | paint: { 112 | 'circle-radius': 5, 113 | 'circle-color': '#006AF9', 114 | 'circle-opacity': 0.8, 115 | }, 116 | } -------------------------------------------------------------------------------- /client/censusGPT/src/utils/plotly-ui-config.js: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains the UI configuration for the Plotly UI. 3 | */ 4 | 5 | const isGeoColumn = (columnName) => { 6 | if (columnName == 'zip_code' || columnName == 'city' || columnName == 'state') { 7 | return true 8 | } 9 | return false 10 | } 11 | 12 | export const getPlotConfig = (rows, cols) => { 13 | let data = [] 14 | let layout = {} 15 | 16 | if (rows.length == 0 || cols.length == 0) { 17 | return {} 18 | } else if (rows.length >= 0 && cols.length == 2) { 19 | // 2 cols, N rows ==> Bar chart 20 | // Col 0 is X axis, Col 1 is Y axis 21 | // Example query: "Top 5 cities in CA with the highest crime and what is the total crime in each of those cities" 22 | 23 | data = [ 24 | { 25 | x: rows.map(x => '\b' + x[0]), // convert to string. otherwise plotly treats 941002 as 94.1k 26 | y: rows.map(x => x[1]), 27 | type: 'bar', 28 | marker: { color: '#006AF9' } 29 | } 30 | ]; 31 | 32 | layout = { 33 | xaxis: {title: cols[0]}, 34 | yaxis: {title: cols[1]}, 35 | } 36 | 37 | } else if (rows.length == 1 && cols.length >= 1) { 38 | // N cols, 1 row ==> Bar chart 39 | // columns is X axis, row 1 is Y axis 40 | // Example query: "What is the distribution of different categories of crimes in Dallas, TX" 41 | 42 | data = [ 43 | { 44 | x: isGeoColumn(cols[0]) ? cols.slice(1) : cols, 45 | y: isGeoColumn(cols[0]) ? rows[0].slice(1) : rows[0], 46 | type: 'bar', 47 | marker: { color: '#006AF9' } 48 | } 49 | ]; 50 | 51 | } else { 52 | // N cols, N rows ==> Stacked chart. 53 | // column 0 is X axis, column 1 to N is Y axis 54 | // Example query: "What is the percentage population of asian, black and hispanic people in all zipcodes in san francisco" 55 | 56 | for (let i = 1; i < cols.length; i++) { 57 | 58 | // if the column is not a number, don't plot it 59 | if (typeof rows[0][i] !== 'number') { 60 | continue 61 | } 62 | 63 | data.push({ 64 | x: rows.map(x => '\b' + x[0]), // convert to string. otherwise plotly treats 941002 as 94.1k 65 | y: rows.map(x => x[i]), 66 | name: cols[i], 67 | type: 'bar' 68 | }) 69 | } 70 | 71 | layout = { 72 | barmode: 'stack', 73 | xaxis: {title: cols[0]}, 74 | } 75 | } 76 | 77 | layout = document.documentElement.classList.contains('dark') ? { 78 | ...layout, 79 | font: { color: '#fff' }, 80 | yaxis: { gridcolor: '#444' } 81 | } : { 82 | ...layout, 83 | font: { color: '#000' } 84 | } 85 | 86 | return {data, layout} 87 | } -------------------------------------------------------------------------------- /client/censusGPT/src/utils/user.js: -------------------------------------------------------------------------------- 1 | import { v4 as uuidv4 } from 'uuid'; 2 | 3 | export const getUserId = () => { 4 | const localStorageKey = 'census_user_id'; 5 | let userId = localStorage.getItem(localStorageKey); 6 | 7 | if (!userId) { 8 | // Generate a unique ID for the user 9 | userId = `${uuidv4()}_${new Date().getTime()}`; 10 | 11 | // Save the user ID in local storage 12 | localStorage.setItem(localStorageKey, userId); 13 | } 14 | 15 | return userId; 16 | } 17 | -------------------------------------------------------------------------------- /client/censusGPT/src/utils/utils.js: -------------------------------------------------------------------------------- 1 | // Utils for the client side 2 | 3 | /** 4 | * Formatted Zipcodes for Mapbox 5 | * @param {*} zips 6 | * @returns {string[]} – The formatted zipcodes 7 | */ 8 | export const getZipcodesMapboxFormatted = (zips) => { 9 | return zips.map((x) => '' + x['zipcode'] + '') 10 | } 11 | 12 | /** 13 | * Gets the zipcode's latitude longitude from the query search results 14 | * @param {*} result - The search results 15 | * @returns {object[]} – The formatted zipcodes 16 | */ 17 | export const getZipcodes = (result) => { 18 | let zipcode_index = result.column_names.indexOf('zip_code') 19 | if (zipcode_index == -1 || !result.results) return [] 20 | 21 | return result.results.map((x) => { 22 | return { zipcode: x['zip_code'], lat: x['lat'], long: x['long'] } 23 | }) 24 | } 25 | 26 | /** 27 | * Gets the cities latitude longitude from the query search results 28 | * @param {*} result – The search results 29 | * @returns {object[]} – The zipcodes 30 | */ 31 | export const getCities = (result) => { 32 | let city_index = result.column_names.indexOf('city') 33 | if (city_index == -1 || !result.results) return [] 34 | 35 | return result.results.map((x) => { 36 | return { city: x['city'], lat: x['lat'], long: x['long'] } 37 | }) 38 | } 39 | -------------------------------------------------------------------------------- /client/censusGPT/src/vitals.js: -------------------------------------------------------------------------------- 1 | const vitalsUrl = 'https://vitals.vercel-analytics.com/v1/vitals' 2 | 3 | function getConnectionSpeed() { 4 | return 'connection' in navigator && 5 | navigator['connection'] && 6 | 'effectiveType' in navigator['connection'] 7 | ? navigator['connection']['effectiveType'] 8 | : '' 9 | } 10 | 11 | export function sendToVercelAnalytics(metric) { 12 | const analyticsId = process.env.REACT_APP_VERCEL_ANALYTICS_ID 13 | if (!analyticsId) { 14 | return 15 | } 16 | 17 | const body = { 18 | dsn: analyticsId, 19 | id: metric.id, 20 | page: window.location.pathname, 21 | href: window.location.href, 22 | event_name: metric.name, 23 | value: metric.value.toString(), 24 | speed: getConnectionSpeed(), 25 | } 26 | 27 | const blob = new Blob([new URLSearchParams(body).toString()], { 28 | // This content type is necessary for `sendBeacon` 29 | type: 'application/x-www-form-urlencoded', 30 | }) 31 | if (navigator.sendBeacon) { 32 | navigator.sendBeacon(vitalsUrl, blob) 33 | } else 34 | fetch(vitalsUrl, { 35 | body: blob, 36 | method: 'POST', 37 | credentials: 'omit', 38 | keepalive: true, 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /client/censusGPT/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: ["./src/**/*.{html,js}", "./src/*.{html,js}"], 4 | darkMode: 'class', 5 | theme: { 6 | extend: { 7 | colors: { 8 | 'dark': { 9 | 900: '#1F1F1F', 10 | 800: '#292929', 11 | 300: '#777777' 12 | } 13 | }, 14 | }, 15 | }, 16 | plugins: [require('@tailwindcss/forms')], 17 | } 18 | 19 | 20 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | 4 | The directory holds all the .CSV, .XLSX and other data files that are (or will be) included in [CensusGPT](https://censusgpt.com)'s dataset. 5 | 6 | If you would like to contribute data, please create a PR with your file and include a short description of the data you're contributing. 7 | -------------------------------------------------------------------------------- /license.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | --------------------------------------------------------------------------------