├── tests ├── __init__.py └── unit_tests │ ├── __init__.py │ ├── databases │ ├── __init__.py │ └── test_sqlite.py │ ├── prompts │ ├── __init__.py │ ├── test_data_catalog.py │ └── test_general.py │ ├── data_catalog │ ├── __init__.py │ └── test_sample_datacatalog.py │ ├── test_auto_analyst.py │ └── test_analysis.py ├── auto_analyst ├── __init__.py ├── llms │ ├── __init__.py │ ├── base.py │ └── openai.py ├── data_catalog │ ├── __init__.py │ ├── base.py │ └── sample_datacatalog.py ├── databases │ ├── __init__.py │ ├── sample_data │ │ ├── chinook.sqlite │ │ └── chinook_tables.csv │ ├── base.py │ ├── bigquery.py │ ├── redshift.py │ ├── sqlite.py │ └── postgres.py ├── forms.py ├── templates │ ├── partials │ │ └── _top_bar.html │ ├── home.html │ ├── config.html │ └── base.html ├── prompts │ ├── __init__.py │ ├── data_catalog.py │ └── general.py ├── config.json ├── config_parser.py ├── requirements.txt ├── app.py ├── analysis.py ├── static │ └── script.js └── auto_analyst.py ├── mypy.ini ├── misc ├── auto_analyst.png └── youtube_thumbnail.png ├── pytest.ini ├── .gitignore ├── setup.py ├── .flake8 ├── .pre-commit-config.yaml ├── LICENSE ├── CONTRIBUTING.md └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /auto_analyst/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /auto_analyst/llms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /auto_analyst/data_catalog/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /auto_analyst/databases/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit_tests/databases/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit_tests/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit_tests/data_catalog/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | files = . 3 | ignore_missing_imports = true -------------------------------------------------------------------------------- /misc/auto_analyst.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aadityaubhat/auto-analyst/HEAD/misc/auto_analyst.png -------------------------------------------------------------------------------- /misc/youtube_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aadityaubhat/auto-analyst/HEAD/misc/youtube_thumbnail.png -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = "tests" 3 | markers = 4 | llm: marks tests that use LLMs (deselect with '-m "not llm"') -------------------------------------------------------------------------------- /auto_analyst/databases/sample_data/chinook.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aadityaubhat/auto-analyst/HEAD/auto_analyst/databases/sample_data/chinook.sqlite -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | auto_analyst/config.py 2 | venv 3 | archive.py 4 | data.csv 5 | .vscode 6 | .pytest_cache 7 | .mypy_cache 8 | *.egg-info 9 | **/__pycache__/ 10 | logs/ 11 | auto_analyst/config.json -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="auto_analyst", 5 | version="0.0.1", 6 | packages=find_packages(), 7 | install_requires=[], 8 | entry_points={}, 9 | ) 10 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501 3 | max-line-length = 88 4 | extend-ignore = G010, G201, G202, W503, W291 5 | extend-select = E203, E231, E501 6 | enable-extensions = pyflakes 7 | exclude = 8 | .git, 9 | __pycache__, 10 | venv -------------------------------------------------------------------------------- /auto_analyst/forms.py: -------------------------------------------------------------------------------- 1 | from flask_wtf import FlaskForm 2 | from flask_wtf.file import FileField, FileRequired, FileAllowed 3 | 4 | 5 | class ConfigForm(FlaskForm): 6 | config_file = FileField( 7 | "Upload a new config file", 8 | validators=[FileRequired(), FileAllowed(["json"], "JSON Files only")], 9 | ) 10 | -------------------------------------------------------------------------------- /auto_analyst/templates/partials/_top_bar.html: -------------------------------------------------------------------------------- 1 | 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.3.0 4 | hooks: 5 | - id: black 6 | language_version: python3.10 7 | - repo: https://github.com/PyCQA/flake8 8 | rev: 6.0.0 9 | hooks: 10 | - id: flake8 11 | - repo: https://github.com/pre-commit/mirrors-mypy 12 | rev: v1.3.0 13 | hooks: 14 | - id: mypy 15 | - repo: local 16 | hooks: 17 | - id: pytest 18 | name: pytest 19 | entry: pytest 20 | language: system 21 | types: [python] 22 | exclude: '^setup.py$' 23 | -------------------------------------------------------------------------------- /auto_analyst/databases/base.py: -------------------------------------------------------------------------------- 1 | from abc import ( 2 | abstractmethod, 3 | ABC, 4 | ) 5 | import pandas as pd 6 | from typing import List 7 | 8 | 9 | class BaseDatabase(ABC): 10 | """Abstract Base Class responsible for defining Database""" 11 | 12 | @abstractmethod 13 | def run_query(self, query: str) -> pd.DataFrame: 14 | """Run query""" 15 | raise NotImplementedError 16 | 17 | @abstractmethod 18 | def get_tables(self) -> List: 19 | """List tables""" 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def get_schema(self, table_name: str) -> pd.DataFrame: 24 | """Get schema for the given table""" 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /auto_analyst/templates/home.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | 3 | {% block content %} 4 |
5 | 6 |
7 | 8 |
9 | 10 |
11 |
12 | 13 | 14 |
15 |
16 |
17 | {% endblock %} 18 | -------------------------------------------------------------------------------- /auto_analyst/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .general import ( 2 | render_type_messages, 3 | render_query_prompt, 4 | query_system_prompt, 5 | plotly_system_prompt, 6 | yes_no_system_prompt, 7 | transformed_data_system_prompt, 8 | render_update_query_prompt, 9 | render_transformed_data_prompt, 10 | render_plotly_code_prompt, 11 | render_plotly_code_check_prompt, 12 | ) 13 | 14 | __all__ = [ 15 | "render_type_messages", 16 | "render_query_prompt", 17 | "query_system_prompt", 18 | "transformed_data_system_prompt", 19 | "yes_no_system_prompt", 20 | "plotly_system_prompt", 21 | "render_update_query_prompt", 22 | "render_transformed_data_prompt", 23 | "render_plotly_code_prompt", 24 | "render_plotly_code_check_prompt", 25 | ] 26 | -------------------------------------------------------------------------------- /auto_analyst/llms/base.py: -------------------------------------------------------------------------------- 1 | from abc import ( 2 | abstractmethod, 3 | ABC, 4 | ) 5 | from typing import Optional 6 | 7 | 8 | class BaseLLM(ABC): 9 | """Base class responsible for defining LLM""" 10 | 11 | @abstractmethod 12 | def get_reply(self, prompt: Optional[str], **kwargs) -> str: 13 | """Get reply from LLM 14 | Args: 15 | prompt (Optional[str]): Prompt to be used for generating reply 16 | Returns: 17 | str: Reply from LLM""" 18 | raise NotImplementedError 19 | 20 | def get_code(self, prompt: Optional[str], **kwargs) -> str: 21 | """Get code from LLM 22 | Args: 23 | prompt (Optional[str]): Prompt to be used for generating code 24 | Returns: 25 | str: Code from LLM""" 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /auto_analyst/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "llms": { 3 | "driverllm": { 4 | "name": "Driver LLM", 5 | "type": "openai", 6 | "api_key": "", 7 | "model": "GPT_3_5_TURBO" 8 | }, 9 | "data_catalog_llm": { 10 | "name": "Data Catalog LLM", 11 | "type": "openai", 12 | "api_key": "", 13 | "model": "GPT_3_5_TURBO" 14 | } 15 | }, 16 | "database": { 17 | "name": "Chinook Sqlite", 18 | "type": "sqlite", 19 | "path": "auto_analyst/databases/sample_data/chinook.sqlite" 20 | }, 21 | "data_catalog": { 22 | "name": "Chinook Data Dictionary", 23 | "type": "sample", 24 | "path": "auto_analyst/databases/sample_data/chinook_tables.csv" 25 | }, 26 | "auto_analyst_settings": { 27 | "flask_secret_key" : "", 28 | "query_retry_count": 3 29 | } 30 | } -------------------------------------------------------------------------------- /auto_analyst/prompts/data_catalog.py: -------------------------------------------------------------------------------- 1 | import jinja2 2 | 3 | environment = jinja2.Environment() 4 | 5 | system_prompt = "You are a helpful assistant that helps find the accurate tables for a given question. Answer as concisely as possible." 6 | 7 | source_tables_template = environment.from_string( 8 | """ 9 | From the list of following tables: 10 | {{ tables_df.to_string(index=False) }} 11 | 12 | Select the appropriate source tables for the following question: 13 | {{ question }} 14 | 15 | Answer in the following format: 16 | table1, table2, table3, . . . tableN 17 | 18 | If no appropriate tables are found, say 'No Tables Found'""" 19 | ) 20 | 21 | 22 | def render_source_tables_prompt(question, tables_df) -> str: 23 | """Render prompt to select source tables for a given question 24 | Args: 25 | question (str): Question to be answered 26 | tables_df (pd.DataFrame): Dataframe containing list of all tables""" 27 | return source_tables_template.render(question=question, tables_df=tables_df) 28 | -------------------------------------------------------------------------------- /tests/unit_tests/prompts/test_data_catalog.py: -------------------------------------------------------------------------------- 1 | from auto_analyst.prompts import data_catalog 2 | import pandas as pd 3 | 4 | 5 | def test_render_source_table_prompt(): 6 | question = "What is the average rating of all the movies?" 7 | tables_df = pd.DataFrame( 8 | { 9 | "table_name": ["table1", "table2", "table3"], 10 | "table_description": ["desc1", "desc2", "desc3"], 11 | } 12 | ) 13 | prompt = data_catalog.render_source_tables_prompt(question, tables_df) 14 | assert ( 15 | prompt 16 | == """ 17 | From the list of following tables: 18 | table_name table_description 19 | table1 desc1 20 | table2 desc2 21 | table3 desc3 22 | 23 | Select the appropriate source tables for the following question: 24 | What is the average rating of all the movies? 25 | 26 | Answer in the following format: 27 | table1, table2, table3, . . . tableN 28 | 29 | If no appropriate tables are found, say 'No Tables Found'""" 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aaditya Bhat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /auto_analyst/templates/config.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | 3 | {% block content %} 4 |
5 |
6 | {% if not config_updated %} 7 |

Update Configuration

8 |
9 | {{ form.hidden_tag() }} 10 |
11 | {{ form.config_file.label(class="sr-only") }} 12 | {{ form.config_file(class="form-control-file", id="config_file", 13 | onchange="document.getElementById('update_button').style.display = 'inline-block'") }} 14 |
15 | 16 |
17 | {% else %} 18 |

Update Configuration

19 | New Config loaded 20 | {% endif %} 21 |
22 |
23 |

Current Configuration

24 |
{{ content | tojson(4) }}
25 |
26 | 27 | 28 |
29 | {% endblock %} -------------------------------------------------------------------------------- /tests/unit_tests/test_auto_analyst.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from auto_analyst.auto_analyst import AutoAnalyst 3 | from auto_analyst.databases.sqlite import SQLLite 4 | from auto_analyst.data_catalog.sample_datacatalog import SampleDataCatalog 5 | from auto_analyst.config_parser import parse_openai_api_key 6 | from auto_analyst.llms.openai import OpenAILLM, Model 7 | import pytest 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def app(): 12 | app = Flask(__name__) 13 | yield app 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def aa(): 18 | driver_llm = OpenAILLM(parse_openai_api_key(), Model.GPT_3_5_TURBO) 19 | sample_db = SQLLite() 20 | sample_datacatalog = SampleDataCatalog(driver_llm) 21 | 22 | aa = AutoAnalyst( 23 | database=sample_db, datacatalog=sample_datacatalog, driver_llm=driver_llm 24 | ) 25 | yield aa 26 | 27 | 28 | @pytest.mark.llm 29 | def test_analyze_query(aa, app): 30 | with app.app_context(): 31 | analysis = aa.analyze("What is the total sales by country?") 32 | assert ( 33 | analysis.query.lower() 34 | == """select c.country, sum(il.unitprice * il.quantity) as totalsales\nfrom customer c\njoin invoice i on c.customerid = i.customerid\njoin invoiceline il on i.invoiceid = il.invoiceid\ngroup by c.country;""".lower() 35 | ) 36 | -------------------------------------------------------------------------------- /auto_analyst/databases/bigquery.py: -------------------------------------------------------------------------------- 1 | from google.cloud import bigquery 2 | from google.oauth2 import service_account 3 | from typing import List 4 | import pandas as pd 5 | from auto_analyst.databases.base import BaseDatabase 6 | 7 | 8 | class BigQueryDatabase(BaseDatabase): 9 | """Implementation of BaseDatabase for BigQuery""" 10 | 11 | def __init__(self, project_id: str, credentials_path: str): 12 | """Initialize BigQueryDatabase with project_id and path to service account credentials""" 13 | self.project_id = project_id 14 | self.credentials = service_account.Credentials.from_service_account_file( 15 | credentials_path 16 | ) 17 | self.client = bigquery.Client( 18 | credentials=self.credentials, project=self.project_id 19 | ) 20 | 21 | def run_query(self, query: str) -> pd.DataFrame: 22 | """Run query using BigQuery""" 23 | query_job = self.client.query(query) # API request 24 | return query_job.to_dataframe() 25 | 26 | def get_tables(self) -> List[str]: 27 | """List tables in BigQuery""" 28 | tables = list(self.client.list_tables(self.project_id)) 29 | return [table.table_id for table in tables] 30 | 31 | def get_schema(self, table_name: str) -> List[bigquery.SchemaField]: 32 | """Get schema for the given table in BigQuery""" 33 | table = self.client.get_table(table_name) # API request 34 | return table.schema 35 | -------------------------------------------------------------------------------- /auto_analyst/databases/sample_data/chinook_tables.csv: -------------------------------------------------------------------------------- 1 | table_name,description 2 | Album,"Contains information about each album, including the album title and the artist ID associated with the album." 3 | Artist,"Contains information about each artist, including the artist's name." 4 | Customer,"Contains information about each customer, including their first name, last name, email, address, city, state, country, postal code, phone, and the employee ID who supports the customer." 5 | Employee,"Contains information about each employee, including their first name, last name, title, reports to (employee ID of their supervisor), birth date, hire date, email, address, city, state, country, and postal code." 6 | Genre,"Contains information about each genre, including the genre's name." 7 | Invoice,"Contains information about each invoice, including the customer ID, invoice date, billing address, billing city, billing state, billing country, and billing postal code." 8 | InvoiceLine,"Contains information about each line item on an invoice, including the invoice ID, track ID, unit price, and quantity." 9 | MediaType,"Contains information about each media type, including the media type's name." 10 | Playlist,"Contains information about each playlist, including the playlist's name." 11 | PlaylistTrack,"Represents the relationship between playlists and tracks, including the playlist ID and track ID." 12 | Track,"Contains information about each track, including the track name, album ID, media type ID, genre ID, composer, track length (in milliseconds), bytes, and unit price." -------------------------------------------------------------------------------- /tests/unit_tests/databases/test_sqlite.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from flask import Flask 3 | from auto_analyst.databases.sqlite import SQLLite 4 | 5 | 6 | class TestSQLLite(unittest.TestCase): 7 | def setUp(self): 8 | self.app = Flask(__name__) 9 | self.ctx = self.app.app_context() 10 | self.ctx.push() 11 | self.db = SQLLite() 12 | 13 | def tearDown(self): 14 | self.ctx.pop() 15 | with self.app.app_context(): 16 | self.db.close_connection() 17 | 18 | def test_list_tables(self): 19 | with self.app.app_context(): 20 | table_list = self.db.get_tables() 21 | self.assertEqual(len(table_list), 11) 22 | self.assertEqual( 23 | sorted(table_list["table_name"].tolist()), 24 | [ 25 | "Album", 26 | "Artist", 27 | "Customer", 28 | "Employee", 29 | "Genre", 30 | "Invoice", 31 | "InvoiceLine", 32 | "MediaType", 33 | "Playlist", 34 | "PlaylistTrack", 35 | "Track", 36 | ], 37 | ) 38 | 39 | def test_get_schema(self): 40 | with self.app.app_context(): 41 | schema = self.db.get_schema("Album") 42 | self.assertEqual(len(schema), 3) 43 | self.assertEqual( 44 | sorted(schema["name"].tolist()), ["AlbumId", "ArtistId", "Title"] 45 | ) 46 | 47 | def test_run_query(self): 48 | with self.app.app_context(): 49 | result = self.db.run_query("select * from Invoice") 50 | self.assertEqual(len(result), 412) 51 | 52 | 53 | # if this file is run directly, run the tests 54 | if __name__ == "__main__": 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /auto_analyst/databases/redshift.py: -------------------------------------------------------------------------------- 1 | from auto_analyst.databases.base import BaseDatabase 2 | import redshift_connector 3 | 4 | 5 | class Redshift(BaseDatabase): 6 | """Class to interact with Redshift 7 | Attributes: 8 | host (str): Hostname of the Redshift cluster 9 | port (int): Port number of the Redshift cluster 10 | user (str): Username to connect to the Redshift cluster 11 | password (str): Password to connect to the Redshift cluster 12 | """ 13 | 14 | def __init__(self, host, port, user, password, database): 15 | """Initialize Redshift 16 | 17 | Args: 18 | host (str): Hostname of the Redshift cluster 19 | port (int): Port number of the Redshift cluster 20 | user (str): Username to connect to the Redshift cluster 21 | password (str): Password to connect to the Redshift cluster 22 | database (str): Database name""" 23 | self.host = host 24 | self.port = port 25 | self.user = user 26 | self.password = password 27 | self.database = database 28 | self._connect() 29 | 30 | def _connect(self): 31 | """Connect to Redshift""" 32 | self.connection = redshift_connector.connect( 33 | host=self.host, 34 | port=self.port, 35 | user=self.user, 36 | password=self.password, 37 | database=self.database, 38 | ) 39 | self.cursor = self.connection.cursor() 40 | 41 | def _disconnect(self): 42 | """Disconnect from Redshift""" 43 | self.connection.close() 44 | 45 | def run_query(self, query: str): 46 | """Run query 47 | Args: 48 | query (str): Query to be executed 49 | Returns: 50 | pd.DataFrame: Dataframe containing the results of the query""" 51 | self.cursor.execute(query) 52 | return self.cursor.fetch_dataframe() 53 | -------------------------------------------------------------------------------- /auto_analyst/databases/sqlite.py: -------------------------------------------------------------------------------- 1 | from flask import g 2 | from auto_analyst.databases.base import BaseDatabase 3 | import sqlite3 4 | import pandas as pd 5 | 6 | 7 | class SQLLite(BaseDatabase): 8 | """Class for SQLLite 9 | Attributes: 10 | db_path (str): Path to the SQLLite database""" 11 | 12 | def __init__(self, db_path=None): 13 | """Initialize SQLLite 14 | Args: 15 | db_path (str): Path to the SQLLite database""" 16 | if db_path is None: 17 | db_path = "auto_analyst/databases/sample_data/chinook.sqlite" 18 | self.db_path = db_path 19 | 20 | def get_cursor(self) -> sqlite3.Cursor: 21 | """Connect to SQLLite if not already connected 22 | Returns: 23 | sqlite3.Cursor: Cursor object""" 24 | if "db" not in g: 25 | g.db = sqlite3.connect(self.db_path) 26 | g.cursor = g.db.cursor() 27 | return g.cursor 28 | 29 | def close_connection(self) -> None: 30 | """Disconnect from SQLLite""" 31 | db = g.pop("db", None) 32 | if db is not None: 33 | db.close() 34 | 35 | def run_query(self, query: str) -> pd.DataFrame: 36 | """Run query 37 | Args: 38 | query (str): Query to be executed 39 | Returns: 40 | pd.DataFrame: Dataframe containing the results of the query""" 41 | if "db" not in g: 42 | self.get_cursor() 43 | return pd.read_sql_query(query, g.db) 44 | 45 | def get_tables(self) -> pd.DataFrame: 46 | """Get all tables""" 47 | return self.run_query( 48 | "select name as table_name from sqlite_master where type='table'" 49 | ) 50 | 51 | def get_schema(self, table_name: str): 52 | """Get schema for the given table 53 | Args: 54 | table_name (str): Name of the table 55 | Returns: 56 | pd.DataFrame: Dataframe containing table schema""" 57 | return self.run_query(f"PRAGMA table_info({table_name})") 58 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to AutoAnalyst 2 | 3 | Firstly, thank you for your interest in contributing to AutoAnalyst! 🎉 Your help is essential in keeping this project vibrant and useful for everyone. 4 | 5 | The following is a set of guidelines for contributing to AutoAnalyst. These are mostly guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. 6 | 7 | ## Getting Started 8 | 9 | 1. **Fork the repository**: This creates a copy of the project repository in your own GitHub account, where you can make changes without affecting the main project. 10 | 11 | 2. **Clone the forked repository to your local machine**: This allows you to work on the project from your own computer. 12 | 13 | 3. **Create a new branch**: Branches help you to work on different tasks or features without affecting the main code. 14 | 15 | ## Making Changes 16 | 17 | 1. **Code efficiently**: Try to keep your code simple, efficient, and well-documented. 18 | 19 | 2. **Follow the style guide**: Follow the [PEP8](https://pep8.org/) coding conventions for Python. 20 | 21 | 3. **Test your changes**: Make sure your changes don't break the project. Add tests if necessary. 22 | 23 | 4. **Write good commit messages**: Provide a brief description of your changes. 24 | 25 | ## Submitting Changes 26 | 27 | 1. **Push your changes**: Push your changes to your forked repository. 28 | 29 | 2. **Submit a pull request**: This sends a request to the maintainers of the project to review your changes. 30 | 31 | 3. **Describe your changes**: In the pull request description, describe your changes and their importance. 32 | 33 | 4. **Wait for the review**: The project maintainers will review your changes. If they request changes, please make them. 34 | 35 | ## Reporting Bugs 36 | 37 | For reporting bugs, open an issue on the GitHub page describing the bug, steps to reproduce it and the expected behavior. Provide as much detail as possible. 38 | 39 | ## Suggesting Enhancements 40 | 41 | For enhancement suggestions, open an issue on the GitHub page describing the enhancement in detail. 42 | 43 | Thank you for your contribution! 44 | 45 | ## License 46 | 47 | By contributing to AutoAnalyst, you agree that your contributions will be licensed under its MIT License. 48 | -------------------------------------------------------------------------------- /auto_analyst/databases/postgres.py: -------------------------------------------------------------------------------- 1 | from auto_analyst.databases.base import BaseDatabase 2 | import psycopg2 3 | from psycopg2 import sql 4 | from typing import List 5 | import pandas as pd 6 | from contextlib import closing 7 | 8 | 9 | class PostgresDatabase(BaseDatabase): 10 | """Production level implementation of BaseDatabase for Postgres""" 11 | 12 | def __init__(self, dbname: str, user: str, password: str, host: str, port: str): 13 | """Initialize PostgresDatabase with connection details""" 14 | self.dbname = dbname 15 | self.user = user 16 | self.password = password 17 | self.host = host 18 | self.port = port 19 | 20 | def _connect(self): 21 | return psycopg2.connect( 22 | dbname=self.dbname, 23 | user=self.user, 24 | password=self.password, 25 | host=self.host, 26 | port=self.port, 27 | ) 28 | 29 | def run_query(self, query: str) -> pd.DataFrame: 30 | """Run query using Postgres""" 31 | with closing(self._connect()) as conn: 32 | return pd.read_sql_query(query, conn) 33 | 34 | def get_tables(self) -> List[str]: 35 | """List tables in Postgres""" 36 | query = """ 37 | SELECT table_name 38 | FROM information_schema.tables 39 | WHERE table_schema = 'public' 40 | """ 41 | with closing(self._connect()) as conn: 42 | with closing(conn.cursor()) as cursor: 43 | cursor.execute(query) 44 | return [table[0] for table in cursor.fetchall()] 45 | 46 | def get_schema(self, table_name: str) -> pd.DataFrame: 47 | """Get schema for the given table in Postgres""" 48 | query = sql.SQL( 49 | """ 50 | SELECT column_name, data_type 51 | FROM information_schema.columns 52 | WHERE table_name = {} 53 | """ 54 | ).format(sql.Identifier(table_name)) 55 | 56 | with closing(self._connect()) as conn: 57 | with closing(conn.cursor()) as cursor: 58 | cursor.execute(query) 59 | return pd.DataFrame( 60 | cursor.fetchall(), columns=["column_name", "data_type"] 61 | ) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoAnalyst 2 | 3 | AutoAnalyst is a self-service analytics suite that empowers users to draw actionable insights from their data using natural language. It can be tailored to work with a database and data catalog of your preference. Currently, AutoAnalyst supports Postgres, Redshift, BigQuery, and SQLite databases, with more being actively added. For data catalog, it supports CSV files at present, but compatibility with Alation and Datahub is being actively developed. A PyPi package for AutoAnalyst is in the works, but for now, the project can be set up locally using the instructions below. 4 | 5 | ## Local Setup 6 | 7 | 1. Clone this repository to your local machine. 8 | 2. Create a Python virtual environment. 9 | 3. Install the necessary dependencies with `pip install -r requirements.txt`. 10 | 4. Choose to follow either the demo setup or custom database setup instructions. 11 | 12 | ### Demo Setup 13 | 14 | 1. Update the `auto_analyst/config.json` file with your OpenAI API Key. 15 | 2. Run `python -m auto_analyst.app`. 16 | 3. Open a web browser and go to `localhost:5000`. 17 | 18 | ### Custom Database Setup 19 | 20 | 1. Update the `auto_analyst/config.json` file with your: 21 | - OpenAI API Key. 22 | - Flask secret key. 23 | - Database credentials. 24 | - Data catalog connection details. 25 | 2. Run `python -m auto_analyst.app`. 26 | 3. Open a web browser and go to `localhost:5000`. 27 | 28 | Once AutoAnalyst is running, you should see something similar to this in your web browser: 29 | 30 | ![Screenshot of the demo](https://github.com/aadityaubhat/auto-analyst/blob/main/misc/auto_analyst.png) 31 | 32 | ## Demo 33 | 34 | You can watch a comprehensive demo of the AutoAnalyst suite 35 | [![IMAGE ALT TEXT HERE](https://github.com/aadityaubhat/auto-analyst/blob/main/misc/youtube_thumbnail.png)](https://www.youtube.com/watch?v=fp1nv-GdKic) 36 | 37 | ## Contributing 38 | 39 | Contributions to the AutoAnalyst project are welcome and appreciated. If you're interested in contributing, please see our [CONTRIBUTING.md](CONTRIBUTING.md) guide. You'll find information on how to get started, our coding standards, and how to submit your changes for review. 40 | 41 | For major changes, please open an issue first to discuss what you would like to change. We encourage you to use the issues page for bug reports and feature requests. 42 | 43 | Together, we can build a robust, intuitive, and versatile self-service analytics suite! 44 | -------------------------------------------------------------------------------- /tests/unit_tests/test_analysis.py: -------------------------------------------------------------------------------- 1 | from auto_analyst.analysis import Analysis 2 | import pandas as pd 3 | import plotly.express as px 4 | import pytest 5 | from uuid import uuid4 6 | from auto_analyst.analysis import AnalysisStatus, AnalysisType 7 | 8 | 9 | @pytest.fixture(scope="module") 10 | def analysis(): 11 | analysis = Analysis("What is the total revenue?", uuid4()) 12 | yield analysis 13 | 14 | 15 | def test_metadata(analysis): 16 | assert analysis.metadata == {} 17 | analysis.metadata = {"source_data": ["Invoice", "InvoiceLine"]} 18 | analysis.metadata = {"table_schema": {"Invoice": ["InvoiceId", "Total"]}} 19 | assert analysis.metadata == { 20 | "source_data": ["Invoice", "InvoiceLine"], 21 | "table_schema": {"Invoice": ["InvoiceId", "Total"]}, 22 | } 23 | 24 | 25 | def test_query(analysis): 26 | assert analysis.query is None 27 | analysis.query = "select * from Invoice" 28 | assert analysis.query == "select * from Invoice" 29 | 30 | 31 | def test_result_data(analysis): 32 | assert analysis.result_data is None 33 | analysis.result_data = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) 34 | assert analysis.result_data.equals( 35 | pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) 36 | ) 37 | 38 | 39 | def test_result_plot(analysis): 40 | assert analysis.result_plot is None 41 | analysis.result_plot = px.bar( 42 | x=["giraffes", "orangutans", "monkeys"], y=[20, 14, 23] 43 | ) 44 | assert analysis.result_plot == px.bar( 45 | x=["giraffes", "orangutans", "monkeys"], y=[20, 14, 23] 46 | ) 47 | 48 | 49 | def test_analysis_type(analysis): 50 | assert analysis.analysis_type is None 51 | analysis.analysis_type = "query" 52 | assert analysis.analysis_type == AnalysisType.QUERY 53 | analysis.analysis_type = "data" 54 | assert analysis.analysis_type == AnalysisType.DATA 55 | analysis.analysis_type = "plot" 56 | assert analysis.analysis_type == AnalysisType.PLOT 57 | with pytest.raises(ValueError): 58 | analysis.analysis_type = "prediction" 59 | 60 | 61 | def test_analysis_status(analysis): 62 | assert analysis.analysis_status == AnalysisStatus.INITIATED 63 | analysis.analysis_status = AnalysisStatus.COMPLETED 64 | assert analysis.analysis_status == AnalysisStatus.COMPLETED 65 | analysis.analysis_status = AnalysisStatus.FAILED 66 | assert analysis.analysis_status == AnalysisStatus.FAILED 67 | -------------------------------------------------------------------------------- /auto_analyst/data_catalog/base.py: -------------------------------------------------------------------------------- 1 | from abc import ( 2 | abstractmethod, 3 | ABC, 4 | ) 5 | from typing import ( 6 | List, 7 | Dict, 8 | Optional, 9 | ) 10 | from dataclasses import dataclass, field 11 | 12 | 13 | @dataclass 14 | class Column: 15 | """ 16 | Class to represent a column in a table 17 | Attributes: 18 | name (str): Name of the column 19 | datatype (str): Datatype of the column 20 | description (Optional[str]): Description of the column 21 | cardinality (Optional[int]): Cardinality of the column 22 | unique_values (Optional[int]): Number of unique values in the column""" 23 | 24 | name: str 25 | datatype: str 26 | description: Optional[str] = field(default=None) 27 | cardinality: Optional[int] = field(default=None) 28 | unique_values: Optional[int] = field(default=None) 29 | 30 | def to_str(self) -> str: 31 | """ 32 | Convert Column object to string 33 | Returns: 34 | str: String representation of the Column object 35 | """ 36 | return f"Column(name={self.name}, datatype={self.datatype}, description={self.description}, cardinality={self.cardinality}, unique_values={self.unique_values})" 37 | 38 | 39 | @dataclass 40 | class Table: 41 | """ 42 | Class to represent a table 43 | Attributes: 44 | name (str): Name of the table 45 | description (Optional[str]): Description of the table 46 | columns (List[Column]): List of columns in the table 47 | """ 48 | 49 | name: str 50 | description: Optional[str] = field(default=None) 51 | columns: List[Column] = field(default_factory=list) 52 | 53 | def to_str(self) -> str: 54 | return f"{self.name} : {self.description}" 55 | 56 | 57 | class BaseDataCatalog(ABC): 58 | """Abstract base class responsible for defining Data Catalog""" 59 | 60 | @abstractmethod 61 | def get_table_schemas(self, table_list: List[str]) -> Dict[str, str]: 62 | """Get Table Schemas for a given list of tables 63 | Args: 64 | table_list (List[str]): List of table names 65 | Returns: 66 | Dict[str, str]: Dictionary of table schemas {table_name: table_schema}""" 67 | raise NotImplementedError 68 | 69 | @abstractmethod 70 | def get_source_tables(self, question: str) -> List[Optional[Table]]: 71 | """Get source tables for the given question 72 | Args: 73 | question (str): Question to be answered 74 | Returns: 75 | List[Dict]: List of tables [{table_name: str, table_description: str}, ...] 76 | """ 77 | raise NotImplementedError 78 | -------------------------------------------------------------------------------- /auto_analyst/config_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | from auto_analyst.llms.openai import OpenAILLM, Model 3 | from auto_analyst.databases.sqlite import SQLLite 4 | from auto_analyst.data_catalog.sample_datacatalog import SampleDataCatalog 5 | 6 | 7 | def parse_config(): 8 | """Parses config and returns database, data catalog, driver LLM, auto_analyst_settings""" 9 | with open("auto_analyst/config.json") as f: 10 | config = json.load(f) 11 | 12 | # Parse LLM config 13 | llm_config = config.get("llms", {}) 14 | 15 | if not llm_config: 16 | raise Exception("LLM config is required") 17 | 18 | driverllm_config = llm_config.get("driverllm", "") 19 | 20 | if not driverllm_config: 21 | raise Exception("Driver LLM config is required") 22 | 23 | data_catalog_config = llm_config.get("data_catalog_llm", "") 24 | 25 | if not data_catalog_config: 26 | raise Exception("Data Catalog LLM config is required") 27 | 28 | if driverllm_config.get("type", "") == "openai": 29 | driver_llm = OpenAILLM( 30 | api_key=driverllm_config.get("api_key", ""), 31 | model=Model[driverllm_config.get("model", "")], 32 | ) 33 | else: 34 | raise Exception("Invalid Driver LLM type") 35 | 36 | if data_catalog_config.get("type", "") == "openai": 37 | data_catalog_llm = OpenAILLM( 38 | api_key=data_catalog_config.get("api_key", ""), 39 | model=Model[data_catalog_config.get("model", "")], 40 | ) 41 | else: 42 | raise Exception("Invalid Data Catalog LLM type") 43 | 44 | # Parse database config 45 | database_config = config.get("database", "") 46 | 47 | if not database_config: 48 | raise Exception("Database config is required") 49 | 50 | if database_config.get("type", "") == "sqlite": 51 | database = SQLLite(database_config.get("database_path", "")) 52 | else: 53 | raise Exception("Invalid database type") 54 | 55 | # Parse data catalog config 56 | data_catalog_config = config.get("data_catalog", "") 57 | 58 | if not data_catalog_config: 59 | raise Exception("Data Catalog config is required") 60 | 61 | if data_catalog_config.get("type", "") == "sample": 62 | data_catalog = SampleDataCatalog(data_catalog_llm) 63 | 64 | auto_analyst_settings = config.get("auto_analyst_settings", {}) 65 | 66 | if not auto_analyst_settings: 67 | raise Exception("Auto Analyst Settings are required") 68 | 69 | return database, data_catalog, driver_llm, auto_analyst_settings 70 | 71 | 72 | def parse_openai_api_key(): 73 | with open("auto_analyst/config.json") as f: 74 | config = json.load(f) 75 | llm_config = config.get("llms", {}) 76 | driverllm_config = llm_config.get("driverllm", "") 77 | return driverllm_config.get("api_key", "") 78 | -------------------------------------------------------------------------------- /auto_analyst/requirements.txt: -------------------------------------------------------------------------------- 1 | aiodns==3.0.0 2 | aiohttp==3.8.4 3 | aiohttp-retry==2.8.3 4 | aiosignal==1.3.1 5 | aleph-alpha-client==2.16.0 6 | altair==4.2.2 7 | appnope==0.1.3 8 | asgiref==3.6.0 9 | asn1crypto==1.5.1 10 | async-timeout==4.0.2 11 | attrs==22.2.0 12 | -e git+https://github.com/aadityaubhat/autoanalyst.git@e8a4001d6e5ad050d4200d984797642e6b8d2833#egg=auto_analyst 13 | backcall==0.2.0 14 | beautifulsoup4==4.12.0 15 | black==23.3.0 16 | blinker==1.6.2 17 | boto3==1.26.106 18 | botocore==1.29.106 19 | bs4==0.0.1 20 | cachetools==5.3.0 21 | certifi==2022.12.7 22 | cffi==1.15.1 23 | cfgv==3.3.1 24 | charset-normalizer==3.0.1 25 | click==8.1.3 26 | dataclasses-json==0.5.7 27 | decorator==5.1.1 28 | distlib==0.3.6 29 | duckdb==0.7.0 30 | entrypoints==0.4 31 | exceptiongroup==1.1.1 32 | executing==1.2.0 33 | filelock==3.12.0 34 | flake8==6.0.0 35 | Flask==2.3.2 36 | Flask-WTF==1.1.1 37 | frozenlist==1.3.3 38 | gitdb==4.0.10 39 | GitPython==3.1.31 40 | greenlet==2.0.2 41 | hydralit-components==1.0.10 42 | identify==2.5.24 43 | idna==3.4 44 | importlib-metadata==6.0.0 45 | iniconfig==2.0.0 46 | itsdangerous==2.1.2 47 | Jinja2==3.1.2 48 | jmespath==1.0.1 49 | jsonschema==4.17.3 50 | lxml==4.9.2 51 | markdown-it-py==2.2.0 52 | MarkupSafe==2.1.2 53 | marshmallow==3.19.0 54 | marshmallow-enum==1.5.1 55 | mccabe==0.7.0 56 | mdurl==0.1.2 57 | multidict==6.0.4 58 | mypy==1.3.0 59 | mypy-extensions==1.0.0 60 | nodeenv==1.8.0 61 | numpy==1.24.2 62 | openai==0.27.0 63 | packaging==23.0 64 | pandas==1.5.3 65 | pathspec==0.11.1 66 | patsy==0.5.3 67 | pickleshare==0.7.5 68 | Pillow==9.4.0 69 | platformdirs==3.5.1 70 | plotly==5.13.1 71 | pluggy==1.0.0 72 | pre-commit==3.3.2 73 | protobuf==3.20.3 74 | psutil==5.9.4 75 | ptyprocess==0.7.0 76 | pure-eval==0.2.2 77 | pyarrow==11.0.0 78 | pycares==4.3.0 79 | pycodestyle==2.10.0 80 | pycparser==2.21 81 | pydantic==1.10.5 82 | pydeck==0.8.0 83 | pyflakes==3.0.1 84 | Pygments==2.14.0 85 | Pympler==1.0.1 86 | pyrsistent==0.19.3 87 | pytest==7.2.2 88 | python-dateutil==2.8.2 89 | pytz==2022.7.1 90 | pytz-deprecation-shim==0.1.0.post0 91 | PyYAML==6.0 92 | pyzmq==25.0.2 93 | redshift-connector==2.0.910 94 | requests==2.28.2 95 | rich==13.3.2 96 | s3transfer==0.6.0 97 | scipy==1.10.1 98 | scramp==1.4.4 99 | semver==2.13.0 100 | six==1.16.0 101 | smmap==5.0.0 102 | soupsieve==2.4 103 | SQLAlchemy==1.4.46 104 | sqlparse==0.4.3 105 | statsmodels==0.13.5 106 | tenacity==8.2.1 107 | tokenizers==0.13.2 108 | toml==0.10.2 109 | tomli==2.0.1 110 | toolz==0.12.0 111 | tornado==6.2 112 | tqdm==4.64.1 113 | traitlets==5.9.0 114 | typing-inspect==0.8.0 115 | typing_extensions==4.5.0 116 | tzdata==2022.7 117 | tzlocal==4.2 118 | urllib3==1.26.14 119 | validators==0.20.0 120 | virtualenv==20.23.0 121 | watchdog==3.0.0 122 | wcwidth==0.2.6 123 | Werkzeug==2.3.4 124 | WTForms==3.0.1 125 | yarl==1.8.2 126 | zipp==3.15.0 127 | -------------------------------------------------------------------------------- /tests/unit_tests/prompts/test_general.py: -------------------------------------------------------------------------------- 1 | from auto_analyst.prompts import general 2 | import pandas as pd 3 | from auto_analyst.data_catalog.base import Table 4 | 5 | 6 | def test_render_type_messages(): 7 | question = "What is the average rating of all the movies?" 8 | messages = general.render_type_messages(question) 9 | assert messages == [ 10 | { 11 | "role": "system", 12 | "content": "You are a helpful assistant that determines whether a question is asking for a SQL query, tabular data or a plot.", 13 | }, 14 | {"role": "user", "content": "How many sales were made in August?"}, 15 | {"role": "assistant", "content": "data"}, 16 | {"role": "user", "content": "Relationship between customer age and time spent"}, 17 | {"role": "assistant", "content": "plot"}, 18 | { 19 | "role": "user", 20 | "content": "Query to get 1000 random customers who live in USA", 21 | }, 22 | {"role": "assistant", "content": "query"}, 23 | {"role": "user", "content": "What is the average amount per transaction?"}, 24 | {"role": "assistant", "content": "data"}, 25 | { 26 | "role": "user", 27 | "content": "1, 7, 14, 28 retention for customer who signed up in August", 28 | }, 29 | {"role": "assistant", "content": "plot"}, 30 | {"role": "user", "content": "Plot the timeseries of ad impressions"}, 31 | {"role": "assistant", "content": "plot"}, 32 | {"role": "user", "content": "Query to get last quarters' sales"}, 33 | {"role": "assistant", "content": "query"}, 34 | {"role": "user", "content": "Top 10 customers by number of transactions"}, 35 | {"role": "assistant", "content": "data"}, 36 | {"role": "user", "content": "Histogram of customer age"}, 37 | {"role": "assistant", "content": "plot"}, 38 | {"role": "user", "content": "What is the average rating of all the movies?"}, 39 | ] 40 | 41 | 42 | def test_render_query_prompt(): 43 | question = "What is the average rating of all the movies?" 44 | source_data = [Table("movie", "Movie table")] 45 | table_Schema = { 46 | "movie": pd.DataFrame( 47 | { 48 | "column": ["id", "name", "rating"], 49 | "description": ["id", "name of the movie", "rating of the movie"], 50 | } 51 | ) 52 | } 53 | analysis_type = "data" 54 | transformation = "" 55 | 56 | prompt = general.render_query_prompt( 57 | question, source_data, table_Schema, analysis_type, transformation 58 | ) 59 | 60 | assert ( 61 | prompt 62 | == """ 63 | Given the following tables 64 | 65 | movie: Movie table 66 | 67 | 68 | With schema: 69 | 70 | movie: column description 71 | id id 72 | name name of the movie 73 | rating rating of the movie 74 | 75 | 76 | 77 | Write a SQL query to answer the following question: 78 | What is the average rating of all the movies?""" 79 | ) 80 | -------------------------------------------------------------------------------- /auto_analyst/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request, jsonify, redirect, url_for 2 | from .auto_analyst import AutoAnalyst 3 | from auto_analyst.databases.sqlite import SQLLite 4 | from flask_wtf.csrf import CSRFProtect 5 | import logging 6 | from logging.handlers import RotatingFileHandler 7 | import os 8 | from .config_parser import parse_config 9 | import json 10 | from .forms import ConfigForm 11 | 12 | 13 | app = Flask(__name__) 14 | # Set up the logger 15 | if not os.path.exists("logs"): 16 | os.mkdir("logs") 17 | file_handler = RotatingFileHandler( 18 | "logs/app.log", maxBytes=1024 * 1024 * 10, backupCount=10 19 | ) 20 | file_handler.setFormatter( 21 | logging.Formatter( 22 | "%(asctime)s %(levelname)s: %(message)s " "[in %(pathname)s:%(lineno)d]" 23 | ) 24 | ) 25 | file_handler.setLevel(logging.INFO) 26 | logging.getLogger().addHandler(file_handler) 27 | logging.getLogger().setLevel(logging.INFO) 28 | app.logger.info("Flaskapp startup") 29 | 30 | # Parse config 31 | database, data_catalog, driver_llm, auto_analyst_settings = parse_config() 32 | query_retry_count = auto_analyst_settings.get("query_retry_count", 0) 33 | 34 | 35 | app.config["UPLOAD_FOLDER"] = "/Users/aadityabhat/Documents/autoanalyst/auto_analyst" 36 | app.config["SECRET_KEY"] = auto_analyst_settings.get("flask_secret_key") 37 | csrf = CSRFProtect() 38 | csrf.init_app(app) 39 | 40 | 41 | auto_analyst = AutoAnalyst( 42 | database=database, 43 | datacatalog=data_catalog, 44 | driver_llm=driver_llm, 45 | query_retry_count=query_retry_count, 46 | ) 47 | 48 | 49 | @app.teardown_appcontext 50 | def close_connection(exception): 51 | SQLLite().close_connection() 52 | 53 | 54 | @app.route("/analyze", methods=["POST"]) 55 | def analyze(): 56 | data = request.get_json(force=True) 57 | question = data.get("question", "") 58 | app.logger.info(f"Question: {question}") 59 | if not question: 60 | return jsonify({"error": "Question is required"}), 400 61 | try: 62 | analysis = auto_analyst.analyze(question) 63 | except Exception as e: 64 | app.logger.error(f"Error: {e}") 65 | return jsonify({"error": str(e)}), 500 66 | app.logger.info(f"Analysis Results: {analysis.to_json()}") 67 | return analysis.to_json() 68 | 69 | 70 | @app.route("/") 71 | def home(): 72 | return render_template("home.html") 73 | 74 | 75 | @app.route("/config", methods=["GET", "POST"]) 76 | def config(): 77 | form = ConfigForm() 78 | config_path = "auto_analyst/config.json" 79 | with open(config_path) as f: 80 | content = json.load(f) 81 | 82 | if form.validate_on_submit(): 83 | file = form.config_file.data 84 | filename = "config.json" 85 | file.save(os.path.join(app.config["UPLOAD_FOLDER"], filename)) 86 | 87 | with open(f"auto_analyst/{filename}") as f: 88 | content = json.load(f) 89 | # After form submission, redirect back to the page with 'config_updated' flag 90 | return redirect(url_for("config", config_updated=True)) 91 | 92 | config_updated = "config_updated" in request.args 93 | # Pass 'config_updated' flag to the template 94 | return render_template( 95 | "config.html", form=form, content=content, config_updated=config_updated 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | app.run(debug=True) 101 | -------------------------------------------------------------------------------- /auto_analyst/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AutoAnalyst 7 | 8 | 9 | 10 | 11 | 12 | 63 | 64 | 65 | 66 | {% include 'partials/_top_bar.html' %} 67 | 68 | {% block content %}{% endblock %} 69 | 70 | 71 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /auto_analyst/data_catalog/sample_datacatalog.py: -------------------------------------------------------------------------------- 1 | from auto_analyst.data_catalog.base import ( 2 | BaseDataCatalog, 3 | Table, 4 | ) 5 | from typing import ( 6 | List, 7 | Dict, 8 | Optional, 9 | ) 10 | import pandas as pd 11 | from auto_analyst.llms.base import BaseLLM 12 | from auto_analyst.prompts.data_catalog import ( 13 | render_source_tables_prompt, 14 | system_prompt, 15 | ) 16 | from auto_analyst.databases.sqlite import SQLLite 17 | import logging 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class SampleDataCatalog(BaseDataCatalog): 23 | """Sample Data Catalog 24 | 25 | Attributes: 26 | llm (BaseLLM): Language Model Manager 27 | db (SQLLite): Database object""" 28 | 29 | def __init__(self, llm: BaseLLM): 30 | """Initialize Sample Data Catalog 31 | 32 | Args: 33 | llm (BaseLLM): Language Model Manager""" 34 | self.llm = llm 35 | self.db = SQLLite() 36 | 37 | def _get_all_tables(self) -> pd.DataFrame: 38 | """Get all tables 39 | Returns: 40 | pd.DataFrame: Dataframe containing all tables""" 41 | df_path = "auto_analyst/databases/sample_data/chinook_tables.csv" 42 | return pd.read_csv(df_path) 43 | 44 | def _get_table_schema(self, table_name: str) -> pd.DataFrame: 45 | """Get table schema 46 | 47 | Args: 48 | table_name (str): Name of the table 49 | Returns: 50 | pd.DataFrame: Dataframe containing table schema""" 51 | return self.db.get_schema(table_name) 52 | 53 | def get_source_tables(self, question: str) -> List[Optional[Table]]: 54 | """ 55 | Get source tables for the given question returns empty list if no tables found 56 | 57 | Args: 58 | question (str): Question to be answered 59 | 60 | Returns: 61 | List[Dict]: List of tables [{table_name: str, table_description: str}, ...] 62 | """ 63 | tables_df = self._get_all_tables() 64 | logger.info(f"Question: {question}") 65 | 66 | # Find the appropriate tables to answer the question 67 | response = self.llm.get_reply( 68 | system_prompt=system_prompt, 69 | prompt=render_source_tables_prompt(question, tables_df), 70 | ) 71 | 72 | table_list: List[Optional[Table]] = [] 73 | 74 | if response.lower().strip() == "no tables found": 75 | return table_list 76 | else: 77 | tables = [tbl.strip() for tbl in response.split(",")] 78 | logger.info(f"Tables: {tables}") 79 | logger.info( 80 | f"Length of tables DF: {len(tables_df[tables_df.table_name.isin(tables)])}" 81 | ) 82 | logger.info( 83 | f"Length of tables DF: {len(tables_df[tables_df.table_name.isin(tables)])}" 84 | ) 85 | 86 | for _, row in tables_df[tables_df.table_name.isin(tables)].iterrows(): 87 | table_list.append( 88 | Table(name=row.table_name, description=row.description) 89 | ) 90 | 91 | return table_list 92 | 93 | def get_table_schemas(self, table_list: List[str]) -> Dict[str, pd.DataFrame]: 94 | """Get schema for schema 95 | 96 | Args: 97 | table_list (List[str]): List of table names 98 | Returns: 99 | Dict[str, pd.DataFrame]: Dictionary of table schemas {table_name: table_schema} 100 | """ 101 | result = {} 102 | for table in table_list: 103 | result[table] = self._get_table_schema(table) 104 | 105 | return result 106 | -------------------------------------------------------------------------------- /tests/unit_tests/data_catalog/test_sample_datacatalog.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from auto_analyst.data_catalog.sample_datacatalog import SampleDataCatalog 3 | from auto_analyst.llms.openai import OpenAILLM, Model 4 | from auto_analyst.config_parser import parse_openai_api_key 5 | import pandas as pd 6 | import pytest 7 | 8 | 9 | @pytest.fixture(scope="module") 10 | def app(): 11 | app = Flask(__name__) 12 | yield app 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def sample_datacatalog(): 17 | llm = OpenAILLM(parse_openai_api_key(), Model.GPT_3_5_TURBO) 18 | sample_datacatalog = SampleDataCatalog(llm) 19 | yield sample_datacatalog 20 | 21 | 22 | @pytest.mark.llm 23 | def test_get_source_tables(sample_datacatalog, app): 24 | with app.app_context(): 25 | question = "What is the total sales by country?" 26 | input_tables = sample_datacatalog.get_source_tables(question) 27 | print(input_tables) 28 | assert ["Customer", "Invoice", "InvoiceLine"] == [d.name for d in input_tables] 29 | 30 | 31 | def test_get_table_schemas(sample_datacatalog, app): 32 | with app.app_context(): 33 | table_list = ["customer"] 34 | table_schemas = sample_datacatalog.get_table_schemas(table_list) 35 | assert table_schemas["customer"].equals( 36 | pd.DataFrame( 37 | { 38 | "cid": [ 39 | 0, 40 | 1, 41 | 2, 42 | 3, 43 | 4, 44 | 5, 45 | 6, 46 | 7, 47 | 8, 48 | 9, 49 | 10, 50 | 11, 51 | 12, 52 | ], 53 | "name": [ 54 | "CustomerId", 55 | "FirstName", 56 | "LastName", 57 | "Company", 58 | "Address", 59 | "City", 60 | "State", 61 | "Country", 62 | "PostalCode", 63 | "Phone", 64 | "Fax", 65 | "Email", 66 | "SupportRepId", 67 | ], 68 | "type": [ 69 | "INTEGER", 70 | "NVARCHAR(40)", 71 | "NVARCHAR(20)", 72 | "NVARCHAR(80)", 73 | "NVARCHAR(70)", 74 | "NVARCHAR(40)", 75 | "NVARCHAR(40)", 76 | "NVARCHAR(40)", 77 | "NVARCHAR(10)", 78 | "NVARCHAR(24)", 79 | "NVARCHAR(24)", 80 | "NVARCHAR(60)", 81 | "INTEGER", 82 | ], 83 | "notnull": [ 84 | 1, 85 | 1, 86 | 1, 87 | 0, 88 | 0, 89 | 0, 90 | 0, 91 | 0, 92 | 0, 93 | 0, 94 | 0, 95 | 1, 96 | 0, 97 | ], 98 | "dflt_value": [ 99 | None, 100 | None, 101 | None, 102 | None, 103 | None, 104 | None, 105 | None, 106 | None, 107 | None, 108 | None, 109 | None, 110 | None, 111 | None, 112 | ], 113 | "pk": [ 114 | 1, 115 | 0, 116 | 0, 117 | 0, 118 | 0, 119 | 0, 120 | 0, 121 | 0, 122 | 0, 123 | 0, 124 | 0, 125 | 0, 126 | 0, 127 | ], 128 | } 129 | ) 130 | ) 131 | -------------------------------------------------------------------------------- /auto_analyst/analysis.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Dict, 3 | Union, 4 | Any, 5 | Optional, 6 | ) 7 | from plotly.graph_objs import Figure 8 | import pandas as pd 9 | from enum import Enum 10 | import uuid 11 | import plotly 12 | 13 | 14 | class AnalysisStatus(Enum): 15 | """Class responsible for defining analysis status""" 16 | 17 | INITIATED = "initiated" 18 | QUESTION_TYPE_DONE = "determined question type" 19 | SOURCE_DATA_DONE = "determined source data" 20 | QUERY_DONE = "query done" 21 | RUNNING_QUERY = "running query" 22 | COMPLETED = "completed" 23 | FAILED = "failed" 24 | 25 | 26 | class AnalysisType(Enum): 27 | """Class responsible for defining analysis type""" 28 | 29 | QUERY = "query" 30 | DATA = "data" 31 | PLOT = "plot" 32 | 33 | 34 | class Analysis: 35 | """Class responsible for defining analysis""" 36 | 37 | instances: Dict[uuid.UUID, "Analysis"] = {} 38 | 39 | def __init__( 40 | self, question: str, analysis_uuid: Union[uuid.UUID, None] = None 41 | ) -> None: 42 | self._question = question 43 | self._analysis_status = AnalysisStatus.INITIATED 44 | self._analysis_type: Optional[AnalysisType] = None 45 | self._metadata: Dict["str", Any] = {} 46 | self._query: Optional[str] = None 47 | self._result_data: Optional[pd.DataFrame] = None 48 | self._result_plot: Optional[Figure] = None 49 | 50 | if analysis_uuid is None: 51 | self._analysis_uuid = uuid.uuid4() 52 | else: 53 | self._analysis_uuid = analysis_uuid 54 | 55 | Analysis.instances[self._analysis_uuid] = self 56 | 57 | @property 58 | def analysis_uuid(self) -> uuid.UUID: 59 | """Get analysis UUID""" 60 | return self._analysis_uuid 61 | 62 | @property 63 | def analysis_status(self) -> AnalysisStatus: 64 | """Get analysis status""" 65 | return self._analysis_status 66 | 67 | @analysis_status.setter 68 | def analysis_status(self, analysis_status: AnalysisStatus) -> None: 69 | """Set analysis status""" 70 | self._analysis_status = analysis_status 71 | 72 | @property 73 | def metadata(self) -> Dict: 74 | """Get metadata""" 75 | return self._metadata 76 | 77 | @metadata.setter 78 | def metadata(self, metadata: Dict) -> None: 79 | """Add metadata to the analysis""" 80 | self._metadata.update(metadata) 81 | 82 | @property 83 | def query(self) -> Optional[str]: 84 | """Get query""" 85 | return self._query 86 | 87 | @query.setter 88 | def query(self, query: str) -> None: 89 | """Add query to the analysis""" 90 | self._query = query 91 | 92 | @property 93 | def result_data(self) -> Union[pd.DataFrame, None]: 94 | """Get result data""" 95 | return self._result_data 96 | 97 | @result_data.setter 98 | def result_data(self, result_data: Union[pd.DataFrame, None]) -> None: 99 | """Add result data to the analysis""" 100 | self._result_data = result_data 101 | 102 | @property 103 | def result_plot(self) -> Union[Figure, None]: 104 | """Get result plot""" 105 | return self._result_plot 106 | 107 | @result_plot.setter 108 | def result_plot(self, result_plot: Union[Figure, None]) -> None: 109 | """Add result plot to the analysis""" 110 | self._result_plot = result_plot 111 | 112 | @property 113 | def analysis_type(self) -> Optional[AnalysisType]: 114 | """Get analysis type""" 115 | return self._analysis_type 116 | 117 | @analysis_type.setter 118 | def analysis_type(self, analysis_type: str) -> None: 119 | """Set analysis type""" 120 | try: 121 | self._analysis_type = AnalysisType(analysis_type) 122 | except ValueError: 123 | raise ValueError(f"Invalid analysis type: {analysis_type}") 124 | 125 | def get_results(self) -> Dict: 126 | if isinstance(self.result_plot, plotly.graph_objs._figure.Figure): 127 | return { 128 | "result": { 129 | "analysis_type": self.analysis_type, 130 | "plot": self.result_plot.to_json(), 131 | "data": self.result_data.to_dict(orient="records"), # type: ignore 132 | "query": self.query, 133 | } 134 | } 135 | elif isinstance(self.result_data, pd.DataFrame): 136 | return { 137 | "result": { 138 | "analysis_type": self.analysis_type, 139 | "data": self.result_data.to_dict(orient="records"), 140 | "query": self.query, 141 | } 142 | } 143 | else: 144 | return { 145 | "result": {"analysis_type": self.analysis_type, "query": self.query} 146 | } 147 | 148 | def to_json(self) -> Dict: 149 | """Convert analysis to JSON""" 150 | return { 151 | "analysis_uuid": self.analysis_uuid, 152 | "analysis_type": self.analysis_type.value if self.analysis_type else None, 153 | "metadata": self.metadata, 154 | "query": self.query, 155 | "result_data": self.result_data.to_dict(orient="records") 156 | if isinstance(self.result_data, pd.DataFrame) 157 | else None, 158 | "result_plot": self.result_plot.to_json() if self.result_plot else None, 159 | } 160 | 161 | @classmethod 162 | def get_instance(cls, analysis_uuid: uuid.UUID) -> "Analysis": 163 | """Get analysis instance""" 164 | return Analysis.instances[analysis_uuid] 165 | -------------------------------------------------------------------------------- /auto_analyst/static/script.js: -------------------------------------------------------------------------------- 1 | // Initialize clipboard for all current and future copy buttons 2 | let clipboard = new ClipboardJS('.btn'); 3 | 4 | clipboard.on('success', function (e) { 5 | console.log(e); 6 | alert('Copied to clipboard!'); 7 | }); 8 | 9 | clipboard.on('error', function (e) { 10 | console.log(e); 11 | alert('Failed to copy.'); 12 | }); 13 | 14 | function addToChat(sender, message, isQuery = false, data = null) { 15 | let messageItem = $('
'); 16 | messageItem.addClass('message'); 17 | messageItem.addClass(sender.toLowerCase()); 18 | 19 | let senderElement = $('').text(sender); 20 | let senderLine = $('
').append(senderElement); 21 | senderLine.addClass('d-flex justify-content-between align-items-center'); 22 | 23 | let copyButton; 24 | if (sender.toLowerCase() === 'autoanalyst' && (isQuery || data)) { 25 | // Create the copy to clipboard button and append it to the sender line 26 | copyButton = $('