├── .gitignore ├── README.md ├── assets ├── Model-Accuracy-Comparison.png ├── Models-VES-Comparison.png ├── agent_flow.png ├── agent_server.png ├── architecture.png ├── eval_result_filtered.png └── execution_guided_decoding.png ├── examples ├── agent_flow.png ├── agent_server.ipynb ├── agents.ipynb ├── datasets.ipynb ├── error_dataset.ipynb ├── evaluation.ipynb ├── finetuning.ipynb ├── generators.ipynb ├── lora_tuning.py └── simple_pipeline.ipynb ├── premsql ├── __init__.py ├── agents │ ├── __init__.py │ ├── base.py │ ├── baseline │ │ ├── __init__.py │ │ ├── main.py │ │ ├── prompts.py │ │ └── workers │ │ │ ├── __init__.py │ │ │ ├── analyser.py │ │ │ ├── followup.py │ │ │ ├── plotter.py │ │ │ └── text2sql.py │ ├── memory.py │ ├── models.py │ ├── router.py │ ├── tools │ │ ├── __init__.py │ │ └── plot │ │ │ ├── base.py │ │ │ └── matplotlib_tool.py │ └── utils.py ├── cli.py ├── datasets │ ├── __init__.py │ ├── base.py │ ├── collator.py │ ├── error_dataset.py │ ├── real │ │ ├── bird.py │ │ ├── domains.py │ │ └── spider.py │ └── synthetic │ │ └── gretel.py ├── evaluator │ ├── README.md │ ├── __init__.py │ └── base.py ├── executors │ ├── __init__.py │ ├── base.py │ ├── from_langchain.py │ └── from_sqlite.py ├── generators │ ├── __init__.py │ ├── base.py │ ├── huggingface.py │ ├── mlx.py │ ├── ollama_model.py │ ├── openai.py │ └── premai.py ├── logger.py ├── playground │ ├── __init__.py │ ├── backend │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── admin.py │ │ │ ├── apps.py │ │ │ ├── migrations │ │ │ │ ├── 0001_initial.py │ │ │ │ └── __init__.py │ │ │ ├── models.py │ │ │ ├── pydantic_models.py │ │ │ ├── serializers.py │ │ │ ├── services.py │ │ │ ├── tests.py │ │ │ ├── urls.py │ │ │ ├── utils.py │ │ │ └── views.py │ │ ├── backend │ │ │ ├── __init__.py │ │ │ ├── asgi.py │ │ │ ├── settings.py │ │ │ ├── urls.py │ │ │ └── wsgi.py │ │ ├── backend_client.py │ │ └── manage.py │ ├── frontend │ │ ├── components │ │ │ ├── chat.py │ │ │ ├── session.py │ │ │ ├── streamlit_plot.py │ │ │ └── uploader.py │ │ ├── main.py │ │ └── utils.py │ └── inference_server │ │ ├── api_client.py │ │ └── service.py ├── prompts.py ├── tuner │ ├── __init__.py │ ├── callback.py │ ├── config.py │ ├── full.py │ └── peft.py └── utils.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | experiments 3 | output 4 | test.py 5 | exps 6 | 7 | 8 | # Python specific 9 | *.pyc 10 | *.pyo 11 | __pycache__/ 12 | 13 | # Virtual environments 14 | venv/ 15 | env/ 16 | env.bak/ 17 | env1/ 18 | env2/ 19 | .env 20 | 21 | # IDE specific 22 | .idea/ 23 | .vscode/ 24 | 25 | # Compiled files 26 | *.pyc 27 | *.pyo 28 | *.pyd 29 | *.so 30 | *.dll 31 | *.exe 32 | *.out 33 | *.pyc 34 | *.whl 35 | 36 | # Logs and databases 37 | *.log 38 | *.sqlite3 39 | *.db 40 | 41 | # Data science and ML specific 42 | data/ 43 | models/ 44 | *.h5 45 | *.pkl 46 | *.joblib 47 | 48 | # Jupyter Notebook specific 49 | .ipynb_checkpoints/ 50 | -------------------------------------------------------------------------------- /assets/Model-Accuracy-Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/Model-Accuracy-Comparison.png -------------------------------------------------------------------------------- /assets/Models-VES-Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/Models-VES-Comparison.png -------------------------------------------------------------------------------- /assets/agent_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/agent_flow.png -------------------------------------------------------------------------------- /assets/agent_server.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/agent_server.png -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/architecture.png -------------------------------------------------------------------------------- /assets/eval_result_filtered.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/eval_result_filtered.png -------------------------------------------------------------------------------- /assets/execution_guided_decoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/assets/execution_guided_decoding.png -------------------------------------------------------------------------------- /examples/agent_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/examples/agent_flow.png -------------------------------------------------------------------------------- /examples/agent_server.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Users/anindya/personal/PremSQL/v2_agent/premsql\n" 13 | ] 14 | }, 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/Users/anindya/Library/Caches/pypoetry/virtualenvs/text2sql-jLjiS8B5-py3.11/lib/python3.11/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", 20 | " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "cd .." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import random" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "[7546]" 46 | ] 47 | }, 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "random.sample(range(7000, 9000), k=1)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "8194" 66 | ] 67 | }, 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "random.choice(range(7000, 9000))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "from premsql.generators import Text2SQLGeneratorOpenAI\n", 84 | "\n", 85 | "Text2SQLGeneratorOpenAI(openai_api_key=)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "Create a file named `serve.py` (or it could be anything) and add the following lines there:\n", 93 | "\n", 94 | "```Python\n", 95 | "from premsql.playground import AgentServer\n", 96 | "from premsql.agents import BaseLineAgent\n", 97 | "from premsql.generators import Text2SQLGeneratorMLX\n", 98 | "from premsql.executors import ExecutorUsingLangChain\n", 99 | "from premsql.agents.tools import SimpleMatplotlibTool\n", 100 | "\n", 101 | "db_connection_uri = (\n", 102 | " \"sqlite://///Users/anindya/personal/PremSQL/v2_agent/premsql/codebase_community.sqlite\"\n", 103 | ")\n", 104 | "text2sql_model = Text2SQLGeneratorMLX(\n", 105 | " model_name_or_path=\"premai-io/prem-1B-SQL\", experiment_name=\"text2sql_model\", type=\"test\"\n", 106 | ")\n", 107 | "\n", 108 | "analyser_plotter_model = Text2SQLGeneratorMLX(\n", 109 | " model_name_or_path=\"meta-llama/Llama-3.2-1B-Instruct\", experiment_name=\"analyser_model\", type=\"test\",\n", 110 | ")\n", 111 | "\n", 112 | "baseline = BaseLineAgent(\n", 113 | " session_name=\"local_db_rag\", # An unique session name must be put\n", 114 | " db_connection_uri=db_connection_uri, # DB which needs to connect for Text to SQL \n", 115 | " specialized_model1=text2sql_model, # This referes to the Text to SQL model\n", 116 | " specialized_model2=analyser_plotter_model, # This refers to any model other than Text to SQL\n", 117 | " executor=ExecutorUsingLangChain(), # Which DB executor to use\n", 118 | " auto_filter_tables=False, # Whether to filter tables before Text to SQL or not (uses LLM)\n", 119 | " plot_tool=SimpleMatplotlibTool() # Matplotlib Tool which will be used by plotter worker\n", 120 | ")\n", 121 | "\n", 122 | "agent_server = AgentServer(agent=baseline, port=8263)\n", 123 | "agent_server.launch()\n", 124 | "```\n", 125 | "\n", 126 | "After this just run:\n", 127 | "\n", 128 | "```bash\n", 129 | "python serve.py\n", 130 | "```\n", 131 | "\n", 132 | "You will see a FastAPI server got started at your mentioned port with the following output:\n", 133 | "\n", 134 | "```bash\n", 135 | "INFO: Started server process [78518]\n", 136 | "INFO: Waiting for application startup.\n", 137 | "2024-10-28 00:29:46,953 - [FASTAPI-INFERENCE-SERVICE] - INFO - Starting up the application\n", 138 | "INFO: Application startup complete.\n", 139 | "INFO: Uvicorn running on http://0.0.0.0:8263 (Press CTRL+C to quit)\n", 140 | "```\n", 141 | "\n", 142 | "This means that our server has started now we can query it with our Terminal using Curl or Python requests or Javascript axios. " 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 10, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "from premsql.playground import InferenceServerAPIClient\n", 152 | "from premsql.agents.tools import SimpleMatplotlibTool" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "text2sql-jLjiS8B5-py3.11", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.11.10" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /examples/lora_tuning.py: -------------------------------------------------------------------------------- 1 | from premsql.datasets import ( 2 | BirdDataset, 3 | DomainsDataset, 4 | GretelAIDataset, 5 | SpiderUnifiedDataset, 6 | Text2SQLDataset, 7 | ) 8 | from premsql.datasets.error_dataset import ErrorDatasetGenerator 9 | from premsql.executors.from_sqlite import SQLiteExecutor 10 | from premsql.tuner.peft import Text2SQLPeftTuner 11 | 12 | path = "/root/anindya/text2sql/data" 13 | model_name_or_path = "premai-io/prem-1B-SQL" 14 | 15 | bird_train = BirdDataset( 16 | split="train", 17 | dataset_folder=path, 18 | ).setup_dataset( 19 | num_rows=100, 20 | ) 21 | 22 | spider_train = SpiderUnifiedDataset( 23 | split="train", dataset_folder="./data" 24 | ).setup_dataset(num_rows=100) 25 | 26 | domains_dataset = DomainsDataset( 27 | split="train", 28 | dataset_folder="./data", 29 | ).setup_dataset(num_rows=100) 30 | 31 | gertelai_dataset = GretelAIDataset( 32 | split="train", 33 | dataset_folder="./data", 34 | ).setup_dataset(num_rows=100) 35 | 36 | existing_error_dataset = ErrorDatasetGenerator.from_existing( 37 | experiment_name="testing_error_gen" 38 | ) 39 | merged_dataset = [ 40 | *spider_train, 41 | *bird_train, 42 | *domains_dataset, 43 | *gertelai_dataset, 44 | *existing_error_dataset, 45 | ] 46 | bird_dev = Text2SQLDataset( 47 | dataset_name="bird", 48 | split="validation", 49 | dataset_folder=path, 50 | ).setup_dataset(num_rows=10, filter_by=("difficulty", "challenging")) 51 | 52 | tuner = Text2SQLPeftTuner( 53 | model_name_or_path=model_name_or_path, experiment_name="lora_tuning" 54 | ) 55 | 56 | tuner.train( 57 | train_datasets=merged_dataset, 58 | output_dir="./output", 59 | num_train_epochs=1, 60 | per_device_train_batch_size=1, 61 | gradient_accumulation_steps=1, 62 | evaluation_dataset=bird_dev, 63 | eval_steps=100, 64 | max_seq_length=1024, 65 | executor=SQLiteExecutor(), 66 | filter_eval_results_by=("difficulty", "challenging"), 67 | ) 68 | -------------------------------------------------------------------------------- /premsql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/premsql/__init__.py -------------------------------------------------------------------------------- /premsql/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.agents.baseline.main import BaseLineAgent 2 | from premsql.agents.memory import AgentInteractionMemory 3 | 4 | __all__ = ["BaseLineAgent", "AgentInteractionMemory"] -------------------------------------------------------------------------------- /premsql/agents/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Union 3 | 4 | import pandas as pd 5 | 6 | from premsql.executors.base import BaseExecutor 7 | from premsql.executors.from_langchain import SQLDatabase 8 | from premsql.generators.base import Text2SQLGeneratorBase 9 | from premsql.logger import setup_console_logger 10 | from premsql.agents.memory import AgentInteractionMemory 11 | from premsql.agents.models import ( 12 | AgentOutput, 13 | AnalyserWorkerOutput, 14 | ChartPlotWorkerOutput, 15 | ExitWorkerOutput, 16 | RouterWorkerOutput, 17 | Text2SQLWorkerOutput, 18 | ) 19 | 20 | logger = setup_console_logger("[PIPELINE-BASE]") 21 | 22 | 23 | # If someone wants to make a new worker class 24 | class WorkerBase(ABC): 25 | @abstractmethod 26 | def run(self): 27 | return NotImplementedError() 28 | 29 | 30 | class AnalysisWorkerBase(ABC): 31 | @abstractmethod 32 | def run( 33 | self, question: str, input_dataframe: Optional[pd.DataFrame] = None 34 | ) -> AnalyserWorkerOutput: 35 | raise NotImplementedError 36 | 37 | 38 | class ChartPlotWorkerBase(ABC): 39 | @abstractmethod 40 | def run( 41 | self, question: str, input_dataframe: Optional[pd.DataFrame] = None 42 | ) -> ChartPlotWorkerOutput: 43 | raise NotImplementedError 44 | 45 | 46 | class RouterWorkerBase(ABC): 47 | @abstractmethod 48 | def run( 49 | self, question: str, input_dataframe: Optional[pd.DataFrame] = None 50 | ) -> RouterWorkerOutput: 51 | raise NotImplementedError 52 | 53 | 54 | class Text2SQLWorkerBase(ABC): 55 | def __init__( 56 | self, 57 | db_connection_uri: str, 58 | generator: Text2SQLGeneratorBase, 59 | executor: BaseExecutor, 60 | include_tables: Optional[str] = None, 61 | exclude_tables: Optional[str] = None, 62 | ) -> None: 63 | 64 | self.generator, self.executor = generator, executor 65 | self.db_connection_uri = db_connection_uri 66 | self.db = self.initialize_database( 67 | db_connection_uri=db_connection_uri, 68 | include_tables=include_tables, 69 | exclude_tables=exclude_tables, 70 | ) 71 | 72 | @abstractmethod 73 | def run(self, question: str, **kwargs) -> Text2SQLWorkerOutput: 74 | raise NotImplementedError 75 | 76 | def initialize_database( 77 | self, 78 | db_connection_uri: str, 79 | include_tables: Optional[list] = None, 80 | exclude_tables: Optional[list] = None, 81 | ) -> SQLDatabase: 82 | """This method should return a db object 83 | 84 | To customise this method you make a different db object but 85 | it should have similar methods and behaviour like 86 | langchain SQLDatbase. You can find the implementation of SQLDatabase 87 | here: https://api.python.langchain.com/en/latest/_modules/langchain_community/utilities/sql_database.html#SQLDatabase 88 | """ 89 | try: 90 | return SQLDatabase.from_uri( 91 | database_uri=db_connection_uri, 92 | sample_rows_in_table_info=0, 93 | ignore_tables=exclude_tables, 94 | include_tables=include_tables 95 | ) 96 | except Exception as e: 97 | logger.error(f"Error loading the database: {e}") 98 | raise RuntimeError(f"Error loading the database: {e}") 99 | 100 | 101 | class AgentBase(ABC): 102 | def __init__( 103 | self, 104 | session_name: str, 105 | db_connection_uri: str, 106 | session_db_path: Optional[str] = None, 107 | route_worker_kwargs: Optional[dict] = None, 108 | ) -> None: 109 | self.session_name, self.db_connection_uri = session_name, db_connection_uri 110 | self.history = AgentInteractionMemory( 111 | session_name=session_name, db_path=session_db_path 112 | ) 113 | self.session_db_path = self.history.db_path 114 | self.route_worker_kwargs = route_worker_kwargs 115 | 116 | @abstractmethod 117 | def run( 118 | self, 119 | question: str, 120 | input_dataframe: Optional[dict] = None, 121 | server_mode: Optional[bool] = False, 122 | ) -> Union[ExitWorkerOutput, AgentOutput]: 123 | # Make sure you convert the dataframe to a table 124 | raise NotImplementedError() 125 | 126 | def convert_exit_output_to_agent_output( 127 | self, exit_output: ExitWorkerOutput 128 | ) -> AgentOutput: 129 | return AgentOutput( 130 | session_name=exit_output.session_name, 131 | question=exit_output.question, 132 | db_connection_uri=exit_output.db_connection_uri, 133 | route_taken=exit_output.route_taken, 134 | input_dataframe=exit_output.sql_input_dataframe 135 | or exit_output.analysis_input_dataframe 136 | or exit_output.plot_input_dataframe, 137 | output_dataframe=exit_output.sql_output_dataframe 138 | or exit_output.plot_output_dataframe, 139 | sql_string=exit_output.sql_string, 140 | analysis=exit_output.analysis, 141 | reasoning=exit_output.sql_reasoning 142 | or exit_output.analysis_reasoning 143 | or exit_output.plot_reasoning, 144 | plot_config=exit_output.plot_config, 145 | image_to_plot=exit_output.image_to_plot, 146 | followup_route=exit_output.followup_route_to_take, 147 | followup_suggestion=exit_output.followup_suggestion, 148 | error_from_pipeline=( 149 | exit_output.error_from_sql_worker 150 | or exit_output.error_from_analysis_worker 151 | or exit_output.error_from_plot_worker 152 | or exit_output.error_from_followup_worker 153 | ), 154 | ) 155 | 156 | def __call__( 157 | self, 158 | question: str, 159 | input_dataframe: Optional[dict] = None, 160 | server_mode: Optional[bool] = False, 161 | ) -> Union[ExitWorkerOutput, AgentOutput]: 162 | if server_mode: 163 | kwargs = self.route_worker_kwargs.get("plot", None) 164 | kwargs = ( 165 | {"plot_image": False} 166 | if kwargs is None 167 | else {**kwargs, "plot_image": False} 168 | ) 169 | self.route_worker_kwargs["plot"] = kwargs 170 | 171 | output = self.run(question=question, input_dataframe=input_dataframe) 172 | # TODO: Watch out dict here type mismatch with run 173 | self.history.push(output=output) 174 | if server_mode: 175 | output = self.convert_exit_output_to_agent_output(exit_output=output) 176 | return output 177 | -------------------------------------------------------------------------------- /premsql/agents/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.agents.baseline.workers import ( 2 | BaseLineAnalyserWorker, 3 | BaseLineFollowupWorker, 4 | BaseLinePlotWorker, 5 | BaseLineText2SQLWorker, 6 | ) 7 | 8 | __all__ = [ 9 | "BaseLineAnalyserWorker", 10 | "BaseLineFollowupWorker", 11 | "BaseLinePlotWorker", 12 | "BaseLineText2SQLWorker", 13 | ] 14 | -------------------------------------------------------------------------------- /premsql/agents/baseline/main.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import pandas as pd 4 | 5 | from premsql.executors.base import BaseExecutor 6 | from premsql.generators.base import Text2SQLGeneratorBase 7 | from premsql.agents.base import AgentBase, ExitWorkerOutput 8 | from premsql.agents.baseline.workers import ( 9 | BaseLineAnalyserWorker, 10 | BaseLineFollowupWorker, 11 | BaseLinePlotWorker, 12 | BaseLineText2SQLWorker, 13 | ) 14 | from premsql.agents.router import SimpleRouterWorker 15 | from premsql.agents.tools.plot.base import BasePlotTool 16 | 17 | # TODO: Should the name be changed from baseline to eda or autoeda? 18 | 19 | 20 | class BaseLineAgent(AgentBase): 21 | def __init__( 22 | self, 23 | session_name: str, 24 | db_connection_uri: str, 25 | specialized_model1: Text2SQLGeneratorBase, 26 | specialized_model2: Text2SQLGeneratorBase, 27 | executor: BaseExecutor, 28 | plot_tool: BasePlotTool, 29 | session_db_path: Optional[str] = None, 30 | include_tables: Optional[list] = None, 31 | exclude_tables: Optional[list] = None, 32 | auto_filter_tables: Optional[bool] = False, 33 | route_worker_kwargs: Optional[dict] = {}, 34 | ) -> None: 35 | super().__init__( 36 | session_name=session_name, 37 | db_connection_uri=db_connection_uri, 38 | session_db_path=session_db_path, 39 | route_worker_kwargs=route_worker_kwargs, 40 | ) 41 | self.text2sql_worker = BaseLineText2SQLWorker( 42 | db_connection_uri=db_connection_uri, 43 | generator=specialized_model1, 44 | helper_model=specialized_model1, 45 | executor=executor, 46 | include_tables=include_tables, 47 | exclude_tables=exclude_tables, 48 | auto_filter_tables=auto_filter_tables, 49 | ) 50 | self.analysis_worker = BaseLineAnalyserWorker(generator=specialized_model2) 51 | self.plotter_worker = BaseLinePlotWorker( 52 | generator=specialized_model2, plot_tool=plot_tool 53 | ) 54 | self.followup_worker = BaseLineFollowupWorker(generator=specialized_model2) 55 | self.router = SimpleRouterWorker() 56 | 57 | def run( 58 | self, question: str, input_dataframe: Optional[pd.DataFrame] = None 59 | ) -> ExitWorkerOutput: 60 | decision = self.router.run(question=question, input_dataframe=input_dataframe) 61 | dataframe_from_history = None 62 | # TODO: This is an assumption that the output tables will be in last 63 | # 10 conversation 64 | 65 | history_entries = self.history.get(limit=10) 66 | for entry in history_entries: 67 | content = entry["message"] 68 | df = content.show_output_dataframe() 69 | if df is not None and len(df) > 0: 70 | dataframe_from_history = content.show_output_dataframe() 71 | break 72 | 73 | if decision.route_to in ["query", "analyse", "plot"]: 74 | worker_output = self._execute_worker( 75 | question=question, 76 | route_to=decision.route_to, 77 | input_dataframe=input_dataframe, 78 | dataframe_from_history=dataframe_from_history, 79 | ) 80 | exit_output = self._create_exit_worker_output( 81 | question=question, 82 | route_taken=decision.route_to, 83 | worker_output=worker_output, 84 | ) 85 | if any( 86 | [ 87 | exit_output.error_from_analysis_worker, 88 | exit_output.error_from_plot_worker, 89 | exit_output.error_from_sql_worker, 90 | ] 91 | ): 92 | followup_output = self._handle_followup(exit_output) 93 | exit_output.followup_suggestion = followup_output.suggestion 94 | exit_output.followup_route_to_take = ( 95 | followup_output.alternative_route or "query" 96 | ) # This is the default route 97 | exit_output.error_from_followup_worker = ( 98 | followup_output.error_from_model 99 | ) 100 | else: 101 | exit_output = self._handle_followup_route(question=question) 102 | return exit_output 103 | 104 | def _execute_worker( 105 | self, 106 | question: str, 107 | route_to: str, 108 | input_dataframe: Optional[pd.DataFrame], 109 | dataframe_from_history: Optional[pd.DataFrame], 110 | ): 111 | decision_mappign = { 112 | "query": lambda: self.text2sql_worker.run( 113 | question=question, 114 | render_results_using="json", 115 | **self.route_worker_kwargs.get("query", {}) 116 | ), 117 | "analyse": lambda: self.analysis_worker.run( 118 | question=question, 119 | input_dataframe=( 120 | dataframe_from_history 121 | if input_dataframe is None 122 | else input_dataframe 123 | ), 124 | **self.route_worker_kwargs.get("analyse", {}) 125 | ), 126 | "plot": lambda: self.plotter_worker.run( 127 | question=question, 128 | input_dataframe=( 129 | dataframe_from_history 130 | if input_dataframe is None 131 | else input_dataframe 132 | ), 133 | **self.route_worker_kwargs.get("plot", {}) 134 | ), 135 | } 136 | return decision_mappign[route_to]() 137 | 138 | def _create_exit_worker_output( 139 | self, 140 | question: str, 141 | route_taken: str, 142 | worker_output: Any, # TODO: change it Literal of worker fixed outputs 143 | ) -> ExitWorkerOutput: 144 | exit_output = ExitWorkerOutput( 145 | session_name=self.session_name, 146 | question=question, 147 | route_taken=route_taken, 148 | db_connection_uri=self.db_connection_uri, 149 | additional_input=getattr(worker_output, "additional_input", None), 150 | ) 151 | if route_taken == "query": 152 | exit_output.sql_string = worker_output.sql_string 153 | exit_output.sql_reasoning = worker_output.sql_reasoning 154 | exit_output.sql_output_dataframe = worker_output.output_dataframe 155 | exit_output.error_from_sql_worker = worker_output.error_from_model 156 | 157 | elif route_taken == "analyse": 158 | exit_output.analysis = worker_output.analysis 159 | exit_output.analysis_reasoning = worker_output.analysis_reasoning 160 | exit_output.analysis_input_dataframe = worker_output.input_dataframe 161 | exit_output.error_from_analysis_worker = worker_output.error_from_model 162 | 163 | elif route_taken == "plot": 164 | exit_output.plot_config = worker_output.plot_config 165 | exit_output.plot_input_dataframe = worker_output.input_dataframe 166 | exit_output.plot_output_dataframe = worker_output.output_dataframe 167 | exit_output.image_to_plot = worker_output.image_plot 168 | exit_output.plot_reasoning = worker_output.plot_reasoning 169 | exit_output.error_from_plot_worker = worker_output.error_from_model 170 | 171 | return exit_output 172 | 173 | def _handle_followup(self, prev_output: ExitWorkerOutput): 174 | return self.followup_worker.run( 175 | prev_output=prev_output, 176 | db_schema=self.text2sql_worker.db.get_context()["table_info"], 177 | user_feedback=None, 178 | ) 179 | 180 | def _handle_followup_route(self, question: str) -> ExitWorkerOutput: 181 | history_entries = self.history.get() 182 | if len(history_entries) == 0: 183 | return ExitWorkerOutput( 184 | session_name=self.session_name, 185 | question=question, 186 | route_taken="followup", 187 | db_connection_uri=self.db_connection_uri, 188 | additional_input=None, 189 | followup_suggestion="Before Writing a followup please either query / analyse / plot", 190 | followup_route_to_take="query", 191 | error_from_followup_worker=None, 192 | ) 193 | else: 194 | followup_output = self.followup_worker.run( 195 | prev_output=self.history.get(limit=1)[0]["message"], 196 | user_feedback=question, 197 | db_schema=self.text2sql_worker.db.get_context()["table_info"], 198 | **self.route_worker_kwargs.get("followup", {}) 199 | ) 200 | return ExitWorkerOutput( 201 | session_name=self.session_name, 202 | question=question, 203 | route_taken="followup", 204 | db_connection_uri=self.db_connection_uri, 205 | additional_input=None, 206 | followup_suggestion=followup_output.suggestion, 207 | followup_route_to_take=followup_output.alternative_route 208 | or "query", # query should alaways be the default route 209 | error_from_followup_worker=followup_output.error_from_model, 210 | ) 211 | -------------------------------------------------------------------------------- /premsql/agents/baseline/workers/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.agents.baseline.workers.analyser import BaseLineAnalyserWorker 2 | from premsql.agents.baseline.workers.followup import BaseLineFollowupWorker 3 | from premsql.agents.baseline.workers.plotter import BaseLinePlotWorker 4 | from premsql.agents.baseline.workers.text2sql import BaseLineText2SQLWorker 5 | 6 | __all__ = [ 7 | "BaseLineText2SQLWorker", 8 | "BaseLineAnalyserWorker", 9 | "BaseLinePlotWorker", 10 | "BaseLineFollowupWorker", 11 | ] 12 | -------------------------------------------------------------------------------- /premsql/agents/baseline/workers/analyser.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pandas as pd 4 | from tqdm.auto import tqdm 5 | 6 | from premsql.generators.base import Text2SQLGeneratorBase 7 | from premsql.logger import setup_console_logger 8 | from premsql.agents.base import AnalyserWorkerOutput, AnalysisWorkerBase 9 | from premsql.agents.baseline.prompts import ( 10 | BASELINE_ANALYSIS_MERGER_PROMPT, 11 | BASELINE_ANALYSIS_WORKER_PROMPT, 12 | ) 13 | from premsql.agents.utils import convert_df_to_dict 14 | 15 | logger = setup_console_logger("[BASELINE-ANALYSER-WORKER]") 16 | 17 | CHUNK_TEMPLATE = """ 18 | # Analysis: 19 | {analysis} 20 | 21 | # Reasoning 22 | {reasoning} 23 | """ 24 | 25 | # TODO: Need to think of the case when there is no df being passed 26 | 27 | 28 | class BaseLineAnalyserWorker(AnalysisWorkerBase): 29 | def __init__(self, generator: Text2SQLGeneratorBase) -> None: 30 | self.generator = generator 31 | 32 | def run_chunkwise_analysis( 33 | self, 34 | question: str, 35 | input_dataframe: pd.DataFrame, 36 | chunk_size: Optional[int] = 20, 37 | max_chunks: Optional[int] = 20, 38 | temperature: Optional[float] = 0.19, 39 | max_new_tokens: Optional[int] = 600, 40 | analysis_prompt_template: Optional[str] = BASELINE_ANALYSIS_WORKER_PROMPT, 41 | merger_prompt_template: Optional[str] = BASELINE_ANALYSIS_MERGER_PROMPT, 42 | verbose: Optional[bool] = False, 43 | ) -> tuple[str, str]: 44 | num_chunks = (len(input_dataframe) + chunk_size - 1) // chunk_size 45 | chunks = [ 46 | input_dataframe[i * chunk_size : (i + 1) * chunk_size] 47 | for i in range(num_chunks) 48 | ][:max_chunks] 49 | analysis_list = [] 50 | num_errors = 0 51 | 52 | for i, chunk in tqdm(enumerate(chunks), total=len(chunks)): 53 | analysis, error_from_model = self.analyse( 54 | question=question, 55 | input_dataframe=chunk, 56 | temperature=temperature, 57 | max_new_tokens=max_new_tokens, 58 | prompt_template=analysis_prompt_template, 59 | ) 60 | if error_from_model: 61 | num_errors += 1 62 | logger.error(f"Error while analysing: {i}, Skipping ...") 63 | continue 64 | 65 | if verbose: 66 | logger.info( 67 | CHUNK_TEMPLATE.format( 68 | analysis=analysis["analysis"], 69 | reasoning=analysis["analysis_reasoning"], 70 | ) 71 | ) 72 | analysis_list.append(analysis) 73 | 74 | analysis_list_str = "\n".join( 75 | [ 76 | analysis["analysis"] + " " + analysis["analysis_reasoning"] 77 | for analysis in analysis_list 78 | ] 79 | ) 80 | if num_errors < len(chunks): 81 | summarized_analysis_prompt = merger_prompt_template.format( 82 | analysis=analysis_list_str 83 | ) 84 | summary = self.generator.generate( 85 | data_blob={"prompt": summarized_analysis_prompt}, 86 | temperature=temperature, 87 | max_new_tokens=max_new_tokens, 88 | postprocess=False, 89 | ) 90 | analysis = { 91 | "analysis": summary, 92 | "analysis_reasoning": "Analysis summarised by AI", 93 | } 94 | error_from_model = None 95 | 96 | else: 97 | analysis = { 98 | "analysis": "\n".join( 99 | [ 100 | content["analyse"] if "analyse" in content else "" 101 | for content in analysis_list 102 | ] 103 | ), 104 | "analysis_reasoning": "Appending all the analysis", 105 | } 106 | error_from_model = "Model not able to summarise analysis" 107 | 108 | return analysis, error_from_model 109 | 110 | def analyse( 111 | self, 112 | question: str, 113 | input_dataframe: pd.DataFrame, 114 | temperature: Optional[float] = 0.19, 115 | max_new_tokens: Optional[int] = 512, 116 | prompt_template: Optional[str] = BASELINE_ANALYSIS_WORKER_PROMPT, 117 | ) -> dict: 118 | output = self.generator.generate( 119 | data_blob={ 120 | "prompt": prompt_template.format( 121 | dataframe=str(input_dataframe), question=question 122 | ) 123 | }, 124 | temperature=temperature, 125 | max_new_tokens=max_new_tokens, 126 | postprocess=False, 127 | ) 128 | try: 129 | sections = output.split('# ') 130 | analysis_from_model, reasoning_from_model = '', '' 131 | for section in sections: 132 | if section.startswith('Analysis:'): 133 | analysis_from_model = section.strip() 134 | elif section.startswith('Reasoning:'): 135 | reasoning_from_model = section.strip() 136 | 137 | analysis = { 138 | "analysis": analysis_from_model, 139 | "analysis_reasoning": reasoning_from_model 140 | } 141 | error_from_model = None 142 | except Exception as e: 143 | analysis = { 144 | "analysis": output, 145 | "analysis_reasoning": "Not able to split analysis and reasoning", 146 | } 147 | error_from_model = str(e) 148 | 149 | logger.info(analysis) 150 | logger.info("------------") 151 | logger.info(error_from_model) 152 | 153 | return analysis, error_from_model 154 | 155 | def run( 156 | self, 157 | question: str, 158 | input_dataframe: pd.DataFrame, 159 | do_chunkwise_analysis: Optional[bool] = False, 160 | chunk_size: Optional[int] = 20, 161 | max_chunks: Optional[int] = 20, 162 | temperature: Optional[float] = 0.19, 163 | max_new_tokens: Optional[int] = 600, 164 | analysis_prompt_template: Optional[str] = BASELINE_ANALYSIS_WORKER_PROMPT, 165 | analysis_merger_template: Optional[str] = BASELINE_ANALYSIS_MERGER_PROMPT, 166 | verbose: Optional[bool] = False, 167 | ) -> AnalyserWorkerOutput: 168 | if len(input_dataframe) > chunk_size and do_chunkwise_analysis: 169 | logger.info("Going for chunk wise analysis ...") 170 | analysis, error_from_model = self.run_chunkwise_analysis( 171 | question=question, 172 | input_dataframe=input_dataframe, 173 | chunk_size=chunk_size, 174 | max_chunks=max_chunks, 175 | analysis_prompt_template=analysis_prompt_template, 176 | merger_prompt_template=analysis_merger_template, 177 | temperature=temperature, 178 | max_new_tokens=max_new_tokens, 179 | verbose=verbose, 180 | ) 181 | else: 182 | if len(input_dataframe) > chunk_size: 183 | logger.info( 184 | "Truncating table, you can also choose chunk wise analysis, but it takes more time." 185 | ) 186 | analysis, error_from_model = self.analyse( 187 | question=question, 188 | input_dataframe=input_dataframe.iloc[:chunk_size, :], 189 | temperature=temperature, 190 | max_new_tokens=max_new_tokens, 191 | prompt_template=analysis_prompt_template, 192 | ) 193 | return AnalyserWorkerOutput( 194 | question=question, 195 | input_dataframe=convert_df_to_dict(df=input_dataframe), 196 | analysis=analysis.get("analysis", "Not able to analyse"), 197 | analysis_reasoning=analysis.get("analysis_reasoning", None), 198 | error_from_model=error_from_model, 199 | additional_input={ 200 | "temperature": temperature, 201 | "max_new_tokens": max_new_tokens, 202 | "chunkwise_analysis": do_chunkwise_analysis, 203 | "chunk_size": chunk_size, 204 | "max_chunks": max_chunks, 205 | }, 206 | ) 207 | -------------------------------------------------------------------------------- /premsql/agents/baseline/workers/followup.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pandas as pd 4 | 5 | from premsql.generators.base import Text2SQLGeneratorBase 6 | from premsql.logger import setup_console_logger 7 | from premsql.agents.base import WorkerBase 8 | from premsql.agents.baseline.prompts import BASELINE_FOLLOWUP_WORKER_PROMPT 9 | from premsql.agents.models import ExitWorkerOutput, FollowupWorkerOutput 10 | 11 | logger = setup_console_logger("[BASELINE-FOLLOWUP-WORKER]") 12 | 13 | 14 | class BaseLineFollowupWorker(WorkerBase): 15 | def __init__(self, generator: Text2SQLGeneratorBase) -> None: 16 | self.generator = generator 17 | 18 | def run( 19 | self, 20 | prev_output: ExitWorkerOutput, 21 | db_schema: str, 22 | user_feedback: Optional[str] = None, 23 | prompt_template: Optional[str] = BASELINE_FOLLOWUP_WORKER_PROMPT, 24 | temperature: Optional[float] = 0.18, 25 | max_new_tokens: Optional[int] = 128, 26 | ) -> FollowupWorkerOutput: 27 | if prev_output.route_taken == "query": 28 | error = "\n".join( 29 | filter(None, [prev_output.error_from_sql_worker, user_feedback]) 30 | ) 31 | dataframe = prev_output.sql_output_dataframe 32 | elif prev_output.route_taken == "plot": 33 | error = "\n".join( 34 | filter(None, [prev_output.error_from_plot_worker, user_feedback]) 35 | ) 36 | dataframe = prev_output.plot_input_dataframe 37 | elif prev_output.route_taken == "analyse": 38 | dataframe = prev_output.analysis_input_dataframe 39 | error = "\n".join( 40 | filter(None, [prev_output.error_from_analysis_worker, user_feedback]) 41 | ) 42 | else: 43 | error = user_feedback 44 | dataframe = next( 45 | ( 46 | df 47 | for df in [ 48 | prev_output.sql_output_dataframe, 49 | prev_output.plot_input_dataframe, 50 | prev_output.analysis_input_dataframe, 51 | ] 52 | if df is not None 53 | ), 54 | None, 55 | ) 56 | 57 | if dataframe: 58 | if isinstance(dataframe, dict) and "data" in dataframe and "columns" in dataframe: 59 | dataframe = pd.DataFrame(dataframe["data"], columns=dataframe["columns"]) 60 | elif not isinstance(dataframe, pd.DataFrame): 61 | try: 62 | dataframe = pd.DataFrame(dataframe) 63 | except: 64 | dataframe = None 65 | 66 | prompt = prompt_template.format( 67 | schema=db_schema, 68 | decision=prev_output.route_taken, 69 | question=prev_output.question, 70 | dataframe=dataframe, 71 | analysis=prev_output.analysis, 72 | error_from_model=error, 73 | ) 74 | try: 75 | result = self.generator.generate( 76 | data_blob={"prompt": prompt}, 77 | temperature=temperature, 78 | max_new_tokens=max_new_tokens, 79 | postprocess=False, 80 | ) 81 | result = eval(result.replace("null", "None")) 82 | error_from_model = None 83 | assert "alternate_decision" in result 84 | assert "suggestion" in result 85 | except Exception as e: 86 | result = { 87 | "alternate_decision": prev_output.route_taken, 88 | "suggestion": "Worker unable to generate alternative suggestion", 89 | } 90 | error_from_model = str(e) 91 | 92 | return FollowupWorkerOutput( 93 | question=user_feedback or prev_output.question, 94 | error_from_model=error_from_model, 95 | route_taken=result["alternate_decision"], 96 | suggestion=result["suggestion"], 97 | additional_input={ 98 | "temperature": temperature, 99 | "max_new_tokens": max_new_tokens, 100 | }, 101 | ) 102 | -------------------------------------------------------------------------------- /premsql/agents/baseline/workers/plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pandas as pd 4 | 5 | from premsql.generators.base import Text2SQLGeneratorBase 6 | from premsql.logger import setup_console_logger 7 | from premsql.agents.base import ChartPlotWorkerBase, ChartPlotWorkerOutput 8 | from premsql.agents.baseline.prompts import BASELINE_CHART_WORKER_PROMPT_TEMPLATE 9 | from premsql.agents.tools.plot.base import BasePlotTool 10 | from premsql.agents.utils import convert_df_to_dict 11 | 12 | logger = setup_console_logger("[PLOT-WORKER]") 13 | 14 | 15 | class BaseLinePlotWorker(ChartPlotWorkerBase): 16 | def __init__( 17 | self, generator: Text2SQLGeneratorBase, plot_tool: BasePlotTool 18 | ) -> None: 19 | self.generator, self.plot_tool = generator, plot_tool 20 | 21 | def run( 22 | self, 23 | question: str, 24 | input_dataframe: pd.DataFrame, 25 | temperature: Optional[float] = 0.1, 26 | max_new_tokens: Optional[int] = 100, 27 | plot_image: Optional[bool] = True, 28 | prompt_template: Optional[str] = BASELINE_CHART_WORKER_PROMPT_TEMPLATE, 29 | **kwargs, 30 | ) -> ChartPlotWorkerOutput: 31 | prompt = prompt_template.format( 32 | columns=list(input_dataframe.columns), question=question 33 | ) 34 | try: 35 | logger.info("Going for generation") 36 | to_plot = self.generator.generate( 37 | data_blob={"prompt": prompt}, 38 | temperature=temperature, 39 | max_new_tokens=max_new_tokens, 40 | postprocess=False, 41 | ) 42 | to_plot = to_plot.replace("null", "None") 43 | plot_config = eval(to_plot) 44 | fig = self.plot_tool.run(data=input_dataframe, plot_config=plot_config) 45 | logger.info(f"Plot config: {plot_config}") 46 | 47 | if plot_image: 48 | output = self.plot_tool.convert_image_to_base64( 49 | self.plot_tool.convert_plot_to_image(fig=fig) 50 | ) 51 | logger.info("Done base64 conversion") 52 | else: 53 | output = None 54 | 55 | return ChartPlotWorkerOutput( 56 | question=question, 57 | input_dataframe=convert_df_to_dict(input_dataframe), 58 | plot_config=plot_config, 59 | plot_reasoning=None, 60 | output_dataframe=None, 61 | image_plot=output, 62 | error_from_model=None, 63 | additional_input={ 64 | "temperature": temperature, 65 | "max_new_tokens": max_new_tokens, 66 | **kwargs, 67 | }, 68 | ) 69 | 70 | except Exception as e: 71 | error_message = f"Error during plot generation: {str(e)}" 72 | return ChartPlotWorkerOutput( 73 | question=question, 74 | input_dataframe=convert_df_to_dict(input_dataframe), 75 | plot_config=None, 76 | image_plot=None, 77 | plot_reasoning=None, 78 | error_from_model=error_message, 79 | additional_input={ 80 | "temperature": temperature, 81 | "max_new_tokens": max_new_tokens, 82 | **kwargs, 83 | }, 84 | ) 85 | -------------------------------------------------------------------------------- /premsql/agents/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Dict, Literal, Optional 3 | 4 | import pandas as pd 5 | from pydantic import BaseModel, Field 6 | 7 | from premsql.logger import setup_console_logger 8 | 9 | logger = setup_console_logger("[BASE-MODELS]") 10 | 11 | 12 | class BaseWorkerOutput(BaseModel): 13 | """Base model for worker outputs with common fields.""" 14 | 15 | question: str 16 | error_from_model: Optional[str] = None 17 | additional_input: Optional[Dict] = Field( 18 | default=None, description="Additional input data" 19 | ) 20 | 21 | 22 | class Text2SQLWorkerOutput(BaseWorkerOutput): 23 | """Output model for Text2SQL worker.""" 24 | 25 | db_connection_uri: str 26 | sql_string: str 27 | sql_reasoning: Optional[str] = None 28 | input_dataframe: Optional[Dict] = None 29 | output_dataframe: Optional[Dict] = None 30 | 31 | def show_output_dataframe(self) -> pd.DataFrame: 32 | if self.output_dataframe: 33 | return pd.DataFrame( 34 | self.output_dataframe["data"], columns=self.output_dataframe["columns"] 35 | ) 36 | return pd.DataFrame() 37 | 38 | 39 | class AnalyserWorkerOutput(BaseWorkerOutput): 40 | """Output model for Analyser worker.""" 41 | 42 | analysis: str 43 | input_dataframe: Optional[Dict] = None 44 | analysis_reasoning: Optional[str] = None 45 | 46 | 47 | class ChartPlotWorkerOutput(BaseWorkerOutput): 48 | """Output model for ChartPlot worker.""" 49 | 50 | input_dataframe: Optional[Dict] = None 51 | plot_config: Optional[Dict] = None 52 | image_plot: Optional[str] = None 53 | plot_reasoning: Optional[str] = None 54 | output_dataframe: Optional[Dict] = None 55 | 56 | 57 | class RouterWorkerOutput(BaseWorkerOutput): 58 | """Output model for Router worker.""" 59 | 60 | route_to: Literal["followup", "plot", "analyse", "query"] 61 | input_dataframe: Optional[Dict] = None 62 | decision_reasoning: Optional[str] = None 63 | 64 | 65 | # This is a more of a custom worker 66 | class FollowupWorkerOutput(BaseWorkerOutput): 67 | """Output model for Followup worker.""" 68 | 69 | route_taken: Literal["followup", "plot", "analyse", "query"] 70 | suggestion: str 71 | alternative_route: Optional[Literal["followup", "plot", "analyse", "query"]] = None 72 | 73 | 74 | class ExitWorkerOutput(BaseModel): 75 | """Output model for Exit worker, combining results from all workers.""" 76 | 77 | session_name: str 78 | question: str 79 | db_connection_uri: str 80 | route_taken: Literal["plot", "analyse", "query", "followup"] 81 | 82 | # Text2SQL fields 83 | sql_string: Optional[str] = None 84 | sql_reasoning: Optional[str] = None 85 | sql_input_dataframe: Optional[Dict] = None 86 | sql_output_dataframe: Optional[Dict] = None 87 | error_from_sql_worker: Optional[str] = None 88 | 89 | # Analysis worker fields 90 | analysis: Optional[str] = None 91 | analysis_reasoning: Optional[str] = None 92 | analysis_input_dataframe: Optional[Dict] = None 93 | error_from_analysis_worker: Optional[str] = None 94 | 95 | # Plot Worker fields 96 | plot_config: Optional[Dict] = None 97 | plot_input_dataframe: Optional[Dict] = None 98 | plot_output_dataframe: Optional[Dict] = None 99 | image_to_plot: Optional[str] = None 100 | plot_reasoning: Optional[str] = None 101 | error_from_plot_worker: Optional[str] = None 102 | 103 | # Followup Worker fields 104 | followup_route_to_take: Optional[ 105 | Literal["plot", "analyse", "query", "followup"] 106 | ] = None 107 | followup_suggestion: Optional[str] = None 108 | error_from_followup_worker: Optional[str] = None 109 | 110 | # Additional input 111 | additional_input: Optional[Dict] = Field( 112 | default=None, description="Additional input data" 113 | ) 114 | 115 | def show_output_dataframe( 116 | self, 117 | ) -> pd.DataFrame: 118 | dataframe = None 119 | if self.route_taken == "query": 120 | dataframe = self.sql_output_dataframe 121 | elif self.route_taken == "plot": 122 | dataframe = self.plot_output_dataframe 123 | elif self.route_taken == "analyse": 124 | dataframe = self.analysis_input_dataframe 125 | 126 | if dataframe: 127 | return pd.DataFrame(dataframe["data"], columns=dataframe["columns"]) 128 | return pd.DataFrame() 129 | 130 | 131 | class AgentOutput(BaseModel): 132 | """Final output model for the entire pipeline.""" 133 | 134 | session_name: str 135 | question: str 136 | db_connection_uri: str 137 | route_taken: Literal["plot", "analyse", "query", "followup"] 138 | input_dataframe: Optional[Dict] = None 139 | output_dataframe: Optional[Dict] = None 140 | sql_string: Optional[str] = None 141 | analysis: Optional[str] = None 142 | reasoning: Optional[str] = None 143 | plot_config: Optional[Dict] = None 144 | image_to_plot: Optional[str] = None 145 | followup_route: Optional[Literal["plot", "analyse", "query", "followup"]] = None 146 | followup_suggestion: Optional[str] = None 147 | error_from_pipeline: Optional[str] = None 148 | created_at: datetime = Field(default_factory=datetime.now) 149 | 150 | def show_output_dataframe( 151 | self, 152 | ) -> pd.DataFrame: 153 | dataframe = None 154 | if self.route_taken == "query": 155 | dataframe = self.sql_output_dataframe 156 | elif self.route_taken == "plot": 157 | dataframe = self.plot_output_dataframe 158 | elif self.route_taken == "analyse": 159 | dataframe = self.analysis_input_dataframe 160 | 161 | if dataframe: 162 | return pd.DataFrame(dataframe["data"], columns=dataframe["columns"]) 163 | return pd.DataFrame() 164 | -------------------------------------------------------------------------------- /premsql/agents/router.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pandas as pd 4 | 5 | from premsql.logger import setup_console_logger 6 | from premsql.agents.base import RouterWorkerBase, RouterWorkerOutput 7 | from premsql.agents.utils import convert_df_to_dict 8 | 9 | logger = setup_console_logger("[BASELINE-ROUTER]") 10 | 11 | 12 | class SimpleRouterWorker(RouterWorkerBase): 13 | def run( 14 | self, question: str, input_dataframe: Optional[pd.DataFrame] 15 | ) -> RouterWorkerOutput: 16 | if question.startswith("/query"): 17 | route_to = "query" 18 | elif question.startswith("/analyse"): 19 | route_to = "analyse" 20 | elif question.startswith("/plot"): 21 | route_to = "plot" 22 | else: 23 | route_to = "followup" 24 | logger.info(f"Routing to: {route_to}") 25 | question = ( 26 | question.split(f"/{route_to}")[1] if route_to != "followup" else question 27 | ) 28 | 29 | return RouterWorkerOutput( 30 | question=question, 31 | route_to=route_to, 32 | input_dataframe=( 33 | convert_df_to_dict(df=input_dataframe) if input_dataframe else None 34 | ), 35 | decision_reasoning="Simple routing based on question prefix", 36 | additional_input={}, 37 | error_from_model=None, 38 | ) 39 | -------------------------------------------------------------------------------- /premsql/agents/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.agents.tools.plot.matplotlib_tool import SimpleMatplotlibTool 2 | 3 | __all__ = ["SimpleMatplotlibTool"] 4 | -------------------------------------------------------------------------------- /premsql/agents/tools/plot/base.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | from abc import ABC, abstractmethod 4 | 5 | import pandas as pd 6 | from PIL import Image 7 | 8 | 9 | class BasePlotTool(ABC): 10 | @abstractmethod 11 | def run(self, data: pd.DataFrame, plot_config: dict): 12 | raise NotImplementedError() 13 | 14 | @abstractmethod 15 | def convert_plot_to_image(self, fig): 16 | raise NotImplementedError 17 | 18 | def convert_image_to_base64(self, image: Image.Image) -> str: 19 | buffered = io.BytesIO() 20 | image.save(buffered, format="PNG") 21 | return base64.b64encode(buffered.getvalue()).decode() 22 | 23 | def save_image(self, image: Image.Image, file_path: str, format: str = "PNG"): 24 | image.save(file_path, format=format) 25 | 26 | def plot_from_base64(self, output_base64: str): 27 | image_data = base64.b64decode(output_base64) 28 | return Image.open(io.BytesIO(image_data)) 29 | -------------------------------------------------------------------------------- /premsql/agents/tools/plot/matplotlib_tool.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import Callable, Dict 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | from matplotlib.axes import Axes 7 | from matplotlib.figure import Figure 8 | from PIL import Image 9 | 10 | from premsql.logger import setup_console_logger 11 | from premsql.agents.tools.plot.base import BasePlotTool 12 | 13 | logger = setup_console_logger("[MATPLOTLIB-TOOL]") 14 | 15 | 16 | class SimpleMatplotlibTool(BasePlotTool): 17 | def __init__(self): 18 | self.plot_functions: Dict[ 19 | str, Callable[[pd.DataFrame, str, str, Axes], None] 20 | ] = { 21 | "area": self._area_plot, 22 | "bar": self._bar_plot, 23 | "scatter": self._scatter_plot, 24 | "histogram": self._histogram_plot, 25 | "line": self._line_plot, 26 | } 27 | 28 | def run(self, data: pd.DataFrame, plot_config: Dict[str, str]) -> Figure: 29 | try: 30 | self._validate_config(data, plot_config) 31 | 32 | plot_type = plot_config["plot_type"] 33 | x = plot_config["x"] 34 | y = plot_config["y"] 35 | 36 | fig, ax = plt.subplots(figsize=(10, 6)) 37 | self.plot_functions[plot_type](data, x, y, ax) 38 | 39 | plt.title(f"{plot_type.capitalize()} Plot: {x} vs {y}") 40 | plt.xlabel(x) 41 | plt.ylabel(y) 42 | plt.tight_layout() 43 | 44 | return fig 45 | except Exception as e: 46 | logger.error(f"Error creating plot: {str(e)}") 47 | return plt.figure() # Return an empty figure on error 48 | 49 | def _validate_config(self, df: pd.DataFrame, plot_config: Dict[str, str]) -> None: 50 | required_keys = ["plot_type", "x", "y"] 51 | missing_keys = [key for key in required_keys if key not in plot_config] 52 | if missing_keys: 53 | raise ValueError( 54 | f"Missing required keys in plot_config: {', '.join(missing_keys)}" 55 | ) 56 | 57 | if plot_config["x"] not in df.columns: 58 | raise ValueError(f"Column '{plot_config['x']}' not found in DataFrame") 59 | 60 | if plot_config["y"] not in df.columns: 61 | raise ValueError(f"Column '{plot_config['y']}' not found in DataFrame") 62 | 63 | if plot_config["plot_type"] not in self.plot_functions: 64 | raise ValueError(f"Unsupported plot type: {plot_config['plot_type']}") 65 | 66 | def _area_plot(self, df: pd.DataFrame, x: str, y: str, ax: Axes) -> None: 67 | ax.fill_between(df[x], df[y]) 68 | 69 | def _bar_plot(self, df: pd.DataFrame, x: str, y: str, ax: Axes) -> None: 70 | ax.bar(df[x], df[y]) 71 | 72 | def _scatter_plot(self, df: pd.DataFrame, x: str, y: str, ax: Axes) -> None: 73 | ax.scatter(df[x], df[y]) 74 | 75 | def _histogram_plot(self, df: pd.DataFrame, x: str, y: str, ax: Axes) -> None: 76 | ax.hist(df[x], bins=20) 77 | 78 | def _line_plot(self, df: pd.DataFrame, x: str, y: str, ax: Axes) -> None: 79 | ax.plot(df[x], df[y]) 80 | 81 | def convert_plot_to_image(self, fig: Figure) -> Image.Image: 82 | buf = io.BytesIO() 83 | fig.savefig(buf, format="png") 84 | buf.seek(0) 85 | return Image.open(buf) 86 | -------------------------------------------------------------------------------- /premsql/agents/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Literal 2 | 3 | import pandas as pd 4 | 5 | from premsql.executors.from_langchain import SQLDatabase 6 | from premsql.logger import setup_console_logger 7 | from premsql.agents.models import AgentOutput, ExitWorkerOutput 8 | 9 | logger = setup_console_logger("[PIPELINE-UTILS]") 10 | 11 | 12 | def convert_df_to_dict(df: pd.DataFrame): 13 | return {"data": df.to_dict(), "columns": list(df.keys())} 14 | 15 | 16 | def execute_and_render_result( 17 | db: SQLDatabase, sql: str, using: Literal["dataframe", "json"] 18 | ): 19 | result = db.run_no_throw(command=sql, fetch="cursor") 20 | 21 | if isinstance(result, str): 22 | return _render_error(result, sql, using) 23 | return _render_data(result, sql, using) 24 | 25 | 26 | def _render_error(error: str, sql: str, using: str) -> Dict[str, Any]: 27 | to_show = {"sql_string": sql, "error_from_model": error, "dataframe": None} 28 | 29 | if using == "dataframe": 30 | to_show["dataframe"] = pd.DataFrame() # empty DataFrame 31 | elif using == "json": 32 | to_show["dataframe"] = {"data": {}, "columns": []} # empty JSON structure 33 | return to_show 34 | 35 | 36 | def _render_data(result, sql: str, using: str) -> Dict[str, Any]: 37 | table = pd.DataFrame(data=result.fetchall(), columns=result.keys()) 38 | if len(table) > 200: 39 | logger.info("Truncating output table to first 200 rows only") 40 | table = table.iloc[:200, :] 41 | 42 | if any(table.columns.duplicated()): 43 | logger.info(f"Found duplicate columns: {table.columns[table.columns.duplicated()].tolist()}") 44 | # Create unique column names by adding suffixes 45 | table.columns = [f"{col}_{i}" if i > 0 else col 46 | for i, col in enumerate(table.columns)] 47 | logger.info(f"Renamed columns to: {table.columns.tolist()}") 48 | 49 | to_show = {"sql_string": sql, "error_from_model": None, "dataframe": table} 50 | 51 | if using == "json": 52 | to_show["dataframe"] = {"columns": list(table.columns), "data": table.to_dict()} 53 | return to_show 54 | 55 | 56 | 57 | def convert_exit_output_to_agent_output(exit_output: ExitWorkerOutput) -> AgentOutput: 58 | return AgentOutput( 59 | session_name=exit_output.session_name, 60 | question=exit_output.question, 61 | db_connection_uri=exit_output.db_connection_uri, 62 | route_taken=exit_output.route_taken, 63 | input_dataframe=exit_output.sql_input_dataframe 64 | or exit_output.analysis_input_dataframe 65 | or exit_output.plot_input_dataframe, 66 | output_dataframe=exit_output.sql_output_dataframe 67 | or exit_output.plot_output_dataframe, 68 | sql_string=exit_output.sql_string, 69 | analysis=exit_output.analysis, 70 | reasoning=exit_output.sql_reasoning 71 | or exit_output.analysis_reasoning 72 | or exit_output.plot_reasoning, 73 | plot_config=exit_output.plot_config, 74 | image_to_plot=exit_output.image_to_plot, 75 | followup_route=exit_output.followup_route_to_take, 76 | followup_suggestion=exit_output.followup_suggestion, 77 | error_from_pipeline=( 78 | exit_output.error_from_sql_worker 79 | or exit_output.error_from_analysis_worker 80 | or exit_output.error_from_plot_worker 81 | or exit_output.error_from_followup_worker 82 | ), 83 | ) -------------------------------------------------------------------------------- /premsql/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from pathlib import Path 5 | 6 | import click 7 | 8 | @click.group() 9 | @click.version_option() 10 | def cli(): 11 | """PremSQL CLI to manage API servers and Streamlit app""" 12 | pass 13 | 14 | @cli.group() 15 | def launch(): 16 | """Launch PremSQL services""" 17 | pass 18 | 19 | 20 | @launch.command(name='all') 21 | def launch_all(): 22 | """Launch both API server and Streamlit app""" 23 | premsql_path = Path(__file__).parent.parent.absolute() 24 | env = os.environ.copy() 25 | env["PYTHONPATH"] = str(premsql_path) 26 | 27 | # Start API server 28 | manage_py_path = premsql_path / "premsql" / "playground" / "backend" / "manage.py" 29 | if not manage_py_path.exists(): 30 | click.echo(f"Error: manage.py not found at {manage_py_path}", err=True) 31 | sys.exit(1) 32 | 33 | # Run migrations first 34 | click.echo("Running database migrations...") 35 | try: 36 | subprocess.run([sys.executable, str(manage_py_path), "makemigrations"], env=env, check=True) 37 | subprocess.run([sys.executable, str(manage_py_path), "migrate"], env=env, check=True) 38 | except subprocess.CalledProcessError as e: 39 | click.echo(f"Error running migrations: {e}", err=True) 40 | sys.exit(1) 41 | 42 | click.echo("Starting the PremSQL backend API server...") 43 | subprocess.Popen([sys.executable, str(manage_py_path), "runserver"], env=env) 44 | 45 | # Launch the streamlit app 46 | click.echo("Starting the PremSQL Streamlit app...") 47 | main_py_path = premsql_path / "premsql" / "playground" / "frontend" / "main.py" 48 | if not main_py_path.exists(): 49 | click.echo(f"Error: main.py not found at {main_py_path}", err=True) 50 | sys.exit(1) 51 | 52 | cmd = [sys.executable, "-m", "streamlit", "run", str(main_py_path), "--server.maxUploadSize=500"] 53 | try: 54 | subprocess.run(cmd, env=env, check=True) 55 | except KeyboardInterrupt: 56 | click.echo("Stopping all services...") 57 | stop() 58 | 59 | @launch.command(name='api') 60 | def launch_api(): 61 | """Launch only the API server""" 62 | premsql_path = Path(__file__).parent.parent.absolute() 63 | env = os.environ.copy() 64 | env["PYTHONPATH"] = str(premsql_path) 65 | manage_py_path = premsql_path / "premsql" / "playground" / "backend" / "manage.py" 66 | 67 | if not manage_py_path.exists(): 68 | click.echo(f"Error: manage.py not found at {manage_py_path}", err=True) 69 | sys.exit(1) 70 | 71 | # Run makemigrations 72 | click.echo("Running database migrations...") 73 | try: 74 | subprocess.run([sys.executable, str(manage_py_path), "makemigrations"], env=env, check=True) 75 | subprocess.run([sys.executable, str(manage_py_path), "migrate"], env=env, check=True) 76 | except subprocess.CalledProcessError as e: 77 | click.echo(f"Error running migrations: {e}", err=True) 78 | sys.exit(1) 79 | 80 | click.echo("Starting the PremSQL backend API server...") 81 | cmd = [sys.executable, str(manage_py_path), "runserver"] 82 | try: 83 | subprocess.run(cmd, env=env, check=True) 84 | except KeyboardInterrupt: 85 | click.echo("API server stopped.") 86 | 87 | @cli.command() 88 | def stop(): 89 | """Stop all PremSQL services""" 90 | click.echo("Stopping all PremSQL services...") 91 | 92 | try: 93 | if sys.platform == "win32": 94 | subprocess.run( 95 | ["taskkill", "/F", "/IM", "python.exe", "/FI", "WINDOWTITLE eq premsql*"], 96 | check=True, 97 | ) 98 | else: 99 | subprocess.run(["pkill", "-f", "manage.py runserver"], check=True) 100 | subprocess.run(["pkill", "-f", "streamlit"], check=True) 101 | click.echo("All services stopped successfully.") 102 | except subprocess.CalledProcessError: 103 | click.echo("No running services found.") 104 | except Exception as e: 105 | click.echo(f"Error stopping services: {e}", err=True) 106 | sys.exit(1) 107 | 108 | if __name__ == "__main__": 109 | cli() -------------------------------------------------------------------------------- /premsql/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from premsql.datasets.base import StandardDataset, Text2SQLBaseDataset 5 | from premsql.datasets.real.bird import BirdDataset 6 | from premsql.datasets.real.domains import DomainsDataset 7 | from premsql.datasets.real.spider import SpiderUnifiedDataset 8 | from premsql.datasets.synthetic.gretel import GretelAIDataset 9 | from premsql.utils import get_accepted_filters 10 | 11 | 12 | class Text2SQLDataset: 13 | def __init__( 14 | self, 15 | dataset_name: str, 16 | split: str, 17 | dataset_folder: Optional[Union[str, Path]] = "./data", 18 | hf_token: Optional[str] = None, 19 | force_download: Optional[bool] = False, 20 | **kwargs 21 | ): 22 | assert dataset_name in ["bird", "domains", "spider", "gretel"], ValueError( 23 | "Dataset should be one of bird, domains, spider, gretel" 24 | ) 25 | dataset_mapping = { 26 | "bird": BirdDataset, 27 | "domains": DomainsDataset, 28 | "spider": SpiderUnifiedDataset, 29 | "gretel": GretelAIDataset, 30 | } 31 | self._text2sql_dataset: Text2SQLBaseDataset = dataset_mapping[dataset_name]( 32 | split=split, 33 | dataset_folder=dataset_folder, 34 | hf_token=hf_token, 35 | force_download=force_download, 36 | **kwargs 37 | ) 38 | 39 | @property 40 | def raw_dataset(self): 41 | return self._text2sql_dataset.dataset 42 | 43 | @property 44 | def filter_availables(self): 45 | return get_accepted_filters(data=self._text2sql_dataset.dataset) 46 | 47 | def setup_dataset( 48 | self, 49 | filter_by: tuple | None = None, 50 | num_rows: int | None = None, 51 | num_fewshot: int | None = None, 52 | model_name_or_path: str | None = None, 53 | prompt_template: str | None = None, 54 | tokenize: bool | None = False 55 | ): 56 | return self._text2sql_dataset.setup_dataset( 57 | filter_by=filter_by, 58 | num_rows=num_rows, 59 | model_name_or_path=model_name_or_path, 60 | tokenize=tokenize, 61 | prompt_template=prompt_template, 62 | num_fewshot=num_fewshot 63 | ) 64 | 65 | 66 | __all__ = [ 67 | "StandardDataset", 68 | "GretelAIDataset", 69 | "SpiderUnifiedDataset", 70 | "BirdDataset", 71 | "DomainsDataset", 72 | "Text2SQLDataset", 73 | ] -------------------------------------------------------------------------------- /premsql/datasets/collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Sequence 3 | from premsql.logger import setup_console_logger 4 | 5 | logger = setup_console_logger("[DATASET-COLLATOR]") 6 | 7 | try: 8 | import torch 9 | import transformers 10 | except ImportError: 11 | logger.warn("Ensure torch and transformers are installed.") 12 | logger.warn("Install them by: pip install torch transformers") 13 | 14 | @dataclass 15 | class DataCollatorForSupervisedDataset: 16 | tokenizer: "transformers.PreTrainedTokenizer" 17 | 18 | def __call__(self, instances: Sequence[dict]) -> dict[str, torch.Tensor]: 19 | input_ids, labels = tuple( 20 | [instance[key] for instance in instances] for key in ("input_ids", "labels") 21 | ) 22 | input_ids = torch.nn.utils.rnn.pad_sequence( 23 | input_ids, 24 | batch_first=True, 25 | padding_value=self.tokenizer.pad_token_id, 26 | ) 27 | labels = torch.nn.utils.rnn.pad_sequence( 28 | labels, batch_first=True, padding_value=-100 29 | ) 30 | return dict( 31 | input_ids=input_ids, 32 | labels=labels, 33 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 34 | ) 35 | -------------------------------------------------------------------------------- /premsql/datasets/error_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional, Sequence 4 | 5 | from tqdm.auto import tqdm 6 | 7 | from premsql.datasets.base import ( 8 | SupervisedDatasetForTraining, 9 | Text2SQLBaseDataset, 10 | Text2SQLBaseInstance, 11 | ) 12 | from premsql.evaluator.base import BaseExecutor, Text2SQLEvaluator 13 | from premsql.generators.base import Text2SQLGeneratorBase 14 | from premsql.logger import setup_console_logger 15 | from premsql.prompts import ERROR_HANDLING_PROMPT 16 | 17 | logger = setup_console_logger("[ERROR-HANDLING-DATASET]") 18 | 19 | 20 | class ErrorDatasetInstance(Text2SQLBaseInstance): 21 | 22 | def __init__(self, dataset: list[dict]) -> None: 23 | super().__init__(dataset=dataset) 24 | 25 | def apply_prompt(self, prompt_template: Optional[str] = ERROR_HANDLING_PROMPT): 26 | data_to_return = [] 27 | for content in tqdm( 28 | self.dataset, total=len(self.dataset), desc="Applying error prompt" 29 | ): 30 | assert "error" in content, "key error is not present" 31 | error_msg = content["error"] 32 | 33 | if error_msg is not None: 34 | prompt = content["prompt"].split("# SQL:")[0].strip() 35 | prediction = content["generated"] 36 | error_prompt = prompt_template.format( 37 | existing_prompt=prompt, sql=prediction, error_msg=error_msg 38 | ) 39 | data_to_return.append( 40 | { 41 | "db_id": content["db_id"], 42 | "question": content["question"], 43 | "SQL": content["SQL"], 44 | "prompt": error_prompt, 45 | "db_path": content["db_path"], 46 | } 47 | ) 48 | return data_to_return 49 | 50 | 51 | class ErrorDatasetGenerator: 52 | @classmethod 53 | def from_existing( 54 | cls, 55 | experiment_name: str, 56 | experiment_folder: Optional[str] = None, 57 | tokenize_model_name_or_path: Optional[str] = None, 58 | hf_token: Optional[str] = None, 59 | ) -> dict: 60 | experiment_folder = Path("./experiments") or Path(experiment_folder) 61 | experiment_path = ( 62 | experiment_folder / "train" / experiment_name / "error_dataset.json" 63 | ) 64 | if not experiment_path.exists(): 65 | raise FileNotFoundError(f"Path {experiment_path} does not exists") 66 | dataset = json.load(open(experiment_path, "r")) 67 | return ( 68 | ErrorDatasetInstance(dataset=dataset) 69 | if not tokenize_model_name_or_path 70 | else SupervisedDatasetForTraining( 71 | dataset=dataset, 72 | model_name_or_path=tokenize_model_name_or_path, 73 | hf_token=hf_token, 74 | ) 75 | ) 76 | 77 | def __init__( 78 | self, 79 | generator: Text2SQLGeneratorBase, 80 | executor: BaseExecutor, 81 | ): 82 | self.generator = generator 83 | self.evaluator = Text2SQLEvaluator( 84 | executor=executor, experiment_path=self.generator.experiment_path 85 | ) 86 | 87 | def generate_and_save( 88 | self, 89 | datasets: Sequence[Text2SQLBaseDataset], 90 | path_to_save: Optional[str] = None, 91 | force: Optional[bool] = False, 92 | tokenize: Optional[bool] = False, 93 | prompt_template: Optional[str] = ERROR_HANDLING_PROMPT, 94 | hf_token: Optional[str] = None, 95 | ) -> None: 96 | 97 | path_to_save = ( 98 | (self.generator.experiment_path / "error_dataset.json") 99 | if path_to_save is None 100 | else Path(path_to_save) 101 | ) 102 | if path_to_save.exists() and force == False: 103 | logger.info("Error dataset already exists") 104 | with open(path_to_save, "r") as json_file: 105 | data_to_return = json.load(json_file) 106 | return data_to_return 107 | 108 | responses = self.generator.generate_and_save_results( 109 | dataset=datasets, temperature=0.1, max_new_tokens=256, force=force 110 | ) 111 | logger.info("Starting Evaluation") 112 | _ = self.evaluator.execute( 113 | metric_name="accuracy", 114 | model_responses=responses, 115 | ) 116 | del responses 117 | 118 | # Now iterate over the error dataset 119 | with open(self.generator.experiment_path / "predict.json", "r") as file: 120 | error_dataset = json.load(file) 121 | 122 | error_instances = ErrorDatasetInstance(dataset=error_dataset).apply_prompt( 123 | prompt_template=prompt_template 124 | ) 125 | 126 | with open(path_to_save, "w") as json_file: 127 | json.dump(error_instances, json_file, indent=4) 128 | 129 | return ( 130 | error_instances 131 | if not tokenize 132 | else SupervisedDatasetForTraining( 133 | dataset=error_instances, 134 | model_name_or_path=self.generator.model_name_or_path, 135 | hf_token=hf_token, 136 | ) 137 | ) -------------------------------------------------------------------------------- /premsql/datasets/real/bird.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from huggingface_hub import snapshot_download 5 | 6 | from premsql.datasets.base import Text2SQLBaseDataset 7 | from premsql.logger import setup_console_logger 8 | 9 | logger = setup_console_logger("[BIRD-DATASET]") 10 | 11 | 12 | class BirdDataset(Text2SQLBaseDataset): 13 | def __init__( 14 | self, 15 | split: str, 16 | dataset_folder: Optional[Union[str, Path]] = "./data", 17 | hf_token: Optional[str] = None, 18 | force_download: Optional[bool] = False, 19 | **kwargs 20 | ): 21 | dataset_folder = Path(dataset_folder) 22 | bird_folder = dataset_folder / "bird" 23 | if not bird_folder.exists() or force_download: 24 | bird_folder.mkdir(parents=True, exist_ok=True) 25 | 26 | # Download it from hf hub 27 | snapshot_download( 28 | repo_id="premai-io/birdbench", 29 | repo_type="dataset", 30 | local_dir=dataset_folder / "bird", 31 | force_download=force_download, 32 | ) 33 | 34 | dataset_path = bird_folder / split 35 | 36 | database_folder_name = kwargs.get("database_folder_name", None) or ( 37 | "train_databases" if split == "train" else "dev_databases" 38 | ) 39 | json_file_name = kwargs.get("json_file_name", None) or ( 40 | "train.json" if split == "train" else "validation.json" 41 | ) 42 | 43 | super().__init__( 44 | split=split, 45 | dataset_path=dataset_path, 46 | database_folder_name=database_folder_name, 47 | json_file_name=json_file_name, 48 | hf_token=hf_token, 49 | ) 50 | logger.info("Loaded Bird Dataset") 51 | 52 | def setup_dataset( 53 | self, 54 | filter_by: tuple | None = None, 55 | num_rows: int | None = None, 56 | num_fewshot: int | None = None, 57 | model_name_or_path: str | None = None, 58 | prompt_template: str | None = None, 59 | tokenize: bool | None = False 60 | ): 61 | logger.info("Setting up Bird Dataset") 62 | return super().setup_dataset( 63 | filter_by=filter_by, 64 | num_rows=num_rows, 65 | model_name_or_path=model_name_or_path, 66 | tokenize=tokenize, 67 | prompt_template=prompt_template, 68 | num_fewshot=num_fewshot 69 | ) -------------------------------------------------------------------------------- /premsql/datasets/real/domains.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from huggingface_hub import snapshot_download 5 | 6 | from premsql.datasets.base import Text2SQLBaseDataset 7 | from premsql.logger import setup_console_logger 8 | 9 | logger = setup_console_logger("[DOMAINS-DATASET]") 10 | 11 | 12 | class DomainsDataset(Text2SQLBaseDataset): 13 | def __init__( 14 | self, 15 | split: str, 16 | dataset_folder: Optional[Union[str, Path]] = "./data", 17 | hf_token: Optional[str] = None, 18 | force_download: Optional[bool] = False, 19 | ): 20 | dataset_folder = Path(dataset_folder) 21 | domains_folder = dataset_folder / "domains" 22 | if not domains_folder.exists() or force_download: 23 | domains_folder.mkdir(parents=True, exist_ok=True) 24 | 25 | # Download it from hf hub 26 | snapshot_download( 27 | repo_id="premai-io/domains", 28 | repo_type="dataset", 29 | local_dir=dataset_folder / "domains", 30 | force_download=force_download, 31 | ) 32 | 33 | assert split in ["train", "validation"], ValueError( 34 | "Split should be either train or validation" 35 | ) 36 | json_file_name = "train.json" if split == "train" else "validation.json" 37 | super().__init__( 38 | split=split, 39 | dataset_path=domains_folder, 40 | database_folder_name="databases", 41 | json_file_name=json_file_name, 42 | hf_token=hf_token, 43 | ) 44 | logger.info("Loaded Domains Dataset") 45 | 46 | # An extra step for Domains Dataset so that it can be 47 | # compatible with the Base dataset and Base instance 48 | 49 | for content in self.dataset: 50 | content["SQL"] = content["query"] 51 | 52 | def setup_dataset( 53 | self, 54 | filter_by: tuple | None = None, 55 | num_rows: int | None = None, 56 | num_fewshot: int | None = None, 57 | model_name_or_path: str | None = None, 58 | prompt_template: str | None = None, 59 | tokenize: bool | None = False 60 | ): 61 | logger.info("Setting up Domains Dataset") 62 | return super().setup_dataset( 63 | filter_by=filter_by, 64 | num_rows=num_rows, 65 | model_name_or_path=model_name_or_path, 66 | tokenize=tokenize, 67 | prompt_template=prompt_template, 68 | num_fewshot=num_fewshot 69 | ) -------------------------------------------------------------------------------- /premsql/datasets/real/spider.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from huggingface_hub import snapshot_download 5 | 6 | from premsql.datasets.base import Text2SQLBaseDataset 7 | from premsql.logger import setup_console_logger 8 | 9 | logger = setup_console_logger("[SPIDER-DATASET]") 10 | 11 | 12 | class SpiderUnifiedDataset(Text2SQLBaseDataset): 13 | def __init__( 14 | self, 15 | split: str, 16 | dataset_folder: Optional[Union[str, Path]] = "./data", 17 | hf_token: Optional[str] = None, 18 | force_download: Optional[bool] = False, 19 | ): 20 | dataset_folder = Path(dataset_folder) 21 | spider_folder = dataset_folder / "spider" 22 | if not spider_folder.exists() or force_download: 23 | spider_folder.mkdir(parents=True, exist_ok=True) 24 | 25 | # Download it from hf hub 26 | snapshot_download( 27 | repo_id="premai-io/spider", 28 | repo_type="dataset", 29 | local_dir=dataset_folder / "spider", 30 | force_download=force_download, 31 | ) 32 | 33 | assert split in ["train", "validation"], ValueError( 34 | "Split should be either train or validation" 35 | ) 36 | json_file_name = "train.json" if split == "train" else "validation.json" 37 | super().__init__( 38 | split=split, 39 | dataset_path=spider_folder, 40 | database_folder_name="database", 41 | json_file_name=json_file_name, 42 | hf_token=hf_token, 43 | ) 44 | logger.info("Loaded Spider Dataset") 45 | 46 | # An extra step for Spider Dataset so that it can be 47 | # compatible with the Base dataset and Base instance 48 | 49 | for content in self.dataset: 50 | content["SQL"] = content["query"] 51 | 52 | def setup_dataset( 53 | self, 54 | filter_by: tuple | None = None, 55 | num_rows: int | None = None, 56 | num_fewshot: int | None = None, 57 | model_name_or_path: str | None = None, 58 | prompt_template: str | None = None, 59 | tokenize: bool | None = False 60 | ): 61 | logger.info("Setting up Spider Dataset") 62 | return super().setup_dataset( 63 | filter_by=filter_by, 64 | num_rows=num_rows, 65 | model_name_or_path=model_name_or_path, 66 | tokenize=tokenize, 67 | prompt_template=prompt_template, 68 | num_fewshot=num_fewshot 69 | ) -------------------------------------------------------------------------------- /premsql/datasets/synthetic/gretel.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from datasets import load_dataset 5 | from tqdm.auto import tqdm 6 | 7 | from premsql.datasets.base import ( 8 | SupervisedDatasetForTraining, 9 | Text2SQLBaseDataset, 10 | Text2SQLBaseInstance, 11 | ) 12 | from premsql.logger import setup_console_logger 13 | from premsql.prompts import BASE_TEXT2SQL_PROMPT 14 | from premsql.utils import filter_options, save_to_json 15 | 16 | logger = setup_console_logger("[GRETELAI-DATASET]") 17 | 18 | 19 | class GretelAIInstance(Text2SQLBaseInstance): 20 | def __init__(self, dataset: list[dict]) -> None: 21 | super().__init__(dataset) 22 | 23 | def apply_prompt( 24 | self, 25 | num_fewshot: Optional[int] = None, 26 | prompt_template: Optional[str] = BASE_TEXT2SQL_PROMPT, 27 | ): 28 | prompt_template = ( 29 | BASE_TEXT2SQL_PROMPT if prompt_template is None else prompt_template 30 | ) 31 | for blob in tqdm(self.dataset, total=len(self.dataset), desc="Applying prompt"): 32 | few_shot_prompt = ( 33 | "" 34 | if num_fewshot is None 35 | else self.add_few_shot_examples(db_id=blob["db_id"], k=num_fewshot) 36 | ) 37 | final_prompt = prompt_template.format( 38 | schemas=blob["context"], 39 | additional_knowledge="", 40 | few_shot_examples=few_shot_prompt, 41 | question=blob["question"], 42 | ) 43 | blob["prompt"] = final_prompt 44 | return self.dataset 45 | 46 | 47 | class GretelAIDataset(Text2SQLBaseDataset): 48 | def __init__( 49 | self, 50 | split: Optional[str] = "train", 51 | dataset_folder: Optional[Union[str, Path]] = "./data", 52 | hf_token: Optional[str] = None, 53 | force_download: Optional[bool] = False, 54 | ): 55 | dataset_folder = Path(dataset_folder) 56 | dataset_path = dataset_folder / "gretel" 57 | if not dataset_path.exists() or force_download: 58 | dataset_path.mkdir(parents=True, exist_ok=True) 59 | dataset = [] 60 | raw_dataset = load_dataset("gretelai/synthetic_text_to_sql", token=hf_token) 61 | 62 | for split in ["train", "test"]: 63 | for content in raw_dataset[split]: 64 | blob_content = { 65 | "id": content["id"], 66 | "question": content["sql_prompt"], 67 | "schema": content["sql_context"], 68 | "SQL": content["sql"], 69 | "context": content["sql_context"], 70 | "task_type": content["sql_task_type"], 71 | "complexity": content["sql_complexity"], 72 | "db_id": content["domain"], 73 | "db_path": None, 74 | } 75 | dataset.append(blob_content) 76 | 77 | save_to_json(save_path=dataset_path / "train.json", json_object=dataset) 78 | 79 | super().__init__( 80 | split="train", 81 | dataset_path=dataset_path, 82 | database_folder_name=None, 83 | json_file_name="train.json", 84 | ) 85 | 86 | def setup_dataset( 87 | self, 88 | filter_by: Optional[tuple] = None, 89 | num_rows: Optional[int] = None, 90 | num_fewshot: Optional[int] = None, 91 | model_name_or_path: Optional[str] = None, 92 | prompt_template: Optional[str] = BASE_TEXT2SQL_PROMPT, 93 | ): 94 | if filter_by: 95 | self.dataset = filter_options(data=self.dataset, filter_by=filter_by) 96 | 97 | if num_rows: 98 | self.dataset = self.dataset[:num_rows] 99 | 100 | self.dataset = GretelAIInstance(dataset=self.dataset).apply_prompt( 101 | num_fewshot=num_fewshot, prompt_template=prompt_template 102 | ) 103 | 104 | return SupervisedDatasetForTraining( 105 | dataset=self.dataset, 106 | model_name_or_path=model_name_or_path, 107 | hf_token=self.hf_token, 108 | ) 109 | -------------------------------------------------------------------------------- /premsql/evaluator/README.md: -------------------------------------------------------------------------------- 1 | ## Evaluators 2 | 3 | premsql evaluators help you to evaluate your text-to-sql models on various validation datasets. 4 | Currently, we support two metrics for evaluation: 5 | 6 | - Execution Accuracy 7 | - Valid Efficiency Score 8 | 9 | **Execution Accuracy (EX):** From the name, it is clear that the correctness of the LLM is measured by comparing the executed results from 10 | the LLM with the ground truth. 11 | 12 | **Valid Efficiency Score (VES):** The primary objective of LLM-generated SQL queries is to be accurate. 13 | However, it also needs to be performance-optimized when dealing with big data. This metric asses both of the 14 | objectives. It quantifies how efficient the query is and whether the query is accurate or not. The figure below 15 | shows how it is computed. 16 | 17 | Here is a quick start on how to use evaluators using premsql 18 | 19 | ```python 20 | import json 21 | from pathlib import Path 22 | from premsql.datasets import Text2SQLDataset 23 | from premsql.generators.premai import Text2SQLGeneratorPremAI 24 | from premsql.evaluator import Text2SQLEvaluator, SQLiteExecutor 25 | 26 | # Get the validation dataset 27 | 28 | dataset = Text2SQLDataset( 29 | dataset_name="bird", 30 | split="test", 31 | database_folder_name="test_databases", 32 | json_file_name="test.json", 33 | dataset_folder="/root/anindya/Submission/text2sql/data", 34 | ).setup_dataset( 35 | num_rows=10, 36 | num_fewshot=3, 37 | ) 38 | 39 | generator = Text2SQLGeneratorPremAI( 40 | model_name="gpt-4o", 41 | project_id=1234, 42 | premai_api_key="FK-xxxx-xxx-xxx", 43 | experiment_name="test_generators", 44 | device="cuda:0", 45 | type="test" 46 | ) 47 | 48 | executor = SQLiteExecutor() 49 | evaluator = Text2SQLEvaluator( 50 | executor=executor, experiment_path=experiment_path 51 | ) 52 | 53 | # Calculate Execution Accuracy 54 | ex = evaluator.execute( 55 | metric_name="accuracy", 56 | model_responses=responses, 57 | filter_by="difficulty" 58 | ) 59 | 60 | # Similarity calculate Valid Efficiency Score 61 | 62 | ves = evaluator.execute( 63 | metric_name="ves", 64 | model_responses=responses, 65 | filter_by="difficulty" 66 | ) 67 | ``` 68 | 69 | **Output** 70 | 71 | Here is the output of execution accuracy of different models. 72 | 73 | ``` 74 | Accuracy: 75 | --------- 76 | +-------------+-------------------+-------------------+ 77 | | Category | num_correct (%) | total questions | 78 | +=============+===================+===================+ 79 | | simple | 58.4865 | 925 | 80 | +-------------+-------------------+-------------------+ 81 | | moderate | 43.75 | 464 | 82 | +-------------+-------------------+-------------------+ 83 | | challenging | 42.7586 | 145 | 84 | +-------------+-------------------+-------------------+ 85 | | overall | 52.5424 | 1534 | 86 | +-------------+-------------------+-------------------+ 87 | 88 | Valid Efficiency Score (VES): 89 | ----------------------------- 90 | 91 | +-------------+-----------+-------------------+ 92 | | Category | VES (%) | total questions | 93 | +=============+===========+===================+ 94 | | simple | 60.1844 | 925 | 95 | +-------------+-----------+-------------------+ 96 | | moderate | 46.4345 | 464 | 97 | +-------------+-----------+-------------------+ 98 | | challenging | 43.9845 | 145 | 99 | +-------------+-----------+-------------------+ 100 | | overall | 54.4941 | 1534 | 101 | +-------------+-----------+-------------------+ 102 | ``` 103 | 104 | We have also benchmarked several closed and open-source models. Here are some results for the following models: 105 | 106 | - gpt-4o 107 | - gpt-4o-mini 108 | - claude-3.5-sonnet 109 | - codellama-70b-instruct 110 | - claude-3-opus 111 | - llama-3.1-405-instruct 112 | 113 | **Accuracy** 114 | 115 | ![accuracy comparison](/assets/Model-Accuracy-Comparison.png) 116 | 117 | **Valid Efficiency Score** 118 | 119 | ![ves comparison](/assets/Models-VES-Comparison.png) 120 | 121 | We have also made a detailed blog about this. If you are more interested in the analysis, you can check out the [blog post here](https://blog.premai.io/text2sql-eval). 122 | -------------------------------------------------------------------------------- /premsql/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.evaluator.base import Text2SQLEvaluator 2 | 3 | __all__ = ["Text2SQLEvaluator"] 4 | -------------------------------------------------------------------------------- /premsql/evaluator/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import traceback 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | 6 | from func_timeout import FunctionTimedOut, func_timeout 7 | from tqdm.auto import tqdm 8 | 9 | from premsql.executors.base import BaseExecutor 10 | from premsql.utils import save_to_json 11 | 12 | 13 | class Text2SQLEvaluator: 14 | def __init__( 15 | self, executor: BaseExecutor, experiment_path: Union[str, Path] 16 | ) -> None: 17 | self.executor = executor 18 | self.experiment_path = Path(experiment_path) 19 | 20 | def _execute_model( 21 | self, 22 | metric_name: str, 23 | generated_sql: str, 24 | gold_sql: str, 25 | dsn_or_db_path: str, 26 | meta_time_out: Optional[int] = 1000, 27 | num_iterations: Optional[int] = None, 28 | debug: Optional[bool] = False, 29 | ): 30 | assert metric_name in ["accuracy", "ves"], "Invalid metric name" 31 | try: 32 | if metric_name == "accuracy": 33 | result = func_timeout( 34 | meta_time_out, 35 | self.executor.match_sqls, 36 | args=(generated_sql, gold_sql, dsn_or_db_path), 37 | ) 38 | elif metric_name == "ves": 39 | num_iterations = 10 if num_iterations is None else num_iterations 40 | result = func_timeout( 41 | meta_time_out, 42 | self.executor.iterated_execution, 43 | args=(generated_sql, gold_sql, dsn_or_db_path, num_iterations), 44 | ) 45 | else: 46 | raise ValueError(f"Invalid metric name: {metric_name}") 47 | 48 | return { 49 | metric_name: result["result"], 50 | "error": result["error"], 51 | } 52 | except FunctionTimedOut as e: 53 | return { 54 | metric_name: 0, 55 | "error": f"Function Timed out: {e}", 56 | } 57 | except Exception as e: 58 | if debug: 59 | traceback.print_exc() 60 | 61 | return { 62 | metric_name: 0, 63 | "error": f"Exception: {e}", 64 | } 65 | 66 | def execute( 67 | self, 68 | metric_name: str, 69 | model_responses: list[dict], 70 | filter_by: Optional[str] = None, 71 | num_iterations: Optional[int] = 10, 72 | meta_time_out: Optional[int] = 10, # change it later to 1000 73 | debug: Optional[bool] = False, 74 | ) -> dict: 75 | data_with_results = [] 76 | 77 | for response in tqdm(model_responses, total=len(model_responses)): 78 | result = self._execute_model( 79 | metric_name=metric_name, 80 | generated_sql=response["generated"], 81 | gold_sql=response["SQL"], 82 | dsn_or_db_path=response["db_path"], 83 | num_iterations=num_iterations, 84 | meta_time_out=meta_time_out, 85 | debug=debug, 86 | ) 87 | data_with_results.append({**response, **result}) 88 | 89 | execution_result = {} 90 | if filter_by: 91 | if filter_by not in data_with_results[0]: 92 | raise KeyError(f"Filter key: {filter_by} is not found in responses") 93 | 94 | filter_values = {response[filter_by] for response in data_with_results} 95 | total_responses = len(data_with_results) 96 | overall_metric = 0.0 97 | 98 | for value in filter_values: 99 | filtered_responses = [ 100 | response 101 | for response in data_with_results 102 | if response[filter_by] == value 103 | ] 104 | metric_value = self.compute_metric( 105 | results=filtered_responses, metric_name=metric_name 106 | ) 107 | execution_result[value] = metric_value 108 | overall_metric += ( 109 | metric_value * len(filtered_responses) / total_responses 110 | ) 111 | 112 | execution_result["overall"] = overall_metric 113 | else: 114 | execution_result["overall"] = self.compute_metric( 115 | results=data_with_results, metric_name=metric_name 116 | ) 117 | 118 | save_to_json( 119 | json_object=execution_result, 120 | save_path=self.experiment_path / f"{metric_name}.json", 121 | ) 122 | 123 | # also save the data_with_results 124 | save_to_json( 125 | json_object=data_with_results, 126 | save_path=self.experiment_path / "predict.json", 127 | ) 128 | return execution_result 129 | 130 | def compute_metric(self, results: list[dict], metric_name: str) -> float: 131 | if metric_name == "accuracy": 132 | return sum(res["accuracy"] for res in results) / len(results) * 100 133 | 134 | elif metric_name == "ves": 135 | num_queries = len(results) 136 | total_ratio = 0.0 137 | for result in results: 138 | total_ratio += math.sqrt(result["ves"]) * 100 139 | ves = total_ratio / num_queries 140 | return ves 141 | 142 | else: 143 | raise ValueError(f"Invalid metric name: {metric_name}") 144 | -------------------------------------------------------------------------------- /premsql/executors/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.executors.from_langchain import ExecutorUsingLangChain 2 | from premsql.executors.from_sqlite import SQLiteExecutor, OptimizedSQLiteExecutor 3 | 4 | __all__ = ["ExecutorUsingLangChain", "SQLiteExecutor", "OptimizedSQLiteExecutor"] 5 | -------------------------------------------------------------------------------- /premsql/executors/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | 5 | 6 | class BaseExecutor(ABC): 7 | @abstractmethod 8 | def execute_sql(self, sql: str, dsn_or_db_path: str) -> dict: 9 | return {"result": None, "execution_time": None, "error": None} 10 | 11 | def match_sqls( 12 | self, predicted_sql: str, gold_sql: str, dsn_or_db_path: str 13 | ) -> bool: 14 | prediction = self.execute_sql(sql=predicted_sql, dsn_or_db_path=dsn_or_db_path) 15 | gold = self.execute_sql(sql=gold_sql, dsn_or_db_path=dsn_or_db_path) 16 | if prediction["error"]: 17 | return { 18 | "result": 0, 19 | "error": prediction["error"], 20 | } 21 | 22 | is_match = set(prediction["result"]) == set(gold["result"]) 23 | return { 24 | "result": int(is_match), 25 | "error": None if is_match else "Table mismatch", 26 | } 27 | 28 | def clean_abnormal(self, input: list[float]) -> list[float]: 29 | input_array = np.asarray(input) 30 | mean = np.mean(input_array) 31 | std = np.std(input_array) 32 | return [x for x in input_array if mean - 3 * std < x < mean + 3 * std] 33 | 34 | def iterated_execution( 35 | self, 36 | predicted_sql: str, 37 | gold_sql: str, 38 | dsn_or_db_path: str, 39 | num_iterations: int, 40 | ) -> dict: 41 | is_match = self.match_sqls( 42 | predicted_sql=predicted_sql, 43 | gold_sql=gold_sql, 44 | dsn_or_db_path=dsn_or_db_path, 45 | ) 46 | 47 | if is_match["result"] == 1: 48 | diff_list = [ 49 | self.execute_sql(sql=gold_sql, dsn_or_db_path=dsn_or_db_path)[ 50 | "execution_time" 51 | ] 52 | / self.execute_sql(sql=gold_sql, dsn_or_db_path=dsn_or_db_path)[ 53 | "execution_time" 54 | ] 55 | for _ in range(num_iterations) 56 | ] 57 | processed_diff_list = self.clean_abnormal(diff_list) 58 | return { 59 | "result": sum(processed_diff_list) / len(processed_diff_list), 60 | "error": None, 61 | } 62 | return {"result": 0, "error": is_match["error"]} 63 | -------------------------------------------------------------------------------- /premsql/executors/from_langchain.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Union 3 | 4 | from langchain_community.utilities.sql_database import SQLDatabase 5 | 6 | from premsql.executors.base import BaseExecutor 7 | from premsql.utils import convert_sqlite_path_to_dsn 8 | 9 | 10 | class ExecutorUsingLangChain(BaseExecutor): 11 | 12 | def execute_sql(self, sql: str, dsn_or_db_path: Union[str, SQLDatabase]) -> dict: 13 | if isinstance(dsn_or_db_path, str): 14 | if dsn_or_db_path.endswith("sqlite"): 15 | dsn_or_db_path = convert_sqlite_path_to_dsn(path=dsn_or_db_path) 16 | db = SQLDatabase.from_uri(dsn_or_db_path) 17 | else: 18 | db = dsn_or_db_path 19 | 20 | start_time = time.time() 21 | response = db.run_no_throw(sql) 22 | end_time = time.time() 23 | 24 | error = response if response.startswith("Error") else None 25 | return { 26 | "result": None if error else response, 27 | "error": error, 28 | "execution_time": end_time - start_time, 29 | } 30 | -------------------------------------------------------------------------------- /premsql/executors/from_sqlite.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import time 3 | 4 | from contextlib import contextmanager 5 | from typing import Any, Dict, Generator 6 | from premsql.executors.base import BaseExecutor 7 | from premsql.logger import setup_console_logger 8 | 9 | 10 | class OptimizedSQLiteExecutor(BaseExecutor): 11 | def __init__(self, timeout: float = 1000.0) -> None: 12 | self.timeout = timeout 13 | self.logger = setup_console_logger(name="[OPTIMIZED-SQLite-EXEC]") 14 | 15 | @contextmanager 16 | def get_connection(self, db_path: str) -> Generator[sqlite3.Connection, None, None]: 17 | if db_path.startswith("sqlite:///"): 18 | db_path = db_path.split("sqlite:///")[1] 19 | 20 | conn = sqlite3.connect(db_path, timeout=self.timeout) 21 | conn.execute("PRAGMA journal_mode = WAL") 22 | conn.execute("PRAGMA synchronous = NORMAL") 23 | conn.execute("PRAGMA cache_size = -64000") # 64MB cache 24 | conn.execute("PRAGMA temp_store = MEMORY") 25 | conn.row_factory = sqlite3.Row 26 | try: 27 | yield conn 28 | finally: 29 | conn.close() 30 | 31 | def execute_sql(self, sql: str, dsn_or_db_path: str) -> Dict[str, Any]: 32 | start_time = time.time() 33 | try: 34 | with self.get_connection(dsn_or_db_path) as conn: 35 | cursor = conn.cursor() 36 | cursor.execute("EXPLAIN QUERY PLAN " + sql) 37 | query_plan = cursor.fetchall() 38 | 39 | if any("SCAN TABLE" in str(row) for row in query_plan): 40 | self.logger.warn("Warning: Full table scan detected. Consider adding an index.") 41 | 42 | cursor.execute(sql) 43 | result = [dict(row) for row in cursor.fetchall()] 44 | error = None 45 | except sqlite3.Error as e: 46 | result = None 47 | error = str(e) 48 | finally: 49 | end_time = time.time() 50 | 51 | return { 52 | "result": result, 53 | "error": error, 54 | "execution_time": end_time - start_time, 55 | } 56 | 57 | def match_sqls(self, predicted_sql: str, gold_sql: str, dsn_or_db_path: str) -> Dict[str, Any]: 58 | with self.get_connection(dsn_or_db_path) as conn: 59 | prediction = self.execute_sql(predicted_sql, dsn_or_db_path) 60 | gold = self.execute_sql(gold_sql, dsn_or_db_path) 61 | 62 | if prediction["error"]: 63 | return {"result": 0, "error": prediction["error"]} 64 | 65 | is_match = set(map(tuple, prediction["result"])) == set(map(tuple, gold["result"])) 66 | return { 67 | "result": int(is_match), 68 | "error": None if is_match else "Table mismatch", 69 | } 70 | 71 | def iterated_execution(self, predicted_sql: str, gold_sql: str, dsn_or_db_path: str, num_iterations: int) -> Dict[str, Any]: 72 | is_match = self.match_sqls(predicted_sql, gold_sql, dsn_or_db_path) 73 | 74 | if is_match["result"] == 1: 75 | with self.get_connection(dsn_or_db_path) as conn: 76 | diff_list = [] 77 | for _ in range(num_iterations): 78 | gold_time = self.execute_sql(gold_sql, dsn_or_db_path)["execution_time"] 79 | predicted_time = self.execute_sql(predicted_sql, dsn_or_db_path)["execution_time"] 80 | diff_list.append(predicted_time / gold_time if gold_time > 0 else float('inf')) 81 | 82 | processed_diff_list = self.clean_abnormal(diff_list) 83 | return { 84 | "result": sum(processed_diff_list) / len(processed_diff_list) if processed_diff_list else 0, 85 | "error": None, 86 | } 87 | return {"result": 0, "error": is_match["error"]} 88 | 89 | 90 | 91 | class SQLiteExecutor(BaseExecutor): 92 | def execute_sql(self, sql: str, dsn_or_db_path: str) -> dict: 93 | if dsn_or_db_path.startswith("sqlite:///"): 94 | dsn_or_db_path = dsn_or_db_path.split("sqlite:///")[1] 95 | 96 | conn = sqlite3.connect(dsn_or_db_path) 97 | cursor = conn.cursor() 98 | 99 | start_time = time.time() 100 | try: 101 | cursor.execute(sql) 102 | result = cursor.fetchall() 103 | error = None 104 | except Exception as e: 105 | result = None 106 | error = str(e) 107 | 108 | end_time = time.time() 109 | cursor.close() 110 | conn.close() 111 | 112 | result = { 113 | "result": result, 114 | "error": error, 115 | "execution_time": end_time - start_time, 116 | } 117 | return result -------------------------------------------------------------------------------- /premsql/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.generators.huggingface import Text2SQLGeneratorHF 2 | from premsql.generators.openai import Text2SQLGeneratorOpenAI 3 | from premsql.generators.premai import Text2SQLGeneratorPremAI 4 | from premsql.generators.mlx import Text2SQLGeneratorMLX 5 | from premsql.generators.ollama_model import Text2SQLGeneratorOllama 6 | 7 | __all__ = ["Text2SQLGeneratorHF", "Text2SQLGeneratorPremAI", "Text2SQLGeneratorOpenAI", "Text2SQLGeneratorMLX", "Text2SQLGeneratorOllama"] 8 | -------------------------------------------------------------------------------- /premsql/generators/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from abc import ABC, abstractmethod 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import sqlparse 8 | from tqdm.auto import tqdm 9 | from platformdirs import user_cache_dir 10 | 11 | from premsql.evaluator.base import BaseExecutor 12 | from premsql.logger import setup_console_logger 13 | from premsql.prompts import ERROR_HANDLING_PROMPT 14 | 15 | logger = setup_console_logger(name="[GENERATOR]") 16 | 17 | 18 | class Text2SQLGeneratorBase(ABC): 19 | def __init__( 20 | self, experiment_name: str, type: str, experiment_folder: Optional[str] = None 21 | ): 22 | self.experiment_folder = ( 23 | Path(experiment_folder) 24 | if experiment_folder is not None 25 | else Path(user_cache_dir()) / "premsql" / "experiments" 26 | ) 27 | 28 | self.experiment_path = Path(self.experiment_folder) / type / experiment_name 29 | if not self.experiment_path.exists(): 30 | self.experiment_path.mkdir(parents=True, exist_ok=True) 31 | logger.info(f"Created new experiment folder: {self.experiment_path}") 32 | else: 33 | logger.info(f"Experiment folder found in: {self.experiment_path}") 34 | 35 | self.client = self.load_client 36 | self.tokenizer = self.load_tokenizer 37 | 38 | @property 39 | @abstractmethod 40 | def load_client(self): 41 | return NotImplementedError 42 | 43 | @property 44 | @abstractmethod 45 | def load_tokenizer(self): 46 | return NotImplementedError 47 | 48 | @property 49 | @abstractmethod 50 | def model_name_or_path(self): 51 | pass 52 | 53 | @abstractmethod 54 | def generate( 55 | self, 56 | data_blob: dict, 57 | temperature: Optional[float] = 0.0, 58 | max_new_tokens: Optional[int] = 256, 59 | postprocess: Optional[bool] = True, 60 | **kwargs, 61 | ) -> str: 62 | raise NotImplementedError 63 | 64 | def execution_guided_decoding( 65 | self, 66 | data_blob: dict, 67 | executor: BaseExecutor, 68 | temperature: Optional[float] = 0.0, 69 | max_new_tokens: Optional[int] = 256, 70 | max_retries: Optional[int] = 5, 71 | postprocess: Optional[bool] = True, 72 | **kwargs, 73 | ): 74 | error_already_found = False 75 | for _ in range(max_retries): 76 | sql = self.generate( 77 | data_blob=data_blob, 78 | temperature=temperature, 79 | max_new_tokens=max_new_tokens, 80 | postprocess=postprocess, 81 | **kwargs, 82 | ) 83 | error = executor.execute_sql(sql=sql, dsn_or_db_path=data_blob["db_path"])[ 84 | "error" 85 | ] 86 | if not error: 87 | return sql 88 | 89 | if not error_already_found: 90 | prompt = data_blob["prompt"].split("# SQL:")[0].strip() 91 | error_prompt = ERROR_HANDLING_PROMPT.format( 92 | existing_prompt=prompt, sql=sql, error_msg=error 93 | ) 94 | data_blob["prompt"] = error_prompt 95 | error_already_found = True 96 | return sql 97 | 98 | def postprocess(self, output_string: str): 99 | sql_start_keywords = [ 100 | r"\bSELECT\b", 101 | r"\bINSERT\b", 102 | r"\bUPDATE\b", 103 | r"\bDELETE\b", 104 | r"\bWITH\b", 105 | ] 106 | 107 | sql_start_pattern = re.compile("|".join(sql_start_keywords), re.IGNORECASE) 108 | match = sql_start_pattern.search(output_string) 109 | if match: 110 | start_pos = match.start() 111 | sql_statement = output_string[start_pos:] 112 | else: 113 | sql_statement = output_string 114 | 115 | return sqlparse.format(sql_statement.split("# SQL:")[-1].strip()) 116 | 117 | def load_results_from_folder(self): 118 | item_names = [item.name for item in self.experiment_path.iterdir()] 119 | 120 | if self.experiment_path.exists() and "predict.json" in item_names: 121 | return json.load(open(self.experiment_path / "predict.json", "r")) 122 | return None 123 | 124 | def generate_and_save_results( 125 | self, 126 | dataset: list[dict], 127 | temperature: Optional[float] = 0.0, 128 | max_new_tokens: Optional[int] = 256, 129 | force: Optional[bool] = False, 130 | postprocess: Optional[bool] = False, 131 | executor: Optional[BaseExecutor] = None, 132 | max_retries: Optional[int] = 5, 133 | **kwargs, 134 | ) -> dict: 135 | 136 | existing_response = self.load_results_from_folder() 137 | if existing_response is not None and force == False: 138 | logger.info("Already results found") 139 | return existing_response 140 | 141 | to_dump = [] 142 | for content in tqdm(dataset, total=len(dataset), desc="Generating result ..."): 143 | sql = ( 144 | self.execution_guided_decoding( 145 | data_blob=content, 146 | executor=executor, 147 | temperature=temperature, 148 | postprocess=postprocess, 149 | max_new_tokens=max_new_tokens, 150 | max_retries=max_retries, 151 | **kwargs, 152 | ) 153 | if executor is not None 154 | else self.generate( 155 | data_blob=content, 156 | temperature=temperature, 157 | max_new_tokens=max_new_tokens, 158 | postprocess=postprocess, 159 | **kwargs, 160 | ) 161 | ) 162 | 163 | to_dump.append({**content, "generated": sql}) 164 | 165 | json.dump(to_dump, open(self.experiment_path / "predict.json", "w"), indent=4) 166 | logger.info(f"All responses are written to: {self.experiment_path}") 167 | return to_dump 168 | -------------------------------------------------------------------------------- /premsql/generators/huggingface.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Union 3 | 4 | from premsql.generators.base import Text2SQLGeneratorBase 5 | from premsql.logger import setup_console_logger 6 | 7 | logger = setup_console_logger(name="[HF-GENERATOR]") 8 | 9 | try: 10 | import torch 11 | import transformers 12 | except ImportError: 13 | logger.warn("Ensure torch and transformers are installed.") 14 | logger.warn("Install them by: pip install torch transformers") 15 | 16 | class Text2SQLGeneratorHF(Text2SQLGeneratorBase): 17 | def __init__( 18 | self, 19 | model_or_name_or_path: Union[str, "transformers.PreTrainedModel"], 20 | experiment_name: str, 21 | type: str, 22 | experiment_folder: Optional[str] = None, 23 | hf_token: Optional[str] = None, 24 | device: Optional[str] = None, 25 | **kwargs 26 | ): 27 | self.hf_api_key = os.environ.get("HF_TOKEN") or hf_token 28 | self._kwargs = kwargs 29 | self.device = ( 30 | device 31 | if device is not None 32 | else ("cuda:0" if torch.cuda.is_available() else "cpu") 33 | ) 34 | self.model_or_name_or_path = model_or_name_or_path 35 | super().__init__( 36 | experiment_name=experiment_name, 37 | experiment_folder=experiment_folder, 38 | type=type, 39 | ) 40 | 41 | @property 42 | def load_client(self) -> "transformers.PreTrainedModel": 43 | if isinstance(self.model_or_name_or_path, str): 44 | return transformers.AutoModelForCausalLM.from_pretrained( 45 | pretrained_model_name_or_path=self.model_or_name_or_path, 46 | token=self.hf_api_key, 47 | **{ 48 | "device_map": self.device, 49 | "torch_dtype": torch.float16, 50 | **self._kwargs, 51 | } 52 | ) 53 | return self.model_or_name_or_path 54 | 55 | @property 56 | def load_tokenizer(self) -> "transformers.PreTrainedTokenizer": 57 | tokenizer = transformers.AutoTokenizer.from_pretrained( 58 | pretrained_model_name_or_path=self.client.config.name_or_path, 59 | token=self.hf_api_key, 60 | padding_side="right", 61 | ) 62 | tokenizer.pad_token = tokenizer.eos_token 63 | return tokenizer 64 | 65 | @property 66 | def model_name_or_path(self): 67 | return self.model_or_name_or_path 68 | 69 | def generate( 70 | self, 71 | data_blob: dict, 72 | temperature: Optional[float] = 0.0, 73 | max_new_tokens: Optional[int] = 256, 74 | postprocess: Optional[bool] = True, 75 | **kwargs 76 | ) -> str: 77 | 78 | prompt = data_blob["prompt"] 79 | input_ids = self.tokenizer.encode( 80 | text=prompt, 81 | return_tensors="pt", 82 | padding="longest", 83 | max_length=self.tokenizer.model_max_length, 84 | truncation=False, 85 | ).to(self.device) 86 | 87 | do_sample = False if temperature == 0.0 else True 88 | generation_config = transformers.GenerationConfig( 89 | **{**kwargs, "temperature": temperature, "max_new_tokens": max_new_tokens} 90 | ) 91 | output_tokens = ( 92 | self.client.generate( 93 | input_ids=input_ids, 94 | do_sample=do_sample, 95 | generation_config=generation_config, 96 | pad_token_id=self.tokenizer.eos_token_id, 97 | ) 98 | .detach() 99 | .tolist()[0] 100 | ) 101 | output_tokens = ( 102 | output_tokens[len(input_ids[0]) :] 103 | if len(output_tokens) > len(input_ids[0]) 104 | else output_tokens 105 | ) 106 | generated = self.tokenizer.decode(output_tokens, skip_special_tokens=True) 107 | return self.postprocess(output_string=generated) if postprocess else generated 108 | -------------------------------------------------------------------------------- /premsql/generators/mlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from premsql.generators.base import Text2SQLGeneratorBase 5 | from premsql.logger import setup_console_logger 6 | 7 | logger = setup_console_logger(name="[MLX-GENERATOR]") 8 | 9 | try: 10 | from mlx_lm import generate 11 | from mlx_lm.tokenizer_utils import load_tokenizer 12 | from mlx_lm.utils import get_model_path, load_model 13 | except ImportError as e: 14 | logger.error("Install mlx using: pip install mlx mlx-lm") 15 | 16 | 17 | 18 | class Text2SQLGeneratorMLX(Text2SQLGeneratorBase): 19 | def __init__( 20 | self, 21 | model_name_or_path: str, 22 | experiment_name: str, 23 | type: str, 24 | experiment_folder: Optional[str] = None, 25 | hf_token: Optional[str] = None, 26 | **kwargs 27 | ): 28 | self.hf_api_key = os.environ.get("HF_TOKEN") or hf_token 29 | self._kwargs = kwargs 30 | self.mlx_model_name_or_path = model_name_or_path 31 | super().__init__( 32 | experiment_name=experiment_name, 33 | experiment_folder=experiment_folder, 34 | type=type, 35 | ) 36 | 37 | @property 38 | def load_client(self): 39 | model_path = get_model_path(self.model_name_or_path) 40 | model = load_model(model_path, **self._kwargs) 41 | return model 42 | 43 | @property 44 | def load_tokenizer(self): 45 | model_path = get_model_path(self.model_name_or_path) 46 | return load_tokenizer(model_path, **self._kwargs) 47 | 48 | @property 49 | def model_name_or_path(self): 50 | return self.mlx_model_name_or_path 51 | 52 | def generate( 53 | self, 54 | data_blob: dict, 55 | temperature: Optional[float] = 0.0, 56 | max_new_tokens: Optional[int] = 256, 57 | postprocess: Optional[bool] = True, 58 | **kwargs 59 | ) -> str: 60 | prompt = data_blob["prompt"] 61 | temp = temperature 62 | generation_args = {"temp": temp, **kwargs} 63 | output = generate( 64 | model=self.client, 65 | tokenizer=self.tokenizer, 66 | prompt=prompt, 67 | max_tokens=max_new_tokens, 68 | **generation_args 69 | ) 70 | return self.postprocess(output) if postprocess else output 71 | -------------------------------------------------------------------------------- /premsql/generators/ollama_model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from premsql.generators.base import Text2SQLGeneratorBase 4 | from premsql.logger import setup_console_logger 5 | 6 | logger = setup_console_logger(name="[OLLAMA-GENERATOR]") 7 | 8 | try: 9 | from ollama import Client 10 | except ImportError: 11 | logger.warn("Ensure ollama is installed") 12 | logger.warn("Install Ollama: curl -fsSL https://ollama.com/install.sh | sh") 13 | logger.warn("Install Ollama python: pip install ollama") 14 | 15 | 16 | class Text2SQLGeneratorOllama(Text2SQLGeneratorBase): 17 | def __init__( 18 | self, 19 | model_name: str, 20 | experiment_name: str, 21 | type: str, 22 | experiment_folder: Optional[str]=None, 23 | **kwargs 24 | ): 25 | self._kwargs = kwargs 26 | self.model_name = model_name 27 | super().__init__( 28 | experiment_name=experiment_name, 29 | experiment_folder=experiment_folder, 30 | type=type 31 | ) 32 | 33 | @property 34 | def load_client(self): 35 | return Client(host='http://localhost:11434') 36 | 37 | @property 38 | def load_tokenizer(self): 39 | pass 40 | 41 | @property 42 | def model_name_or_path(self): 43 | return self.model_name 44 | 45 | def generate( 46 | self, 47 | data_blob: dict, 48 | temperature: Optional[float] = 0.0, 49 | max_new_tokens: Optional[int] = 256, 50 | postprocess: Optional[bool] = True, 51 | **kwargs 52 | ) -> str: 53 | prompt = data_blob["prompt"] 54 | response = self.load_client.chat( 55 | model=self.model_name_or_path, 56 | messages=[{"role":"user", "content":prompt}], 57 | options=dict( 58 | temperature=temperature, 59 | num_ctx=2048 + max_new_tokens 60 | ) 61 | )["message"]["content"] 62 | return self.postprocess(output_string=response) if postprocess else response 63 | -------------------------------------------------------------------------------- /premsql/generators/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from premsql.generators.base import Text2SQLGeneratorBase 5 | 6 | try: 7 | from openai import OpenAI 8 | except ImportError: 9 | raise ImportError("Module openai is not installed") 10 | 11 | 12 | class Text2SQLGeneratorOpenAI(Text2SQLGeneratorBase): 13 | def __init__( 14 | self, 15 | model_name: str, 16 | experiment_name: str, 17 | type: str, 18 | experiment_folder: Optional[str] = None, 19 | openai_api_key: Optional[str] = None, 20 | ): 21 | self._api_key = openai_api_key or os.environ.get("OPENAI_API_KEY") 22 | self.model_name = model_name 23 | super().__init__( 24 | experiment_folder=experiment_folder, 25 | experiment_name=experiment_name, 26 | type=type, 27 | ) 28 | 29 | @property 30 | def load_client(self): 31 | client = OpenAI(api_key=self._api_key) 32 | return client 33 | 34 | @property 35 | def load_tokenizer(self): 36 | pass 37 | 38 | @property 39 | def model_name_or_path(self): 40 | return self.model_name 41 | 42 | def generate( 43 | self, 44 | data_blob: dict, 45 | temperature: Optional[float] = 0.0, 46 | max_new_tokens: Optional[int] = 256, 47 | postprocess: Optional[bool] = True, 48 | **kwargs 49 | ) -> str: 50 | prompt = data_blob["prompt"] 51 | max_tokens = max_new_tokens 52 | generation_config = { 53 | **kwargs, 54 | **{"temperature": temperature, "max_tokens": max_tokens}, 55 | } 56 | completion = ( 57 | self.client.chat.completions.create( 58 | model=self.model_name, 59 | messages=[{"role": "user", "content": prompt}], 60 | **generation_config 61 | ) 62 | .choices[0] 63 | .message.content 64 | ) 65 | return self.postprocess(output_string=completion) if postprocess else completion 66 | -------------------------------------------------------------------------------- /premsql/generators/premai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from premai import Prem 5 | 6 | from premsql.generators.base import Text2SQLGeneratorBase 7 | from premsql.logger import setup_console_logger 8 | 9 | logger = setup_console_logger(name="[PREMAI-GENERATOR]") 10 | 11 | 12 | class Text2SQLGeneratorPremAI(Text2SQLGeneratorBase): 13 | def __init__( 14 | self, 15 | model_name: str, 16 | project_id: str, 17 | experiment_name: str, 18 | type: str, 19 | experiment_folder: Optional[str] = None, 20 | premai_api_key: Optional[str] = None, 21 | **kwargs 22 | ): 23 | self.project_id = project_id 24 | self.premai_api_key = premai_api_key or os.environ.get("PREMAI_API_KEY") 25 | self._kwargs = kwargs 26 | self.model_name = model_name 27 | 28 | super().__init__( 29 | experiment_name=experiment_name, 30 | experiment_folder=experiment_folder, 31 | type=type, 32 | ) 33 | 34 | @property 35 | def load_client(self) -> Prem: 36 | return Prem(api_key=self.premai_api_key) 37 | 38 | @property 39 | def load_tokenizer(self) -> None: 40 | pass 41 | 42 | @property 43 | def model_name_or_path(self) -> str: 44 | return self.model_name 45 | 46 | def generate( 47 | self, 48 | data_blob: dict, 49 | temperature: Optional[float] = 0.0, 50 | max_new_tokens: Optional[int] = 256, 51 | postprocess: Optional[bool] = True, 52 | **kwargs 53 | ) -> str: 54 | prompt = data_blob["prompt"] 55 | max_tokens = max_new_tokens 56 | generation_config = { 57 | **kwargs, 58 | **{"temperature": temperature, "max_tokens": max_tokens}, 59 | } 60 | generated = ( 61 | self.client.chat.completions.create( 62 | project_id=self.project_id, 63 | messages=[{"role": "user", "content": prompt}], 64 | **generation_config 65 | ) 66 | .choices[0] 67 | .message.content 68 | ) 69 | return self.postprocess(output_string=generated) if postprocess else generated 70 | -------------------------------------------------------------------------------- /premsql/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_console_logger(name, level=logging.INFO): 5 | """Function to setup a console logger.""" 6 | formatter = logging.Formatter( 7 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 8 | ) 9 | 10 | console_handler = logging.StreamHandler() 11 | console_handler.setFormatter(formatter) 12 | 13 | logger = logging.getLogger(name) 14 | logger.setLevel(level) 15 | logger.addHandler(console_handler) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /premsql/playground/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.playground.backend.backend_client import BackendAPIClient 2 | from premsql.playground.inference_server.api_client import InferenceServerAPIClient 3 | from premsql.playground.inference_server.service import AgentServer 4 | 5 | __all__ = ["AgentServer", "InferenceServerAPIClient", "BackendAPIClient"] -------------------------------------------------------------------------------- /premsql/playground/backend/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/premsql/playground/backend/api/__init__.py -------------------------------------------------------------------------------- /premsql/playground/backend/api/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import Completions, Session 4 | 5 | admin.site.register(Session) 6 | admin.site.register(Completions) 7 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ApiConfig(AppConfig): 5 | default_auto_field = "django.db.models.BigAutoField" 6 | name = "api" 7 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 5.1.2 on 2024-10-26 09:06 2 | 3 | import django.db.models.deletion 4 | from django.db import migrations, models 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [] 12 | 13 | operations = [ 14 | migrations.CreateModel( 15 | name="Session", 16 | fields=[ 17 | ("session_id", models.AutoField(primary_key=True, serialize=False)), 18 | ("db_connection_uri", models.URLField()), 19 | ("session_name", models.CharField(max_length=255, unique=True)), 20 | ("created_at", models.DateTimeField(auto_now_add=True)), 21 | ("base_url", models.URLField()), 22 | ("session_db_path", models.CharField(max_length=255)), 23 | ], 24 | options={ 25 | "ordering": ["created_at"], 26 | }, 27 | ), 28 | migrations.CreateModel( 29 | name="Completions", 30 | fields=[ 31 | ("chat_id", models.AutoField(primary_key=True, serialize=False)), 32 | ("message_id", models.IntegerField(blank=True, null=True)), 33 | ("session_name", models.CharField(max_length=255)), 34 | ("created_at", models.DateTimeField()), 35 | ("question", models.TextField(blank=True, null=True)), 36 | ( 37 | "session", 38 | models.ForeignKey( 39 | on_delete=django.db.models.deletion.CASCADE, 40 | related_name="messages", 41 | to="api.session", 42 | ), 43 | ), 44 | ], 45 | options={ 46 | "verbose_name_plural": "Completions", 47 | "ordering": ["-created_at"], 48 | }, 49 | ), 50 | ] 51 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/premsql/playground/backend/api/migrations/__init__.py -------------------------------------------------------------------------------- /premsql/playground/backend/api/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Session(models.Model): 5 | session_id = models.AutoField(primary_key=True) 6 | db_connection_uri = models.URLField() 7 | session_name = models.CharField(max_length=255, unique=True) 8 | created_at = models.DateTimeField(auto_now_add=True) 9 | base_url = models.URLField() 10 | session_db_path = models.CharField(max_length=255) 11 | 12 | class Meta: 13 | ordering = ["created_at"] 14 | 15 | 16 | class Completions(models.Model): 17 | chat_id = models.AutoField(primary_key=True) 18 | message_id = models.IntegerField(blank=True, null=True) 19 | session = models.ForeignKey( 20 | Session, on_delete=models.CASCADE, related_name="messages" 21 | ) 22 | session_name = models.CharField(max_length=255) 23 | created_at = models.DateTimeField() 24 | question = models.TextField(blank=True, null=True) 25 | 26 | class Meta: 27 | ordering = ["-created_at"] 28 | verbose_name_plural = "Completions" 29 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/pydantic_models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Literal, Optional 3 | 4 | from pydantic import BaseModel, ConfigDict, Field 5 | 6 | from premsql.agents.models import AgentOutput 7 | 8 | # All the Session Models 9 | 10 | 11 | class SessionCreationRequest(BaseModel): 12 | base_url: str = Field(...) 13 | model_config = ConfigDict(extra="forbid") 14 | 15 | 16 | class SessionCreationResponse(BaseModel): 17 | status_code: Literal[200, 500] = Field(...) 18 | status: Literal["success", "error"] = Field(...) 19 | 20 | session_id: Optional[int] = None 21 | session_name: Optional[str] = None 22 | db_connection_uri: str = Field(None) 23 | session_db_path: str = Field(None) 24 | created_at: Optional[datetime] = None 25 | error_message: Optional[str] = None 26 | 27 | 28 | class SessionSummary(BaseModel): 29 | session_id: int 30 | session_name: str 31 | created_at: datetime 32 | base_url: str 33 | db_connection_uri: str 34 | session_db_path: str 35 | 36 | model_config = ConfigDict(from_attributes=True) 37 | 38 | 39 | class SessionListResponse(BaseModel): 40 | status_code: Literal[200, 500] 41 | status: Literal["success", "error"] 42 | sessions: Optional[List[SessionSummary]] = None 43 | total_count: Optional[int] = None 44 | page: Optional[int] = None 45 | page_size: Optional[int] = None 46 | error_message: Optional[str] = None 47 | 48 | 49 | class SessionDeleteResponse(BaseModel): 50 | session_name: str 51 | status_code: Literal[200, 404, 500] 52 | status: Literal["success", "error"] 53 | error_message: Optional[str] = None 54 | 55 | 56 | # All the chat message models 57 | 58 | 59 | class CompletionCreationRequest(BaseModel): 60 | session_name: str 61 | question: str 62 | 63 | 64 | class CompletionCreationResponse(BaseModel): 65 | status_code: Literal[200, 500] 66 | status: Literal["success", "error"] 67 | message_id: Optional[int] = None 68 | session_name: Optional[str] = None 69 | created_at: Optional[datetime] = None 70 | message: Optional[AgentOutput] = None 71 | question: Optional[str] = None 72 | error_message: Optional[str] = None 73 | 74 | 75 | class CompletionSummary(BaseModel): 76 | message_id: int 77 | session_name: str 78 | base_url: str 79 | created_at: datetime 80 | question: Optional[str] = None 81 | 82 | model_config = ConfigDict(from_attributes=True) 83 | 84 | 85 | class CompletionListResponse(BaseModel): 86 | status_code: Literal[200, 500] 87 | status: Literal["success", "error"] 88 | completions: Optional[List[CompletionSummary]] = None 89 | total_count: Optional[int] = None 90 | error_message: Optional[str] = None 91 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | 4 | class AgentOutputSerializer(serializers.Serializer): 5 | session_name = serializers.CharField() 6 | question = serializers.CharField() 7 | db_connection_uri = serializers.CharField() 8 | route_taken = serializers.ChoiceField( 9 | choices=["plot", "analyse", "query", "followup"] 10 | ) 11 | input_dataframe = serializers.DictField(allow_null=True) 12 | output_dataframe = serializers.DictField(allow_null=True) 13 | sql_string = serializers.CharField(allow_null=True) 14 | analysis = serializers.CharField(allow_null=True) 15 | reasoning = serializers.CharField(allow_null=True) 16 | plot_config = serializers.DictField(allow_null=True) 17 | image_to_plot = serializers.CharField(allow_null=True) 18 | followup_route = serializers.ChoiceField( 19 | choices=["plot", "analyse", "query", "followup"], allow_null=True 20 | ) 21 | followup_suggestion = serializers.CharField(allow_null=True) 22 | error_from_pipeline = serializers.CharField(allow_null=True) 23 | 24 | 25 | # Sessions 26 | class SessionCreationRequestSerializer(serializers.Serializer): 27 | base_url = serializers.CharField() 28 | 29 | 30 | class SessionCreationResponseSerializer(serializers.Serializer): 31 | status_code = serializers.ChoiceField(choices=[200, 500]) 32 | status = serializers.ChoiceField(choices=["success", "error"]) 33 | 34 | session_id = serializers.IntegerField(allow_null=True) 35 | session_name = serializers.CharField(allow_null=True) 36 | db_connection_uri = serializers.CharField(allow_null=True) 37 | session_db_path = serializers.CharField(allow_null=True) 38 | created_at = serializers.DateTimeField(allow_null=True) 39 | error_message = serializers.CharField(allow_null=True) 40 | 41 | 42 | class SessionSummarySerializer(serializers.Serializer): 43 | session_id = serializers.IntegerField() 44 | session_name = serializers.CharField(max_length=255) 45 | created_at = serializers.DateTimeField() 46 | base_url = serializers.CharField() 47 | db_connection_uri = serializers.CharField() 48 | session_db_path = serializers.CharField() 49 | 50 | 51 | class SessionListResponseSerializer(serializers.Serializer): 52 | status_code = serializers.ChoiceField(choices=[200, 500]) 53 | status = serializers.ChoiceField(choices=["success", "error"]) 54 | sessions = SessionSummarySerializer(many=True, allow_null=True) 55 | total_count = serializers.IntegerField(allow_null=True) 56 | page = serializers.IntegerField(allow_null=True) 57 | page_size = serializers.IntegerField(allow_null=True) 58 | error_message = serializers.CharField(allow_null=True) 59 | 60 | 61 | class SessionDeletionResponse(serializers.Serializer): 62 | session_name = serializers.CharField(max_length=255) 63 | status_code = serializers.ChoiceField(choices=[200, 404, 500]) 64 | status = serializers.ChoiceField(choices=["success", "error"]) 65 | error_message = serializers.CharField(allow_null=True) 66 | 67 | 68 | # Chats (Completions) 69 | class CompletionCreationRequestSerializer(serializers.Serializer): 70 | session_name = serializers.CharField() 71 | question = serializers.CharField() 72 | 73 | 74 | class CompletionCreationResponseSerializer(serializers.Serializer): 75 | status_code = serializers.ChoiceField(choices=[200, 500]) 76 | status = serializers.ChoiceField(choices=["success", "error"]) 77 | message_id = serializers.IntegerField(allow_null=True) 78 | session_name = serializers.CharField(allow_null=True) 79 | message = message = AgentOutputSerializer(allow_null=True) 80 | created_at = serializers.DateTimeField(allow_null=True) 81 | question = serializers.CharField(allow_null=True) 82 | error_message = serializers.CharField(allow_null=True) 83 | 84 | 85 | class CompletionSummarySerializer(serializers.Serializer): 86 | message_id = serializers.IntegerField() 87 | session_name = serializers.CharField() 88 | base_url = serializers.CharField() 89 | created_at = serializers.DateTimeField() 90 | question = serializers.CharField(allow_null=True) 91 | 92 | 93 | class CompletionListResponseSerializer(serializers.Serializer): 94 | status_code = serializers.ChoiceField(choices=[200, 500]) 95 | status = serializers.ChoiceField(choices=["success", "error"]) 96 | completions = CompletionSummarySerializer(many=True, allow_null=True) 97 | total_count = serializers.IntegerField(allow_null=True) 98 | error_message = serializers.CharField(allow_null=True) 99 | 100 | 101 | # Utility function for creating model serializers 102 | def create_model_serializer(model_class): 103 | class ModelSerializer(serializers.ModelSerializer): 104 | class Meta: 105 | model = model_class 106 | fields = "__all__" 107 | 108 | return ModelSerializer 109 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | 3 | from . import views 4 | 5 | urlpatterns = [ 6 | path("session/list/", views.list_sessions, name="list_sessions"), 7 | path("session/create", views.create_session, name="create_session"), 8 | path("session//", views.get_session, name="get_session"), 9 | path("session/", views.delete_session, name="delete_session"), 10 | # Chat urls 11 | path("chat/completion", views.create_completion, name="completion"), 12 | path( 13 | "chat/history//", views.get_chat_history, name="chat_history" 14 | ), 15 | ] 16 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import subprocess 5 | 6 | from premsql.logger import setup_console_logger 7 | 8 | logger = setup_console_logger("[BACKEND-UTILS]") 9 | 10 | 11 | def stop_server_on_port(port: int): 12 | try: 13 | result = subprocess.run( 14 | ["lsof", "-ti", f":{port}"], capture_output=True, text=True 15 | ) 16 | if result.returncode == 0: 17 | pid = int(result.stdout.strip()) 18 | os.kill(pid, signal.SIGTERM) 19 | logger.info(f"Server running on port {port} (PID {pid}) stopped.") 20 | else: 21 | logger.info(f"No server found running on port {port}") 22 | except subprocess.CalledProcessError: 23 | logger.info(f"No server found running on port {port}") 24 | except ProcessLookupError: 25 | logger.info(f"Process on port {port} no longer exists") 26 | -------------------------------------------------------------------------------- /premsql/playground/backend/api/views.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from drf_yasg import openapi 4 | from drf_yasg.utils import swagger_auto_schema 5 | from rest_framework import status 6 | from rest_framework.decorators import api_view 7 | from rest_framework.exceptions import ValidationError 8 | from rest_framework.response import Response 9 | 10 | from premsql.logger import setup_console_logger 11 | from premsql.playground.backend.api.pydantic_models import ( 12 | CompletionCreationRequest, 13 | SessionCreationRequest, 14 | SessionListResponse, 15 | SessionSummary, 16 | ) 17 | from premsql.playground.backend.api.serializers import ( 18 | CompletionCreationRequestSerializer, 19 | CompletionCreationResponseSerializer, 20 | CompletionListResponseSerializer, 21 | SessionCreationRequestSerializer, 22 | SessionCreationResponseSerializer, 23 | SessionListResponseSerializer, 24 | SessionSummarySerializer, 25 | ) 26 | 27 | from .services import CompletionService, SessionManageService 28 | 29 | logger = setup_console_logger("[VIEWS]") 30 | 31 | 32 | @swagger_auto_schema( 33 | method="post", 34 | request_body=SessionCreationRequestSerializer, 35 | responses={ 36 | 200: SessionCreationResponseSerializer, 37 | 400: "Bad Request", 38 | 500: SessionCreationResponseSerializer, 39 | }, 40 | ) 41 | @api_view(["POST"]) 42 | def create_session(request): 43 | try: 44 | session_request = SessionCreationRequest(**request.data) 45 | response = SessionManageService().create_session(request=session_request) 46 | return Response(response.model_dump()) 47 | except json.JSONDecodeError: 48 | return Response( 49 | {"status": "error", "error_message": "Invalid JSON"}, 50 | status=status.HTTP_400_BAD_REQUEST, 51 | ) 52 | except Exception as e: 53 | return Response( 54 | {"status": "error", "error_message": str(e)}, 55 | status=status.HTTP_500_INTERNAL_SERVER_ERROR, 56 | ) 57 | 58 | 59 | @swagger_auto_schema( 60 | method="get", 61 | manual_parameters=[ 62 | openapi.Parameter( 63 | "session_name", 64 | openapi.IN_PATH, 65 | description="Name of the session", 66 | type=openapi.TYPE_STRING, 67 | ), 68 | ], 69 | responses={ 70 | 200: SessionSummarySerializer, 71 | 400: "Bad Request", 72 | 500: SessionSummarySerializer, 73 | }, 74 | ) 75 | @api_view(["GET"]) 76 | def get_session(request, session_name): 77 | session = SessionManageService().get_session(session_name=session_name) 78 | if session: 79 | session_summary = SessionSummary.model_validate(session) 80 | response = SessionListResponse( 81 | status="success", 82 | status_code=200, 83 | sessions=[session_summary.model_dump()], 84 | total_count=1, 85 | page=1, 86 | page_size=1, 87 | ) 88 | else: 89 | response = SessionListResponse( 90 | status="error", 91 | status_code=500, 92 | error_message="The requested session does not exist.", 93 | ) 94 | return Response( 95 | response.model_dump(), 96 | status=status.HTTP_200_OK if session else status.HTTP_404_NOT_FOUND, 97 | ) 98 | 99 | 100 | @swagger_auto_schema( 101 | method="get", 102 | manual_parameters=[ 103 | openapi.Parameter( 104 | "page", 105 | openapi.IN_QUERY, 106 | description="Page number", 107 | type=openapi.TYPE_INTEGER, 108 | default=1, 109 | ), 110 | openapi.Parameter( 111 | "page_size", 112 | openapi.IN_QUERY, 113 | description="Number of items per page", 114 | type=openapi.TYPE_INTEGER, 115 | default=20, 116 | ), 117 | ], 118 | responses={ 119 | 200: SessionListResponseSerializer, 120 | 400: "Bad Request", 121 | 500: SessionListResponseSerializer, 122 | }, 123 | ) 124 | @api_view(["GET"]) 125 | def list_sessions(request): 126 | page = int(request.query_params.get("page", 1)) 127 | page_size = int(request.query_params.get("page_size", 20)) 128 | response = SessionManageService().list_session(page=page, page_size=page_size) 129 | return Response(response.model_dump()) 130 | 131 | 132 | @swagger_auto_schema( 133 | method="delete", 134 | manual_parameters=[ 135 | openapi.Parameter( 136 | "session_name", 137 | openapi.IN_PATH, 138 | description="Name of the session to delete", 139 | type=openapi.TYPE_STRING, 140 | required=True, 141 | ), 142 | ], 143 | responses={ 144 | 200: openapi.Response( 145 | "Session deleted successfully", 146 | schema=openapi.Schema( 147 | type=openapi.TYPE_OBJECT, 148 | properties={ 149 | "status": openapi.Schema( 150 | type=openapi.TYPE_STRING, example="success" 151 | ), 152 | "message": openapi.Schema(type=openapi.TYPE_STRING), 153 | }, 154 | ), 155 | ), 156 | 404: "Not Found", 157 | 500: "Internal Server Error", 158 | }, 159 | ) 160 | @api_view(["DELETE"]) 161 | def delete_session(request, session_name): 162 | try: 163 | result = SessionManageService().delete_session(session_name=session_name) 164 | return Response(result.model_dump(), status=result.status_code) 165 | except Exception as e: 166 | return Response( 167 | {"status": "error", "error_message": str(e)}, 168 | status=status.HTTP_500_INTERNAL_SERVER_ERROR, 169 | ) 170 | 171 | 172 | # Completion Views 173 | 174 | 175 | @swagger_auto_schema( 176 | method="post", 177 | request_body=CompletionCreationRequestSerializer, 178 | responses={ 179 | 200: CompletionCreationResponseSerializer, 180 | 400: "Bad Request", 181 | 404: "Not Found", 182 | 500: "Internal Server Error", 183 | }, 184 | ) 185 | @api_view(["POST"]) 186 | def create_completion(request): 187 | try: 188 | completion_request = CompletionCreationRequest(**request.data) 189 | response = CompletionService().completion(request=completion_request) 190 | return Response( 191 | response.model_dump(), 192 | status=( 193 | status.HTTP_200_OK 194 | if response.status == "success" 195 | else status.HTTP_500_INTERNAL_SERVER_ERROR 196 | ), 197 | ) 198 | except ValidationError as e: 199 | return Response( 200 | {"status": "error", "error_message": str(e)}, 201 | status=status.HTTP_400_BAD_REQUEST, 202 | ) 203 | except Exception as e: 204 | return Response( 205 | {"status": "error", "error_message": str(e)}, 206 | status=status.HTTP_500_INTERNAL_SERVER_ERROR, 207 | ) 208 | 209 | 210 | @swagger_auto_schema( 211 | method="get", 212 | manual_parameters=[ 213 | openapi.Parameter( 214 | "session_name", 215 | openapi.IN_PATH, 216 | description="Name of the session", 217 | type=openapi.TYPE_STRING, 218 | required=True, 219 | ), 220 | openapi.Parameter( 221 | "page", 222 | openapi.IN_QUERY, 223 | description="Page number", 224 | type=openapi.TYPE_INTEGER, 225 | default=1, 226 | ), 227 | openapi.Parameter( 228 | "page_size", 229 | openapi.IN_QUERY, 230 | description="Number of items per page", 231 | type=openapi.TYPE_INTEGER, 232 | default=20, 233 | ), 234 | ], 235 | responses={ 236 | 200: CompletionListResponseSerializer, 237 | 400: "Bad Request", 238 | 404: "Not Found", 239 | 500: "Internal Server Error", 240 | }, 241 | ) 242 | @api_view(["GET"]) 243 | def get_chat_history(request, session_name): 244 | try: 245 | page = int(request.query_params.get("page", 1)) 246 | page_size = int(request.query_params.get("page_size", 20)) 247 | 248 | response = CompletionService().chat_history( 249 | session_name=session_name, page=page, page_size=page_size 250 | ) 251 | 252 | return Response( 253 | response.model_dump(), 254 | status=( 255 | status.HTTP_200_OK 256 | if response.status == "success" 257 | else status.HTTP_404_NOT_FOUND 258 | ), 259 | ) 260 | except ValueError: 261 | return Response( 262 | {"status": "error", "error_message": "Invalid page or page_size parameter"}, 263 | status=status.HTTP_400_BAD_REQUEST, 264 | ) 265 | except Exception as e: 266 | return Response( 267 | {"status": "error", "error_message": str(e)}, 268 | status=status.HTTP_500_INTERNAL_SERVER_ERROR, 269 | ) 270 | -------------------------------------------------------------------------------- /premsql/playground/backend/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/premAI-io/premsql/7041239e5ce10f87e99e8df1bc669a2b778f9cb4/premsql/playground/backend/backend/__init__.py -------------------------------------------------------------------------------- /premsql/playground/backend/backend/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for backend project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/5.1/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings") 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /premsql/playground/backend/backend/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for backend project. 3 | 4 | Generated by 'django-admin startproject' using Django 5.1.2. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/5.1/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/5.1/ref/settings/ 11 | """ 12 | 13 | from pathlib import Path 14 | 15 | # Build paths inside the project like this: BASE_DIR / 'subdir'. 16 | BASE_DIR = Path(__file__).resolve().parent.parent 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/5.1/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = "django-insecure-v3#txach78pic91j!s=ia3w+h@58niky5ozim)j0+6r56m$pmj" 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | "django.contrib.admin", 35 | "django.contrib.auth", 36 | "django.contrib.contenttypes", 37 | "django.contrib.sessions", 38 | "django.contrib.messages", 39 | "django.contrib.staticfiles", 40 | "rest_framework", 41 | "drf_yasg", 42 | "api", 43 | ] 44 | 45 | MIDDLEWARE = [ 46 | "django.middleware.security.SecurityMiddleware", 47 | "django.contrib.sessions.middleware.SessionMiddleware", 48 | "django.middleware.common.CommonMiddleware", 49 | "django.middleware.csrf.CsrfViewMiddleware", 50 | "django.contrib.auth.middleware.AuthenticationMiddleware", 51 | "django.contrib.messages.middleware.MessageMiddleware", 52 | "django.middleware.clickjacking.XFrameOptionsMiddleware", 53 | ] 54 | 55 | ROOT_URLCONF = "backend.urls" 56 | 57 | TEMPLATES = [ 58 | { 59 | "BACKEND": "django.template.backends.django.DjangoTemplates", 60 | "DIRS": [], 61 | "APP_DIRS": True, 62 | "OPTIONS": { 63 | "context_processors": [ 64 | "django.template.context_processors.debug", 65 | "django.template.context_processors.request", 66 | "django.contrib.auth.context_processors.auth", 67 | "django.contrib.messages.context_processors.messages", 68 | ], 69 | }, 70 | }, 71 | ] 72 | 73 | WSGI_APPLICATION = "backend.wsgi.application" 74 | 75 | 76 | # Database 77 | # https://docs.djangoproject.com/en/5.1/ref/settings/#databases 78 | 79 | DATABASES = { 80 | "default": { 81 | "ENGINE": "django.db.backends.sqlite3", 82 | "NAME": BASE_DIR / "db.sqlite3", 83 | } 84 | } 85 | 86 | 87 | # Password validation 88 | # https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators 89 | 90 | AUTH_PASSWORD_VALIDATORS = [ 91 | { 92 | "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", 93 | }, 94 | { 95 | "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", 96 | }, 97 | { 98 | "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", 99 | }, 100 | { 101 | "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", 102 | }, 103 | ] 104 | 105 | 106 | # Internationalization 107 | # https://docs.djangoproject.com/en/5.1/topics/i18n/ 108 | 109 | LANGUAGE_CODE = "en-us" 110 | 111 | TIME_ZONE = "UTC" 112 | 113 | USE_I18N = True 114 | 115 | USE_TZ = True 116 | 117 | 118 | # Static files (CSS, JavaScript, Images) 119 | # https://docs.djangoproject.com/en/5.1/howto/static-files/ 120 | 121 | STATIC_URL = "static/" 122 | 123 | # Default primary key field type 124 | # https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field 125 | 126 | DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" 127 | -------------------------------------------------------------------------------- /premsql/playground/backend/backend/urls.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from django.urls import include, path 3 | from drf_yasg import openapi 4 | from drf_yasg.views import get_schema_view 5 | from rest_framework import permissions 6 | 7 | schema_view = get_schema_view( 8 | openapi.Info( 9 | title="PremSQL API", 10 | default_version="v0.0.1", 11 | description="API which controls PremSQL pipelines and agents", 12 | contact=openapi.Contact(email="anindyadeep@premai.io"), 13 | license=openapi.License(name="MIT"), 14 | ), 15 | public=True, 16 | permission_classes=(permissions.AllowAny,), 17 | ) 18 | 19 | urlpatterns = [ 20 | path("admin/", admin.site.urls), 21 | path("api/", include("api.urls")), 22 | path( 23 | "swagger/", schema_view.without_ui(cache_timeout=0), name="schema-json" 24 | ), 25 | path( 26 | "swagger/", 27 | schema_view.with_ui("swagger", cache_timeout=0), 28 | name="schema-swagger-ui", 29 | ), 30 | path("redoc/", schema_view.with_ui("redoc", cache_timeout=0), name="schema-redoc"), 31 | ] 32 | -------------------------------------------------------------------------------- /premsql/playground/backend/backend/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for backend project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/5.1/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /premsql/playground/backend/backend_client.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from premsql.logger import setup_console_logger 3 | from premsql.playground.backend.api.pydantic_models import ( 4 | SessionCreationResponse, 5 | SessionDeleteResponse, 6 | SessionListResponse, 7 | SessionCreationRequest, 8 | CompletionCreationRequest, 9 | CompletionCreationResponse, 10 | CompletionListResponse, 11 | ) 12 | 13 | BASE_URL = "http://127.0.0.1:8000/api" 14 | 15 | logger = setup_console_logger("BACKEND-API-CLIENT") 16 | 17 | 18 | class BackendAPIClient: 19 | def __init__(self): 20 | self.base_url = BASE_URL 21 | self.headers = { 22 | 'accept': 'application/json', 23 | 'Content-Type': 'application/json', 24 | } 25 | 26 | def create_session(self, request: SessionCreationRequest) -> SessionCreationResponse: 27 | try: 28 | response = requests.post( 29 | f"{self.base_url}/session/create", 30 | json=request.model_dump(), 31 | headers=self.headers 32 | ) 33 | response.raise_for_status() # Raises an HTTPError for bad responses 34 | 35 | return SessionCreationResponse(**response.json()) 36 | except requests.RequestException as e: 37 | logger.error(f"Error creating session: {str(e)}") 38 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 39 | return SessionCreationResponse( 40 | status="error", 41 | status_code=response.status_code if 'response' in locals() and hasattr(response, 'status_code') else 500, 42 | error_message=f"Failed to create session: {str(e)}" 43 | ) 44 | except ValueError as e: 45 | logger.error(f"Error parsing response: {str(e)}") 46 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 47 | return SessionCreationResponse( 48 | status="error", 49 | status_code=500, 50 | error_message=f"Failed to parse server response: {str(e)}" 51 | ) 52 | 53 | except requests.RequestException as e: 54 | logger.error(f"Error creating session: {str(e)}") 55 | logger.error(f"Response content: {response.text}") 56 | return SessionCreationResponse( 57 | status="error", 58 | status_code=response.status_code if hasattr(response, 'status_code') else 500, 59 | error_message=f"Failed to create session: {str(e)}" 60 | ) 61 | except ValueError as e: 62 | logger.error(f"Error parsing response: {str(e)}") 63 | logger.error(f"Response content: {response.text}") 64 | return SessionCreationResponse( 65 | status="error", 66 | status_code=500, 67 | error_message=f"Failed to parse server response: {str(e)}" 68 | ) 69 | 70 | def list_sessions(self, page: int = 1, page_size: int = 20) -> SessionListResponse: 71 | try: 72 | response = requests.get( 73 | f"{self.base_url}/session/list/", 74 | params={"page": page, "page_size": page_size}, 75 | headers=self.headers 76 | ) 77 | response.raise_for_status() 78 | return SessionListResponse(**response.json()) 79 | except requests.RequestException as e: 80 | logger.error(f"Error listing sessions: {str(e)}") 81 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 82 | return SessionListResponse( 83 | status="error", 84 | status_code=response.status_code if 'response' in locals() and hasattr(response, 'status_code') else 500, 85 | error_message=f"Failed to list sessions: {str(e)}", 86 | sessions=[], 87 | total_count=0 88 | ) 89 | except ValueError as e: 90 | logger.error(f"Error parsing response: {str(e)}") 91 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 92 | return SessionListResponse( 93 | status="error", 94 | status_code=500, 95 | error_message=f"Failed to parse server response: {str(e)}", 96 | sessions=[], 97 | total_count=0 98 | ) 99 | 100 | def get_session(self, session_name: str) -> SessionListResponse: 101 | try: 102 | response = requests.get( 103 | f"{self.base_url}/session/{session_name}/", 104 | headers=self.headers 105 | ) 106 | response.raise_for_status() 107 | return SessionListResponse(**response.json()) 108 | except requests.RequestException as e: 109 | logger.error(f"Error getting session: {str(e)}") 110 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 111 | return SessionListResponse( 112 | status="error", 113 | status_code=response.status_code if 'response' in locals() and hasattr(response, 'status_code') else 500, 114 | error_message=f"Failed to get session: {str(e)}", 115 | name="", 116 | created_at="", 117 | sessions=[] 118 | ) 119 | except (ValueError, KeyError, IndexError) as e: 120 | logger.error(f"Error parsing response: {str(e)}") 121 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 122 | return SessionListResponse( 123 | status="error", 124 | status_code=500, 125 | error_message=f"Failed to parse server response: {str(e)}", 126 | name="", 127 | created_at="", 128 | sessions=[] 129 | ) 130 | 131 | def delete_session(self, session_name: str) -> SessionDeleteResponse: 132 | try: 133 | response = requests.delete( 134 | f"{self.base_url}/session/{session_name}", 135 | headers=self.headers 136 | ) 137 | response.raise_for_status() 138 | return SessionDeleteResponse(**response.json()) 139 | except requests.RequestException as e: 140 | logger.error(f"Error deleting session: {str(e)}") 141 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 142 | return SessionDeleteResponse( 143 | status="error", 144 | status_code=response.status_code if 'response' in locals() and hasattr(response, 'status_code') else 500, 145 | error_message=f"Failed to delete session: {str(e)}" 146 | ) 147 | except ValueError as e: 148 | logger.error(f"Error parsing response: {str(e)}") 149 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 150 | return SessionDeleteResponse( 151 | status="error", 152 | status_code=500, 153 | error_message=f"Failed to parse server response: {str(e)}" 154 | ) 155 | 156 | # Chats 157 | def create_completion(self, request: CompletionCreationRequest) -> CompletionCreationResponse: 158 | try: 159 | response = requests.post( 160 | f"{self.base_url}/chat/completion", 161 | json=request.model_dump(), 162 | headers=self.headers 163 | ) 164 | response.raise_for_status() 165 | return CompletionCreationResponse(**response.json()) 166 | except requests.RequestException as e: 167 | logger.error(f"Error creating completion: {str(e)}") 168 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 169 | return CompletionCreationResponse( 170 | status="error", 171 | status_code=response.status_code if 'response' in locals() and hasattr(response, 'status_code') else 500, 172 | error_message=f"Failed to create completion: {str(e)}", 173 | completion="" 174 | ) 175 | except ValueError as e: 176 | logger.error(f"Error parsing response: {str(e)}") 177 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 178 | return CompletionCreationResponse( 179 | status="error", 180 | status_code=500, 181 | error_message=f"Failed to parse server response: {str(e)}", 182 | completion="" 183 | ) 184 | 185 | def get_chat_history(self, session_name: str, page: int = 1, page_size: int = 20) -> CompletionListResponse: 186 | try: 187 | response = requests.get( 188 | f"{self.base_url}/chat/history/{session_name}/", 189 | params={"page": page, "page_size": page_size}, 190 | headers=self.headers 191 | ) 192 | response.raise_for_status() 193 | return CompletionListResponse(**response.json()) 194 | except requests.RequestException as e: 195 | logger.error(f"Error getting chat history: {str(e)}") 196 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 197 | return CompletionListResponse( 198 | status="error", 199 | status_code=500, 200 | error_message=f"Failed to get chat history: {str(e)}", 201 | completions=[], 202 | total_count=0 203 | ) 204 | except ValueError as e: 205 | logger.error(f"Error parsing response: {str(e)}") 206 | logger.error(f"Response content: {response.text if 'response' in locals() else 'No response'}") 207 | return CompletionListResponse( 208 | status="error", 209 | status_code=500, 210 | error_message=f"Failed to parse server response: {str(e)}", 211 | completions=[], 212 | total_count=0 213 | ) 214 | -------------------------------------------------------------------------------- /premsql/playground/backend/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | """Run administrative tasks.""" 9 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings") 10 | try: 11 | from django.core.management import execute_from_command_line 12 | import django 13 | # Patch Django's CommandParser before executing command 14 | from django.core.management.base import CommandParser 15 | original_init = CommandParser.__init__ 16 | def new_init(self, **kwargs): 17 | kwargs.pop('allow_abbrev', None) # Remove allow_abbrev if present 18 | original_init(self, **kwargs) 19 | CommandParser.__init__ = new_init 20 | except ImportError as exc: 21 | raise ImportError( 22 | "Couldn't import Django. Are you sure it's installed and " 23 | "available on your PYTHONPATH environment variable? Did you " 24 | "forget to activate a virtual environment?" 25 | ) from exc 26 | execute_from_command_line(sys.argv) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /premsql/playground/frontend/components/chat.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import streamlit as st 3 | from premsql.playground.backend.backend_client import BackendAPIClient 4 | from premsql.playground.inference_server.api_client import InferenceServerAPIClient 5 | from premsql.playground.backend.api.pydantic_models import CompletionCreationRequest 6 | from premsql.playground.frontend.components.streamlit_plot import StreamlitPlotTool 7 | from premsql.agents.memory import AgentInteractionMemory 8 | from premsql.agents.utils import convert_exit_output_to_agent_output 9 | from premsql.agents.models import ExitWorkerOutput, AgentOutput 10 | from premsql.logger import setup_console_logger 11 | 12 | logger = setup_console_logger("FRONTEND-CHAT") 13 | 14 | class ChatComponent: 15 | def __init__(self) -> None: 16 | self.backend_client = BackendAPIClient() 17 | self.inference_client = InferenceServerAPIClient() 18 | self.plotter = StreamlitPlotTool() 19 | 20 | def _streamlit_chat_output(self, message: AgentOutput | ExitWorkerOutput): 21 | if isinstance(message, ExitWorkerOutput): 22 | message = convert_exit_output_to_agent_output(exit_output=message) 23 | 24 | if message.output_dataframe: 25 | try: 26 | df = message.output_dataframe 27 | df = pd.DataFrame(df["data"], columns=df["columns"]) 28 | if message.plot_config is None: 29 | st.dataframe(df) 30 | except Exception as e: 31 | st.error(f"Error: {e}") 32 | 33 | if message.analysis: 34 | st.markdown(message.analysis) 35 | if message.plot_config: 36 | df = message.input_dataframe 37 | if df: 38 | self.plotter.run( 39 | data=pd.DataFrame(df["data"], columns=df["columns"]), 40 | plot_config=message.plot_config 41 | ) 42 | if message.followup_suggestion: 43 | st.warning(message.followup_suggestion) 44 | with st.expander(label="Reasoning"): 45 | if message.sql_string: 46 | st.code(message.sql_string) 47 | if message.reasoning: 48 | st.markdown(message.reasoning) 49 | if message.plot_config: 50 | st.json(message.plot_config) 51 | if message.error_from_pipeline: 52 | st.error(message.error_from_pipeline) 53 | 54 | 55 | def render_chat_env(self, session_name: str) -> None: 56 | session_info = self.backend_client.get_session( 57 | session_name=session_name 58 | ) 59 | if session_info.status_code == 500: 60 | st.error(f"Failed to render chat History for session: {session_name}") 61 | 62 | session = session_info.sessions[0] 63 | session_db_path = session.session_db_path 64 | base_url = session.base_url 65 | # TODO: Need to understand how can I start the session 66 | 67 | history = AgentInteractionMemory( 68 | session_name=session_name, db_path=session_db_path 69 | ) 70 | 71 | messages = history.generate_messages_from_session(session_name=session_name, server_mode=True) 72 | if not messages: 73 | st.warning("No chat history available for this session.") 74 | else: 75 | for message in messages: 76 | with st.chat_message("user"): st.markdown(message.question) 77 | with st.chat_message("assistant"): 78 | self._streamlit_chat_output(message=message) 79 | 80 | 81 | base_url = f"http://{base_url}" if not base_url.startswith("http://") else base_url 82 | is_session_online_status = self.inference_client.is_online(base_url=base_url) 83 | if is_session_online_status != 200: 84 | st.divider() 85 | st.warning(f"Session ended. Restart Agent Server to start the session at: {base_url}") 86 | 87 | else: 88 | if prompt := st.chat_input("What is your question?"): 89 | with st.chat_message("user"): 90 | st.markdown(prompt) 91 | 92 | with st.chat_message("assistant"): 93 | with st.spinner("Thinking..."): 94 | response = self.backend_client.create_completion( 95 | CompletionCreationRequest( 96 | session_name=session_name, 97 | question=prompt 98 | ) 99 | ) 100 | if response.status_code == 200: 101 | self._streamlit_chat_output( 102 | message=history.get_by_message_id(message_id=response.message_id) 103 | ) 104 | else: 105 | st.error("Something went wrong. Try again") 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /premsql/playground/frontend/components/session.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from premsql.playground.backend.backend_client import BackendAPIClient 3 | from premsql.playground.backend.api.pydantic_models import SessionCreationRequest 4 | 5 | additional_link_markdown = """ 6 | Here are some quick links to get you started with Prem: 7 | 8 | - Head over to [Prem App](https://app.premai.io/projects/) to start building on Gen AI. 9 | - [Prem AI documentation](https://docs.premai.io/get-started/why-prem) 10 | - [PremSQL documentation](https://docs.premai.io/premsql/introduction) 11 | """ 12 | 13 | class SessionComponent: 14 | def __init__(self) -> None: 15 | self.backend_client = BackendAPIClient() 16 | 17 | def render_list_sessions(self): 18 | with st.sidebar: 19 | st.sidebar.title("Your Past Sessions") 20 | all_sessions = self.backend_client.list_sessions(page_size=100).sessions 21 | if all_sessions: 22 | all_sessions = [session.session_name for session in all_sessions] 23 | 24 | selected_session = st.selectbox( 25 | label="Your Sessions (refresh if you have created a new one)", 26 | options=all_sessions, 27 | ) 28 | return selected_session 29 | 30 | def render_register_session(self): 31 | with st.sidebar: 32 | st.sidebar.title("Register new Session") 33 | with st.form( 34 | key="session_creation", 35 | clear_on_submit=True, 36 | border=100, 37 | enter_to_submit=False, 38 | ): 39 | base_url = st.text_input( 40 | label="base_url", 41 | placeholder="the base url in which AgentServer is running" 42 | ) 43 | button = st.form_submit_button(label="Submit") 44 | if button: 45 | response = self.backend_client.create_session( 46 | request=SessionCreationRequest(base_url=base_url) 47 | ) 48 | if response.status_code == 500: 49 | st.toast(body=st.markdown(response.error_message), icon="❌") 50 | else: 51 | st.toast(body=f"Session: {response.session_name} created successfully", icon="🥳") 52 | return response 53 | 54 | def render_additional_links(self): 55 | with st.sidebar: 56 | with st.container(height=200): 57 | st.markdown(additional_link_markdown) 58 | 59 | 60 | def render_delete_session_view(self): 61 | with st.sidebar: 62 | with st.expander(label="Delete a session"): 63 | with st.form(key="delete_session", clear_on_submit=True, enter_to_submit=False): 64 | session_name = st.text_input(label="Enter session name") 65 | button = st.form_submit_button(label="Submit") 66 | if button: 67 | all_sessions = self.backend_client.list_sessions(page_size=100).sessions 68 | all_sessions = [session.session_name for session in all_sessions] 69 | if session_name not in all_sessions: 70 | st.error("Session does not exist") 71 | else: 72 | self.backend_client.delete_session(session_name=session_name) 73 | st.success(f"Deleted session: {session_name}. Please refresh") -------------------------------------------------------------------------------- /premsql/playground/frontend/components/streamlit_plot.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from typing import Dict, Any 3 | import pandas as pd 4 | import streamlit as st 5 | from premsql.logger import setup_console_logger 6 | from premsql.agents.tools.plot.base import BasePlotTool 7 | 8 | logger = setup_console_logger("[STREAMLIT-TOOL]") 9 | 10 | class StreamlitPlotTool(BasePlotTool): 11 | def __init__(self): 12 | self.plot_functions = { 13 | "area": self._area_plot, 14 | "bar": self._bar_plot, 15 | "scatter": self._scatter_plot, 16 | "histogram": self._histogram_plot, 17 | "line": self._line_plot, 18 | } 19 | 20 | def run(self, data: pd.DataFrame, plot_config: Dict[str, str]) -> Any: 21 | try: 22 | self._validate_config(data, plot_config) 23 | 24 | plot_type = plot_config["plot_type"] 25 | x = plot_config["x"] 26 | y = plot_config["y"] 27 | 28 | st.markdown(f"**{plot_type.capitalize()} Plot: {x} vs {y}**") 29 | return self.plot_functions[plot_type](data, x, y) 30 | except Exception as e: 31 | error_msg = f"Error creating plot: {str(e)}" 32 | stack_trace = traceback.format_exc() 33 | logger.error(f"{error_msg}\n{stack_trace}") 34 | logger.error(f"Error creating plot: {str(e)}") 35 | st.error(f"Error creating plot: {str(e)}") 36 | return None 37 | 38 | def _validate_config(self, df: pd.DataFrame, plot_config: Dict[str, str]) -> None: 39 | required_keys = ["plot_type", "x", "y"] 40 | missing_keys = [key for key in required_keys if key not in plot_config] 41 | if missing_keys: 42 | raise ValueError(f"Missing required keys in plot_config: {', '.join(missing_keys)}") 43 | 44 | for key in ["x", "y"]: 45 | if key not in plot_config: 46 | raise ValueError(f"'{key}' is missing from plot_config") 47 | if not isinstance(plot_config[key], str): 48 | raise TypeError(f"plot_config['{key}'] should be a string, but got {type(plot_config[key])}") 49 | 50 | if not isinstance(df, pd.DataFrame): 51 | raise TypeError(f"Expected df to be a pandas DataFrame, but got {type(df)}") 52 | 53 | if not hasattr(df, 'columns'): 54 | raise AttributeError(f"df does not have a 'columns' attribute. Type: {type(df)}") 55 | 56 | if plot_config["x"] not in df.columns: 57 | raise ValueError(f"Column '{plot_config['x']}' not found in DataFrame. Available columns: {', '.join(df.columns)}") 58 | 59 | if plot_config["y"] not in df.columns: 60 | raise ValueError(f"Column '{plot_config['y']}' not found in DataFrame. Available columns: {', '.join(df.columns)}") 61 | 62 | if plot_config["plot_type"] not in self.plot_functions: 63 | raise ValueError(f"Unsupported plot type: {plot_config['plot_type']}. Supported types: {', '.join(self.plot_functions.keys())}") 64 | 65 | def _area_plot(self, df: pd.DataFrame, x: str, y: str) -> Any: 66 | chart_data = df[[x, y]].set_index(x) 67 | return st.area_chart(chart_data) 68 | 69 | def _bar_plot(self, df: pd.DataFrame, x: str, y: str) -> Any: 70 | chart_data = df[[x, y]].set_index(x) 71 | return st.bar_chart(chart_data) 72 | 73 | def _scatter_plot(self, df: pd.DataFrame, x: str, y: str) -> Any: 74 | chart_data = df[[x, y]] 75 | return st.scatter_chart(chart_data, x=x, y=y) 76 | 77 | def _histogram_plot(self, df: pd.DataFrame, x: str, y: str) -> Any: 78 | # Streamlit doesn't have a built-in histogram function, so we'll use a bar chart 79 | hist_data = df[x].value_counts().sort_index() 80 | chart_data = pd.DataFrame({x: hist_data.index, 'count': hist_data.values}) 81 | return st.bar_chart(chart_data.set_index(x)) 82 | 83 | def _line_plot(self, df: pd.DataFrame, x: str, y: str) -> Any: 84 | chart_data = df[[x, y]].set_index(x) 85 | return st.line_chart(chart_data) 86 | 87 | def convert_plot_to_image(self, fig): 88 | pass -------------------------------------------------------------------------------- /premsql/playground/frontend/main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from premsql.playground.frontend.components.chat import ChatComponent 3 | from premsql.playground.frontend.components.session import SessionComponent 4 | from premsql.playground.frontend.components.uploader import UploadComponent 5 | 6 | st.set_page_config(page_title="PremSQL Playground", page_icon="🔍", layout="wide") 7 | 8 | def render_main_view(): 9 | session_component = SessionComponent() 10 | 11 | selected_session = session_component.render_list_sessions() 12 | session_creation = session_component.render_register_session() 13 | session_component.render_additional_links() 14 | 15 | if session_creation is not None: 16 | if session_creation.status_code == 200: 17 | new_session_name = session_creation.session_name 18 | st.success(f"New session created: {new_session_name}") 19 | ChatComponent().render_chat_env(session_name=new_session_name) 20 | elif selected_session is not None: 21 | ChatComponent().render_chat_env(session_name=selected_session) 22 | 23 | session_component.render_delete_session_view() 24 | 25 | def main(): 26 | _, col2, _ = st.sidebar.columns([1, 2, 1]) 27 | with col2: 28 | st.image( 29 | "https://static.premai.io/logo.svg", 30 | use_container_width=True, 31 | width=150, 32 | clamp=True, 33 | ) 34 | st.header("PremSQL Playground") 35 | st.title("PremSQL Playground") 36 | 37 | # Add navigation 38 | selected_page = st.sidebar.selectbox("Navigation", ["Chat", "Upload csvs or use Kaggle"]) 39 | 40 | if selected_page == "Chat": 41 | st.write("Welcome to the PremSQL Playground. Select or create a session to get started.") 42 | render_main_view() 43 | else: 44 | st.write( 45 | "You can either upload multiple csv files or enter a valid Kaggle ID. " 46 | "This will migrate all the csvs into a SQLite Database. You can then " 47 | "use them for natural language powered analysis using PremSQL." 48 | ) 49 | UploadComponent.render_kaggle_view() 50 | UploadComponent.render_csv_upload_view() 51 | 52 | if __name__ == "__main__": 53 | main() -------------------------------------------------------------------------------- /premsql/playground/frontend/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import pandas as pd 4 | import kagglehub 5 | import sqlite3 6 | from pathlib import Path 7 | from platformdirs import user_cache_dir 8 | from premsql.logger import setup_console_logger 9 | 10 | logger = setup_console_logger("[FRONTEND-UTILS]") 11 | 12 | def _is_valid_kaggle_id(kaggle_id: str) -> bool: 13 | pattern = r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$' 14 | return bool(re.match(pattern, kaggle_id)) 15 | 16 | def download_from_kaggle(kaggle_dataset_id: str): 17 | path = kagglehub.dataset_download(handle=kaggle_dataset_id) 18 | return path 19 | 20 | def _migrate_to_sqlite(csv_folder: Path, sqlite_db_path: Path) -> Path: 21 | """Common migration logic for both Kaggle and local CSV uploads.""" 22 | conn = sqlite3.connect(sqlite_db_path) 23 | try: 24 | for csv_file in csv_folder.glob('*.csv'): 25 | table_name = csv_file.stem 26 | df = pd.read_csv(csv_file) 27 | df.to_sql(table_name, conn, if_exists='replace', index=False) 28 | logger.info(f"Migrated {csv_file.name} to table '{table_name}'") 29 | 30 | logger.info(f"Successfully migrated all CSV files to {sqlite_db_path}") 31 | return sqlite_db_path 32 | except Exception as e: 33 | logger.error(f"Error during migration: {e}") 34 | raise 35 | finally: 36 | conn.close() 37 | 38 | def migrate_from_csv_to_sqlite( 39 | folder_containing_csvs: str, 40 | session_name: str 41 | ) -> Path: 42 | sqlite_db_folder = Path(user_cache_dir()) / "premsql" / "kaggle" 43 | os.makedirs(sqlite_db_folder, exist_ok=True) 44 | sqlite_db_path = sqlite_db_folder / f"{session_name}.sqlite" 45 | return _migrate_to_sqlite(Path(folder_containing_csvs), sqlite_db_path) 46 | 47 | def migrate_local_csvs_to_sqlite( 48 | uploaded_files: list, 49 | session_name: str 50 | ) -> Path: 51 | cache_dir = Path(user_cache_dir()) 52 | csv_folder = cache_dir / "premsql" / "csv_uploads" / session_name 53 | sqlite_db_folder = cache_dir / "premsql" / "csv_uploads" 54 | 55 | os.makedirs(csv_folder, exist_ok=True) 56 | os.makedirs(sqlite_db_folder, exist_ok=True) 57 | 58 | sqlite_db_path = sqlite_db_folder / f"{session_name}.sqlite" 59 | 60 | # Save uploaded files to CSV folder 61 | for uploaded_file in uploaded_files: 62 | file_path = csv_folder / uploaded_file.name 63 | with open(file_path, 'wb') as f: 64 | f.write(uploaded_file.getvalue()) 65 | 66 | return _migrate_to_sqlite(csv_folder, sqlite_db_path) -------------------------------------------------------------------------------- /premsql/playground/inference_server/api_client.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from urllib.parse import urljoin 3 | 4 | import requests 5 | 6 | 7 | class InferenceServerAPIError(Exception): 8 | pass 9 | 10 | 11 | class InferenceServerAPIClient: 12 | def __init__(self, timeout: int = 600) -> None: 13 | self.headers = { 14 | "accept": "application/json", 15 | "Content-Type": "application/json", 16 | } 17 | self.timeout = timeout 18 | 19 | def _make_request( 20 | self, 21 | base_url: str, 22 | method: str, 23 | endpoint: str, 24 | data: Optional[Dict[str, Any]] = None, 25 | ) -> Dict[str, Any]: 26 | url = urljoin(base_url.rstrip("/"), endpoint) 27 | try: 28 | response = requests.request( 29 | method, url, headers=self.headers, json=data, timeout=self.timeout 30 | ) 31 | response.raise_for_status() 32 | return response.json() 33 | except requests.RequestException as e: 34 | raise InferenceServerAPIError(f"API request failed: {str(e)}") 35 | 36 | def is_online(self, base_url: str) -> bool: 37 | endpoint = "/health" 38 | try: 39 | response = self._make_request(base_url, "GET", endpoint) 40 | return response.get("status_code") 41 | except Exception as e: 42 | return 500 43 | 44 | def post_completion(self, base_url: str, question: str) -> Dict[str, Any]: 45 | if not question.strip(): 46 | raise ValueError("Question cannot be empty") 47 | endpoint = "/completion" 48 | data = {"question": question} 49 | return self._make_request(base_url, "POST", endpoint, data) 50 | 51 | def get_session_info(self, base_url: str) -> Dict[str, Any]: 52 | endpoint = "/session_info" 53 | return self._make_request(base_url, "GET", endpoint) 54 | 55 | def get_chat_history(self, base_url: str, message_id: int) -> Dict[str, Any]: 56 | if message_id < 1: 57 | raise ValueError("Message ID must be a positive integer") 58 | endpoint = f"/chat_history/{message_id}" 59 | return self._make_request(base_url, "GET", endpoint) 60 | 61 | def delete_session(self, base_url: str) -> Dict[str, Any]: 62 | endpoint = "/delete_session/" 63 | return self._make_request(base_url, "DELETE", endpoint) 64 | -------------------------------------------------------------------------------- /premsql/playground/inference_server/service.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from contextlib import asynccontextmanager 4 | from datetime import datetime 5 | from typing import Optional 6 | 7 | from fastapi import FastAPI, HTTPException 8 | from fastapi.middleware.cors import CORSMiddleware 9 | from pydantic import BaseModel 10 | 11 | from premsql.logger import setup_console_logger 12 | from premsql.agents.base import AgentBase, AgentOutput 13 | 14 | logger = setup_console_logger("[FASTAPI-INFERENCE-SERVICE]") 15 | 16 | 17 | class QuestionInput(BaseModel): 18 | question: str 19 | 20 | 21 | class SessionInfoResponse(BaseModel): 22 | status: int 23 | session_name: Optional[str] = None 24 | db_connection_uri: Optional[str] = None 25 | session_db_path: Optional[str] = None 26 | base_url: Optional[str] = None 27 | created_at: Optional[datetime] = None 28 | 29 | 30 | class ChatHistoryResponse(BaseModel): 31 | message_id: int 32 | agent_output: AgentOutput 33 | 34 | 35 | class CompletionResponse(BaseModel): 36 | message_id: int 37 | message: AgentOutput 38 | 39 | 40 | class AgentServer: 41 | def __init__( 42 | self, 43 | agent: AgentBase, 44 | url: Optional[str] = "localhost", 45 | port: Optional[int] = 8100, 46 | ) -> None: 47 | self.agent = agent 48 | self.port = port 49 | self.url = url 50 | self.app = self.create_app() 51 | 52 | @asynccontextmanager 53 | async def lifespan(self, app: FastAPI): 54 | # Startup: Log the initialization 55 | logger.info("Starting up the application") 56 | yield 57 | # Shutdown: Clean up resources 58 | logger.info("Shutting down the application") 59 | if hasattr(self.agent, "cleanup"): 60 | await self.agent.cleanup() 61 | 62 | def create_app(self): 63 | app = FastAPI(lifespan=self.lifespan) 64 | app.add_middleware( 65 | CORSMiddleware, 66 | allow_origins=["*"], # Allows all origins 67 | allow_credentials=True, 68 | allow_methods=["*"], # Allows all methods 69 | allow_headers=["*"], # Allows all headers 70 | ) 71 | 72 | @app.post("/completion", response_model=CompletionResponse) 73 | async def completion(input_data: QuestionInput): 74 | try: 75 | result = self.agent(question=input_data.question, server_mode=True) 76 | message_id = self.agent.history.get_latest_message_id() 77 | return CompletionResponse( 78 | message=AgentOutput(**result.model_dump()), message_id=message_id 79 | ) 80 | except Exception as e: 81 | stack_trace = traceback.format_exc() 82 | logger.error(stack_trace) 83 | logger.error(f"Error processing query: {str(e)}") 84 | raise HTTPException( 85 | status_code=500, detail=f"Error processing query: {str(e)}" 86 | ) 87 | 88 | # TODO: I need a method which will just get the "latets message_id" 89 | 90 | @app.get("/chat_history/{message_id}", response_model=ChatHistoryResponse) 91 | async def get_chat_history(message_id: int): 92 | try: 93 | exit_output = self.agent.history.get_by_message_id( 94 | message_id=message_id 95 | ) 96 | if exit_output is None: 97 | raise HTTPException( 98 | status_code=404, 99 | detail=f"Message with ID {message_id} not found", 100 | ) 101 | agent_output = self.agent.convert_exit_output_to_agent_output( 102 | exit_output=exit_output 103 | ) 104 | return ChatHistoryResponse( 105 | message_id=message_id, agent_output=agent_output 106 | ) 107 | except Exception as e: 108 | logger.error(f"Error retrieving chat history: {str(e)}") 109 | raise HTTPException( 110 | status_code=500, detail=f"Error retrieving chat history: {str(e)}" 111 | ) 112 | 113 | @app.get("/") 114 | async def health_check(): 115 | return { 116 | "status_code": 200, 117 | "status": f"healthy, running: {self.agent.session_name}" 118 | } 119 | 120 | @app.get("/health") 121 | async def health_check(): 122 | return {"status_code": 200, "status": "healthy"} 123 | 124 | @app.get("/session_info", response_model=SessionInfoResponse) 125 | async def get_session_info(): 126 | try: 127 | session_name = getattr(self.agent, "session_name", None) 128 | db_connection_uri = getattr(self.agent, "db_connection_uri", None) 129 | session_db_path = getattr(self.agent, "session_db_path", None) 130 | 131 | if any( 132 | attr is None 133 | for attr in [session_name, db_connection_uri, session_db_path] 134 | ): 135 | raise ValueError("One or more required attributes are None") 136 | 137 | return SessionInfoResponse( 138 | status=200, 139 | session_name=session_name, 140 | db_connection_uri=db_connection_uri, 141 | session_db_path=session_db_path, 142 | base_url=f"{self.url}:{self.port}", 143 | created_at=datetime.now(), 144 | ) 145 | except Exception as e: 146 | logger.error(f"Error getting session info: {str(e)}") 147 | return SessionInfoResponse( 148 | status=500, 149 | session_name=None, 150 | db_connection_uri=None, 151 | session_db_path=None, 152 | base_url=None, 153 | created_at=None, 154 | ) 155 | 156 | return app 157 | 158 | def launch(self): 159 | import uvicorn 160 | 161 | logger.info(f"Starting server on port {self.port}") 162 | uvicorn.run(self.app, host=self.url, port=int(self.port)) 163 | -------------------------------------------------------------------------------- /premsql/prompts.py: -------------------------------------------------------------------------------- 1 | BASE_TEXT2SQL_PROMPT = """ 2 | # Follow these instruction: 3 | You will be given schemas of tables of a database. Your job is to write correct 4 | error free SQL query based on the question asked. Please make sure: 5 | 6 | 1. Do not add ``` at start / end of the query. It should be a single line query in a single line (string format) 7 | 2. Make sure the column names are correct and exists in the table 8 | 3. For column names which has a space with it, make sure you have put `` in that column name 9 | 4. Think step by step and always check schema and question and the column names before writing the 10 | query. 11 | 12 | # Database and Table Schema: 13 | {schemas} 14 | 15 | {additional_knowledge} 16 | 17 | # Here are some Examples on how to generate SQL statements and use column names: 18 | {few_shot_examples} 19 | 20 | # Question: {question} 21 | 22 | # SQL: 23 | """ 24 | 25 | OLD_BASE_TEXT2SQL_PROMPT = """ 26 | # Instruction: 27 | - You will be given a question and a database schema. 28 | - You need to write a SQL query to answer the question. 29 | Do not add ``` at start / end of the query. It should be a single line query in 30 | a single line (string format). 31 | - Make sure the column names are correct and exists in the table 32 | - For column names which has a space with it, make sure you have put `` in that column name 33 | 34 | # Database and Table Schema: 35 | {schemas} 36 | 37 | {additional_knowledge} 38 | 39 | # Here are some Examples on how to generate SQL statements and use column names: 40 | {few_shot_examples} 41 | 42 | # Question: {question} 43 | 44 | # SQL: 45 | """ 46 | 47 | ERROR_HANDLING_PROMPT = """ 48 | {existing_prompt} 49 | 50 | # Generated SQL: {sql} 51 | 52 | ## Error Message 53 | 54 | {error_msg} 55 | 56 | Carefully review the original question and error message, then rewrite the SQL query to address the identified issues. 57 | Ensure your corrected query uses correct column names, 58 | follows proper SQL syntax, and accurately answers the original question 59 | without introducing new errors. 60 | 61 | # SQL: 62 | """ 63 | -------------------------------------------------------------------------------- /premsql/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from premsql.tuner.callback import Text2SQLEvaluationCallback 2 | from premsql.tuner.config import ( 3 | DefaultLoraConfig, 4 | DefaultPeftArguments, 5 | DefaultTrainingArguments, 6 | ) 7 | from premsql.tuner.full import Text2SQLFullFinetuner 8 | from premsql.tuner.peft import Text2SQLPeftTuner 9 | 10 | __all__ = [ 11 | "Text2SQLFullFinetuner", 12 | "Text2SQLPeftTuner", 13 | "DefaultLoraConfig", 14 | "DefaultPeftArguments", 15 | "Text2SQLEvaluationCallback", 16 | ] 17 | -------------------------------------------------------------------------------- /premsql/tuner/callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from premsql.datasets.base import Text2SQLBaseDataset 5 | from premsql.evaluator.base import BaseExecutor, Text2SQLEvaluator 6 | from premsql.generators.huggingface import Text2SQLGeneratorHF 7 | from premsql.logger import setup_console_logger 8 | 9 | logger = setup_console_logger("[EVALUATION-CALLBACK]") 10 | 11 | try: 12 | from torch.utils.tensorboard import SummaryWriter 13 | from transformers import ( 14 | Trainer, 15 | TrainerCallback, 16 | TrainerControl, 17 | TrainerState, 18 | TrainingArguments, 19 | ) 20 | except ImportError: 21 | logger.warn("Unable to import torch and transformers. Install: pip install torch transformers") 22 | 23 | 24 | class Text2SQLEvaluationCallback(TrainerCallback): 25 | def __init__( 26 | self, 27 | trainer: Trainer, 28 | trainer_args: TrainingArguments, 29 | eval_dataset: Text2SQLBaseDataset, 30 | executor: BaseExecutor, 31 | experiment_name: str, 32 | model_or_name_or_id: str, 33 | eval_steps: int, 34 | hf_token: Optional[str] = None, 35 | filter_results_by: Optional[tuple] = None, 36 | ): 37 | self.trainer = trainer 38 | self.eval_steps = eval_steps 39 | self.experiment_name = experiment_name 40 | 41 | log_dir = trainer_args.logging_dir 42 | os.makedirs(log_dir, exist_ok=True) 43 | 44 | self.tb_writer = SummaryWriter(log_dir=log_dir) 45 | logger.info(f"TensorBoard log directory: {log_dir}") 46 | 47 | self.model_or_name_or_id = model_or_name_or_id 48 | self.hf_token = hf_token 49 | self.dataset = eval_dataset 50 | self.executor = executor 51 | self.filter_by = filter_results_by 52 | 53 | def on_step_end( 54 | self, 55 | args: TrainingArguments, 56 | state: TrainerState, 57 | control: TrainerControl, 58 | **kwargs, 59 | ): 60 | if args.local_rank == 0 and state.global_step % self.eval_steps == 0: 61 | logger.info(f"Evaluating at step {state.global_step}") 62 | model = Text2SQLGeneratorHF( 63 | model_or_name_or_path=self.trainer.model, 64 | experiment_name=f"{self.experiment_name}_step_{state.global_step}", 65 | type="test", 66 | device="cuda:0", 67 | ) 68 | responses = model.generate_and_save_results( 69 | dataset=self.dataset, temperature=0.1, max_new_tokens=256, force=True 70 | ) 71 | evaluator = Text2SQLEvaluator( 72 | executor=self.executor, experiment_path=model.experiment_path 73 | ) 74 | if self.filter_by: 75 | ex_score = evaluator.execute( 76 | metric_name="accuracy", 77 | model_responses=responses, 78 | filter_by=self.filter_by[0], 79 | ) 80 | else: 81 | ex_score = evaluator.execute( 82 | metric_name="accuracy", 83 | model_responses=responses, 84 | ) 85 | logger.info(f"Execution Accuracy at step {state.global_step} | {ex_score}") 86 | 87 | # Log into tensorboard 88 | logger.info(f"Logging to TensorBoard: {ex_score}") 89 | for difficulty, score in ex_score.items(): 90 | logger.info(f"Logging {difficulty}: {score}") 91 | self.tb_writer.add_scalar( 92 | f"execution_accuracy/{difficulty}", score, state.global_step 93 | ) 94 | self.tb_writer.flush() # Force writing to disk 95 | 96 | state.log_history.append( 97 | { 98 | "step": state.global_step, 99 | "execution_accuracy": ( 100 | ex_score.get(self.filter_by[1]) 101 | if self.filter_by 102 | else ex_score.get("overall") 103 | ), 104 | "selected_difficulty": ( 105 | self.filter_by[0] if self.filter_by else "overall" 106 | ), 107 | } 108 | ) 109 | return control 110 | 111 | def on_train_end( 112 | self, 113 | args: TrainingArguments, 114 | state: TrainerState, 115 | control: TrainerControl, 116 | **kwargs, 117 | ): 118 | self.tb_writer.close() 119 | logger.info("TensorBoard writer closed") 120 | -------------------------------------------------------------------------------- /premsql/tuner/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | from premsql.logger import setup_console_logger 4 | 5 | logger = setup_console_logger("[TUNER-CONFIG]") 6 | 7 | try: 8 | from peft import LoraConfig, TaskType 9 | from transformers import TrainingArguments 10 | except ImportError: 11 | logger.warn("Unable to find peft and transformers. Install: pip install peft transformers") 12 | 13 | 14 | @dataclass 15 | class DefaultTrainingArguments(TrainingArguments): 16 | output_dir: str 17 | num_train_epochs: int 18 | per_device_train_batch_size: int 19 | gradient_accumulation_steps: int 20 | 21 | load_best_model_at_end: Optional[bool] = field(default=True) 22 | gradient_checkpointing: Optional[bool] = field(default=True) 23 | evaluation_strategy: Optional[str] = field(default="no") 24 | 25 | cache_dir: Optional[str] = field(default=None) 26 | optim: str = field(default="adamw_hf") 27 | model_max_length: int = field( 28 | default=1024, 29 | metadata={ 30 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 31 | }, 32 | ) 33 | max_seq_length: int = field(default=1024) 34 | ddp_find_unused_parameters: Optional[bool] = field(default=False) 35 | fp16: bool = field(default=False) 36 | bf16: bool = field(default=True) 37 | 38 | weight_decay: float = field(default=0.1) 39 | lr_scheduler_type: str = field(default="cosine") 40 | warmup_ratio: float = field(default=0.01) 41 | logging_steps: int = field(default=10) 42 | save_strategy: str = field(default="steps") 43 | save_steps: int = field(default=200) 44 | save_total_limit: int = field(default=3) 45 | auto_find_batch_size: Optional[bool] = field(default=False) 46 | report_to: List[str] = field(default_factory=lambda: ["tensorboard"]) 47 | 48 | 49 | @dataclass 50 | class DefaultPeftArguments(TrainingArguments): 51 | output_dir: str 52 | num_train_epochs: int 53 | per_device_train_batch_size: int 54 | gradient_accumulation_steps: int 55 | 56 | load_best_model_at_end: Optional[bool] = field(default=False) 57 | gradient_checkpointing: Optional[bool] = field(default=True) 58 | evaluation_strategy: Optional[str] = field(default="no") 59 | optim: str = field(default="adamw_hf") 60 | 61 | max_grad_norm: Optional[bool] = field(default=0.3) 62 | weight_decay: float = field(default=0.1) 63 | lr_scheduler_type: str = field(default="cosine") 64 | warmup_ratio: float = field(default=0.01) 65 | logging_steps: int = field(default=10) 66 | save_strategy: str = field(default="steps") 67 | save_steps: int = field(default=200) 68 | save_total_limit: int = field(default=3) 69 | auto_find_batch_size: Optional[bool] = field(default=False) 70 | report_to: List[str] = field(default_factory=lambda: ["tensorboard"]) 71 | 72 | fp16: Optional[bool] = field(default=False) 73 | bf16: Optional[bool] = field(default=True) 74 | neftune_noise_alpha: Optional[int] = field(default=5) 75 | 76 | 77 | @dataclass 78 | class DefaultLoraConfig(LoraConfig): 79 | lora_alpha: float = field(default=32) 80 | lora_dropout: float = field(default=0.1) 81 | r: int = field(default=64) 82 | target_modules: List[str] = field( 83 | default_factory=lambda: [ 84 | "q_proj", 85 | "v_proj", 86 | "k_proj", 87 | "o_proj", 88 | "gate_proj", 89 | "up_proj", 90 | "down_proj", 91 | "lm_head", 92 | ] 93 | ) 94 | task_type: TaskType = field(default=TaskType.CAUSAL_LM) 95 | -------------------------------------------------------------------------------- /premsql/tuner/full.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import transformers 4 | 5 | from premsql.datasets.base import Text2SQLBaseDataset 6 | from premsql.datasets.collator import DataCollatorForSupervisedDataset 7 | from premsql.evaluator.base import BaseExecutor 8 | from premsql.logger import setup_console_logger 9 | from premsql.tuner.callback import Text2SQLEvaluationCallback 10 | from premsql.tuner.config import DefaultTrainingArguments 11 | 12 | logger = setup_console_logger("[FULL-FINETUNE]") 13 | 14 | 15 | class Text2SQLFullFinetuner: 16 | def __init__( 17 | self, 18 | model_name_or_path: str, 19 | experiment_name: str, 20 | hf_token: Optional[str] = None, 21 | **model_kwargs, 22 | ): 23 | self.model_name_or_path = model_name_or_path 24 | 25 | logger.warning("Setting up Pretrained-Model: " + str(model_name_or_path)) 26 | self.model = transformers.AutoModelForCausalLM.from_pretrained( 27 | model_name_or_path, token=hf_token, **model_kwargs 28 | ) 29 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 30 | model_name_or_path, padding_size="right", token=hf_token 31 | ) 32 | self.data_collator = DataCollatorForSupervisedDataset(tokenizer=self.tokenizer) 33 | 34 | self._hf_token = hf_token 35 | self.experiment_name = experiment_name 36 | 37 | def train( 38 | self, 39 | train_datasets: Sequence[Text2SQLBaseDataset], 40 | output_dir: str, 41 | num_train_epochs: int, 42 | per_device_train_batch_size: int, 43 | gradient_accumulation_steps: int, 44 | evaluation_dataset: Optional[Text2SQLBaseDataset] = None, 45 | eval_steps: Optional[int] = 500, 46 | executor: Optional[BaseExecutor] = None, 47 | filter_eval_results_by: Optional[tuple] = None, 48 | **training_arguments, 49 | ): 50 | self.training_arguments = DefaultTrainingArguments( 51 | output_dir=output_dir, 52 | num_train_epochs=num_train_epochs, 53 | per_device_train_batch_size=per_device_train_batch_size, 54 | gradient_accumulation_steps=gradient_accumulation_steps, 55 | **training_arguments, 56 | ) 57 | 58 | data_module = dict( 59 | train_dataset=train_datasets, 60 | eval_dataset=None, 61 | data_collator=self.data_collator, 62 | ) 63 | trainer = transformers.Trainer( 64 | model=self.model, 65 | tokenizer=self.tokenizer, 66 | args=self.training_arguments, 67 | **data_module, 68 | ) 69 | 70 | if evaluation_dataset is not None and executor is not None: 71 | eval_callback = Text2SQLEvaluationCallback( 72 | trainer=trainer, 73 | trainer_args=self.training_arguments, 74 | eval_dataset=evaluation_dataset, 75 | experiment_name=self.experiment_name, 76 | model_or_name_or_id=self.model_name_or_path, 77 | eval_steps=eval_steps, 78 | executor=executor, 79 | filter_results_by=filter_eval_results_by, 80 | hf_token=self._hf_token, 81 | ) 82 | trainer.add_callback(eval_callback) 83 | 84 | trainer.train() 85 | trainer.save_model(output_dir=self.training_arguments.output_dir) 86 | -------------------------------------------------------------------------------- /premsql/tuner/peft.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | from premsql.datasets.base import Text2SQLBaseDataset 4 | from premsql.datasets.collator import DataCollatorForSupervisedDataset 5 | from premsql.evaluator.base import BaseExecutor 6 | from premsql.logger import setup_console_logger 7 | from premsql.tuner.callback import Text2SQLEvaluationCallback 8 | from premsql.tuner.config import DefaultLoraConfig, DefaultPeftArguments 9 | 10 | logger = setup_console_logger("[LORA-FINETUNE]") 11 | 12 | try: 13 | import torch 14 | import transformers 15 | from peft import LoraConfig 16 | from transformers import BitsAndBytesConfig 17 | from trl import SFTTrainer 18 | except ImportError: 19 | logger.warn("Ensure torch transformers peft and trl are installed.") 20 | logger.warn("Install them by: pip install torch peft trl transformers") 21 | 22 | 23 | class Text2SQLPeftTuner: 24 | def __init__( 25 | self, 26 | model_name_or_path: str, 27 | experiment_name: str, 28 | peft_config: Optional[LoraConfig] = None, 29 | bnb_config: Optional[BitsAndBytesConfig] = None, 30 | hf_token: Optional[str] = None, 31 | **model_kwargs, 32 | ): 33 | self.peft_config = peft_config or DefaultLoraConfig() 34 | self.bnb_config = bnb_config 35 | self.model_name_or_path = model_name_or_path 36 | 37 | logger.warning("Setting up Pretrained-Model: " + str(model_name_or_path)) 38 | 39 | self.model = transformers.AutoModelForCausalLM.from_pretrained( 40 | model_name_or_path, 41 | token=hf_token, 42 | torch_dtype=torch.bfloat16, 43 | quantization_config=bnb_config, 44 | **model_kwargs, 45 | ) 46 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 47 | model_name_or_path, padding_size="right", token=hf_token 48 | ) 49 | self.data_collator = DataCollatorForSupervisedDataset(tokenizer=self.tokenizer) 50 | 51 | self._hf_token = hf_token 52 | self.experiment_name = experiment_name 53 | 54 | def train( 55 | self, 56 | train_datasets: Sequence[Text2SQLBaseDataset], 57 | output_dir: str, 58 | num_train_epochs: int, 59 | max_seq_length: int, 60 | per_device_train_batch_size: int, 61 | gradient_accumulation_steps: int, 62 | evaluation_dataset: Optional[Text2SQLBaseDataset] = None, 63 | eval_steps: Optional[int] = 500, 64 | executor: Optional[BaseExecutor] = None, 65 | filter_eval_results_by: Optional[tuple] = None, 66 | **training_arguments, 67 | ): 68 | self.training_arguments = transformers.TrainingArguments( 69 | **DefaultPeftArguments( 70 | output_dir=output_dir, 71 | num_train_epochs=num_train_epochs, 72 | per_device_train_batch_size=per_device_train_batch_size, 73 | gradient_accumulation_steps=gradient_accumulation_steps, 74 | **training_arguments, 75 | ).to_dict() 76 | ) 77 | 78 | if "raw" in train_datasets[0]: 79 | formatting_func = lambda x: x["raw"]["prompt"] 80 | else: 81 | formatting_func = lambda x: x["prompt"] 82 | 83 | trainer = SFTTrainer( 84 | model=self.model, 85 | train_dataset=train_datasets, 86 | peft_config=self.peft_config, 87 | tokenizer=self.tokenizer, 88 | args=self.training_arguments, 89 | packing=True, 90 | formatting_func=formatting_func, 91 | max_seq_length=max_seq_length, 92 | ) 93 | 94 | if evaluation_dataset is not None and executor is not None: 95 | eval_callback = Text2SQLEvaluationCallback( 96 | trainer=trainer, 97 | trainer_args=self.training_arguments, 98 | eval_dataset=evaluation_dataset, 99 | experiment_name=self.experiment_name, 100 | model_or_name_or_id=self.model_name_or_path, 101 | eval_steps=eval_steps, 102 | executor=executor, 103 | filter_results_by=filter_eval_results_by, 104 | hf_token=self._hf_token, 105 | ) 106 | trainer.add_callback(eval_callback) 107 | 108 | trainer.train() 109 | trainer.save_model(output_dir=self.training_arguments.output_dir) 110 | -------------------------------------------------------------------------------- /premsql/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import re 5 | import sqlite3 6 | from collections import defaultdict 7 | from pathlib import Path 8 | from textwrap import dedent 9 | from typing import Optional, Sequence, Union 10 | 11 | from tqdm.auto import tqdm 12 | from premsql.logger import setup_console_logger 13 | 14 | logger = setup_console_logger(name="[UTILS]") 15 | 16 | try: 17 | from transformers import PreTrainedTokenizer 18 | except ImportError: 19 | logger.warn("Unable to use transformers. Install using: pip install transformers") 20 | 21 | 22 | def convert_sqlite_path_to_dsn(path: str): 23 | sqlite3_pattern = r"^sqlite:\/\/\/.*" 24 | if re.match(sqlite3_pattern, path): 25 | return path 26 | return f"sqlite:///{os.path.abspath(path)}" 27 | 28 | 29 | def convert_sqlite_dsn_to_path(dsn: str) -> str: 30 | sqlite3_pattern = r"^sqlite:\/\/\/(.*)" 31 | match = re.match(sqlite3_pattern, dsn) 32 | if match: 33 | return os.path.abspath(match.group(1)) 34 | return dsn 35 | 36 | 37 | def print_data(data: dict): 38 | if "prompt" in data: 39 | prompt = data["prompt"] 40 | data["prompt"] = prompt[:100] + "...." + prompt[-100:] 41 | 42 | elif "prompt" in data["raw"]: 43 | prompt = data["raw"]["prompt"] 44 | data["raw"]["prompt"] = prompt[:100] + "...." + prompt[-100:] 45 | 46 | else: 47 | raise ValueError("Prompt key not found in data") 48 | 49 | return data 50 | 51 | 52 | def save_to_json(save_path: Union[str, Path], json_object: dict): 53 | try: 54 | save_path = Path(save_path) if isinstance(save_path, str) else save_path 55 | with open(save_path, "w") as json_file: 56 | json.dump(json_object, json_file, indent=4, ensure_ascii=False) 57 | logger.info(f"Saved JSON in: {save_path}") 58 | except Exception as e: 59 | logger.error(f"Unable to save JSON, Error: {e}") 60 | 61 | 62 | def load_from_json(result_json_path: str) -> dict: 63 | try: 64 | with open(result_json_path, "r") as json_file: 65 | return json.load(json_file) 66 | except Exception as e: 67 | logger.error(f"Unable to load JSON, Error: {e}") 68 | 69 | 70 | def sqlite_schema_prompt(db_path: str) -> str: 71 | schemas = {} 72 | conn = sqlite3.connect(db_path) 73 | cursor = conn.cursor() 74 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 75 | tables = cursor.fetchall() 76 | 77 | for table in tables: 78 | table_name = table[0] 79 | if table_name == "sqlite_sequence": 80 | continue 81 | cursor.execute( 82 | f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';" 83 | ) 84 | create_table_sql = cursor.fetchone() 85 | if create_table_sql: 86 | schemas[table_name] = create_table_sql[0] 87 | else: 88 | schemas[table_name] = "Schema does not exist" 89 | 90 | schema_prompt = "\n".join( 91 | schemas[table[0]] for table in tables if table[0] != "sqlite_sequence" 92 | ) 93 | return schema_prompt 94 | 95 | 96 | def get_random_few_shot_prompts(dataset: list[dict], num_few_shot: int): 97 | assert "db_id" in dataset[0], ValueError( 98 | "db_id key should be present to use this function" 99 | ) 100 | 101 | grouped_content = defaultdict(list) 102 | few_shot_prompts = {} 103 | template = dedent( 104 | """ 105 | Question: {question} 106 | SQL: {sql} 107 | """ 108 | ) 109 | 110 | for content in dataset: 111 | grouped_content[content["db_id"]].append(content) 112 | 113 | for db_id, contents in grouped_content.items(): 114 | num_few_shot = min(num_few_shot, len(contents)) 115 | random_sample = random.sample(contents, num_few_shot) 116 | 117 | few_shot_prompt = "".join( 118 | template.format(question=element["question"], sql=element["SQL"]) 119 | for element in random_sample 120 | ) 121 | few_shot_prompts[db_id] = few_shot_prompt 122 | return few_shot_prompts 123 | 124 | 125 | def get_accepted_filters(data: list[dict]) -> Sequence[str]: 126 | key_num_mapping = {} 127 | for key in data[0].keys(): 128 | key_num_mapping[key] = len(set([content[key] for content in data])) 129 | 130 | accepted_keys = [] 131 | for key, num in key_num_mapping.items(): 132 | if num < len(data) * 0.5 and key != "db_path": 133 | accepted_keys.append(key) 134 | return accepted_keys 135 | 136 | 137 | def filter_options( 138 | data: list[dict], filter_by: tuple, accepted_keys: Optional[Sequence[str]] = None 139 | ): 140 | filter_key, filter_value = filter_by 141 | accepted_keys = ( 142 | get_accepted_filters(data=data) if accepted_keys is None else accepted_keys 143 | ) 144 | 145 | assert filter_key in accepted_keys, ValueError( 146 | f"Filtering is supported for keys: `{''.join(accepted_keys)}`" 147 | ) 148 | for key in accepted_keys: 149 | if filter_key == key: 150 | accepted_values = set([content[key] for content in data]) 151 | assert filter_value in accepted_values, ValueError( 152 | f"Available values for key: {key} are: {', '.join(accepted_values)}" 153 | ) 154 | 155 | filtered_data = [content for content in data if content[filter_key] == filter_value] 156 | return filtered_data 157 | 158 | 159 | def tokenize_fn(strings: Sequence[str], tokenizer: "PreTrainedTokenizer") -> dict: 160 | """Tokenizes a list of string""" 161 | tokenized_list = [ 162 | tokenizer( 163 | text=text, 164 | return_tensors="pt", 165 | padding="longest", 166 | max_length=tokenizer.model_max_length, 167 | truncation=False, 168 | ) 169 | for text in tqdm(strings, total=len(strings), desc="Tokenizing") 170 | ] 171 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 172 | input_ids_lens = label_ids_lens = [ 173 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 174 | for tokenized in tokenized_list 175 | ] 176 | return dict( 177 | input_ids=input_ids, 178 | labels=labels, 179 | input_ids_lens=input_ids_lens, 180 | label_ids_lens=label_ids_lens, 181 | ) 182 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "premsql" 3 | version = "0.2.10" 4 | description = "" 5 | authors = ["Anindyadeep "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | datasets = "^2.20.0" 11 | einops = "^0.8.0" 12 | black = "^24.4.2" 13 | fastapi = "^0.112.0" 14 | huggingface-hub = "^0.24.5" 15 | isort = "^5.13.2" 16 | numpy = "^1.26.3" 17 | tqdm = "^4.66.4" 18 | mysql-connector-python = "^9.0.0" 19 | SQLAlchemy = "^2.0.30" 20 | sqlparse = "^0.5.1" 21 | click = "^8.1.3" 22 | langchain-community = "^0.3.3" 23 | openai = "^1.52.0" 24 | premai = "^0.3.73" 25 | django = "^5.1.2" 26 | djangorestframework = "^3.15.2" 27 | drf-yasg = "^1.21.8" 28 | func_timeout = "^4.3.5" 29 | matplotlib = "^3.9.2" 30 | pillow = ">=8,<11" 31 | uvicorn = "^0.32.0" 32 | streamlit = "^1.40.0" 33 | kagglehub = "^0.3.3" 34 | 35 | [tool.poetry.extras] 36 | mac = ["mlx", "mlx-lm"] 37 | 38 | [tool.poetry.group.mac] 39 | optional = true 40 | 41 | [tool.poetry.group.mac.dependencies] 42 | mlx = "^0.19.1" 43 | mlx-lm = "^0.19.2" 44 | 45 | [tool.poetry.group.linux.dependencies] 46 | transformers = "^4.43.3" 47 | torch = "^2.4.0" 48 | 49 | [tool.poetry.group.windows.dependencies] 50 | transformers = "^4.43.3" 51 | torch = "^2.4.0" 52 | 53 | [build-system] 54 | requires = ["poetry-core"] 55 | build-backend = "poetry.core.masonry.api" 56 | 57 | [tool.poetry.scripts] 58 | premsql = "premsql.cli:cli" 59 | --------------------------------------------------------------------------------