├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── python-publish.yaml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── img ├── top-10-customers.png └── vanna-readme-diagram.png ├── papers ├── ai-sql-accuracy-2023-08-17.md └── img │ ├── accuracy-by-llm.png │ ├── accuracy-using-contextual-examples.png │ ├── accuracy-using-schema-only.png │ ├── accuracy-using-static-examples.png │ ├── chat-gpt-question.png │ ├── chatgpt-results.png │ ├── framework-for-sql-generation.png │ ├── question-flow.png │ ├── schema-only.png │ ├── sql-error.png │ ├── summary-table.png │ ├── summary.png │ ├── test-architecture.png │ ├── test-levers.png │ ├── using-contextually-relevant-examples.png │ └── using-sql-examples.png ├── pyproject.toml ├── setup.cfg ├── src ├── .editorconfig └── vanna │ ├── ZhipuAI │ ├── ZhipuAI_Chat.py │ ├── ZhipuAI_embeddings.py │ └── __init__.py │ ├── __init__.py │ ├── advanced │ └── __init__.py │ ├── anthropic │ ├── __init__.py │ └── anthropic_chat.py │ ├── azuresearch │ ├── __init__.py │ └── azuresearch_vector.py │ ├── base │ ├── __init__.py │ └── base.py │ ├── bedrock │ ├── __init__.py │ └── bedrock_converse.py │ ├── chromadb │ ├── __init__.py │ └── chromadb_vector.py │ ├── cohere │ ├── __init__.py │ ├── cohere_chat.py │ └── cohere_embeddings.py │ ├── deepseek │ ├── __init__.py │ └── deepseek_chat.py │ ├── exceptions │ └── __init__.py │ ├── faiss │ ├── __init__.py │ └── faiss.py │ ├── flask │ ├── __init__.py │ ├── assets.py │ └── auth.py │ ├── google │ ├── __init__.py │ ├── bigquery_vector.py │ └── gemini_chat.py │ ├── hf │ ├── __init__.py │ └── hf.py │ ├── local.py │ ├── marqo │ ├── __init__.py │ └── marqo.py │ ├── milvus │ ├── __init__.py │ └── milvus_vector.py │ ├── mistral │ ├── __init__.py │ └── mistral.py │ ├── mock │ ├── __init__.py │ ├── embedding.py │ ├── llm.py │ └── vectordb.py │ ├── ollama │ ├── __init__.py │ └── ollama.py │ ├── openai │ ├── __init__.py │ ├── openai_chat.py │ └── openai_embeddings.py │ ├── opensearch │ ├── __init__.py │ ├── opensearch_vector.py │ └── opensearch_vector_semantic.py │ ├── oracle │ ├── __init__.py │ └── oracle_vector.py │ ├── pgvector │ ├── __init__.py │ └── pgvector.py │ ├── pinecone │ ├── __init__.py │ └── pinecone_vector.py │ ├── qdrant │ ├── __init__.py │ └── qdrant.py │ ├── qianfan │ ├── Qianfan_Chat.py │ ├── Qianfan_embeddings.py │ └── __init__.py │ ├── qianwen │ ├── QianwenAI_chat.py │ ├── QianwenAI_embeddings.py │ └── __init__.py │ ├── remote.py │ ├── types │ └── __init__.py │ ├── utils.py │ ├── vannadb │ ├── __init__.py │ └── vannadb_vector.py │ ├── vllm │ ├── __init__.py │ └── vllm.py │ ├── weaviate │ ├── __init__.py │ └── weaviate_vector.py │ └── xinference │ ├── __init__.py │ └── xinference.py ├── tests ├── test_imports.py ├── test_instantiation.py ├── test_pgvector.py └── test_vanna.py ├── tox.ini └── training_data ├── cybersyn-data-commons └── questions.json ├── cybersyn-financial-data └── questions.json ├── cybersyn-us-global-public └── questions.json ├── fivetran-ads-snowflake └── questions.json ├── sample-fraud └── questions.json ├── sample-imdb └── questions.json ├── sample-retention └── questions.json ├── sample-salaries └── questions.json ├── similarweb └── questions.json ├── snowflake-cost └── questions.json └── tpc-h └── questions.json /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-detectable=false 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: ["bug"] 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Error logs/Screenshots** 24 | If applicable, add logs/screenshots to give more information about the issue. 25 | 26 | **Desktop (please complete the following information where):** 27 | - OS: [e.g. Ubuntu] 28 | - Version: [e.g. 20.04] 29 | - Python: [3.9] 30 | - Vanna: [2.8.0] 31 | 32 | **Additional context** 33 | Add any other context about the problem here. 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: ["enhancements"] 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Basic Integration Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 3.10 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.10" 20 | - name: Install pip 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install tox 24 | - name: Run tests 25 | env: 26 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python 27 | VANNA_API_KEY: ${{ secrets.VANNA_API_KEY }} 28 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 29 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} 30 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 31 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 32 | SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} 33 | SNOWFLAKE_USERNAME: ${{ secrets.SNOWFLAKE_USERNAME }} 34 | SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }} 35 | run: tox -e py310 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | **.egg-info 3 | venv 4 | .DS_Store 5 | notebooks/* 6 | tests/__pycache__ 7 | __pycache__/ 8 | .idea 9 | .coverage 10 | docs/*.html 11 | .ipynb_checkpoints/ 12 | .tox/ 13 | notebooks/chroma.sqlite3 14 | dist 15 | .env 16 | *.sqlite 17 | htmlcov 18 | chroma.sqlite3 19 | *.bin 20 | .coverage.* 21 | milvus.db 22 | .milvus.db.lock 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'docs|node_modules|migrations|.git|.tox|assets.py' 2 | default_stages: [ commit ] 3 | fail_fast: true 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v3.2.0 8 | hooks: 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-merge-conflict 12 | - id: debug-statements 13 | - id: mixed-line-ending 14 | 15 | - repo: https://github.com/pycqa/isort 16 | rev: 5.12.0 17 | hooks: 18 | - id: isort 19 | args: [ "--profile", "black", "--filter-files" ] 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Setup 4 | ```bash 5 | git clone https://github.com/vanna-ai/vanna.git 6 | cd vanna/ 7 | 8 | python3 -m venv venv 9 | source venv/bin/activate 10 | 11 | # install package in editable mode 12 | pip install -e '.[all]' tox pre-commit 13 | 14 | # Setup pre-commit hooks 15 | pre-commit install 16 | 17 | # List dev targets 18 | tox list 19 | 20 | # Run tests 21 | tox -e py310 22 | ``` 23 | 24 | ## Running the test on a Mac 25 | ```bash 26 | tox -e mac 27 | ``` 28 | 29 | ## Do this before you submit a PR: 30 | 31 | Find the most relevant sample notebook and then replace the install command with: 32 | 33 | ```bash 34 | %pip install 'git+https://github.com/vanna-ai/vanna@your-branch#egg=vanna[chromadb,snowflake,openai]' 35 | ``` 36 | 37 | Run the necessary cells and verify that it works as expected in a real-world scenario. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vanna.AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | | GitHub | PyPI | Documentation | Gurubase | 4 | | ------ | ---- | ------------- | -------- | 5 | | [![GitHub](https://img.shields.io/badge/GitHub-vanna-blue?logo=github)](https://github.com/vanna-ai/vanna) | [![PyPI](https://img.shields.io/pypi/v/vanna?logo=pypi)](https://pypi.org/project/vanna/) | [![Documentation](https://img.shields.io/badge/Documentation-vanna-blue?logo=read-the-docs)](https://vanna.ai/docs/) | [![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20Vanna%20Guru-006BFF)](https://gurubase.io/g/vanna) | 6 | 7 | # Vanna 8 | Vanna is an MIT-licensed open-source Python RAG (Retrieval-Augmented Generation) framework for SQL generation and related functionality. 9 | 10 | https://github.com/vanna-ai/vanna/assets/7146154/1901f47a-515d-4982-af50-f12761a3b2ce 11 | 12 | ![vanna-quadrants](https://github.com/vanna-ai/vanna/assets/7146154/1c7c88ba-c144-4ecf-a028-cf5ba7344ca2) 13 | 14 | ## How Vanna works 15 | 16 | ![Screen Recording 2024-01-24 at 11 21 37 AM](https://github.com/vanna-ai/vanna/assets/7146154/1d2718ad-12a8-4a76-afa2-c61754462f93) 17 | 18 | 19 | Vanna works in two easy steps - train a RAG "model" on your data, and then ask questions which will return SQL queries that can be set up to automatically run on your database. 20 | 21 | 1. **Train a RAG "model" on your data**. 22 | 2. **Ask questions**. 23 | 24 | ![](img/vanna-readme-diagram.png) 25 | 26 | If you don't know what RAG is, don't worry -- you don't need to know how this works under the hood to use it. You just need to know that you "train" a model, which stores some metadata and then use it to "ask" questions. 27 | 28 | See the [base class](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) for more details on how this works under the hood. 29 | 30 | ## User Interfaces 31 | These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface. 32 | 33 | - [Jupyter Notebook](https://vanna.ai/docs/postgres-openai-vanna-vannadb/) 34 | - [vanna-ai/vanna-streamlit](https://github.com/vanna-ai/vanna-streamlit) 35 | - [vanna-ai/vanna-flask](https://github.com/vanna-ai/vanna-flask) 36 | - [vanna-ai/vanna-slack](https://github.com/vanna-ai/vanna-slack) 37 | 38 | ## Supported LLMs 39 | 40 | - [OpenAI](https://github.com/vanna-ai/vanna/tree/main/src/vanna/openai) 41 | - [Anthropic](https://github.com/vanna-ai/vanna/tree/main/src/vanna/anthropic) 42 | - [Gemini](https://github.com/vanna-ai/vanna/blob/main/src/vanna/google/gemini_chat.py) 43 | - [HuggingFace](https://github.com/vanna-ai/vanna/blob/main/src/vanna/hf/hf.py) 44 | - [AWS Bedrock](https://github.com/vanna-ai/vanna/tree/main/src/vanna/bedrock) 45 | - [Ollama](https://github.com/vanna-ai/vanna/tree/main/src/vanna/ollama) 46 | - [Qianwen](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qianwen) 47 | - [Qianfan](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qianfan) 48 | - [Zhipu](https://github.com/vanna-ai/vanna/tree/main/src/vanna/ZhipuAI) 49 | 50 | ## Supported VectorStores 51 | 52 | - [AzureSearch](https://github.com/vanna-ai/vanna/tree/main/src/vanna/azuresearch) 53 | - [Opensearch](https://github.com/vanna-ai/vanna/tree/main/src/vanna/opensearch) 54 | - [PgVector](https://github.com/vanna-ai/vanna/tree/main/src/vanna/pgvector) 55 | - [PineCone](https://github.com/vanna-ai/vanna/tree/main/src/vanna/pinecone) 56 | - [ChromaDB](https://github.com/vanna-ai/vanna/tree/main/src/vanna/chromadb) 57 | - [FAISS](https://github.com/vanna-ai/vanna/tree/main/src/vanna/faiss) 58 | - [Marqo](https://github.com/vanna-ai/vanna/tree/main/src/vanna/marqo) 59 | - [Milvus](https://github.com/vanna-ai/vanna/tree/main/src/vanna/milvus) 60 | - [Qdrant](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qdrant) 61 | - [Weaviate](https://github.com/vanna-ai/vanna/tree/main/src/vanna/weaviate) 62 | - [Oracle](https://github.com/vanna-ai/vanna/tree/main/src/vanna/oracle) 63 | 64 | ## Supported Databases 65 | 66 | - [PostgreSQL](https://www.postgresql.org/) 67 | - [MySQL](https://www.mysql.com/) 68 | - [PrestoDB](https://prestodb.io/) 69 | - [Apache Hive](https://hive.apache.org/) 70 | - [ClickHouse](https://clickhouse.com/) 71 | - [Snowflake](https://www.snowflake.com/en/) 72 | - [Oracle](https://www.oracle.com/) 73 | - [Microsoft SQL Server](https://www.microsoft.com/en-us/sql-server/sql-server-downloads) 74 | - [BigQuery](https://cloud.google.com/bigquery) 75 | - [SQLite](https://www.sqlite.org/) 76 | - [DuckDB](https://duckdb.org/) 77 | 78 | 79 | ## Getting started 80 | See the [documentation](https://vanna.ai/docs/) for specifics on your desired database, LLM, etc. 81 | 82 | If you want to get a feel for how it works after training, you can try this [Colab notebook](https://vanna.ai/docs/app/). 83 | 84 | 85 | ### Install 86 | ```bash 87 | pip install vanna 88 | ``` 89 | 90 | There are a number of optional packages that can be installed so see the [documentation](https://vanna.ai/docs/) for more details. 91 | 92 | ### Import 93 | See the [documentation](https://vanna.ai/docs/) if you're customizing the LLM or vector database. 94 | 95 | ```python 96 | # The import statement will vary depending on your LLM and vector database. This is an example for OpenAI + ChromaDB 97 | 98 | from vanna.openai.openai_chat import OpenAI_Chat 99 | from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore 100 | 101 | class MyVanna(ChromaDB_VectorStore, OpenAI_Chat): 102 | def __init__(self, config=None): 103 | ChromaDB_VectorStore.__init__(self, config=config) 104 | OpenAI_Chat.__init__(self, config=config) 105 | 106 | vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'}) 107 | 108 | # See the documentation for other options 109 | 110 | ``` 111 | 112 | 113 | ## Training 114 | You may or may not need to run these `vn.train` commands depending on your use case. See the [documentation](https://vanna.ai/docs/) for more details. 115 | 116 | These statements are shown to give you a feel for how it works. 117 | 118 | ### Train with DDL Statements 119 | DDL statements contain information about the table names, columns, data types, and relationships in your database. 120 | 121 | ```python 122 | vn.train(ddl=""" 123 | CREATE TABLE IF NOT EXISTS my-table ( 124 | id INT PRIMARY KEY, 125 | name VARCHAR(100), 126 | age INT 127 | ) 128 | """) 129 | ``` 130 | 131 | ### Train with Documentation 132 | Sometimes you may want to add documentation about your business terminology or definitions. 133 | 134 | ```python 135 | vn.train(documentation="Our business defines XYZ as ...") 136 | ``` 137 | 138 | ### Train with SQL 139 | You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL. 140 | 141 | ```python 142 | vn.train(sql="SELECT name, age FROM my-table WHERE name = 'John Doe'") 143 | ``` 144 | 145 | 146 | ## Asking questions 147 | ```python 148 | vn.ask("What are the top 10 customers by sales?") 149 | ``` 150 | 151 | You'll get SQL 152 | ```sql 153 | SELECT c.c_name as customer_name, 154 | sum(l.l_extendedprice * (1 - l.l_discount)) as total_sales 155 | FROM snowflake_sample_data.tpch_sf1.lineitem l join snowflake_sample_data.tpch_sf1.orders o 156 | ON l.l_orderkey = o.o_orderkey join snowflake_sample_data.tpch_sf1.customer c 157 | ON o.o_custkey = c.c_custkey 158 | GROUP BY customer_name 159 | ORDER BY total_sales desc limit 10; 160 | ``` 161 | 162 | If you've connected to a database, you'll get the table: 163 |
164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 |
CUSTOMER_NAMETOTAL_SALES
0Customer#0001435006757566.0218
1Customer#0000952576294115.3340
2Customer#0000871156184649.5176
3Customer#0001311136080943.8305
4Customer#0001343806075141.9635
5Customer#0001038346059770.3232
6Customer#0000696826057779.0348
7Customer#0001020226039653.6335
8Customer#0000985876027021.5855
9Customer#0000646605905659.6159
225 |
226 | 227 | You'll also get an automated Plotly chart: 228 | ![](img/top-10-customers.png) 229 | 230 | ## RAG vs. Fine-Tuning 231 | RAG 232 | - Portable across LLMs 233 | - Easy to remove training data if any of it becomes obsolete 234 | - Much cheaper to run than fine-tuning 235 | - More future-proof -- if a better LLM comes out, you can just swap it out 236 | 237 | Fine-Tuning 238 | - Good if you need to minimize tokens in the prompt 239 | - Slow to get started 240 | - Expensive to train and run (generally) 241 | 242 | ## Why Vanna? 243 | 244 | 1. **High accuracy on complex datasets.** 245 | - Vanna’s capabilities are tied to the training data you give it 246 | - More training data means better accuracy for large and complex datasets 247 | 2. **Secure and private.** 248 | - Your database contents are never sent to the LLM or the vector database 249 | - SQL execution happens in your local environment 250 | 3. **Self learning.** 251 | - If using via Jupyter, you can choose to "auto-train" it on the queries that were successfully executed 252 | - If using via other interfaces, you can have the interface prompt the user to provide feedback on the results 253 | - Correct question to SQL pairs are stored for future reference and make the future results more accurate 254 | 4. **Supports any SQL database.** 255 | - The package allows you to connect to any SQL database that you can otherwise connect to with Python 256 | 5. **Choose your front end.** 257 | - Most people start in a Jupyter Notebook. 258 | - Expose to your end users via Slackbot, web app, Streamlit app, or a custom front end. 259 | 260 | ## Extending Vanna 261 | Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details. 262 | 263 | ## Vanna in 100 Seconds 264 | 265 | https://github.com/vanna-ai/vanna/assets/7146154/eb90ee1e-aa05-4740-891a-4fc10e611cab 266 | 267 | ## More resources 268 | - [Full Documentation](https://vanna.ai/docs/) 269 | - [Website](https://vanna.ai) 270 | - [Discord group for support](https://discord.gg/qUZYKHremx) 271 | -------------------------------------------------------------------------------- /img/top-10-customers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/img/top-10-customers.png -------------------------------------------------------------------------------- /img/vanna-readme-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/img/vanna-readme-diagram.png -------------------------------------------------------------------------------- /papers/img/accuracy-by-llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-by-llm.png -------------------------------------------------------------------------------- /papers/img/accuracy-using-contextual-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-using-contextual-examples.png -------------------------------------------------------------------------------- /papers/img/accuracy-using-schema-only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-using-schema-only.png -------------------------------------------------------------------------------- /papers/img/accuracy-using-static-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-using-static-examples.png -------------------------------------------------------------------------------- /papers/img/chat-gpt-question.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/chat-gpt-question.png -------------------------------------------------------------------------------- /papers/img/chatgpt-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/chatgpt-results.png -------------------------------------------------------------------------------- /papers/img/framework-for-sql-generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/framework-for-sql-generation.png -------------------------------------------------------------------------------- /papers/img/question-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/question-flow.png -------------------------------------------------------------------------------- /papers/img/schema-only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/schema-only.png -------------------------------------------------------------------------------- /papers/img/sql-error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/sql-error.png -------------------------------------------------------------------------------- /papers/img/summary-table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/summary-table.png -------------------------------------------------------------------------------- /papers/img/summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/summary.png -------------------------------------------------------------------------------- /papers/img/test-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/test-architecture.png -------------------------------------------------------------------------------- /papers/img/test-levers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/test-levers.png -------------------------------------------------------------------------------- /papers/img/using-contextually-relevant-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/using-contextually-relevant-examples.png -------------------------------------------------------------------------------- /papers/img/using-sql-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/using-sql-examples.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "vanna" 7 | version = "0.7.9" 8 | authors = [ 9 | { name="Zain Hoda", email="zain@vanna.ai" }, 10 | ] 11 | 12 | description = "Generate SQL queries from natural language" 13 | readme = "README.md" 14 | requires-python = ">=3.9" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ] 20 | dependencies = [ 21 | "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "flask-sock", "flasgger", "sqlalchemy" 22 | ] 23 | 24 | [project.urls] 25 | "Homepage" = "https://github.com/vanna-ai/vanna" 26 | "Bug Tracker" = "https://github.com/vanna-ai/vanna/issues" 27 | 28 | [project.optional-dependencies] 29 | postgres = ["psycopg2-binary", "db-dtypes"] 30 | mysql = ["PyMySQL"] 31 | clickhouse = ["clickhouse_connect"] 32 | bigquery = ["google-cloud-bigquery"] 33 | snowflake = ["snowflake-connector-python"] 34 | duckdb = ["duckdb"] 35 | google = ["google-generativeai", "google-cloud-aiplatform"] 36 | all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb<1.0.0", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "langchain-community", "langchain-huggingface", "xinference-client"] 37 | test = ["tox"] 38 | chromadb = ["chromadb<1.0.0"] 39 | openai = ["openai"] 40 | qianfan = ["qianfan"] 41 | mistralai = ["mistralai>=1.0.0"] 42 | anthropic = ["anthropic"] 43 | gemini = ["google-generativeai"] 44 | marqo = ["marqo"] 45 | zhipuai = ["zhipuai"] 46 | ollama = ["ollama", "httpx"] 47 | qdrant = ["qdrant-client", "fastembed"] 48 | vllm = ["vllm"] 49 | pinecone = ["pinecone", "fastembed"] 50 | opensearch = ["opensearch-py", "opensearch-dsl", "langchain-community", "langchain-huggingface"] 51 | hf = ["transformers"] 52 | milvus = ["pymilvus[model]"] 53 | bedrock = ["boto3", "botocore"] 54 | weaviate = ["weaviate-client"] 55 | azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"] 56 | pgvector = ["langchain-postgres>=0.0.12"] 57 | faiss-cpu = ["faiss-cpu"] 58 | faiss-gpu = ["faiss-gpu"] 59 | xinference-client = ["xinference-client"] 60 | oracle = ["oracledb", "chromadb<1.0.0"] 61 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = BLK100,W503,E203,E722,F821,F841 3 | max-line-length = 100 4 | exclude = .tox,.git,docs,venv,jupyter_notebook_config.py,jupyter_lab_config.py,assets.py 5 | 6 | [tool:brunette] 7 | verbose = true 8 | single-quotes = false 9 | target-version = py39 10 | exclude = .tox,.git,docs,venv,assets.py 11 | -------------------------------------------------------------------------------- /src/.editorconfig: -------------------------------------------------------------------------------- 1 | # top-most EditorConfig file 2 | root = true 3 | 4 | # Python files 5 | [*.py] 6 | # Indentation style: space 7 | indent_style = space 8 | 9 | # Indentation size: Use 2 spaces 10 | indent_size = 2 11 | 12 | # Newline character at the end of file 13 | insert_final_newline = true 14 | 15 | # Charset: utf-8 16 | charset = utf-8 17 | 18 | # Trim trailing whitespace 19 | trim_trailing_whitespace = true 20 | 21 | # Max line length: 79 characters as per PEP 8 guidelines 22 | max_line_length = 79 23 | 24 | # Set end of line format to LF 25 | 26 | # Exclude specific files or directories 27 | exclude = 'docs|node_modules|migrations|.git|.tox' 28 | -------------------------------------------------------------------------------- /src/vanna/ZhipuAI/ZhipuAI_Chat.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | import pandas as pd 5 | from zhipuai import ZhipuAI 6 | 7 | from ..base import VannaBase 8 | 9 | 10 | class ZhipuAI_Chat(VannaBase): 11 | def __init__(self, config=None): 12 | VannaBase.__init__(self, config=config) 13 | if config is None: 14 | return 15 | if "api_key" not in config: 16 | raise Exception("Missing api_key in config") 17 | self.api_key = config["api_key"] 18 | self.model = config["model"] if "model" in config else "glm-4" 19 | self.api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" 20 | 21 | # Static methods similar to those in ZhipuAI_Chat for message formatting and utility 22 | @staticmethod 23 | def system_message(message: str) -> dict: 24 | return {"role": "system", "content": message} 25 | 26 | @staticmethod 27 | def user_message(message: str) -> dict: 28 | return {"role": "user", "content": message} 29 | 30 | @staticmethod 31 | def assistant_message(message: str) -> dict: 32 | return {"role": "assistant", "content": message} 33 | 34 | @staticmethod 35 | def str_to_approx_token_count(string: str) -> int: 36 | return len(string) / 4 37 | 38 | @staticmethod 39 | def add_ddl_to_prompt( 40 | initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000 41 | ) -> str: 42 | if len(ddl_list) > 0: 43 | initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" 44 | 45 | for ddl in ddl_list: 46 | if ( 47 | ZhipuAI_Chat.str_to_approx_token_count(initial_prompt) 48 | + ZhipuAI_Chat.str_to_approx_token_count(ddl) 49 | < max_tokens 50 | ): 51 | initial_prompt += f"{ddl}\n\n" 52 | 53 | return initial_prompt 54 | 55 | @staticmethod 56 | def add_documentation_to_prompt( 57 | initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000 58 | ) -> str: 59 | if len(documentation_List) > 0: 60 | initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" 61 | 62 | for documentation in documentation_List: 63 | if ( 64 | ZhipuAI_Chat.str_to_approx_token_count(initial_prompt) 65 | + ZhipuAI_Chat.str_to_approx_token_count(documentation) 66 | < max_tokens 67 | ): 68 | initial_prompt += f"{documentation}\n\n" 69 | 70 | return initial_prompt 71 | 72 | @staticmethod 73 | def add_sql_to_prompt( 74 | initial_prompt: str, sql_List: List[str], max_tokens: int = 14000 75 | ) -> str: 76 | if len(sql_List) > 0: 77 | initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" 78 | 79 | for question in sql_List: 80 | if ( 81 | ZhipuAI_Chat.str_to_approx_token_count(initial_prompt) 82 | + ZhipuAI_Chat.str_to_approx_token_count(question["sql"]) 83 | < max_tokens 84 | ): 85 | initial_prompt += f"{question['question']}\n{question['sql']}\n\n" 86 | 87 | return initial_prompt 88 | 89 | def get_sql_prompt( 90 | self, 91 | question: str, 92 | question_sql_list: List, 93 | ddl_list: List, 94 | doc_list: List, 95 | **kwargs, 96 | ): 97 | initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" 98 | 99 | initial_prompt = ZhipuAI_Chat.add_ddl_to_prompt( 100 | initial_prompt, ddl_list, max_tokens=14000 101 | ) 102 | 103 | initial_prompt = ZhipuAI_Chat.add_documentation_to_prompt( 104 | initial_prompt, doc_list, max_tokens=14000 105 | ) 106 | 107 | message_log = [ZhipuAI_Chat.system_message(initial_prompt)] 108 | 109 | for example in question_sql_list: 110 | if example is None: 111 | print("example is None") 112 | else: 113 | if example is not None and "question" in example and "sql" in example: 114 | message_log.append(ZhipuAI_Chat.user_message(example["question"])) 115 | message_log.append(ZhipuAI_Chat.assistant_message(example["sql"])) 116 | 117 | message_log.append({"role": "user", "content": question}) 118 | 119 | return message_log 120 | 121 | def get_followup_questions_prompt( 122 | self, 123 | question: str, 124 | df: pd.DataFrame, 125 | question_sql_list: List, 126 | ddl_list: List, 127 | doc_list: List, 128 | **kwargs, 129 | ): 130 | initial_prompt = f"The user initially asked the question: '{question}': \n\n" 131 | 132 | initial_prompt = ZhipuAI_Chat.add_ddl_to_prompt( 133 | initial_prompt, ddl_list, max_tokens=14000 134 | ) 135 | 136 | initial_prompt = ZhipuAI_Chat.add_documentation_to_prompt( 137 | initial_prompt, doc_list, max_tokens=14000 138 | ) 139 | 140 | initial_prompt = ZhipuAI_Chat.add_sql_to_prompt( 141 | initial_prompt, question_sql_list, max_tokens=14000 142 | ) 143 | 144 | message_log = [ZhipuAI_Chat.system_message(initial_prompt)] 145 | message_log.append( 146 | ZhipuAI_Chat.user_message( 147 | "Generate a List of followup questions that the user might ask about this data. Respond with a List of questions, one per line. Do not answer with any explanations -- just the questions." 148 | ) 149 | ) 150 | 151 | return message_log 152 | 153 | def generate_question(self, sql: str, **kwargs) -> str: 154 | response = self.submit_prompt( 155 | [ 156 | self.system_message( 157 | "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." 158 | ), 159 | self.user_message(sql), 160 | ], 161 | **kwargs, 162 | ) 163 | 164 | return response 165 | 166 | def _extract_python_code(self, markdown_string: str) -> str: 167 | # Regex pattern to match Python code blocks 168 | pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" 169 | 170 | # Find all matches in the markdown string 171 | matches = re.findall(pattern, markdown_string, re.IGNORECASE) 172 | 173 | # Extract the Python code from the matches 174 | python_code = [] 175 | for match in matches: 176 | python = match[0] if match[0] else match[1] 177 | python_code.append(python.strip()) 178 | 179 | if len(python_code) == 0: 180 | return markdown_string 181 | 182 | return python_code[0] 183 | 184 | def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: 185 | # Remove the fig.show() statement from the plotly code 186 | plotly_code = raw_plotly_code.replace("fig.show()", "") 187 | 188 | return plotly_code 189 | 190 | def generate_plotly_code( 191 | self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs 192 | ) -> str: 193 | if question is not None: 194 | system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" 195 | else: 196 | system_msg = "The following is a pandas DataFrame " 197 | 198 | if sql is not None: 199 | system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" 200 | 201 | system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" 202 | 203 | message_log = [ 204 | self.system_message(system_msg), 205 | self.user_message( 206 | "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." 207 | ), 208 | ] 209 | 210 | plotly_code = self.submit_prompt(message_log, kwargs=kwargs) 211 | 212 | return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) 213 | 214 | def submit_prompt( 215 | self, prompt, max_tokens=500, temperature=0.7, top_p=0.7, stop=None, **kwargs 216 | ): 217 | if prompt is None: 218 | raise Exception("Prompt is None") 219 | 220 | if len(prompt) == 0: 221 | raise Exception("Prompt is empty") 222 | 223 | client = ZhipuAI(api_key=self.api_key) 224 | response = client.chat.completions.create( 225 | model="glm-4", 226 | max_tokens=max_tokens, 227 | temperature=temperature, 228 | top_p=top_p, 229 | stop=stop, 230 | messages=prompt, 231 | ) 232 | 233 | return response.choices[0].message.content 234 | -------------------------------------------------------------------------------- /src/vanna/ZhipuAI/ZhipuAI_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from zhipuai import ZhipuAI 3 | from chromadb import Documents, EmbeddingFunction, Embeddings 4 | from ..base import VannaBase 5 | 6 | class ZhipuAI_Embeddings(VannaBase): 7 | """ 8 | [future functionality] This function is used to generate embeddings from ZhipuAI. 9 | 10 | Args: 11 | VannaBase (_type_): _description_ 12 | """ 13 | def __init__(self, config=None): 14 | VannaBase.__init__(self, config=config) 15 | if "api_key" not in config: 16 | raise Exception("Missing api_key in config") 17 | self.api_key = config["api_key"] 18 | self.client = ZhipuAI(api_key=self.api_key) 19 | 20 | def generate_embedding(self, data: str, **kwargs) -> List[float]: 21 | 22 | embedding = self.client.embeddings.create( 23 | model="embedding-2", 24 | input=data, 25 | ) 26 | 27 | return embedding.data[0].embedding 28 | 29 | 30 | 31 | class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]): 32 | """ 33 | A embeddingFunction that uses ZhipuAI to generate embeddings which can use in chromadb. 34 | usage: 35 | class MyVanna(ChromaDB_VectorStore, ZhipuAI_Chat): 36 | def __init__(self, config=None): 37 | ChromaDB_VectorStore.__init__(self, config=config) 38 | ZhipuAI_Chat.__init__(self, config=config) 39 | 40 | config={'api_key': 'xxx'} 41 | zhipu_embedding_function = ZhipuAIEmbeddingFunction(config=config) 42 | config = {"api_key": "xxx", "model": "glm-4","path":"xy","embedding_function":zhipu_embedding_function} 43 | 44 | vn = MyVanna(config) 45 | 46 | """ 47 | def __init__(self, config=None): 48 | if config is None or "api_key" not in config: 49 | raise ValueError("Missing 'api_key' in config") 50 | 51 | self.api_key = config["api_key"] 52 | self.model_name = config.get("model_name", "embedding-2") 53 | 54 | try: 55 | self.client = ZhipuAI(api_key=self.api_key) 56 | except Exception as e: 57 | raise ValueError(f"Error initializing ZhipuAI client: {e}") 58 | 59 | def __call__(self, input: Documents) -> Embeddings: 60 | # Replace newlines, which can negatively affect performance. 61 | input = [t.replace("\n", " ") for t in input] 62 | all_embeddings = [] 63 | print(f"Generating embeddings for {len(input)} documents") 64 | 65 | # Iterating over each document for individual API calls 66 | for document in input: 67 | try: 68 | response = self.client.embeddings.create( 69 | model=self.model_name, 70 | input=document 71 | ) 72 | # print(response) 73 | embedding = response.data[0].embedding 74 | all_embeddings.append(embedding) 75 | # print(f"Cost required: {response.usage.total_tokens}") 76 | except Exception as e: 77 | raise ValueError(f"Error generating embedding for document: {e}") 78 | 79 | return all_embeddings -------------------------------------------------------------------------------- /src/vanna/ZhipuAI/__init__.py: -------------------------------------------------------------------------------- 1 | from .ZhipuAI_Chat import ZhipuAI_Chat 2 | from .ZhipuAI_embeddings import ZhipuAI_Embeddings, ZhipuAIEmbeddingFunction 3 | -------------------------------------------------------------------------------- /src/vanna/__init__.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import os 4 | from dataclasses import dataclass 5 | from typing import Callable, List, Tuple, Union 6 | 7 | import pandas as pd 8 | import requests 9 | import plotly.graph_objs 10 | 11 | from .exceptions import ( 12 | OTPCodeError, 13 | ValidationError, 14 | ) 15 | from .types import ( 16 | ApiKey, 17 | Status, 18 | TrainingData, 19 | UserEmail, 20 | UserOTP, 21 | ) 22 | from .utils import sanitize_model_name, validate_config_path 23 | 24 | api_key: Union[str, None] = None # API key for Vanna.AI 25 | 26 | fig_as_img: bool = False # Whether or not to return Plotly figures as images 27 | 28 | run_sql: Union[ 29 | Callable[[str], pd.DataFrame], None 30 | ] = None # Function to convert SQL to a Pandas DataFrame 31 | """ 32 | **Example** 33 | ```python 34 | vn.run_sql = lambda sql: pd.read_sql(sql, engine) 35 | ``` 36 | 37 | Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function. 38 | Instead of setting this directly you can also use [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] to set this. 39 | 40 | """ 41 | 42 | __org: Union[str, None] = None # Organization name for Vanna.AI 43 | 44 | _unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc" 45 | 46 | def error_deprecation(): 47 | raise Exception(""" 48 | Please switch to the following method for initializing Vanna: 49 | 50 | from vanna.remote import VannaDefault 51 | 52 | api_key = # Your API key from https://vanna.ai/account/profile 53 | vanna_model_name = # Your model name from https://vanna.ai/account/profile 54 | 55 | vn = VannaDefault(model=vanna_model_name, api_key=api_key) 56 | """) 57 | 58 | def __unauthenticated_rpc_call(method, params): 59 | headers = { 60 | "Content-Type": "application/json", 61 | } 62 | data = {"method": method, "params": [__dataclass_to_dict(obj) for obj in params]} 63 | 64 | response = requests.post( 65 | _unauthenticated_endpoint, headers=headers, data=json.dumps(data) 66 | ) 67 | return response.json() 68 | 69 | 70 | 71 | def __dataclass_to_dict(obj): 72 | return dataclasses.asdict(obj) 73 | 74 | 75 | def get_api_key(email: str, otp_code: Union[str, None] = None) -> str: 76 | """ 77 | **Example:** 78 | ```python 79 | vn.get_api_key(email="my-email@example.com") 80 | ``` 81 | 82 | Login to the Vanna.AI API. 83 | 84 | Args: 85 | email (str): The email address to login with. 86 | otp_code (Union[str, None]): The OTP code to login with. If None, an OTP code will be sent to the email address. 87 | 88 | Returns: 89 | str: The API key. 90 | """ 91 | vanna_api_key = os.environ.get("VANNA_API_KEY", None) 92 | 93 | if vanna_api_key is not None: 94 | return vanna_api_key 95 | 96 | if email == "my-email@example.com": 97 | raise ValidationError( 98 | "Please replace 'my-email@example.com' with your email address." 99 | ) 100 | 101 | if otp_code is None: 102 | params = [UserEmail(email=email)] 103 | 104 | d = __unauthenticated_rpc_call(method="send_otp", params=params) 105 | 106 | if "result" not in d: 107 | raise OTPCodeError("Error sending OTP code.") 108 | 109 | status = Status(**d["result"]) 110 | 111 | if not status.success: 112 | raise OTPCodeError(f"Error sending OTP code: {status.message}") 113 | 114 | otp_code = input("Check your email for the code and enter it here: ") 115 | 116 | params = [UserOTP(email=email, otp=otp_code)] 117 | 118 | d = __unauthenticated_rpc_call(method="verify_otp", params=params) 119 | 120 | if "result" not in d: 121 | raise OTPCodeError("Error verifying OTP code.") 122 | 123 | key = ApiKey(**d["result"]) 124 | 125 | if key is None: 126 | raise OTPCodeError("Error verifying OTP code.") 127 | 128 | api_key = key.key 129 | 130 | return api_key 131 | 132 | 133 | def set_api_key(key: str) -> None: 134 | error_deprecation() 135 | 136 | 137 | def get_models() -> List[str]: 138 | error_deprecation() 139 | 140 | 141 | def create_model(model: str, db_type: str) -> bool: 142 | error_deprecation() 143 | 144 | 145 | def add_user_to_model(model: str, email: str, is_admin: bool) -> bool: 146 | error_deprecation() 147 | 148 | 149 | def update_model_visibility(public: bool) -> bool: 150 | error_deprecation() 151 | 152 | 153 | def set_model(model: str): 154 | error_deprecation() 155 | 156 | 157 | def add_sql( 158 | question: str, sql: str, tag: Union[str, None] = "Manually Trained" 159 | ) -> bool: 160 | error_deprecation() 161 | 162 | 163 | def add_ddl(ddl: str) -> bool: 164 | error_deprecation() 165 | 166 | 167 | def add_documentation(documentation: str) -> bool: 168 | error_deprecation() 169 | 170 | 171 | @dataclass 172 | class TrainingPlanItem: 173 | item_type: str 174 | item_group: str 175 | item_name: str 176 | item_value: str 177 | 178 | def __str__(self): 179 | if self.item_type == self.ITEM_TYPE_SQL: 180 | return f"Train on SQL: {self.item_group} {self.item_name}" 181 | elif self.item_type == self.ITEM_TYPE_DDL: 182 | return f"Train on DDL: {self.item_group} {self.item_name}" 183 | elif self.item_type == self.ITEM_TYPE_IS: 184 | return f"Train on Information Schema: {self.item_group} {self.item_name}" 185 | 186 | ITEM_TYPE_SQL = "sql" 187 | ITEM_TYPE_DDL = "ddl" 188 | ITEM_TYPE_IS = "is" 189 | 190 | 191 | class TrainingPlan: 192 | """ 193 | A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained. 194 | 195 | **Example:** 196 | ```python 197 | plan = vn.get_training_plan() 198 | 199 | plan.get_summary() 200 | ``` 201 | 202 | """ 203 | 204 | _plan: List[TrainingPlanItem] 205 | 206 | def __init__(self, plan: List[TrainingPlanItem]): 207 | self._plan = plan 208 | 209 | def __str__(self): 210 | return "\n".join(self.get_summary()) 211 | 212 | def __repr__(self): 213 | return self.__str__() 214 | 215 | def get_summary(self) -> List[str]: 216 | """ 217 | **Example:** 218 | ```python 219 | plan = vn.get_training_plan() 220 | 221 | plan.get_summary() 222 | ``` 223 | 224 | Get a summary of the training plan. 225 | 226 | Returns: 227 | List[str]: A list of strings describing the training plan. 228 | """ 229 | 230 | return [f"{item}" for item in self._plan] 231 | 232 | def remove_item(self, item: str): 233 | """ 234 | **Example:** 235 | ```python 236 | plan = vn.get_training_plan() 237 | 238 | plan.remove_item("Train on SQL: What is the average salary of employees?") 239 | ``` 240 | 241 | Remove an item from the training plan. 242 | 243 | Args: 244 | item (str): The item to remove. 245 | """ 246 | for plan_item in self._plan: 247 | if str(plan_item) == item: 248 | self._plan.remove(plan_item) 249 | break 250 | 251 | 252 | def get_training_plan_postgres( 253 | filter_databases: Union[List[str], None] = None, 254 | filter_schemas: Union[List[str], None] = None, 255 | include_information_schema: bool = False, 256 | use_historical_queries: bool = True, 257 | ) -> TrainingPlan: 258 | error_deprecation() 259 | 260 | 261 | def get_training_plan_generic(df) -> TrainingPlan: 262 | error_deprecation() 263 | 264 | 265 | def get_training_plan_experimental( 266 | filter_databases: Union[List[str], None] = None, 267 | filter_schemas: Union[List[str], None] = None, 268 | include_information_schema: bool = False, 269 | use_historical_queries: bool = True, 270 | ) -> TrainingPlan: 271 | error_deprecation() 272 | 273 | 274 | def train( 275 | question: str = None, 276 | sql: str = None, 277 | ddl: str = None, 278 | documentation: str = None, 279 | json_file: str = None, 280 | sql_file: str = None, 281 | plan: TrainingPlan = None, 282 | ) -> bool: 283 | error_deprecation() 284 | 285 | 286 | def flag_sql_for_review( 287 | question: str, sql: Union[str, None] = None, error_msg: Union[str, None] = None 288 | ) -> bool: 289 | error_deprecation() 290 | 291 | 292 | def remove_sql(question: str) -> bool: 293 | error_deprecation() 294 | 295 | 296 | def remove_training_data(id: str) -> bool: 297 | error_deprecation() 298 | 299 | 300 | def generate_sql(question: str) -> str: 301 | error_deprecation() 302 | 303 | 304 | def get_related_training_data(question: str) -> TrainingData: 305 | error_deprecation() 306 | 307 | 308 | def generate_meta(question: str) -> str: 309 | error_deprecation() 310 | 311 | 312 | def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]: 313 | error_deprecation() 314 | 315 | 316 | def generate_questions() -> List[str]: 317 | error_deprecation() 318 | 319 | 320 | def ask( 321 | question: Union[str, None] = None, 322 | print_results: bool = True, 323 | auto_train: bool = True, 324 | generate_followups: bool = True, 325 | ) -> Union[ 326 | Tuple[ 327 | Union[str, None], 328 | Union[pd.DataFrame, None], 329 | Union[plotly.graph_objs.Figure, None], 330 | Union[List[str], None], 331 | ], 332 | None, 333 | ]: 334 | error_deprecation() 335 | 336 | def generate_plotly_code( 337 | question: Union[str, None], 338 | sql: Union[str, None], 339 | df: pd.DataFrame, 340 | chart_instructions: Union[str, None] = None, 341 | ) -> str: 342 | error_deprecation() 343 | 344 | 345 | def get_plotly_figure( 346 | plotly_code: str, df: pd.DataFrame, dark_mode: bool = True 347 | ) -> plotly.graph_objs.Figure: 348 | error_deprecation() 349 | 350 | 351 | def get_results(cs, default_database: str, sql: str) -> pd.DataFrame: 352 | error_deprecation() 353 | 354 | 355 | def generate_explanation(sql: str) -> str: 356 | error_deprecation() 357 | 358 | 359 | def generate_question(sql: str) -> str: 360 | error_deprecation() 361 | 362 | 363 | def get_all_questions() -> pd.DataFrame: 364 | error_deprecation() 365 | 366 | 367 | def get_training_data() -> pd.DataFrame: 368 | error_deprecation() 369 | 370 | 371 | def connect_to_sqlite(url: str): 372 | error_deprecation() 373 | 374 | 375 | def connect_to_snowflake( 376 | account: str, 377 | username: str, 378 | password: str, 379 | database: str, 380 | schema: Union[str, None] = None, 381 | role: Union[str, None] = None, 382 | ): 383 | error_deprecation() 384 | 385 | 386 | def connect_to_postgres( 387 | host: str = None, 388 | dbname: str = None, 389 | user: str = None, 390 | password: str = None, 391 | port: int = None, 392 | ): 393 | error_deprecation() 394 | 395 | 396 | def connect_to_bigquery(cred_file_path: str = None, project_id: str = None): 397 | error_deprecation() 398 | 399 | def connect_to_duckdb(url: str="memory", init_sql: str = None): 400 | error_deprecation() -------------------------------------------------------------------------------- /src/vanna/advanced/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class VannaAdvanced(ABC): 5 | def __init__(self, config=None): 6 | self.config = config 7 | 8 | @abstractmethod 9 | def get_function(self, question: str, additional_data: dict = {}) -> dict: 10 | pass 11 | 12 | @abstractmethod 13 | def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict: 14 | pass 15 | 16 | @abstractmethod 17 | def update_function(self, old_function_name: str, updated_function: dict) -> bool: 18 | pass 19 | 20 | @abstractmethod 21 | def delete_function(self, function_name: str) -> bool: 22 | pass 23 | 24 | @abstractmethod 25 | def get_all_functions(self) -> list: 26 | pass 27 | -------------------------------------------------------------------------------- /src/vanna/anthropic/__init__.py: -------------------------------------------------------------------------------- 1 | from .anthropic_chat import Anthropic_Chat 2 | -------------------------------------------------------------------------------- /src/vanna/anthropic/anthropic_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import anthropic 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | class Anthropic_Chat(VannaBase): 9 | def __init__(self, client=None, config=None): 10 | VannaBase.__init__(self, config=config) 11 | 12 | # default parameters - can be overrided using config 13 | self.temperature = 0.7 14 | self.max_tokens = 500 15 | 16 | if "temperature" in config: 17 | self.temperature = config["temperature"] 18 | 19 | if "max_tokens" in config: 20 | self.max_tokens = config["max_tokens"] 21 | 22 | if client is not None: 23 | self.client = client 24 | return 25 | 26 | if config is None and client is None: 27 | self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 28 | return 29 | 30 | if "api_key" in config: 31 | self.client = anthropic.Anthropic(api_key=config["api_key"]) 32 | 33 | def system_message(self, message: str) -> any: 34 | return {"role": "system", "content": message} 35 | 36 | def user_message(self, message: str) -> any: 37 | return {"role": "user", "content": message} 38 | 39 | def assistant_message(self, message: str) -> any: 40 | return {"role": "assistant", "content": message} 41 | 42 | def submit_prompt(self, prompt, **kwargs) -> str: 43 | if prompt is None: 44 | raise Exception("Prompt is None") 45 | 46 | if len(prompt) == 0: 47 | raise Exception("Prompt is empty") 48 | 49 | # Count the number of tokens in the message log 50 | # Use 4 as an approximation for the number of characters per token 51 | num_tokens = 0 52 | for message in prompt: 53 | num_tokens += len(message["content"]) / 4 54 | 55 | if self.config is not None and "model" in self.config: 56 | print( 57 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)" 58 | ) 59 | # claude required system message is a single filed 60 | # https://docs.anthropic.com/claude/reference/messages_post 61 | system_message = '' 62 | no_system_prompt = [] 63 | for prompt_message in prompt: 64 | role = prompt_message['role'] 65 | if role == 'system': 66 | system_message = prompt_message['content'] 67 | else: 68 | no_system_prompt.append({"role": role, "content": prompt_message['content']}) 69 | 70 | response = self.client.messages.create( 71 | model=self.config["model"], 72 | messages=no_system_prompt, 73 | system=system_message, 74 | max_tokens=self.max_tokens, 75 | temperature=self.temperature, 76 | ) 77 | 78 | return response.content[0].text 79 | -------------------------------------------------------------------------------- /src/vanna/azuresearch/__init__.py: -------------------------------------------------------------------------------- 1 | from .azuresearch_vector import AzureAISearch_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/azuresearch/azuresearch_vector.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | from typing import List 4 | 5 | import pandas as pd 6 | from azure.core.credentials import AzureKeyCredential 7 | from azure.search.documents import SearchClient 8 | from azure.search.documents.indexes import SearchIndexClient 9 | from azure.search.documents.indexes.models import ( 10 | ExhaustiveKnnAlgorithmConfiguration, 11 | ExhaustiveKnnParameters, 12 | SearchableField, 13 | SearchField, 14 | SearchFieldDataType, 15 | SearchIndex, 16 | VectorSearch, 17 | VectorSearchAlgorithmKind, 18 | VectorSearchAlgorithmMetric, 19 | VectorSearchProfile, 20 | ) 21 | from azure.search.documents.models import VectorFilterMode, VectorizedQuery 22 | from fastembed import TextEmbedding 23 | 24 | from ..base import VannaBase 25 | from ..utils import deterministic_uuid 26 | 27 | 28 | class AzureAISearch_VectorStore(VannaBase): 29 | """ 30 | AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search. 31 | 32 | Args: 33 | config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config. 34 | - azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net". 35 | - azure_search_api_key (str): Azure Search API key. 36 | - dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5. 37 | - fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5". 38 | - index_name (str, optional): Name of the index. Defaults to "vanna-index". 39 | - n_results (int, optional): Number of results to return. Defaults to 10. 40 | - n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results. 41 | - n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results. 42 | - n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results. 43 | 44 | Raises: 45 | ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config. 46 | """ 47 | def __init__(self, config=None): 48 | VannaBase.__init__(self, config=config) 49 | 50 | self.config = config or None 51 | 52 | if config is None: 53 | raise ValueError( 54 | "config is required, pass an API key, 'azure_search_api_key', in the config." 55 | ) 56 | 57 | azure_search_endpoint = config.get("azure_search_endpoint", "https://azcognetive.search.windows.net") 58 | azure_search_api_key = config.get("azure_search_api_key") 59 | 60 | self.dimensions = config.get("dimensions", 384) 61 | self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") 62 | 63 | self.index_name = config.get("index_name", "vanna-index") 64 | 65 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) 66 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) 67 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10)) 68 | 69 | if not azure_search_api_key: 70 | raise ValueError( 71 | "'azure_search_api_key' is required in config to use AzureAISearch_VectorStore" 72 | ) 73 | 74 | self.index_client = SearchIndexClient( 75 | endpoint=azure_search_endpoint, 76 | credential=AzureKeyCredential(azure_search_api_key) 77 | ) 78 | 79 | self.search_client = SearchClient( 80 | endpoint=azure_search_endpoint, 81 | index_name=self.index_name, 82 | credential=AzureKeyCredential(azure_search_api_key) 83 | ) 84 | 85 | if self.index_name not in self._get_indexes(): 86 | self._create_index() 87 | 88 | def _create_index(self) -> bool: 89 | fields = [ 90 | SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), 91 | SearchableField(name="document", type=SearchFieldDataType.String, searchable=True, filterable=True), 92 | SearchField(name="type", type=SearchFieldDataType.String, filterable=True, searchable=True), 93 | SearchField(name="document_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=self.dimensions, vector_search_profile_name="ExhaustiveKnnProfile"), 94 | ] 95 | 96 | vector_search = VectorSearch( 97 | algorithms=[ 98 | ExhaustiveKnnAlgorithmConfiguration( 99 | name="ExhaustiveKnn", 100 | kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN, 101 | parameters=ExhaustiveKnnParameters( 102 | metric=VectorSearchAlgorithmMetric.COSINE 103 | ) 104 | ) 105 | ], 106 | profiles=[ 107 | VectorSearchProfile( 108 | name="ExhaustiveKnnProfile", 109 | algorithm_configuration_name="ExhaustiveKnn", 110 | ) 111 | ] 112 | ) 113 | 114 | index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search) 115 | result = self.index_client.create_or_update_index(index) 116 | print(f'{result.name} created') 117 | 118 | def _get_indexes(self) -> list: 119 | return [index for index in self.index_client.list_index_names()] 120 | 121 | def add_ddl(self, ddl: str) -> str: 122 | id = deterministic_uuid(ddl) + "-ddl" 123 | document = { 124 | "id": id, 125 | "document": ddl, 126 | "type": "ddl", 127 | "document_vector": self.generate_embedding(ddl) 128 | } 129 | self.search_client.upload_documents(documents=[document]) 130 | return id 131 | 132 | def add_documentation(self, doc: str) -> str: 133 | id = deterministic_uuid(doc) + "-doc" 134 | document = { 135 | "id": id, 136 | "document": doc, 137 | "type": "doc", 138 | "document_vector": self.generate_embedding(doc) 139 | } 140 | self.search_client.upload_documents(documents=[document]) 141 | return id 142 | 143 | def add_question_sql(self, question: str, sql: str) -> str: 144 | question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False) 145 | id = deterministic_uuid(question_sql_json) + "-sql" 146 | document = { 147 | "id": id, 148 | "document": question_sql_json, 149 | "type": "sql", 150 | "document_vector": self.generate_embedding(question_sql_json) 151 | } 152 | self.search_client.upload_documents(documents=[document]) 153 | return id 154 | 155 | def get_related_ddl(self, text: str) -> List[str]: 156 | result = [] 157 | vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector") 158 | df = pd.DataFrame( 159 | self.search_client.search( 160 | top=self.n_results_ddl, 161 | vector_queries=[vector_query], 162 | select=["id", "document", "type"], 163 | filter=f"type eq 'ddl'" 164 | ) 165 | ) 166 | 167 | if len(df): 168 | result = df["document"].tolist() 169 | return result 170 | 171 | def get_related_documentation(self, text: str) -> List[str]: 172 | result = [] 173 | vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector") 174 | 175 | df = pd.DataFrame( 176 | self.search_client.search( 177 | top=self.n_results_documentation, 178 | vector_queries=[vector_query], 179 | select=["id", "document", "type"], 180 | filter=f"type eq 'doc'", 181 | vector_filter_mode=VectorFilterMode.PRE_FILTER 182 | ) 183 | ) 184 | 185 | if len(df): 186 | result = df["document"].tolist() 187 | return result 188 | 189 | def get_similar_question_sql(self, question: str) -> List[str]: 190 | result = [] 191 | # Vectorize the text 192 | vector_query = VectorizedQuery(vector=self.generate_embedding(question), fields="document_vector") 193 | df = pd.DataFrame( 194 | self.search_client.search( 195 | top=self.n_results_sql, 196 | vector_queries=[vector_query], 197 | select=["id", "document", "type"], 198 | filter=f"type eq 'sql'" 199 | ) 200 | ) 201 | 202 | if len(df): # Check if there is similar query and the result is not empty 203 | result = [ast.literal_eval(element) for element in df["document"].tolist()] 204 | 205 | return result 206 | 207 | def get_training_data(self) -> List[str]: 208 | 209 | search = self.search_client.search( 210 | search_text="*", 211 | select=['id', 'document', 'type'], 212 | filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')" 213 | ).by_page() 214 | 215 | df = pd.DataFrame([item for page in search for item in page]) 216 | 217 | if len(df): 218 | df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"]) 219 | df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"]) 220 | df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"] 221 | 222 | return df[["id", "question", "content", "type"]] 223 | 224 | return pd.DataFrame() 225 | 226 | def remove_training_data(self, id: str) -> bool: 227 | result = self.search_client.delete_documents(documents=[{'id':id}]) 228 | return result[0].succeeded 229 | 230 | def remove_index(self): 231 | self.index_client.delete_index(self.index_name) 232 | 233 | def generate_embedding(self, data: str, **kwargs) -> List[float]: 234 | embedding_model = TextEmbedding(model_name=self.fastembed_model) 235 | embedding = next(embedding_model.embed(data)) 236 | return embedding.tolist() 237 | -------------------------------------------------------------------------------- /src/vanna/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import VannaBase 2 | -------------------------------------------------------------------------------- /src/vanna/bedrock/__init__.py: -------------------------------------------------------------------------------- 1 | from .bedrock_converse import Bedrock_Converse -------------------------------------------------------------------------------- /src/vanna/bedrock/bedrock_converse.py: -------------------------------------------------------------------------------- 1 | from ..base import VannaBase 2 | 3 | try: 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | except ImportError: 7 | raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models") 8 | 9 | class Bedrock_Converse(VannaBase): 10 | def __init__(self, client=None, config=None): 11 | VannaBase.__init__(self, config=config) 12 | 13 | # default parameters 14 | self.temperature = 0.0 15 | self.max_tokens = 500 16 | 17 | if client is None: 18 | raise ValueError( 19 | "A valid Bedrock runtime client must be provided to invoke Bedrock models" 20 | ) 21 | else: 22 | self.client = client 23 | 24 | if config is None: 25 | raise ValueError( 26 | "Config is required with model_id and inference parameters" 27 | ) 28 | 29 | if "modelId" not in config: 30 | raise ValueError( 31 | "config must contain a modelId to invoke" 32 | ) 33 | else: 34 | self.model = config["modelId"] 35 | 36 | if "temperature" in config: 37 | self.temperature = config["temperature"] 38 | 39 | if "max_tokens" in config: 40 | self.max_tokens = config["max_tokens"] 41 | 42 | def system_message(self, message: str) -> dict: 43 | return {"role": "system", "content": message} 44 | 45 | def user_message(self, message: str) -> dict: 46 | return {"role": "user", "content": message} 47 | 48 | def assistant_message(self, message: str) -> dict: 49 | return {"role": "assistant", "content": message} 50 | 51 | def submit_prompt(self, prompt, **kwargs) -> str: 52 | inference_config = { 53 | "temperature": self.temperature, 54 | "maxTokens": self.max_tokens 55 | } 56 | additional_model_fields = { 57 | "top_p": 1, # setting top_p value for nucleus sampling 58 | } 59 | 60 | system_message = None 61 | no_system_prompt = [] 62 | for prompt_message in prompt: 63 | role = prompt_message["role"] 64 | if role == "system": 65 | system_message = prompt_message["content"] 66 | else: 67 | no_system_prompt.append({"role": role, "content":[{"text": prompt_message["content"]}]}) 68 | 69 | converse_api_params = { 70 | "modelId": self.model, 71 | "messages": no_system_prompt, 72 | "inferenceConfig": inference_config, 73 | "additionalModelRequestFields": additional_model_fields 74 | } 75 | 76 | if system_message: 77 | converse_api_params["system"] = [{"text": system_message}] 78 | 79 | try: 80 | response = self.client.converse(**converse_api_params) 81 | text_content = response["output"]["message"]["content"][0]["text"] 82 | return text_content 83 | except ClientError as err: 84 | message = err.response["Error"]["Message"] 85 | raise Exception(f"A Bedrock client error occurred: {message}") -------------------------------------------------------------------------------- /src/vanna/chromadb/__init__.py: -------------------------------------------------------------------------------- 1 | from .chromadb_vector import ChromaDB_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/chromadb/chromadb_vector.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | import chromadb 5 | import pandas as pd 6 | from chromadb.config import Settings 7 | from chromadb.utils import embedding_functions 8 | 9 | from ..base import VannaBase 10 | from ..utils import deterministic_uuid 11 | 12 | default_ef = embedding_functions.DefaultEmbeddingFunction() 13 | 14 | 15 | class ChromaDB_VectorStore(VannaBase): 16 | def __init__(self, config=None): 17 | VannaBase.__init__(self, config=config) 18 | if config is None: 19 | config = {} 20 | 21 | path = config.get("path", ".") 22 | self.embedding_function = config.get("embedding_function", default_ef) 23 | curr_client = config.get("client", "persistent") 24 | collection_metadata = config.get("collection_metadata", None) 25 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) 26 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10)) 27 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) 28 | 29 | if curr_client == "persistent": 30 | self.chroma_client = chromadb.PersistentClient( 31 | path=path, settings=Settings(anonymized_telemetry=False) 32 | ) 33 | elif curr_client == "in-memory": 34 | self.chroma_client = chromadb.EphemeralClient( 35 | settings=Settings(anonymized_telemetry=False) 36 | ) 37 | elif isinstance(curr_client, chromadb.api.client.Client): 38 | # allow providing client directly 39 | self.chroma_client = curr_client 40 | else: 41 | raise ValueError(f"Unsupported client was set in config: {curr_client}") 42 | 43 | self.documentation_collection = self.chroma_client.get_or_create_collection( 44 | name="documentation", 45 | embedding_function=self.embedding_function, 46 | metadata=collection_metadata, 47 | ) 48 | self.ddl_collection = self.chroma_client.get_or_create_collection( 49 | name="ddl", 50 | embedding_function=self.embedding_function, 51 | metadata=collection_metadata, 52 | ) 53 | self.sql_collection = self.chroma_client.get_or_create_collection( 54 | name="sql", 55 | embedding_function=self.embedding_function, 56 | metadata=collection_metadata, 57 | ) 58 | 59 | def generate_embedding(self, data: str, **kwargs) -> List[float]: 60 | embedding = self.embedding_function([data]) 61 | if len(embedding) == 1: 62 | return embedding[0] 63 | return embedding 64 | 65 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 66 | question_sql_json = json.dumps( 67 | { 68 | "question": question, 69 | "sql": sql, 70 | }, 71 | ensure_ascii=False, 72 | ) 73 | id = deterministic_uuid(question_sql_json) + "-sql" 74 | self.sql_collection.add( 75 | documents=question_sql_json, 76 | embeddings=self.generate_embedding(question_sql_json), 77 | ids=id, 78 | ) 79 | 80 | return id 81 | 82 | def add_ddl(self, ddl: str, **kwargs) -> str: 83 | id = deterministic_uuid(ddl) + "-ddl" 84 | self.ddl_collection.add( 85 | documents=ddl, 86 | embeddings=self.generate_embedding(ddl), 87 | ids=id, 88 | ) 89 | return id 90 | 91 | def add_documentation(self, documentation: str, **kwargs) -> str: 92 | id = deterministic_uuid(documentation) + "-doc" 93 | self.documentation_collection.add( 94 | documents=documentation, 95 | embeddings=self.generate_embedding(documentation), 96 | ids=id, 97 | ) 98 | return id 99 | 100 | def get_training_data(self, **kwargs) -> pd.DataFrame: 101 | sql_data = self.sql_collection.get() 102 | 103 | df = pd.DataFrame() 104 | 105 | if sql_data is not None: 106 | # Extract the documents and ids 107 | documents = [json.loads(doc) for doc in sql_data["documents"]] 108 | ids = sql_data["ids"] 109 | 110 | # Create a DataFrame 111 | df_sql = pd.DataFrame( 112 | { 113 | "id": ids, 114 | "question": [doc["question"] for doc in documents], 115 | "content": [doc["sql"] for doc in documents], 116 | } 117 | ) 118 | 119 | df_sql["training_data_type"] = "sql" 120 | 121 | df = pd.concat([df, df_sql]) 122 | 123 | ddl_data = self.ddl_collection.get() 124 | 125 | if ddl_data is not None: 126 | # Extract the documents and ids 127 | documents = [doc for doc in ddl_data["documents"]] 128 | ids = ddl_data["ids"] 129 | 130 | # Create a DataFrame 131 | df_ddl = pd.DataFrame( 132 | { 133 | "id": ids, 134 | "question": [None for doc in documents], 135 | "content": [doc for doc in documents], 136 | } 137 | ) 138 | 139 | df_ddl["training_data_type"] = "ddl" 140 | 141 | df = pd.concat([df, df_ddl]) 142 | 143 | doc_data = self.documentation_collection.get() 144 | 145 | if doc_data is not None: 146 | # Extract the documents and ids 147 | documents = [doc for doc in doc_data["documents"]] 148 | ids = doc_data["ids"] 149 | 150 | # Create a DataFrame 151 | df_doc = pd.DataFrame( 152 | { 153 | "id": ids, 154 | "question": [None for doc in documents], 155 | "content": [doc for doc in documents], 156 | } 157 | ) 158 | 159 | df_doc["training_data_type"] = "documentation" 160 | 161 | df = pd.concat([df, df_doc]) 162 | 163 | return df 164 | 165 | def remove_training_data(self, id: str, **kwargs) -> bool: 166 | if id.endswith("-sql"): 167 | self.sql_collection.delete(ids=id) 168 | return True 169 | elif id.endswith("-ddl"): 170 | self.ddl_collection.delete(ids=id) 171 | return True 172 | elif id.endswith("-doc"): 173 | self.documentation_collection.delete(ids=id) 174 | return True 175 | else: 176 | return False 177 | 178 | def remove_collection(self, collection_name: str) -> bool: 179 | """ 180 | This function can reset the collection to empty state. 181 | 182 | Args: 183 | collection_name (str): sql or ddl or documentation 184 | 185 | Returns: 186 | bool: True if collection is deleted, False otherwise 187 | """ 188 | if collection_name == "sql": 189 | self.chroma_client.delete_collection(name="sql") 190 | self.sql_collection = self.chroma_client.get_or_create_collection( 191 | name="sql", embedding_function=self.embedding_function 192 | ) 193 | return True 194 | elif collection_name == "ddl": 195 | self.chroma_client.delete_collection(name="ddl") 196 | self.ddl_collection = self.chroma_client.get_or_create_collection( 197 | name="ddl", embedding_function=self.embedding_function 198 | ) 199 | return True 200 | elif collection_name == "documentation": 201 | self.chroma_client.delete_collection(name="documentation") 202 | self.documentation_collection = self.chroma_client.get_or_create_collection( 203 | name="documentation", embedding_function=self.embedding_function 204 | ) 205 | return True 206 | else: 207 | return False 208 | 209 | @staticmethod 210 | def _extract_documents(query_results) -> list: 211 | """ 212 | Static method to extract the documents from the results of a query. 213 | 214 | Args: 215 | query_results (pd.DataFrame): The dataframe to use. 216 | 217 | Returns: 218 | List[str] or None: The extracted documents, or an empty list or 219 | single document if an error occurred. 220 | """ 221 | if query_results is None: 222 | return [] 223 | 224 | if "documents" in query_results: 225 | documents = query_results["documents"] 226 | 227 | if len(documents) == 1 and isinstance(documents[0], list): 228 | try: 229 | documents = [json.loads(doc) for doc in documents[0]] 230 | except Exception as e: 231 | return documents[0] 232 | 233 | return documents 234 | 235 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 236 | return ChromaDB_VectorStore._extract_documents( 237 | self.sql_collection.query( 238 | query_texts=[question], 239 | n_results=self.n_results_sql, 240 | ) 241 | ) 242 | 243 | def get_related_ddl(self, question: str, **kwargs) -> list: 244 | return ChromaDB_VectorStore._extract_documents( 245 | self.ddl_collection.query( 246 | query_texts=[question], 247 | n_results=self.n_results_ddl, 248 | ) 249 | ) 250 | 251 | def get_related_documentation(self, question: str, **kwargs) -> list: 252 | return ChromaDB_VectorStore._extract_documents( 253 | self.documentation_collection.query( 254 | query_texts=[question], 255 | n_results=self.n_results_documentation, 256 | ) 257 | ) 258 | -------------------------------------------------------------------------------- /src/vanna/cohere/__init__.py: -------------------------------------------------------------------------------- 1 | from .cohere_chat import Cohere_Chat 2 | from .cohere_embeddings import Cohere_Embeddings -------------------------------------------------------------------------------- /src/vanna/cohere/cohere_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from openai import OpenAI 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | class Cohere_Chat(VannaBase): 9 | def __init__(self, client=None, config=None): 10 | VannaBase.__init__(self, config=config) 11 | 12 | # default parameters - can be overridden using config 13 | self.temperature = 0.2 # Lower temperature for more precise SQL generation 14 | self.model = "command-a-03-2025" # Cohere's default model 15 | 16 | if config is not None: 17 | if "temperature" in config: 18 | self.temperature = config["temperature"] 19 | if "model" in config: 20 | self.model = config["model"] 21 | 22 | if client is not None: 23 | self.client = client 24 | return 25 | 26 | # Check for API key in environment variable 27 | api_key = os.getenv("COHERE_API_KEY") 28 | 29 | # Check for API key in config 30 | if config is not None and "api_key" in config: 31 | api_key = config["api_key"] 32 | 33 | # Validate API key 34 | if not api_key: 35 | raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.") 36 | 37 | # Initialize client with validated API key 38 | self.client = OpenAI( 39 | base_url="https://api.cohere.ai/compatibility/v1", 40 | api_key=api_key, 41 | ) 42 | 43 | def system_message(self, message: str) -> any: 44 | return {"role": "developer", "content": message} # Cohere uses 'developer' for system role 45 | 46 | def user_message(self, message: str) -> any: 47 | return {"role": "user", "content": message} 48 | 49 | def assistant_message(self, message: str) -> any: 50 | return {"role": "assistant", "content": message} 51 | 52 | def submit_prompt(self, prompt, **kwargs) -> str: 53 | if prompt is None: 54 | raise Exception("Prompt is None") 55 | 56 | if len(prompt) == 0: 57 | raise Exception("Prompt is empty") 58 | 59 | # Count the number of tokens in the message log 60 | # Use 4 as an approximation for the number of characters per token 61 | num_tokens = 0 62 | for message in prompt: 63 | num_tokens += len(message["content"]) / 4 64 | 65 | # Use model from kwargs, config, or default 66 | model = kwargs.get("model", self.model) 67 | if self.config is not None and "model" in self.config and model == self.model: 68 | model = self.config["model"] 69 | 70 | print(f"Using model {model} for {num_tokens} tokens (approx)") 71 | try: 72 | response = self.client.chat.completions.create( 73 | model=model, 74 | messages=prompt, 75 | temperature=self.temperature, 76 | ) 77 | 78 | # Check if response has expected structure 79 | if not response or not hasattr(response, 'choices') or not response.choices: 80 | raise ValueError("Received empty or malformed response from API") 81 | 82 | if not response.choices[0] or not hasattr(response.choices[0], 'message'): 83 | raise ValueError("Response is missing expected 'message' field") 84 | 85 | if not hasattr(response.choices[0].message, 'content'): 86 | raise ValueError("Response message is missing expected 'content' field") 87 | 88 | return response.choices[0].message.content 89 | 90 | except Exception as e: 91 | # Log the error and raise a more informative exception 92 | error_msg = f"Error processing Cohere chat response: {str(e)}" 93 | print(error_msg) 94 | raise Exception(error_msg) -------------------------------------------------------------------------------- /src/vanna/cohere/cohere_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from openai import OpenAI 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | class Cohere_Embeddings(VannaBase): 9 | def __init__(self, client=None, config=None): 10 | VannaBase.__init__(self, config=config) 11 | 12 | # Default embedding model 13 | self.model = "embed-multilingual-v3.0" 14 | 15 | if config is not None and "model" in config: 16 | self.model = config["model"] 17 | 18 | if client is not None: 19 | self.client = client 20 | return 21 | 22 | # Check for API key in environment variable 23 | api_key = os.getenv("COHERE_API_KEY") 24 | 25 | # Check for API key in config 26 | if config is not None and "api_key" in config: 27 | api_key = config["api_key"] 28 | 29 | # Validate API key 30 | if not api_key: 31 | raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.") 32 | 33 | # Initialize client with validated API key 34 | self.client = OpenAI( 35 | base_url="https://api.cohere.ai/compatibility/v1", 36 | api_key=api_key, 37 | ) 38 | 39 | def generate_embedding(self, data: str, **kwargs) -> list[float]: 40 | if not data: 41 | raise ValueError("Cannot generate embedding for empty input data") 42 | 43 | # Use model from kwargs, config, or default 44 | model = kwargs.get("model", self.model) 45 | if self.config is not None and "model" in self.config and model == self.model: 46 | model = self.config["model"] 47 | 48 | try: 49 | embedding = self.client.embeddings.create( 50 | model=model, 51 | input=data, 52 | encoding_format="float", # Ensure we get float values 53 | ) 54 | 55 | # Check if response has expected structure 56 | if not embedding or not hasattr(embedding, 'data') or not embedding.data: 57 | raise ValueError("Received empty or malformed embedding response from API") 58 | 59 | if not embedding.data[0] or not hasattr(embedding.data[0], 'embedding'): 60 | raise ValueError("Embedding response is missing expected 'embedding' field") 61 | 62 | if not embedding.data[0].embedding: 63 | raise ValueError("Received empty embedding vector") 64 | 65 | return embedding.data[0].embedding 66 | 67 | except Exception as e: 68 | # Log the error and raise a more informative exception 69 | error_msg = f"Error generating embedding with Cohere: {str(e)}" 70 | print(error_msg) 71 | raise Exception(error_msg) -------------------------------------------------------------------------------- /src/vanna/deepseek/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepseek_chat import DeepSeekChat 2 | -------------------------------------------------------------------------------- /src/vanna/deepseek/deepseek_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from openai import OpenAI 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | 9 | # from vanna.chromadb import ChromaDB_VectorStore 10 | 11 | # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat): 12 | # def __init__(self, config=None): 13 | # ChromaDB_VectorStore.__init__(self, config=config) 14 | # DeepSeekChat.__init__(self, config=config) 15 | 16 | # vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"}) 17 | 18 | 19 | class DeepSeekChat(VannaBase): 20 | def __init__(self, config=None): 21 | if config is None: 22 | raise ValueError( 23 | "For DeepSeek, config must be provided with an api_key and model" 24 | ) 25 | if "api_key" not in config: 26 | raise ValueError("config must contain a DeepSeek api_key") 27 | 28 | if "model" not in config: 29 | raise ValueError("config must contain a DeepSeek model") 30 | 31 | api_key = config["api_key"] 32 | model = config["model"] 33 | self.model = model 34 | self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1") 35 | 36 | def system_message(self, message: str) -> any: 37 | return {"role": "system", "content": message} 38 | 39 | def user_message(self, message: str) -> any: 40 | return {"role": "user", "content": message} 41 | 42 | def assistant_message(self, message: str) -> any: 43 | return {"role": "assistant", "content": message} 44 | 45 | def generate_sql(self, question: str, **kwargs) -> str: 46 | # 使用父类的 generate_sql 47 | sql = super().generate_sql(question, **kwargs) 48 | 49 | # 替换 "\_" 为 "_" 50 | sql = sql.replace("\\_", "_") 51 | 52 | return sql 53 | 54 | def submit_prompt(self, prompt, **kwargs) -> str: 55 | chat_response = self.client.chat.completions.create( 56 | model=self.model, 57 | messages=prompt, 58 | ) 59 | 60 | return chat_response.choices[0].message.content 61 | -------------------------------------------------------------------------------- /src/vanna/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | class ImproperlyConfigured(Exception): 2 | """Raise for incorrect configuration.""" 3 | 4 | pass 5 | 6 | 7 | class DependencyError(Exception): 8 | """Raise for missing dependencies.""" 9 | 10 | pass 11 | 12 | 13 | class ConnectionError(Exception): 14 | """Raise for connection""" 15 | 16 | pass 17 | 18 | 19 | class OTPCodeError(Exception): 20 | """Raise for invalid otp or not able to send it""" 21 | 22 | pass 23 | 24 | 25 | class SQLRemoveError(Exception): 26 | """Raise when not able to remove SQL""" 27 | 28 | pass 29 | 30 | 31 | class ExecutionError(Exception): 32 | """Raise when not able to execute Code""" 33 | 34 | pass 35 | 36 | 37 | class ValidationError(Exception): 38 | """Raise for validations""" 39 | 40 | pass 41 | 42 | 43 | class APIError(Exception): 44 | """Raise for API errors""" 45 | 46 | pass 47 | -------------------------------------------------------------------------------- /src/vanna/faiss/__init__.py: -------------------------------------------------------------------------------- 1 | from .faiss import FAISS -------------------------------------------------------------------------------- /src/vanna/faiss/faiss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import uuid 4 | from typing import List, Dict, Any 5 | 6 | import faiss 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from ..base import VannaBase 11 | from ..exceptions import DependencyError 12 | 13 | class FAISS(VannaBase): 14 | def __init__(self, config=None): 15 | if config is None: 16 | config = {} 17 | 18 | VannaBase.__init__(self, config=config) 19 | 20 | try: 21 | import faiss 22 | except ImportError: 23 | raise DependencyError( 24 | "FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'" 25 | ) 26 | 27 | try: 28 | from sentence_transformers import SentenceTransformer 29 | except ImportError: 30 | raise DependencyError( 31 | "SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'." 32 | ) 33 | 34 | self.path = config.get("path", ".") 35 | self.embedding_dim = config.get('embedding_dim', 384) 36 | self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10)) 37 | self.n_results_ddl = config.get('n_results_ddl', config.get("n_results", 10)) 38 | self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10)) 39 | self.curr_client = config.get("client", "persistent") 40 | 41 | if self.curr_client == 'persistent': 42 | self.sql_index = self._load_or_create_index('sql_index.faiss') 43 | self.ddl_index = self._load_or_create_index('ddl_index.faiss') 44 | self.doc_index = self._load_or_create_index('doc_index.faiss') 45 | elif self.curr_client == 'in-memory': 46 | self.sql_index = faiss.IndexFlatL2(self.embedding_dim) 47 | self.ddl_index = faiss.IndexFlatL2(self.embedding_dim) 48 | self.doc_index = faiss.IndexFlatL2(self.embedding_dim) 49 | elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client): 50 | self.sql_index = self.curr_client[0] 51 | self.ddl_index = self.curr_client[1] 52 | self.doc_index = self.curr_client[2] 53 | else: 54 | raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}") 55 | 56 | self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json') 57 | self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json') 58 | self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json') 59 | 60 | model_name = config.get('embedding_model', 'all-MiniLM-L6-v2') 61 | self.embedding_model = SentenceTransformer(model_name) 62 | 63 | def _load_or_create_index(self, filename): 64 | filepath = os.path.join(self.path, filename) 65 | if os.path.exists(filepath): 66 | return faiss.read_index(filepath) 67 | return faiss.IndexFlatL2(self.embedding_dim) 68 | 69 | def _load_or_create_metadata(self, filename): 70 | filepath = os.path.join(self.path, filename) 71 | if os.path.exists(filepath): 72 | with open(filepath, 'r') as f: 73 | return json.load(f) 74 | return [] 75 | 76 | def _save_index(self, index, filename): 77 | if self.curr_client == 'persistent': 78 | filepath = os.path.join(self.path, filename) 79 | faiss.write_index(index, filepath) 80 | 81 | def _save_metadata(self, metadata, filename): 82 | if self.curr_client == 'persistent': 83 | filepath = os.path.join(self.path, filename) 84 | with open(filepath, 'w') as f: 85 | json.dump(metadata, f) 86 | 87 | def generate_embedding(self, data: str, **kwargs) -> List[float]: 88 | embedding = self.embedding_model.encode(data) 89 | assert embedding.shape[0] == self.embedding_dim, \ 90 | f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}" 91 | return embedding.tolist() 92 | 93 | def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str: 94 | embedding = self.generate_embedding(text) 95 | index.add(np.array([embedding], dtype=np.float32)) 96 | entry_id = str(uuid.uuid4()) 97 | metadata_list.append({"id": entry_id, **(extra_metadata or {})}) 98 | return entry_id 99 | 100 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 101 | entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql}) 102 | self._save_index(self.sql_index, 'sql_index.faiss') 103 | self._save_metadata(self.sql_metadata, 'sql_metadata.json') 104 | return entry_id 105 | 106 | def add_ddl(self, ddl: str, **kwargs) -> str: 107 | entry_id = self._add_to_index(self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl}) 108 | self._save_index(self.ddl_index, 'ddl_index.faiss') 109 | self._save_metadata(self.ddl_metadata, 'ddl_metadata.json') 110 | return entry_id 111 | 112 | def add_documentation(self, documentation: str, **kwargs) -> str: 113 | entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation}) 114 | self._save_index(self.doc_index, 'doc_index.faiss') 115 | self._save_metadata(self.doc_metadata, 'doc_metadata.json') 116 | return entry_id 117 | 118 | def _get_similar(self, index, metadata_list, text, n_results) -> list: 119 | embedding = self.generate_embedding(text) 120 | D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results) 121 | return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]] 122 | 123 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 124 | return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql) 125 | 126 | def get_related_ddl(self, question: str, **kwargs) -> list: 127 | return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)] 128 | 129 | def get_related_documentation(self, question: str, **kwargs) -> list: 130 | return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)] 131 | 132 | def get_training_data(self, **kwargs) -> pd.DataFrame: 133 | sql_data = pd.DataFrame(self.sql_metadata) 134 | sql_data['training_data_type'] = 'sql' 135 | 136 | ddl_data = pd.DataFrame(self.ddl_metadata) 137 | ddl_data['training_data_type'] = 'ddl' 138 | 139 | doc_data = pd.DataFrame(self.doc_metadata) 140 | doc_data['training_data_type'] = 'documentation' 141 | 142 | return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True) 143 | 144 | def remove_training_data(self, id: str, **kwargs) -> bool: 145 | for metadata_list, index, index_name in [ 146 | (self.sql_metadata, self.sql_index, 'sql_index.faiss'), 147 | (self.ddl_metadata, self.ddl_index, 'ddl_index.faiss'), 148 | (self.doc_metadata, self.doc_index, 'doc_index.faiss') 149 | ]: 150 | for i, item in enumerate(metadata_list): 151 | if item['id'] == id: 152 | del metadata_list[i] 153 | new_index = faiss.IndexFlatL2(self.embedding_dim) 154 | embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list] 155 | if embeddings: 156 | new_index.add(np.array(embeddings, dtype=np.float32)) 157 | setattr(self, index_name.split('.')[0], new_index) 158 | 159 | if self.curr_client == 'persistent': 160 | self._save_index(new_index, index_name) 161 | self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json") 162 | 163 | return True 164 | return False 165 | 166 | def remove_collection(self, collection_name: str) -> bool: 167 | if collection_name in ["sql", "ddl", "documentation"]: 168 | setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim)) 169 | setattr(self, f"{collection_name}_metadata", []) 170 | 171 | if self.curr_client == 'persistent': 172 | self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss") 173 | self._save_metadata([], f"{collection_name}_metadata.json") 174 | 175 | return True 176 | return False -------------------------------------------------------------------------------- /src/vanna/flask/auth.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import flask 4 | 5 | 6 | class AuthInterface(ABC): 7 | @abstractmethod 8 | def get_user(self, flask_request) -> any: 9 | pass 10 | 11 | @abstractmethod 12 | def is_logged_in(self, user: any) -> bool: 13 | pass 14 | 15 | @abstractmethod 16 | def override_config_for_user(self, user: any, config: dict) -> dict: 17 | pass 18 | 19 | @abstractmethod 20 | def login_form(self) -> str: 21 | pass 22 | 23 | @abstractmethod 24 | def login_handler(self, flask_request) -> str: 25 | pass 26 | 27 | @abstractmethod 28 | def callback_handler(self, flask_request) -> str: 29 | pass 30 | 31 | @abstractmethod 32 | def logout_handler(self, flask_request) -> str: 33 | pass 34 | 35 | class NoAuth(AuthInterface): 36 | def get_user(self, flask_request) -> any: 37 | return {} 38 | 39 | def is_logged_in(self, user: any) -> bool: 40 | return True 41 | 42 | def override_config_for_user(self, user: any, config: dict) -> dict: 43 | return config 44 | 45 | def login_form(self) -> str: 46 | return '' 47 | 48 | def login_handler(self, flask_request) -> str: 49 | return 'No login required' 50 | 51 | def callback_handler(self, flask_request) -> str: 52 | return 'No login required' 53 | 54 | def logout_handler(self, flask_request) -> str: 55 | return 'No login required' 56 | -------------------------------------------------------------------------------- /src/vanna/google/__init__.py: -------------------------------------------------------------------------------- 1 | from .bigquery_vector import BigQuery_VectorStore 2 | from .gemini_chat import GoogleGeminiChat 3 | -------------------------------------------------------------------------------- /src/vanna/google/bigquery_vector.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import uuid 4 | from typing import List, Optional 5 | from vertexai.language_models import ( 6 | TextEmbeddingInput, 7 | TextEmbeddingModel 8 | ) 9 | 10 | import pandas as pd 11 | from google.cloud import bigquery 12 | 13 | from ..base import VannaBase 14 | 15 | 16 | class BigQuery_VectorStore(VannaBase): 17 | def __init__(self, config: dict, **kwargs): 18 | self.config = config 19 | 20 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) 21 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10)) 22 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) 23 | 24 | if "api_key" in config or os.getenv("GOOGLE_API_KEY"): 25 | """ 26 | If Google api_key is provided through config 27 | or set as an environment variable, assign it. 28 | """ 29 | print("Configuring genai") 30 | self.type = "GEMINI" 31 | import google.generativeai as genai 32 | 33 | genai.configure(api_key=config["api_key"]) 34 | 35 | self.genai = genai 36 | else: 37 | self.type = "VERTEX_AI" 38 | # Authenticate using VertexAI 39 | 40 | if self.config.get("project_id"): 41 | self.project_id = self.config.get("project_id") 42 | else: 43 | self.project_id = os.getenv("GOOGLE_CLOUD_PROJECT") 44 | 45 | if self.project_id is None: 46 | raise ValueError("Project ID is not set") 47 | 48 | self.conn = bigquery.Client(project=self.project_id) 49 | 50 | dataset_name = self.config.get('bigquery_dataset_name', 'vanna_managed') 51 | self.dataset_id = f"{self.project_id}.{dataset_name}" 52 | dataset = bigquery.Dataset(self.dataset_id) 53 | 54 | try: 55 | self.conn.get_dataset(self.dataset_id) # Make an API request. 56 | print(f"Dataset {self.dataset_id} already exists") 57 | except Exception: 58 | # Dataset does not exist, create it 59 | dataset.location = "US" 60 | self.conn.create_dataset(dataset, timeout=30) # Make an API request. 61 | print(f"Created dataset {self.dataset_id}") 62 | 63 | # Create a table called training_data in the dataset that contains the columns: 64 | # id, training_data_type, question, content, embedding, created_at 65 | 66 | self.table_id = f"{self.dataset_id}.training_data" 67 | schema = [ 68 | bigquery.SchemaField("id", "STRING", mode="REQUIRED"), 69 | bigquery.SchemaField("training_data_type", "STRING", mode="REQUIRED"), 70 | bigquery.SchemaField("question", "STRING", mode="REQUIRED"), 71 | bigquery.SchemaField("content", "STRING", mode="REQUIRED"), 72 | bigquery.SchemaField("embedding", "FLOAT64", mode="REPEATED"), 73 | bigquery.SchemaField("created_at", "TIMESTAMP", mode="REQUIRED"), 74 | ] 75 | 76 | table = bigquery.Table(self.table_id, schema=schema) 77 | 78 | try: 79 | self.conn.get_table(self.table_id) # Make an API request. 80 | print(f"Table {self.table_id} already exists") 81 | except Exception: 82 | # Table does not exist, create it 83 | self.conn.create_table(table, timeout=30) # Make an API request. 84 | print(f"Created table {self.table_id}") 85 | 86 | # Create VECTOR INDEX IF NOT EXISTS 87 | # TODO: This requires 5000 rows before it can be created 88 | # vector_index_query = f""" 89 | # CREATE VECTOR INDEX IF NOT EXISTS my_index 90 | # ON `{self.table_id}`(embedding) 91 | # OPTIONS( 92 | # distance_type='COSINE', 93 | # index_type='IVF', 94 | # ivf_options='{{"num_lists": 1000}}' 95 | # ) 96 | # """ 97 | 98 | # try: 99 | # self.conn.query(vector_index_query).result() # Make an API request. 100 | # print(f"Vector index on {self.table_id} created or already exists") 101 | # except Exception as e: 102 | # print(f"Failed to create vector index: {e}") 103 | 104 | def store_training_data(self, training_data_type: str, question: str, content: str, embedding: List[float], **kwargs) -> str: 105 | id = str(uuid.uuid4()) 106 | created_at = datetime.datetime.now() 107 | self.conn.insert_rows_json(self.table_id, [{ 108 | "id": id, 109 | "training_data_type": training_data_type, 110 | "question": question, 111 | "content": content, 112 | "embedding": embedding, 113 | "created_at": created_at.isoformat() 114 | }]) 115 | 116 | return id 117 | 118 | def fetch_similar_training_data(self, training_data_type: str, question: str, n_results, **kwargs) -> pd.DataFrame: 119 | question_embedding = self.generate_question_embedding(question) 120 | 121 | query = f""" 122 | SELECT 123 | base.id as id, 124 | base.question as question, 125 | base.training_data_type as training_data_type, 126 | base.content as content, 127 | distance 128 | FROM 129 | VECTOR_SEARCH( 130 | TABLE `{self.table_id}`, 131 | 'embedding', 132 | (SELECT * FROM UNNEST([STRUCT({question_embedding})])), 133 | top_k => 5, 134 | distance_type => 'COSINE', 135 | options => '{{"use_brute_force":true}}' 136 | ) 137 | WHERE 138 | base.training_data_type = '{training_data_type}' 139 | """ 140 | 141 | results = self.conn.query(query).result().to_dataframe() 142 | return results 143 | 144 | def get_embeddings(self, data: str, task: str) -> List[float]: 145 | embeddings = None 146 | 147 | if self.type == "VERTEX_AI": 148 | input = [TextEmbeddingInput(data, task)] 149 | model = TextEmbeddingModel.from_pretrained("text-embedding-004") 150 | 151 | result = model.get_embeddings(input) 152 | 153 | if len(result) > 0: 154 | embeddings = result[0].values 155 | else: 156 | # Use Gemini Consumer API 157 | result = self.genai.embed_content( 158 | model="models/text-embedding-004", 159 | content=data, 160 | task_type=task) 161 | 162 | if 'embedding' in result: 163 | embeddings = result['embedding'] 164 | 165 | return embeddings 166 | 167 | def generate_question_embedding(self, data: str, **kwargs) -> List[float]: 168 | result = self.get_embeddings(data, "RETRIEVAL_QUERY") 169 | 170 | if result != None: 171 | return result 172 | else: 173 | raise ValueError("No embeddings returned") 174 | 175 | def generate_storage_embedding(self, data: str, **kwargs) -> List[float]: 176 | result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT") 177 | 178 | if result != None: 179 | return result 180 | else: 181 | raise ValueError("No embeddings returned") 182 | 183 | # task = "RETRIEVAL_DOCUMENT" 184 | # inputs = [TextEmbeddingInput(data, task)] 185 | # embeddings = self.vertex_embedding_model.get_embeddings(inputs) 186 | 187 | # if len(embeddings) == 0: 188 | # raise ValueError("No embeddings returned") 189 | 190 | # return embeddings[0].values 191 | 192 | return result 193 | 194 | def generate_embedding(self, data: str, **kwargs) -> List[float]: 195 | return self.generate_storage_embedding(data, **kwargs) 196 | 197 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 198 | df = self.fetch_similar_training_data(training_data_type="sql", question=question, n_results=self.n_results_sql) 199 | 200 | # Return a list of dictionaries with only question, sql fields. The content field needs to be renamed to sql 201 | return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(orient="records") 202 | 203 | def get_related_ddl(self, question: str, **kwargs) -> list: 204 | df = self.fetch_similar_training_data(training_data_type="ddl", question=question, n_results=self.n_results_ddl) 205 | 206 | # Return a list of strings of the content 207 | return df["content"].tolist() 208 | 209 | def get_related_documentation(self, question: str, **kwargs) -> list: 210 | df = self.fetch_similar_training_data(training_data_type="documentation", question=question, n_results=self.n_results_documentation) 211 | 212 | # Return a list of strings of the content 213 | return df["content"].tolist() 214 | 215 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 216 | doc = { 217 | "question": question, 218 | "sql": sql 219 | } 220 | 221 | embedding = self.generate_embedding(str(doc)) 222 | 223 | return self.store_training_data(training_data_type="sql", question=question, content=sql, embedding=embedding) 224 | 225 | def add_ddl(self, ddl: str, **kwargs) -> str: 226 | embedding = self.generate_embedding(ddl) 227 | 228 | return self.store_training_data(training_data_type="ddl", question="", content=ddl, embedding=embedding) 229 | 230 | def add_documentation(self, documentation: str, **kwargs) -> str: 231 | embedding = self.generate_embedding(documentation) 232 | 233 | return self.store_training_data(training_data_type="documentation", question="", content=documentation, embedding=embedding) 234 | 235 | def get_training_data(self, **kwargs) -> pd.DataFrame: 236 | query = f"SELECT id, training_data_type, question, content FROM `{self.table_id}`" 237 | 238 | return self.conn.query(query).result().to_dataframe() 239 | 240 | def remove_training_data(self, id: str, **kwargs) -> bool: 241 | query = f"DELETE FROM `{self.table_id}` WHERE id = '{id}'" 242 | 243 | try: 244 | self.conn.query(query).result() 245 | return True 246 | 247 | except Exception as e: 248 | print(f"Failed to remove training data: {e}") 249 | return False 250 | -------------------------------------------------------------------------------- /src/vanna/google/gemini_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class GoogleGeminiChat(VannaBase): 7 | def __init__(self, config=None): 8 | VannaBase.__init__(self, config=config) 9 | 10 | # default temperature - can be overrided using config 11 | self.temperature = 0.7 12 | 13 | if "temperature" in config: 14 | self.temperature = config["temperature"] 15 | 16 | if "model_name" in config: 17 | model_name = config["model_name"] 18 | else: 19 | model_name = "gemini-1.5-pro" 20 | 21 | self.google_api_key = None 22 | 23 | if "api_key" in config or os.getenv("GOOGLE_API_KEY"): 24 | """ 25 | If Google api_key is provided through config 26 | or set as an environment variable, assign it. 27 | """ 28 | import google.generativeai as genai 29 | 30 | genai.configure(api_key=config["api_key"]) 31 | self.chat_model = genai.GenerativeModel(model_name) 32 | else: 33 | # Authenticate using VertexAI 34 | import google.auth 35 | import vertexai 36 | from vertexai.generative_models import GenerativeModel 37 | 38 | json_file_path = config.get("google_credentials") # Assuming the JSON file path is provided in the config 39 | 40 | if not json_file_path or not os.path.exists(json_file_path): 41 | raise FileNotFoundError(f"JSON credentials file not found at: {json_file_path}") 42 | 43 | try: 44 | # Validate and set the JSON file path for GOOGLE_APPLICATION_CREDENTIALS 45 | os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = json_file_path 46 | 47 | # Initialize VertexAI with the credentials 48 | credentials, _ = google.auth.default() 49 | vertexai.init(credentials=credentials) 50 | self.chat_model = GenerativeModel(model_name) 51 | except google.auth.exceptions.DefaultCredentialsError as e: 52 | raise RuntimeError(f"Default credentials error: {e}") 53 | except google.auth.exceptions.TransportError as e: 54 | raise RuntimeError(f"Transport error during authentication: {e}") 55 | except Exception as e: 56 | raise RuntimeError(f"Failed to authenticate using JSON file: {e}") 57 | 58 | def system_message(self, message: str) -> any: 59 | return message 60 | 61 | def user_message(self, message: str) -> any: 62 | return message 63 | 64 | def assistant_message(self, message: str) -> any: 65 | return message 66 | 67 | def submit_prompt(self, prompt, **kwargs) -> str: 68 | response = self.chat_model.generate_content( 69 | prompt, 70 | generation_config={ 71 | "temperature": self.temperature, 72 | }, 73 | ) 74 | return response.text 75 | -------------------------------------------------------------------------------- /src/vanna/hf/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf import Hf 2 | -------------------------------------------------------------------------------- /src/vanna/hf/hf.py: -------------------------------------------------------------------------------- 1 | import re 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | 4 | from ..base import VannaBase 5 | 6 | 7 | class Hf(VannaBase): 8 | def __init__(self, config=None): 9 | model_name_or_path = self.config.get( 10 | "model_name_or_path", None 11 | ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files 12 | # list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview 13 | quantization_config = self.config.get("quantization_config", None) 14 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 15 | self.model = AutoModelForCausalLM.from_pretrained( 16 | model_name_or_path, 17 | quantization_config=quantization_config, 18 | device_map="auto", 19 | ) 20 | 21 | def system_message(self, message: str) -> any: 22 | return {"role": "system", "content": message} 23 | 24 | def user_message(self, message: str) -> any: 25 | return {"role": "user", "content": message} 26 | 27 | def assistant_message(self, message: str) -> any: 28 | return {"role": "assistant", "content": message} 29 | 30 | def extract_sql_query(self, text): 31 | """ 32 | Extracts the first SQL statement after the word 'select', ignoring case, 33 | matches until the first semicolon, three backticks, or the end of the string, 34 | and removes three backticks if they exist in the extracted string. 35 | 36 | Args: 37 | - text (str): The string to search within for an SQL statement. 38 | 39 | Returns: 40 | - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. 41 | """ 42 | # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string 43 | pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) 44 | 45 | match = pattern.search(text) 46 | if match: 47 | # Remove three backticks from the matched string if they exist 48 | return match.group(0).replace("```", "") 49 | else: 50 | return text 51 | 52 | def generate_sql(self, question: str, **kwargs) -> str: 53 | # Use the super generate_sql 54 | sql = super().generate_sql(question, **kwargs) 55 | 56 | # Replace "\_" with "_" 57 | sql = sql.replace("\\_", "_") 58 | 59 | sql = sql.replace("\\", "") 60 | 61 | return self.extract_sql_query(sql) 62 | 63 | def submit_prompt(self, prompt, **kwargs) -> str: 64 | 65 | input_ids = self.tokenizer.apply_chat_template( 66 | prompt, add_generation_prompt=True, return_tensors="pt" 67 | ).to(self.model.device) 68 | 69 | outputs = self.model.generate( 70 | input_ids, 71 | max_new_tokens=512, 72 | eos_token_id=self.tokenizer.eos_token_id, 73 | do_sample=True, 74 | temperature=1, 75 | top_p=0.9, 76 | ) 77 | response = outputs[0][input_ids.shape[-1] :] 78 | response = self.tokenizer.decode(response, skip_special_tokens=True) 79 | self.log(response) 80 | 81 | return response 82 | -------------------------------------------------------------------------------- /src/vanna/local.py: -------------------------------------------------------------------------------- 1 | from .chromadb.chromadb_vector import ChromaDB_VectorStore 2 | from .openai.openai_chat import OpenAI_Chat 3 | 4 | 5 | class LocalContext_OpenAI(ChromaDB_VectorStore, OpenAI_Chat): 6 | def __init__(self, config=None): 7 | ChromaDB_VectorStore.__init__(self, config=config) 8 | OpenAI_Chat.__init__(self, config=config) 9 | -------------------------------------------------------------------------------- /src/vanna/marqo/__init__.py: -------------------------------------------------------------------------------- 1 | from .marqo import Marqo_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/marqo/marqo.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import marqo 4 | import pandas as pd 5 | 6 | from ..base import VannaBase 7 | 8 | 9 | class Marqo_VectorStore(VannaBase): 10 | def __init__(self, config=None): 11 | VannaBase.__init__(self, config=config) 12 | 13 | if config is not None and "marqo_url" in config: 14 | marqo_url = config["marqo_url"] 15 | else: 16 | marqo_url = "http://localhost:8882" 17 | 18 | if config is not None and "marqo_model" in config: 19 | marqo_model = config["marqo_model"] 20 | else: 21 | marqo_model = "hf/all_datasets_v4_MiniLM-L6" 22 | 23 | self.mq = marqo.Client(url=marqo_url) 24 | 25 | for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]: 26 | try: 27 | self.mq.create_index(index, model=marqo_model) 28 | except Exception as e: 29 | print(e) 30 | print(f"Marqo index {index} already exists") 31 | pass 32 | 33 | def generate_embedding(self, data: str, **kwargs) -> list[float]: 34 | # Marqo doesn't need to generate embeddings 35 | pass 36 | 37 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 38 | id = str(uuid.uuid4()) + "-sql" 39 | question_sql_dict = { 40 | "question": question, 41 | "sql": sql, 42 | "_id": id, 43 | } 44 | 45 | self.mq.index("vanna-sql").add_documents( 46 | [question_sql_dict], 47 | tensor_fields=["question", "sql"], 48 | ) 49 | 50 | return id 51 | 52 | def add_ddl(self, ddl: str, **kwargs) -> str: 53 | id = str(uuid.uuid4()) + "-ddl" 54 | ddl_dict = { 55 | "ddl": ddl, 56 | "_id": id, 57 | } 58 | 59 | self.mq.index("vanna-ddl").add_documents( 60 | [ddl_dict], 61 | tensor_fields=["ddl"], 62 | ) 63 | return id 64 | 65 | def add_documentation(self, documentation: str, **kwargs) -> str: 66 | id = str(uuid.uuid4()) + "-doc" 67 | doc_dict = { 68 | "doc": documentation, 69 | "_id": id, 70 | } 71 | 72 | self.mq.index("vanna-doc").add_documents( 73 | [doc_dict], 74 | tensor_fields=["doc"], 75 | ) 76 | return id 77 | 78 | def get_training_data(self, **kwargs) -> pd.DataFrame: 79 | data = [] 80 | 81 | for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]: 82 | data.append( 83 | { 84 | "id": hit["_id"], 85 | "training_data_type": "documentation", 86 | "question": "", 87 | "content": hit["doc"], 88 | } 89 | ) 90 | 91 | for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]: 92 | data.append( 93 | { 94 | "id": hit["_id"], 95 | "training_data_type": "ddl", 96 | "question": "", 97 | "content": hit["ddl"], 98 | } 99 | ) 100 | 101 | for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]: 102 | data.append( 103 | { 104 | "id": hit["_id"], 105 | "training_data_type": "sql", 106 | "question": hit["question"], 107 | "content": hit["sql"], 108 | } 109 | ) 110 | 111 | df = pd.DataFrame(data) 112 | 113 | return df 114 | 115 | def remove_training_data(self, id: str, **kwargs) -> bool: 116 | if id.endswith("-sql"): 117 | self.mq.index("vanna-sql").delete_documents(ids=[id]) 118 | return True 119 | elif id.endswith("-ddl"): 120 | self.mq.index("vanna-ddl").delete_documents(ids=[id]) 121 | return True 122 | elif id.endswith("-doc"): 123 | self.mq.index("vanna-doc").delete_documents(ids=[id]) 124 | return True 125 | else: 126 | return False 127 | 128 | # Static method to extract the documents from the results of a query 129 | @staticmethod 130 | def _extract_documents(data) -> list: 131 | # Check if 'hits' key is in the dictionary and if it's a list 132 | if "hits" in data and isinstance(data["hits"], list): 133 | # Iterate over each item in 'hits' 134 | 135 | if len(data["hits"]) == 0: 136 | return [] 137 | 138 | # If there is a "doc" key, return the value of that key 139 | if "doc" in data["hits"][0]: 140 | return [hit["doc"] for hit in data["hits"]] 141 | 142 | # If there is a "ddl" key, return the value of that key 143 | if "ddl" in data["hits"][0]: 144 | return [hit["ddl"] for hit in data["hits"]] 145 | 146 | # Otherwise return the entire hit 147 | return [ 148 | {key: value for key, value in hit.items() if not key.startswith("_")} 149 | for hit in data["hits"] 150 | ] 151 | else: 152 | # Return an empty list if 'hits' is not found or not a list 153 | return [] 154 | 155 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 156 | return Marqo_VectorStore._extract_documents( 157 | self.mq.index("vanna-sql").search(question) 158 | ) 159 | 160 | def get_related_ddl(self, question: str, **kwargs) -> list: 161 | return Marqo_VectorStore._extract_documents( 162 | self.mq.index("vanna-ddl").search(question) 163 | ) 164 | 165 | def get_related_documentation(self, question: str, **kwargs) -> list: 166 | return Marqo_VectorStore._extract_documents( 167 | self.mq.index("vanna-doc").search(question) 168 | ) 169 | -------------------------------------------------------------------------------- /src/vanna/milvus/__init__.py: -------------------------------------------------------------------------------- 1 | from .milvus_vector import Milvus_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/mistral/__init__.py: -------------------------------------------------------------------------------- 1 | from .mistral import Mistral 2 | -------------------------------------------------------------------------------- /src/vanna/mistral/mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mistralai import Mistral as MistralClient 4 | from mistralai import UserMessage 5 | 6 | from ..base import VannaBase 7 | 8 | 9 | class Mistral(VannaBase): 10 | def __init__(self, config=None): 11 | if config is None: 12 | raise ValueError( 13 | "For Mistral, config must be provided with an api_key and model" 14 | ) 15 | 16 | if "api_key" not in config: 17 | raise ValueError("config must contain a Mistral api_key") 18 | 19 | if "model" not in config: 20 | raise ValueError("config must contain a Mistral model") 21 | 22 | api_key = config["api_key"] 23 | model = config["model"] 24 | self.client = MistralClient(api_key=api_key) 25 | self.model = model 26 | 27 | def system_message(self, message: str) -> any: 28 | return {"role": "system", "content": message} 29 | 30 | def user_message(self, message: str) -> any: 31 | return {"role": "user", "content": message} 32 | 33 | def assistant_message(self, message: str) -> any: 34 | return {"role": "assistant", "content": message} 35 | 36 | def generate_sql(self, question: str, **kwargs) -> str: 37 | # Use the super generate_sql 38 | sql = super().generate_sql(question, **kwargs) 39 | 40 | # Replace "\_" with "_" 41 | sql = sql.replace("\\_", "_") 42 | 43 | return sql 44 | 45 | def submit_prompt(self, prompt, **kwargs) -> str: 46 | chat_response = self.client.chat.complete( 47 | model=self.model, 48 | messages=prompt, 49 | ) 50 | 51 | return chat_response.choices[0].message.content 52 | -------------------------------------------------------------------------------- /src/vanna/mock/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import MockEmbedding 2 | from .llm import MockLLM 3 | from .vectordb import MockVectorDB 4 | -------------------------------------------------------------------------------- /src/vanna/mock/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class MockEmbedding(VannaBase): 7 | def __init__(self, config=None): 8 | pass 9 | 10 | def generate_embedding(self, data: str, **kwargs) -> List[float]: 11 | return [1.0, 2.0, 3.0, 4.0, 5.0] 12 | -------------------------------------------------------------------------------- /src/vanna/mock/llm.py: -------------------------------------------------------------------------------- 1 | 2 | from ..base import VannaBase 3 | 4 | 5 | class MockLLM(VannaBase): 6 | def __init__(self, config=None): 7 | pass 8 | 9 | def system_message(self, message: str) -> any: 10 | return {"role": "system", "content": message} 11 | 12 | def user_message(self, message: str) -> any: 13 | return {"role": "user", "content": message} 14 | 15 | def assistant_message(self, message: str) -> any: 16 | return {"role": "assistant", "content": message} 17 | 18 | def submit_prompt(self, prompt, **kwargs) -> str: 19 | return "Mock LLM response" 20 | -------------------------------------------------------------------------------- /src/vanna/mock/vectordb.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class MockVectorDB(VannaBase): 7 | def __init__(self, config=None): 8 | pass 9 | 10 | def _get_id(self, value: str, **kwargs) -> str: 11 | # Hash the value and return the ID 12 | return str(hash(value)) 13 | 14 | def add_ddl(self, ddl: str, **kwargs) -> str: 15 | return self._get_id(ddl) 16 | 17 | def add_documentation(self, doc: str, **kwargs) -> str: 18 | return self._get_id(doc) 19 | 20 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 21 | return self._get_id(question) 22 | 23 | def get_related_ddl(self, question: str, **kwargs) -> list: 24 | return [] 25 | 26 | def get_related_documentation(self, question: str, **kwargs) -> list: 27 | return [] 28 | 29 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 30 | return [] 31 | 32 | def get_training_data(self, **kwargs) -> pd.DataFrame: 33 | return pd.DataFrame({'id': {0: '19546-ddl', 34 | 1: '91597-sql', 35 | 2: '133976-sql', 36 | 3: '59851-doc', 37 | 4: '73046-sql'}, 38 | 'training_data_type': {0: 'ddl', 39 | 1: 'sql', 40 | 2: 'sql', 41 | 3: 'documentation', 42 | 4: 'sql'}, 43 | 'question': {0: None, 44 | 1: 'What are the top selling genres?', 45 | 2: 'What are the low 7 artists by sales?', 46 | 3: None, 47 | 4: 'What is the total sales for each customer?'}, 48 | 'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)', 49 | 1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;', 50 | 2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;', 51 | 3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.', 52 | 4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}}) 53 | 54 | def remove_training_data(id: str, **kwargs) -> bool: 55 | return True 56 | -------------------------------------------------------------------------------- /src/vanna/ollama/__init__.py: -------------------------------------------------------------------------------- 1 | from .ollama import Ollama 2 | -------------------------------------------------------------------------------- /src/vanna/ollama/ollama.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | from httpx import Timeout 5 | 6 | from ..base import VannaBase 7 | from ..exceptions import DependencyError 8 | 9 | 10 | class Ollama(VannaBase): 11 | def __init__(self, config=None): 12 | 13 | try: 14 | ollama = __import__("ollama") 15 | except ImportError: 16 | raise DependencyError( 17 | "You need to install required dependencies to execute this method, run command:" 18 | " \npip install ollama" 19 | ) 20 | 21 | if not config: 22 | raise ValueError("config must contain at least Ollama model") 23 | if 'model' not in config.keys(): 24 | raise ValueError("config must contain at least Ollama model") 25 | self.host = config.get("ollama_host", "http://localhost:11434") 26 | self.model = config["model"] 27 | if ":" not in self.model: 28 | self.model += ":latest" 29 | 30 | self.ollama_timeout = config.get("ollama_timeout", 240.0) 31 | 32 | self.ollama_client = ollama.Client(self.host, timeout=Timeout(self.ollama_timeout)) 33 | self.keep_alive = config.get('keep_alive', None) 34 | self.ollama_options = config.get('options', {}) 35 | self.num_ctx = self.ollama_options.get('num_ctx', 2048) 36 | self.__pull_model_if_ne(self.ollama_client, self.model) 37 | 38 | @staticmethod 39 | def __pull_model_if_ne(ollama_client, model): 40 | model_response = ollama_client.list() 41 | model_lists = [model_element['model'] for model_element in 42 | model_response.get('models', [])] 43 | if model not in model_lists: 44 | ollama_client.pull(model) 45 | 46 | def system_message(self, message: str) -> any: 47 | return {"role": "system", "content": message} 48 | 49 | def user_message(self, message: str) -> any: 50 | return {"role": "user", "content": message} 51 | 52 | def assistant_message(self, message: str) -> any: 53 | return {"role": "assistant", "content": message} 54 | 55 | def extract_sql(self, llm_response): 56 | """ 57 | Extracts the first SQL statement after the word 'select', ignoring case, 58 | matches until the first semicolon, three backticks, or the end of the string, 59 | and removes three backticks if they exist in the extracted string. 60 | 61 | Args: 62 | - llm_response (str): The string to search within for an SQL statement. 63 | 64 | Returns: 65 | - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. 66 | """ 67 | # Remove ollama-generated extra characters 68 | llm_response = llm_response.replace("\\_", "_") 69 | llm_response = llm_response.replace("\\", "") 70 | 71 | # Regular expression to find ```sql' and capture until '```' 72 | sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL) 73 | # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string 74 | select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)', 75 | llm_response, 76 | re.IGNORECASE | re.DOTALL) 77 | if sql: 78 | self.log( 79 | f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") 80 | return sql.group(1).replace("```", "") 81 | elif select_with: 82 | self.log( 83 | f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}") 84 | return select_with.group(0) 85 | else: 86 | return llm_response 87 | 88 | def submit_prompt(self, prompt, **kwargs) -> str: 89 | self.log( 90 | f"Ollama parameters:\n" 91 | f"model={self.model},\n" 92 | f"options={self.ollama_options},\n" 93 | f"keep_alive={self.keep_alive}") 94 | self.log(f"Prompt Content:\n{json.dumps(prompt, ensure_ascii=False)}") 95 | response_dict = self.ollama_client.chat(model=self.model, 96 | messages=prompt, 97 | stream=False, 98 | options=self.ollama_options, 99 | keep_alive=self.keep_alive) 100 | 101 | self.log(f"Ollama Response:\n{str(response_dict)}") 102 | 103 | return response_dict['message']['content'] 104 | -------------------------------------------------------------------------------- /src/vanna/openai/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai_chat import OpenAI_Chat 2 | from .openai_embeddings import OpenAI_Embeddings 3 | -------------------------------------------------------------------------------- /src/vanna/openai/openai_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from openai import OpenAI 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | class OpenAI_Chat(VannaBase): 9 | def __init__(self, client=None, config=None): 10 | VannaBase.__init__(self, config=config) 11 | 12 | # default parameters - can be overrided using config 13 | self.temperature = 0.7 14 | 15 | if "temperature" in config: 16 | self.temperature = config["temperature"] 17 | 18 | if "api_type" in config: 19 | raise Exception( 20 | "Passing api_type is now deprecated. Please pass an OpenAI client instead." 21 | ) 22 | 23 | if "api_base" in config: 24 | raise Exception( 25 | "Passing api_base is now deprecated. Please pass an OpenAI client instead." 26 | ) 27 | 28 | if "api_version" in config: 29 | raise Exception( 30 | "Passing api_version is now deprecated. Please pass an OpenAI client instead." 31 | ) 32 | 33 | if client is not None: 34 | self.client = client 35 | return 36 | 37 | if config is None and client is None: 38 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 39 | return 40 | 41 | if "api_key" in config: 42 | self.client = OpenAI(api_key=config["api_key"]) 43 | 44 | def system_message(self, message: str) -> any: 45 | return {"role": "system", "content": message} 46 | 47 | def user_message(self, message: str) -> any: 48 | return {"role": "user", "content": message} 49 | 50 | def assistant_message(self, message: str) -> any: 51 | return {"role": "assistant", "content": message} 52 | 53 | def submit_prompt(self, prompt, **kwargs) -> str: 54 | if prompt is None: 55 | raise Exception("Prompt is None") 56 | 57 | if len(prompt) == 0: 58 | raise Exception("Prompt is empty") 59 | 60 | # Count the number of tokens in the message log 61 | # Use 4 as an approximation for the number of characters per token 62 | num_tokens = 0 63 | for message in prompt: 64 | num_tokens += len(message["content"]) / 4 65 | 66 | if kwargs.get("model", None) is not None: 67 | model = kwargs.get("model", None) 68 | print( 69 | f"Using model {model} for {num_tokens} tokens (approx)" 70 | ) 71 | response = self.client.chat.completions.create( 72 | model=model, 73 | messages=prompt, 74 | stop=None, 75 | temperature=self.temperature, 76 | ) 77 | elif kwargs.get("engine", None) is not None: 78 | engine = kwargs.get("engine", None) 79 | print( 80 | f"Using model {engine} for {num_tokens} tokens (approx)" 81 | ) 82 | response = self.client.chat.completions.create( 83 | engine=engine, 84 | messages=prompt, 85 | stop=None, 86 | temperature=self.temperature, 87 | ) 88 | elif self.config is not None and "engine" in self.config: 89 | print( 90 | f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" 91 | ) 92 | response = self.client.chat.completions.create( 93 | engine=self.config["engine"], 94 | messages=prompt, 95 | stop=None, 96 | temperature=self.temperature, 97 | ) 98 | elif self.config is not None and "model" in self.config: 99 | print( 100 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)" 101 | ) 102 | response = self.client.chat.completions.create( 103 | model=self.config["model"], 104 | messages=prompt, 105 | stop=None, 106 | temperature=self.temperature, 107 | ) 108 | else: 109 | if num_tokens > 3500: 110 | model = "gpt-3.5-turbo-16k" 111 | else: 112 | model = "gpt-3.5-turbo" 113 | 114 | print(f"Using model {model} for {num_tokens} tokens (approx)") 115 | response = self.client.chat.completions.create( 116 | model=model, 117 | messages=prompt, 118 | stop=None, 119 | temperature=self.temperature, 120 | ) 121 | 122 | # Find the first response from the chatbot that has text in it (some responses may not have text) 123 | for choice in response.choices: 124 | if "text" in choice: 125 | return choice.text 126 | 127 | # If no response with text is found, return the first response's content (which may be empty) 128 | return response.choices[0].message.content 129 | -------------------------------------------------------------------------------- /src/vanna/openai/openai_embeddings.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class OpenAI_Embeddings(VannaBase): 7 | def __init__(self, client=None, config=None): 8 | VannaBase.__init__(self, config=config) 9 | 10 | if client is not None: 11 | self.client = client 12 | return 13 | 14 | if self.client is not None: 15 | return 16 | 17 | self.client = OpenAI() 18 | 19 | if config is None: 20 | return 21 | 22 | if "api_type" in config: 23 | self.client.api_type = config["api_type"] 24 | 25 | if "api_base" in config: 26 | self.client.api_base = config["api_base"] 27 | 28 | if "api_version" in config: 29 | self.client.api_version = config["api_version"] 30 | 31 | if "api_key" in config: 32 | self.client.api_key = config["api_key"] 33 | 34 | def generate_embedding(self, data: str, **kwargs) -> list[float]: 35 | if self.config is not None and "engine" in self.config: 36 | embedding = self.client.embeddings.create( 37 | engine=self.config["engine"], 38 | input=data, 39 | ) 40 | else: 41 | embedding = self.client.embeddings.create( 42 | model="text-embedding-ada-002", 43 | input=data, 44 | ) 45 | 46 | return embedding.get("data")[0]["embedding"] 47 | -------------------------------------------------------------------------------- /src/vanna/opensearch/__init__.py: -------------------------------------------------------------------------------- 1 | from .opensearch_vector import OpenSearch_VectorStore 2 | from .opensearch_vector_semantic import OpenSearch_Semantic_VectorStore 3 | -------------------------------------------------------------------------------- /src/vanna/opensearch/opensearch_vector_semantic.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pandas as pd 4 | from langchain_community.vectorstores import OpenSearchVectorSearch 5 | 6 | from ..base import VannaBase 7 | from ..utils import deterministic_uuid 8 | 9 | 10 | class OpenSearch_Semantic_VectorStore(VannaBase): 11 | def __init__(self, config=None): 12 | VannaBase.__init__(self, config=config) 13 | if config is None: 14 | config = {} 15 | 16 | if "embedding_function" in config: 17 | self.embedding_function = config.get("embedding_function") 18 | else: 19 | from langchain_huggingface import HuggingFaceEmbeddings 20 | self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") 21 | 22 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) 23 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10)) 24 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) 25 | 26 | self.document_index = config.get("es_document_index", "vanna_document_index") 27 | self.ddl_index = config.get("es_ddl_index", "vanna_ddl_index") 28 | self.question_sql_index = config.get("es_question_sql_index", "vanna_questions_sql_index") 29 | 30 | self.log(f"OpenSearch_Semantic_VectorStore initialized with document_index: {self.document_index}, ddl_index: {self.ddl_index}, question_sql_index: {self.question_sql_index}") 31 | 32 | es_urls = config.get("es_urls", "https://localhost:9200") 33 | ssl = config.get("es_ssl", True) 34 | verify_certs = config.get("es_verify_certs", True) 35 | 36 | if "es_user" in config: 37 | auth = (config["es_user"], config["es_password"]) 38 | else: 39 | auth = None 40 | 41 | headers = config.get("es_headers", None) 42 | timeout = config.get("es_timeout", 60) 43 | max_retries = config.get("es_max_retries", 10) 44 | 45 | common_args = { 46 | "opensearch_url": es_urls, 47 | "embedding_function": self.embedding_function, 48 | "engine": "faiss", 49 | "http_auth": auth, 50 | "use_ssl": ssl, 51 | "verify_certs": verify_certs, 52 | "timeout": timeout, 53 | "max_retries": max_retries, 54 | "retry_on_timeout": True, 55 | "headers": headers, 56 | } 57 | 58 | self.documentation_store = OpenSearchVectorSearch(index_name=self.document_index, **common_args) 59 | self.ddl_store = OpenSearchVectorSearch(index_name=self.ddl_index, **common_args) 60 | self.sql_store = OpenSearchVectorSearch(index_name=self.question_sql_index, **common_args) 61 | 62 | def add_ddl(self, ddl: str, **kwargs) -> str: 63 | _id = deterministic_uuid(ddl) + "-ddl" 64 | self.ddl_store.add_texts(texts=[ddl], ids=[_id], **kwargs) 65 | return _id 66 | 67 | def add_documentation(self, documentation: str, **kwargs) -> str: 68 | _id = deterministic_uuid(documentation) + "-doc" 69 | self.documentation_store.add_texts(texts=[documentation], ids=[_id], **kwargs) 70 | return _id 71 | 72 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 73 | question_sql_json = json.dumps( 74 | { 75 | "question": question, 76 | "sql": sql, 77 | }, 78 | ensure_ascii=False, 79 | ) 80 | 81 | _id = deterministic_uuid(question_sql_json) + "-sql" 82 | self.sql_store.add_texts(texts=[question_sql_json], ids=[_id], **kwargs) 83 | return _id 84 | 85 | def get_related_ddl(self, question: str, **kwargs) -> list: 86 | documents = self.ddl_store.similarity_search(query=question, k=self.n_results_ddl) 87 | return [document.page_content for document in documents] 88 | 89 | def get_related_documentation(self, question: str, **kwargs) -> list: 90 | documents = self.documentation_store.similarity_search(query=question, k=self.n_results_documentation) 91 | return [document.page_content for document in documents] 92 | 93 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 94 | documents = self.sql_store.similarity_search(query=question, k=self.n_results_sql) 95 | return [json.loads(document.page_content) for document in documents] 96 | 97 | def get_training_data(self, **kwargs) -> pd.DataFrame: 98 | data = [] 99 | query = { 100 | "query": { 101 | "match_all": {} 102 | } 103 | } 104 | 105 | indices = [ 106 | {"index": self.document_index, "type": "documentation"}, 107 | {"index": self.question_sql_index, "type": "sql"}, 108 | {"index": self.ddl_index, "type": "ddl"}, 109 | ] 110 | 111 | # Use documentation_store.client consistently for search on all indices 112 | opensearch_client = self.documentation_store.client 113 | 114 | for index_info in indices: 115 | index_name = index_info["index"] 116 | training_data_type = index_info["type"] 117 | scroll = '1m' # keep scroll context for 1 minute 118 | response = opensearch_client.search( 119 | index=index_name, 120 | ignore_unavailable=True, 121 | body=query, 122 | scroll=scroll, 123 | size=1000 124 | ) 125 | 126 | scroll_id = response.get('_scroll_id') 127 | 128 | while scroll_id: 129 | hits = response['hits']['hits'] 130 | if not hits: 131 | break # No more hits, exit loop 132 | 133 | for hit in hits: 134 | source = hit['_source'] 135 | if training_data_type == "sql": 136 | try: 137 | doc_dict = json.loads(source['text']) 138 | content = doc_dict.get("sql") 139 | question = doc_dict.get("question") 140 | except json.JSONDecodeError as e: 141 | self.log(f"Skipping row with custom_id {hit['_id']} due to JSON parsing error: {e}","Error") 142 | continue 143 | else: # documentation or ddl 144 | content = source['text'] 145 | question = None 146 | 147 | data.append({ 148 | "id": hit["_id"], 149 | "training_data_type": training_data_type, 150 | "question": question, 151 | "content": content, 152 | }) 153 | 154 | # Get next batch of results, using documentation_store.client.scroll 155 | response = opensearch_client.scroll(scroll_id=scroll_id, scroll=scroll) 156 | scroll_id = response.get('_scroll_id') 157 | 158 | return pd.DataFrame(data) 159 | 160 | def remove_training_data(self, id: str, **kwargs) -> bool: 161 | try: 162 | if id.endswith("-sql"): 163 | return self.sql_store.delete(ids=[id], **kwargs) 164 | elif id.endswith("-ddl"): 165 | return self.ddl_store.delete(ids=[id], **kwargs) 166 | elif id.endswith("-doc"): 167 | return self.documentation_store.delete(ids=[id], **kwargs) 168 | else: 169 | return False 170 | except Exception as e: 171 | self.log(f"Error deleting training dataError deleting training data: {e}", "Error") 172 | return False 173 | 174 | def generate_embedding(self, data: str, **kwargs) -> list[float]: 175 | pass 176 | -------------------------------------------------------------------------------- /src/vanna/oracle/__init__.py: -------------------------------------------------------------------------------- 1 | from .oracle_vector import Oracle_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/pgvector/__init__.py: -------------------------------------------------------------------------------- 1 | from .pgvector import PG_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/pgvector/pgvector.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import logging 4 | import uuid 5 | 6 | import pandas as pd 7 | from langchain_core.documents import Document 8 | from langchain_postgres.vectorstores import PGVector 9 | from sqlalchemy import create_engine, text 10 | 11 | from .. import ValidationError 12 | from ..base import VannaBase 13 | from ..types import TrainingPlan, TrainingPlanItem 14 | 15 | 16 | class PG_VectorStore(VannaBase): 17 | def __init__(self, config=None): 18 | if not config or "connection_string" not in config: 19 | raise ValueError( 20 | "A valid 'config' dictionary with a 'connection_string' is required.") 21 | 22 | VannaBase.__init__(self, config=config) 23 | 24 | if config and "connection_string" in config: 25 | self.connection_string = config.get("connection_string") 26 | self.n_results = config.get("n_results", 10) 27 | 28 | if config and "embedding_function" in config: 29 | self.embedding_function = config.get("embedding_function") 30 | else: 31 | from langchain_huggingface import HuggingFaceEmbeddings 32 | self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") 33 | 34 | self.sql_collection = PGVector( 35 | embeddings=self.embedding_function, 36 | collection_name="sql", 37 | connection=self.connection_string, 38 | ) 39 | self.ddl_collection = PGVector( 40 | embeddings=self.embedding_function, 41 | collection_name="ddl", 42 | connection=self.connection_string, 43 | ) 44 | self.documentation_collection = PGVector( 45 | embeddings=self.embedding_function, 46 | collection_name="documentation", 47 | connection=self.connection_string, 48 | ) 49 | 50 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 51 | question_sql_json = json.dumps( 52 | { 53 | "question": question, 54 | "sql": sql, 55 | }, 56 | ensure_ascii=False, 57 | ) 58 | id = str(uuid.uuid4()) + "-sql" 59 | createdat = kwargs.get("createdat") 60 | doc = Document( 61 | page_content=question_sql_json, 62 | metadata={"id": id, "createdat": createdat}, 63 | ) 64 | self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]]) 65 | 66 | return id 67 | 68 | def add_ddl(self, ddl: str, **kwargs) -> str: 69 | _id = str(uuid.uuid4()) + "-ddl" 70 | doc = Document( 71 | page_content=ddl, 72 | metadata={"id": _id}, 73 | ) 74 | self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]]) 75 | return _id 76 | 77 | def add_documentation(self, documentation: str, **kwargs) -> str: 78 | _id = str(uuid.uuid4()) + "-doc" 79 | doc = Document( 80 | page_content=documentation, 81 | metadata={"id": _id}, 82 | ) 83 | self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]]) 84 | return _id 85 | 86 | def get_collection(self, collection_name): 87 | match collection_name: 88 | case "sql": 89 | return self.sql_collection 90 | case "ddl": 91 | return self.ddl_collection 92 | case "documentation": 93 | return self.documentation_collection 94 | case _: 95 | raise ValueError("Specified collection does not exist.") 96 | 97 | def get_similar_question_sql(self, question: str) -> list: 98 | documents = self.sql_collection.similarity_search(query=question, k=self.n_results) 99 | return [ast.literal_eval(document.page_content) for document in documents] 100 | 101 | def get_related_ddl(self, question: str, **kwargs) -> list: 102 | documents = self.ddl_collection.similarity_search(query=question, k=self.n_results) 103 | return [document.page_content for document in documents] 104 | 105 | def get_related_documentation(self, question: str, **kwargs) -> list: 106 | documents = self.documentation_collection.similarity_search(query=question, k=self.n_results) 107 | return [document.page_content for document in documents] 108 | 109 | def train( 110 | self, 111 | question: str | None = None, 112 | sql: str | None = None, 113 | ddl: str | None = None, 114 | documentation: str | None = None, 115 | plan: TrainingPlan | None = None, 116 | createdat: str | None = None, 117 | ): 118 | if question and not sql: 119 | raise ValidationError("Please provide a SQL query.") 120 | 121 | if documentation: 122 | logging.info(f"Adding documentation: {documentation}") 123 | return self.add_documentation(documentation) 124 | 125 | if sql and question: 126 | return self.add_question_sql(question=question, sql=sql, createdat=createdat) 127 | 128 | if ddl: 129 | logging.info(f"Adding ddl: {ddl}") 130 | return self.add_ddl(ddl) 131 | 132 | if plan: 133 | for item in plan._plan: 134 | if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: 135 | self.add_ddl(item.item_value) 136 | elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: 137 | self.add_documentation(item.item_value) 138 | elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name: 139 | self.add_question_sql(question=item.item_name, sql=item.item_value) 140 | 141 | def get_training_data(self, **kwargs) -> pd.DataFrame: 142 | # Establishing the connection 143 | engine = create_engine(self.connection_string) 144 | 145 | # Querying the 'langchain_pg_embedding' table 146 | query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding" 147 | df_embedding = pd.read_sql(query_embedding, engine) 148 | 149 | # List to accumulate the processed rows 150 | processed_rows = [] 151 | 152 | # Process each row in the DataFrame 153 | for _, row in df_embedding.iterrows(): 154 | custom_id = row["cmetadata"]["id"] 155 | document = row["document"] 156 | training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:] 157 | 158 | if training_data_type == "sql": 159 | # Convert the document string to a dictionary 160 | try: 161 | doc_dict = ast.literal_eval(document) 162 | question = doc_dict.get("question") 163 | content = doc_dict.get("sql") 164 | except (ValueError, SyntaxError): 165 | logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.") 166 | continue 167 | elif training_data_type in ["documentation", "ddl"]: 168 | question = None # Default value for question 169 | content = document 170 | else: 171 | # If the suffix is not recognized, skip this row 172 | logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.") 173 | continue 174 | 175 | # Append the processed data to the list 176 | processed_rows.append( 177 | {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type} 178 | ) 179 | 180 | # Create a DataFrame from the list of processed rows 181 | df_processed = pd.DataFrame(processed_rows) 182 | 183 | return df_processed 184 | 185 | def remove_training_data(self, id: str, **kwargs) -> bool: 186 | # Create the database engine 187 | engine = create_engine(self.connection_string) 188 | 189 | # SQL DELETE statement 190 | delete_statement = text( 191 | """ 192 | DELETE FROM langchain_pg_embedding 193 | WHERE cmetadata ->> 'id' = :id 194 | """ 195 | ) 196 | 197 | # Connect to the database and execute the delete statement 198 | with engine.connect() as connection: 199 | # Start a transaction 200 | with connection.begin() as transaction: 201 | try: 202 | result = connection.execute(delete_statement, {"id": id}) 203 | # Commit the transaction if the delete was successful 204 | transaction.commit() 205 | # Check if any row was deleted and return True or False accordingly 206 | return result.rowcount > 0 207 | except Exception as e: 208 | # Rollback the transaction in case of error 209 | logging.error(f"An error occurred: {e}") 210 | transaction.rollback() 211 | return False 212 | 213 | def remove_collection(self, collection_name: str) -> bool: 214 | engine = create_engine(self.connection_string) 215 | 216 | # Determine the suffix to look for based on the collection name 217 | suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"} 218 | suffix = suffix_map.get(collection_name) 219 | 220 | if not suffix: 221 | logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.") 222 | return False 223 | 224 | # SQL query to delete rows based on the condition 225 | query = text( 226 | f""" 227 | DELETE FROM langchain_pg_embedding 228 | WHERE cmetadata->>'id' LIKE '%{suffix}' 229 | """ 230 | ) 231 | 232 | # Execute the deletion within a transaction block 233 | with engine.connect() as connection: 234 | with connection.begin() as transaction: 235 | try: 236 | result = connection.execute(query) 237 | transaction.commit() # Explicitly commit the transaction 238 | if result.rowcount > 0: 239 | logging.info( 240 | f"Deleted {result.rowcount} rows from " 241 | f"langchain_pg_embedding where collection is {collection_name}." 242 | ) 243 | return True 244 | else: 245 | logging.info(f"No rows deleted for collection {collection_name}.") 246 | return False 247 | except Exception as e: 248 | logging.error(f"An error occurred: {e}") 249 | transaction.rollback() # Rollback in case of error 250 | return False 251 | 252 | def generate_embedding(self, *args, **kwargs): 253 | pass 254 | -------------------------------------------------------------------------------- /src/vanna/pinecone/__init__.py: -------------------------------------------------------------------------------- 1 | from .pinecone_vector import PineconeDB_VectorStore 2 | 3 | __all__ = ["PineconeDB_VectorStore"] 4 | -------------------------------------------------------------------------------- /src/vanna/qdrant/__init__.py: -------------------------------------------------------------------------------- 1 | from .qdrant import Qdrant_VectorStore 2 | 3 | __all__ = ["Qdrant_VectorStore"] 4 | -------------------------------------------------------------------------------- /src/vanna/qianfan/Qianfan_Chat.py: -------------------------------------------------------------------------------- 1 | import qianfan 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class Qianfan_Chat(VannaBase): 7 | def __init__(self, client=None, config=None): 8 | VannaBase.__init__(self, config=config) 9 | 10 | if "api_key" not in config: 11 | raise Exception("Missing api_key in config") 12 | self.api_key = config["api_key"] 13 | 14 | if "secret_key" not in config: 15 | raise Exception("Missing secret_key in config") 16 | self.secret_key = config["secret_key"] 17 | 18 | # default parameters - can be overrided using config 19 | self.temperature = 0.9 20 | self.max_tokens = 1024 21 | 22 | if "temperature" in config: 23 | self.temperature = config["temperature"] 24 | 25 | if "max_tokens" in config: 26 | self.max_tokens = config["max_tokens"] 27 | 28 | self.model = config["model"] if "model" in config else "ERNIE-Speed" 29 | 30 | if client is not None: 31 | self.client = client 32 | return 33 | 34 | self.client = qianfan.ChatCompletion(ak=self.api_key, 35 | sk=self.secret_key) 36 | 37 | def system_message(self, message: str) -> any: 38 | return {"role": "system", "content": message} 39 | 40 | def user_message(self, message: str) -> any: 41 | return {"role": "user", "content": message} 42 | 43 | def assistant_message(self, message: str) -> any: 44 | return {"role": "assistant", "content": message} 45 | 46 | def get_sql_prompt( 47 | self, 48 | initial_prompt: str, 49 | question: str, 50 | question_sql_list: list, 51 | ddl_list: list, 52 | doc_list: list, 53 | **kwargs, 54 | ): 55 | """ 56 | Example: 57 | ```python 58 | vn.get_sql_prompt( 59 | question="What are the top 10 customers by sales?", 60 | question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}], 61 | ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"], 62 | doc_list=["The customers table contains information about customers and their sales."], 63 | ) 64 | 65 | ``` 66 | 67 | This method is used to generate a prompt for the LLM to generate SQL. 68 | 69 | Args: 70 | question (str): The question to generate SQL for. 71 | question_sql_list (list): A list of questions and their corresponding SQL statements. 72 | ddl_list (list): A list of DDL statements. 73 | doc_list (list): A list of documentation. 74 | 75 | Returns: 76 | any: The prompt for the LLM to generate SQL. 77 | """ 78 | 79 | if initial_prompt is None: 80 | initial_prompt = f"You are a {self.dialect} expert. " + \ 81 | "Please help to generate a SQL to answer the question based on some context.Please don't give any explanation for your answer. Just only generate a SQL \n" 82 | 83 | initial_prompt = self.add_ddl_to_prompt( 84 | initial_prompt, ddl_list, max_tokens=self.max_tokens 85 | ) 86 | 87 | if self.static_documentation != "": 88 | doc_list.append(self.static_documentation) 89 | 90 | initial_prompt = self.add_documentation_to_prompt( 91 | initial_prompt, doc_list, max_tokens=self.max_tokens 92 | ) 93 | message_log = [] 94 | 95 | if question_sql_list is None or len(question_sql_list) == 0: 96 | initial_prompt = initial_prompt + f"question: {question}" 97 | message_log.append(self.user_message(initial_prompt)) 98 | else: 99 | for i, example in question_sql_list: 100 | if example is None: 101 | print("example is None") 102 | else: 103 | if example is not None and "question" in example and "sql" in example: 104 | if i == 0: 105 | initial_prompt = initial_prompt + f"question: {example['question']}" 106 | message_log.append(self.user_message(initial_prompt)) 107 | else: 108 | message_log.append(self.user_message(example["question"])) 109 | message_log.append(self.assistant_message(example["sql"])) 110 | 111 | message_log.append(self.user_message(question)) 112 | return message_log 113 | 114 | def submit_prompt(self, prompt, **kwargs) -> str: 115 | if prompt is None: 116 | raise Exception("Prompt is None") 117 | 118 | if len(prompt) == 0: 119 | raise Exception("Prompt is empty") 120 | 121 | # Count the number of tokens in the message log 122 | # Use 4 as an approximation for the number of characters per token 123 | num_tokens = 0 124 | for message in prompt: 125 | num_tokens += len(message["content"]) / 4 126 | 127 | if kwargs.get("model", None) is not None: 128 | model = kwargs.get("model", None) 129 | print( 130 | f"Using model {model} for {num_tokens} tokens (approx)" 131 | ) 132 | response = self.client.do( 133 | model=self.model, 134 | messages=prompt, 135 | max_output_tokens=self.max_tokens, 136 | stop=None, 137 | temperature=self.temperature, 138 | ) 139 | elif self.config is not None and "model" in self.config: 140 | print( 141 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)" 142 | ) 143 | response = self.client.do( 144 | model=self.config.get("model"), 145 | messages=prompt, 146 | max_output_tokens=self.max_tokens, 147 | stop=None, 148 | temperature=self.temperature, 149 | ) 150 | else: 151 | if num_tokens > 3500: 152 | model = "ERNIE-Speed-128K" 153 | else: 154 | model = "ERNIE-Speed-8K" 155 | 156 | print(f"Using model {model} for {num_tokens} tokens (approx)") 157 | response = self.client.do( 158 | model=model, 159 | messages=prompt, 160 | max_output_tokens=self.max_tokens, 161 | stop=None, 162 | temperature=self.temperature, 163 | ) 164 | 165 | return response.body.get("result") 166 | -------------------------------------------------------------------------------- /src/vanna/qianfan/Qianfan_embeddings.py: -------------------------------------------------------------------------------- 1 | import qianfan 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class Qianfan_Embeddings(VannaBase): 7 | def __init__(self, client=None, config=None): 8 | VannaBase.__init__(self, config=config) 9 | 10 | if client is not None: 11 | self.client = client 12 | return 13 | 14 | if "api_key" not in config: 15 | raise Exception("Missing api_key in config") 16 | self.api_key = config["api_key"] 17 | 18 | if "secret_key" not in config: 19 | raise Exception("Missing secret_key in config") 20 | self.secret_key = config["secret_key"] 21 | 22 | self.client = qianfan.Embedding(ak=self.api_key, sk=self.secret_key) 23 | 24 | def generate_embedding(self, data: str, **kwargs) -> list[float]: 25 | if self.config is not None and "model" in self.config: 26 | embedding = self.client.do( 27 | model=self.config["model"], 28 | input=[data], 29 | ) 30 | else: 31 | embedding = self.client.do( 32 | model="bge-large-zh", 33 | input=[data], 34 | ) 35 | 36 | return embedding.get("data")[0]["embedding"] 37 | -------------------------------------------------------------------------------- /src/vanna/qianfan/__init__.py: -------------------------------------------------------------------------------- 1 | from .Qianfan_Chat import Qianfan_Chat 2 | from .Qianfan_embeddings import Qianfan_Embeddings 3 | -------------------------------------------------------------------------------- /src/vanna/qianwen/QianwenAI_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from openai import OpenAI 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | class QianWenAI_Chat(VannaBase): 9 | def __init__(self, client=None, config=None): 10 | VannaBase.__init__(self, config=config) 11 | 12 | # default parameters - can be overrided using config 13 | self.temperature = 0.7 14 | 15 | if "temperature" in config: 16 | self.temperature = config["temperature"] 17 | 18 | if "api_type" in config: 19 | raise Exception( 20 | "Passing api_type is now deprecated. Please pass an OpenAI client instead." 21 | ) 22 | 23 | if "api_base" in config: 24 | raise Exception( 25 | "Passing api_base is now deprecated. Please pass an OpenAI client instead." 26 | ) 27 | 28 | if "api_version" in config: 29 | raise Exception( 30 | "Passing api_version is now deprecated. Please pass an OpenAI client instead." 31 | ) 32 | 33 | if client is not None: 34 | self.client = client 35 | return 36 | 37 | if config is None and client is None: 38 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 39 | return 40 | 41 | if "api_key" in config: 42 | if "base_url" not in config: 43 | self.client = OpenAI(api_key=config["api_key"], 44 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") 45 | else: 46 | self.client = OpenAI(api_key=config["api_key"], 47 | base_url=config["base_url"]) 48 | 49 | def system_message(self, message: str) -> any: 50 | return {"role": "system", "content": message} 51 | 52 | def user_message(self, message: str) -> any: 53 | return {"role": "user", "content": message} 54 | 55 | def assistant_message(self, message: str) -> any: 56 | return {"role": "assistant", "content": message} 57 | 58 | def submit_prompt(self, prompt, **kwargs) -> str: 59 | if prompt is None: 60 | raise Exception("Prompt is None") 61 | 62 | if len(prompt) == 0: 63 | raise Exception("Prompt is empty") 64 | 65 | # Count the number of tokens in the message log 66 | # Use 4 as an approximation for the number of characters per token 67 | num_tokens = 0 68 | for message in prompt: 69 | num_tokens += len(message["content"]) / 4 70 | 71 | if kwargs.get("model", None) is not None: 72 | model = kwargs.get("model", None) 73 | print( 74 | f"Using model {model} for {num_tokens} tokens (approx)" 75 | ) 76 | response = self.client.chat.completions.create( 77 | model=model, 78 | messages=prompt, 79 | stop=None, 80 | temperature=self.temperature, 81 | ) 82 | elif kwargs.get("engine", None) is not None: 83 | engine = kwargs.get("engine", None) 84 | print( 85 | f"Using model {engine} for {num_tokens} tokens (approx)" 86 | ) 87 | response = self.client.chat.completions.create( 88 | engine=engine, 89 | messages=prompt, 90 | stop=None, 91 | temperature=self.temperature, 92 | ) 93 | elif self.config is not None and "engine" in self.config: 94 | print( 95 | f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" 96 | ) 97 | response = self.client.chat.completions.create( 98 | engine=self.config["engine"], 99 | messages=prompt, 100 | stop=None, 101 | temperature=self.temperature, 102 | ) 103 | elif self.config is not None and "model" in self.config: 104 | print( 105 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)" 106 | ) 107 | response = self.client.chat.completions.create( 108 | model=self.config["model"], 109 | messages=prompt, 110 | stop=None, 111 | temperature=self.temperature, 112 | ) 113 | else: 114 | if num_tokens > 3500: 115 | model = "qwen-long" 116 | else: 117 | model = "qwen-plus" 118 | 119 | print(f"Using model {model} for {num_tokens} tokens (approx)") 120 | response = self.client.chat.completions.create( 121 | model=model, 122 | messages=prompt, 123 | stop=None, 124 | temperature=self.temperature, 125 | ) 126 | 127 | # Find the first response from the chatbot that has text in it (some responses may not have text) 128 | for choice in response.choices: 129 | if "text" in choice: 130 | return choice.text 131 | 132 | # If no response with text is found, return the first response's content (which may be empty) 133 | return response.choices[0].message.content 134 | -------------------------------------------------------------------------------- /src/vanna/qianwen/QianwenAI_embeddings.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | from ..base import VannaBase 4 | 5 | 6 | class QianWenAI_Embeddings(VannaBase): 7 | def __init__(self, client=None, config=None): 8 | VannaBase.__init__(self, config=config) 9 | 10 | if client is not None: 11 | self.client = client 12 | return 13 | 14 | if self.client is not None: 15 | return 16 | 17 | self.client = OpenAI() 18 | 19 | if config is None: 20 | return 21 | 22 | if "api_type" in config: 23 | self.client.api_type = config["api_type"] 24 | 25 | if "api_base" in config: 26 | self.client.api_base = config["api_base"] 27 | 28 | if "api_version" in config: 29 | self.client.api_version = config["api_version"] 30 | 31 | if "api_key" in config: 32 | self.client.api_key = config["api_key"] 33 | 34 | def generate_embedding(self, data: str, **kwargs) -> list[float]: 35 | if self.config is not None and "engine" in self.config: 36 | embedding = self.client.embeddings.create( 37 | engine=self.config["engine"], 38 | input=data, 39 | ) 40 | else: 41 | embedding = self.client.embeddings.create( 42 | model="bge-large-zh", 43 | input=data, 44 | ) 45 | 46 | return embedding.get("data")[0]["embedding"] 47 | -------------------------------------------------------------------------------- /src/vanna/qianwen/__init__.py: -------------------------------------------------------------------------------- 1 | from .QianwenAI_chat import QianWenAI_Chat 2 | from .QianwenAI_embeddings import QianWenAI_Embeddings 3 | -------------------------------------------------------------------------------- /src/vanna/remote.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | from io import StringIO 4 | from typing import Callable, List, Tuple, Union 5 | 6 | import pandas as pd 7 | import requests 8 | 9 | from .base import VannaBase 10 | from .types import ( 11 | AccuracyStats, 12 | ApiKey, 13 | DataFrameJSON, 14 | DataResult, 15 | Explanation, 16 | FullQuestionDocument, 17 | NewOrganization, 18 | NewOrganizationMember, 19 | Organization, 20 | OrganizationList, 21 | PlotlyResult, 22 | Question, 23 | QuestionCategory, 24 | QuestionId, 25 | QuestionList, 26 | QuestionSQLPair, 27 | QuestionStringList, 28 | SQLAnswer, 29 | Status, 30 | StatusWithId, 31 | StringData, 32 | TrainingData, 33 | UserEmail, 34 | UserOTP, 35 | Visibility, 36 | ) 37 | from .vannadb import VannaDB_VectorStore 38 | 39 | 40 | class VannaDefault(VannaDB_VectorStore): 41 | def __init__(self, model: str, api_key: str, config=None): 42 | VannaBase.__init__(self, config=config) 43 | VannaDB_VectorStore.__init__(self, vanna_model=model, vanna_api_key=api_key, config=config) 44 | 45 | self._model = model 46 | self._api_key = api_key 47 | 48 | self._endpoint = ( 49 | "https://ask.vanna.ai/rpc" 50 | if config is None or "endpoint" not in config 51 | else config["endpoint"] 52 | ) 53 | 54 | def system_message(self, message: str) -> any: 55 | return {"role": "system", "content": message} 56 | 57 | def user_message(self, message: str) -> any: 58 | return {"role": "user", "content": message} 59 | 60 | def assistant_message(self, message: str) -> any: 61 | return {"role": "assistant", "content": message} 62 | 63 | def submit_prompt(self, prompt, **kwargs) -> str: 64 | # JSON-ify the prompt 65 | json_prompt = json.dumps(prompt, ensure_ascii=False) 66 | 67 | params = [StringData(data=json_prompt)] 68 | 69 | d = self._rpc_call(method="submit_prompt", params=params) 70 | 71 | if "result" not in d: 72 | return None 73 | 74 | # Load the result into a dataclass 75 | results = StringData(**d["result"]) 76 | 77 | return results.data 78 | -------------------------------------------------------------------------------- /src/vanna/types/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Dict, List, Union 5 | 6 | 7 | @dataclass 8 | class Status: 9 | success: bool 10 | message: str 11 | 12 | 13 | @dataclass 14 | class StatusWithId: 15 | success: bool 16 | message: str 17 | id: str 18 | 19 | 20 | @dataclass 21 | class QuestionList: 22 | questions: List[FullQuestionDocument] 23 | 24 | 25 | @dataclass 26 | class FullQuestionDocument: 27 | id: QuestionId 28 | question: Question 29 | answer: SQLAnswer | None 30 | data: DataResult | None 31 | plotly: PlotlyResult | None 32 | 33 | 34 | @dataclass 35 | class QuestionSQLPair: 36 | question: str 37 | sql: str 38 | tag: Union[str, None] 39 | 40 | 41 | @dataclass 42 | class Organization: 43 | name: str 44 | user: str | None 45 | connection: Connection | None 46 | 47 | 48 | @dataclass 49 | class OrganizationList: 50 | organizations: List[str] 51 | 52 | 53 | @dataclass 54 | class QuestionStringList: 55 | questions: List[str] 56 | 57 | 58 | @dataclass 59 | class Visibility: 60 | visibility: bool 61 | 62 | 63 | @dataclass 64 | class UserEmail: 65 | email: str 66 | 67 | 68 | @dataclass 69 | class NewOrganization: 70 | org_name: str 71 | db_type: str 72 | 73 | 74 | @dataclass 75 | class NewOrganizationMember: 76 | org_name: str 77 | email: str 78 | is_admin: bool 79 | 80 | 81 | @dataclass 82 | class UserOTP: 83 | email: str 84 | otp: str 85 | 86 | 87 | @dataclass 88 | class ApiKey: 89 | key: str 90 | 91 | 92 | @dataclass 93 | class QuestionId: 94 | id: str 95 | 96 | 97 | @dataclass 98 | class Question: 99 | question: str 100 | 101 | 102 | @dataclass 103 | class QuestionCategory: 104 | question: str 105 | category: str 106 | 107 | NO_SQL_GENERATED = "No SQL Generated" 108 | SQL_UNABLE_TO_RUN = "SQL Unable to Run" 109 | BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query" 110 | SQL_RAN = "SQL Ran Successfully" 111 | FLAGGED_FOR_REVIEW = "Flagged for Review" 112 | REVIEWED_AND_APPROVED = "Reviewed and Approved" 113 | REVIEWED_AND_REJECTED = "Reviewed and Rejected" 114 | REVIEWED_AND_UPDATED = "Reviewed and Updated" 115 | 116 | 117 | @dataclass 118 | class AccuracyStats: 119 | num_questions: int 120 | data: Dict[str, int] 121 | 122 | 123 | @dataclass 124 | class Followup: 125 | followup: str 126 | 127 | 128 | @dataclass 129 | class QuestionEmbedding: 130 | question: Question 131 | embedding: List[float] 132 | 133 | 134 | @dataclass 135 | class Connection: 136 | # TODO: implement 137 | pass 138 | 139 | 140 | @dataclass 141 | class SQLAnswer: 142 | raw_answer: str 143 | prefix: str 144 | postfix: str 145 | sql: str 146 | 147 | 148 | @dataclass 149 | class Explanation: 150 | explanation: str 151 | 152 | 153 | @dataclass 154 | class DataResult: 155 | question: str | None 156 | sql: str | None 157 | table_markdown: str 158 | error: str | None 159 | correction_attempts: int 160 | 161 | 162 | @dataclass 163 | class PlotlyResult: 164 | plotly_code: str 165 | 166 | 167 | @dataclass 168 | class WarehouseDefinition: 169 | name: str 170 | tables: List[TableDefinition] 171 | 172 | 173 | @dataclass 174 | class TableDefinition: 175 | schema_name: str 176 | table_name: str 177 | ddl: str | None 178 | columns: List[ColumnDefinition] 179 | 180 | 181 | @dataclass 182 | class ColumnDefinition: 183 | name: str 184 | type: str 185 | is_primary_key: bool 186 | is_foreign_key: bool 187 | foreign_key_table: str 188 | foreign_key_column: str 189 | 190 | 191 | @dataclass 192 | class Diagram: 193 | raw: str 194 | mermaid_code: str 195 | 196 | 197 | @dataclass 198 | class StringData: 199 | data: str 200 | 201 | 202 | @dataclass 203 | class DataFrameJSON: 204 | data: str 205 | 206 | 207 | @dataclass 208 | class TrainingData: 209 | questions: List[dict] 210 | ddl: List[str] 211 | documentation: List[str] 212 | 213 | 214 | @dataclass 215 | class TrainingPlanItem: 216 | item_type: str 217 | item_group: str 218 | item_name: str 219 | item_value: str 220 | 221 | def __str__(self): 222 | if self.item_type == self.ITEM_TYPE_SQL: 223 | return f"Train on SQL: {self.item_group} {self.item_name}" 224 | elif self.item_type == self.ITEM_TYPE_DDL: 225 | return f"Train on DDL: {self.item_group} {self.item_name}" 226 | elif self.item_type == self.ITEM_TYPE_IS: 227 | return f"Train on Information Schema: {self.item_group} {self.item_name}" 228 | 229 | ITEM_TYPE_SQL = "sql" 230 | ITEM_TYPE_DDL = "ddl" 231 | ITEM_TYPE_IS = "is" 232 | 233 | 234 | class TrainingPlan: 235 | """ 236 | A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained. 237 | 238 | **Example:** 239 | ```python 240 | plan = vn.get_training_plan() 241 | 242 | plan.get_summary() 243 | ``` 244 | 245 | """ 246 | 247 | _plan: List[TrainingPlanItem] 248 | 249 | def __init__(self, plan: List[TrainingPlanItem]): 250 | self._plan = plan 251 | 252 | def __str__(self): 253 | return "\n".join(self.get_summary()) 254 | 255 | def __repr__(self): 256 | return self.__str__() 257 | 258 | def get_summary(self) -> List[str]: 259 | """ 260 | **Example:** 261 | ```python 262 | plan = vn.get_training_plan() 263 | 264 | plan.get_summary() 265 | ``` 266 | 267 | Get a summary of the training plan. 268 | 269 | Returns: 270 | List[str]: A list of strings describing the training plan. 271 | """ 272 | 273 | return [f"{item}" for item in self._plan] 274 | 275 | def remove_item(self, item: str): 276 | """ 277 | **Example:** 278 | ```python 279 | plan = vn.get_training_plan() 280 | 281 | plan.remove_item("Train on SQL: What is the average salary of employees?") 282 | ``` 283 | 284 | Remove an item from the training plan. 285 | 286 | Args: 287 | item (str): The item to remove. 288 | """ 289 | for plan_item in self._plan: 290 | if str(plan_item) == item: 291 | self._plan.remove(plan_item) 292 | break 293 | -------------------------------------------------------------------------------- /src/vanna/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import re 4 | import uuid 5 | from typing import Union 6 | 7 | from .exceptions import ImproperlyConfigured, ValidationError 8 | 9 | 10 | def validate_config_path(path): 11 | if not os.path.exists(path): 12 | raise ImproperlyConfigured( 13 | f'No such configuration file: {path}' 14 | ) 15 | 16 | if not os.path.isfile(path): 17 | raise ImproperlyConfigured( 18 | f'Config should be a file: {path}' 19 | ) 20 | 21 | if not os.access(path, os.R_OK): 22 | raise ImproperlyConfigured( 23 | f'Cannot read the config file. Please grant read privileges: {path}' 24 | ) 25 | 26 | 27 | def sanitize_model_name(model_name): 28 | try: 29 | model_name = model_name.lower() 30 | 31 | # Replace spaces with a hyphen 32 | model_name = model_name.replace(" ", "-") 33 | 34 | if '-' in model_name: 35 | 36 | # remove double hyphones 37 | model_name = re.sub(r"-+", "-", model_name) 38 | if '_' in model_name: 39 | # If name contains both underscores and hyphen replace all underscores with hyphens 40 | model_name = re.sub(r'_', '-', model_name) 41 | 42 | # Remove special characters only allow underscore 43 | model_name = re.sub(r"[^a-zA-Z0-9-_]", "", model_name) 44 | 45 | # Remove hyphen or underscore if any at the last or first 46 | if model_name[-1] in ("-", "_"): 47 | model_name = model_name[:-1] 48 | if model_name[0] in ("-", "_"): 49 | model_name = model_name[1:] 50 | 51 | return model_name 52 | except Exception as e: 53 | raise ValidationError(e) 54 | 55 | 56 | def deterministic_uuid(content: Union[str, bytes]) -> str: 57 | """Creates deterministic UUID on hash value of string or byte content. 58 | 59 | Args: 60 | content: String or byte representation of data. 61 | 62 | Returns: 63 | UUID of the content. 64 | """ 65 | if isinstance(content, str): 66 | content_bytes = content.encode("utf-8") 67 | elif isinstance(content, bytes): 68 | content_bytes = content 69 | else: 70 | raise ValueError(f"Content type {type(content)} not supported !") 71 | 72 | hash_object = hashlib.sha256(content_bytes) 73 | hash_hex = hash_object.hexdigest() 74 | namespace = uuid.UUID("00000000-0000-0000-0000-000000000000") 75 | content_uuid = str(uuid.uuid5(namespace, hash_hex)) 76 | 77 | return content_uuid 78 | -------------------------------------------------------------------------------- /src/vanna/vannadb/__init__.py: -------------------------------------------------------------------------------- 1 | from .vannadb_vector import VannaDB_VectorStore 2 | -------------------------------------------------------------------------------- /src/vanna/vllm/__init__.py: -------------------------------------------------------------------------------- 1 | from .vllm import Vllm 2 | -------------------------------------------------------------------------------- /src/vanna/vllm/vllm.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import requests 4 | 5 | from ..base import VannaBase 6 | 7 | 8 | class Vllm(VannaBase): 9 | def __init__(self, config=None): 10 | if config is None or "vllm_host" not in config: 11 | self.host = "http://localhost:8000" 12 | else: 13 | self.host = config["vllm_host"] 14 | 15 | if config is None or "model" not in config: 16 | raise ValueError("check the config for vllm") 17 | else: 18 | self.model = config["model"] 19 | 20 | if "auth-key" in config: 21 | self.auth_key = config["auth-key"] 22 | else: 23 | self.auth_key = None 24 | 25 | if "temperature" in config: 26 | self.temperature = config["temperature"] 27 | else: 28 | # default temperature - can be overrided using config 29 | self.temperature = 0.7 30 | 31 | def system_message(self, message: str) -> any: 32 | return {"role": "system", "content": message} 33 | 34 | def user_message(self, message: str) -> any: 35 | return {"role": "user", "content": message} 36 | 37 | def assistant_message(self, message: str) -> any: 38 | return {"role": "assistant", "content": message} 39 | 40 | def extract_sql_query(self, text): 41 | """ 42 | Extracts the first SQL statement after the word 'select', ignoring case, 43 | matches until the first semicolon, three backticks, or the end of the string, 44 | and removes three backticks if they exist in the extracted string. 45 | 46 | Args: 47 | - text (str): The string to search within for an SQL statement. 48 | 49 | Returns: 50 | - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. 51 | """ 52 | # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string 53 | pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) 54 | 55 | match = pattern.search(text) 56 | if match: 57 | # Remove three backticks from the matched string if they exist 58 | return match.group(0).replace("```", "") 59 | else: 60 | return text 61 | 62 | def generate_sql(self, question: str, **kwargs) -> str: 63 | # Use the super generate_sql 64 | sql = super().generate_sql(question, **kwargs) 65 | 66 | # Replace "\_" with "_" 67 | sql = sql.replace("\\_", "_") 68 | 69 | sql = sql.replace("\\", "") 70 | 71 | return self.extract_sql_query(sql) 72 | 73 | def submit_prompt(self, prompt, **kwargs) -> str: 74 | url = f"{self.host}/v1/chat/completions" 75 | data = { 76 | "model": self.model, 77 | "temperature": self.temperature, 78 | "stream": False, 79 | "messages": prompt, 80 | } 81 | 82 | if self.auth_key is not None: 83 | headers = { 84 | 'Content-Type': 'application/json', 85 | 'Authorization': f'Bearer {self.auth_key}' 86 | } 87 | 88 | response = requests.post(url, headers=headers,json=data) 89 | 90 | 91 | else: 92 | response = requests.post(url, json=data) 93 | 94 | response_dict = response.json() 95 | 96 | self.log(response.text) 97 | 98 | return response_dict['choices'][0]['message']['content'] 99 | -------------------------------------------------------------------------------- /src/vanna/weaviate/__init__.py: -------------------------------------------------------------------------------- 1 | from .weaviate_vector import WeaviateDatabase 2 | -------------------------------------------------------------------------------- /src/vanna/weaviate/weaviate_vector.py: -------------------------------------------------------------------------------- 1 | import weaviate 2 | import weaviate.classes as wvc 3 | from fastembed import TextEmbedding 4 | 5 | from vanna.base import VannaBase 6 | 7 | 8 | class WeaviateDatabase(VannaBase): 9 | 10 | def __init__(self, config=None): 11 | """ 12 | Initialize the VannaEnhanced class with the provided configuration. 13 | 14 | :param config: Dictionary containing configuration parameters. 15 | 16 | params: 17 | weaviate_url (str): Weaviate cluster URL while using weaviate cloud, 18 | weaviate_api_key (str): Weaviate API key while using weaviate cloud, 19 | weaviate_port (num): Weaviate port while using local weaviate, 20 | weaviate_grpc (num): Weaviate gRPC port while using local weaviate, 21 | fastembed_model (str): Fastembed model name for text embeddings. BAAI/bge-small-en-v1.5 by default. 22 | 23 | """ 24 | super().__init__(config=config) 25 | 26 | if config is None: 27 | raise ValueError("config is required") 28 | 29 | self.n_results = config.get("n_results", 3) 30 | self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") 31 | self.weaviate_api_key = config.get("weaviate_api_key") 32 | self.weaviate_url = config.get("weaviate_url") 33 | self.weaviate_port = config.get("weaviate_port") 34 | self.weaviate_grpc_port = config.get("weaviate_grpc", 50051) 35 | 36 | if not self.weaviate_api_key and not self.weaviate_port: 37 | raise ValueError("Add proper credentials to connect to weaviate") 38 | 39 | self.weaviate_client = self._initialize_weaviate_client() 40 | self.embeddings = TextEmbedding(model_name=self.fastembed_model) 41 | 42 | self.training_data_cluster = { 43 | "sql": "SQLTrainingDataEntry", 44 | "ddl": "DDLEntry", 45 | "doc": "DocumentationEntry" 46 | } 47 | 48 | self._create_collections_if_not_exist() 49 | 50 | def _create_collections_if_not_exist(self): 51 | properties_dict = { 52 | self.training_data_cluster['ddl']: [ 53 | wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT), 54 | ], 55 | self.training_data_cluster['doc']: [ 56 | wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT), 57 | ], 58 | self.training_data_cluster['sql']: [ 59 | wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT), 60 | wvc.config.Property(name="natural_language_question", data_type=wvc.config.DataType.TEXT), 61 | ] 62 | } 63 | 64 | for cluster, properties in properties_dict.items(): 65 | if not self.weaviate_client.collections.exists(cluster): 66 | self.weaviate_client.collections.create( 67 | name=cluster, 68 | properties=properties 69 | ) 70 | 71 | def _initialize_weaviate_client(self): 72 | if self.weaviate_api_key: 73 | return weaviate.connect_to_wcs( 74 | cluster_url=self.weaviate_url, 75 | auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key), 76 | additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)), 77 | skip_init_checks=True 78 | ) 79 | else: 80 | return weaviate.connect_to_local( 81 | port=self.weaviate_port, 82 | grpc_port=self.weaviate_grpc_port, 83 | additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)), 84 | skip_init_checks=True 85 | ) 86 | 87 | def generate_embedding(self, data: str, **kwargs): 88 | embedding_model = TextEmbedding(model_name=self.fastembed_model) 89 | embedding = next(embedding_model.embed(data)) 90 | return embedding.tolist() 91 | 92 | 93 | def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str: 94 | self.weaviate_client.connect() 95 | response = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]).data.insert( 96 | properties=data_object, 97 | vector=vector 98 | ) 99 | self.weaviate_client.close() 100 | return response 101 | 102 | def add_ddl(self, ddl: str, **kwargs) -> str: 103 | data_object = { 104 | "description": ddl, 105 | } 106 | response = self._insert_data('ddl', data_object, self.generate_embedding(ddl)) 107 | return f'{response}-ddl' 108 | 109 | def add_documentation(self, doc: str, **kwargs) -> str: 110 | data_object = { 111 | "description": doc, 112 | } 113 | response = self._insert_data('doc', data_object, self.generate_embedding(doc)) 114 | return f'{response}-doc' 115 | 116 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str: 117 | data_object = { 118 | "sql": sql, 119 | "natural_language_question": question, 120 | } 121 | response = self._insert_data('sql', data_object, self.generate_embedding(question)) 122 | return f'{response}-sql' 123 | 124 | def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list) -> list: 125 | self.weaviate_client.connect() 126 | collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]) 127 | response = collection.query.near_vector( 128 | near_vector=vector_input, 129 | limit=self.n_results, 130 | return_properties=return_properties 131 | ) 132 | response_list = [item.properties for item in response.objects] 133 | self.weaviate_client.close() 134 | return response_list 135 | 136 | def get_related_ddl(self, question: str, **kwargs) -> list: 137 | vector_input = self.generate_embedding(question) 138 | response_list = self._query_collection('ddl', vector_input, ["description"]) 139 | return [item["description"] for item in response_list] 140 | 141 | def get_related_documentation(self, question: str, **kwargs) -> list: 142 | vector_input = self.generate_embedding(question) 143 | response_list = self._query_collection('doc', vector_input, ["description"]) 144 | return [item["description"] for item in response_list] 145 | 146 | def get_similar_question_sql(self, question: str, **kwargs) -> list: 147 | vector_input = self.generate_embedding(question) 148 | response_list = self._query_collection('sql', vector_input, ["sql", "natural_language_question"]) 149 | return [{"question": item["natural_language_question"], "sql": item["sql"]} for item in response_list] 150 | 151 | def get_training_data(self, **kwargs) -> list: 152 | self.weaviate_client.connect() 153 | combined_response_list = [] 154 | for collection_name in self.training_data_cluster.values(): 155 | if self.weaviate_client.collections.exists(collection_name): 156 | collection = self.weaviate_client.collections.get(collection_name) 157 | response_list = [item.properties for item in collection.iterator()] 158 | combined_response_list.extend(response_list) 159 | self.weaviate_client.close() 160 | return combined_response_list 161 | 162 | def remove_training_data(self, id: str, **kwargs) -> bool: 163 | self.weaviate_client.connect() 164 | success = False 165 | if id.endswith("-sql"): 166 | id = id.replace('-sql', '') 167 | success = self.weaviate_client.collections.get(self.training_data_cluster['sql']).data.delete_by_id(id) 168 | elif id.endswith("-ddl"): 169 | id = id.replace('-ddl', '') 170 | success = self.weaviate_client.collections.get(self.training_data_cluster['ddl']).data.delete_by_id(id) 171 | elif id.endswith("-doc"): 172 | id = id.replace('-doc', '') 173 | success = self.weaviate_client.collections.get(self.training_data_cluster['doc']).data.delete_by_id(id) 174 | self.weaviate_client.close() 175 | return success 176 | -------------------------------------------------------------------------------- /src/vanna/xinference/__init__.py: -------------------------------------------------------------------------------- 1 | from .xinference import Xinference 2 | -------------------------------------------------------------------------------- /src/vanna/xinference/xinference.py: -------------------------------------------------------------------------------- 1 | from xinference_client.client.restful.restful_client import ( 2 | Client, 3 | RESTfulChatModelHandle, 4 | ) 5 | 6 | from ..base import VannaBase 7 | 8 | 9 | class Xinference(VannaBase): 10 | def __init__(self, config=None): 11 | VannaBase.__init__(self, config=config) 12 | 13 | if not config or "base_url" not in config: 14 | raise ValueError("config must contain at least Xinference base_url") 15 | 16 | base_url = config["base_url"] 17 | api_key = config.get("api_key", "not empty") 18 | self.xinference_client = Client(base_url=base_url, api_key=api_key) 19 | 20 | def system_message(self, message: str) -> any: 21 | return {"role": "system", "content": message} 22 | 23 | def user_message(self, message: str) -> any: 24 | return {"role": "user", "content": message} 25 | 26 | def assistant_message(self, message: str) -> any: 27 | return {"role": "assistant", "content": message} 28 | 29 | def submit_prompt(self, prompt, **kwargs) -> str: 30 | if prompt is None: 31 | raise Exception("Prompt is None") 32 | 33 | if len(prompt) == 0: 34 | raise Exception("Prompt is empty") 35 | 36 | num_tokens = 0 37 | for message in prompt: 38 | num_tokens += len(message["content"]) / 4 39 | 40 | model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None) 41 | if model_uid is None: 42 | raise ValueError("model_uid is required") 43 | 44 | xinference_model = self.xinference_client.get_model(model_uid) 45 | if isinstance(xinference_model, RESTfulChatModelHandle): 46 | print( 47 | f"Using model_uid {model_uid} for {num_tokens} tokens (approx)" 48 | ) 49 | 50 | response = xinference_model.chat(prompt) 51 | return response["choices"][0]["message"]["content"] 52 | else: 53 | raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle") 54 | -------------------------------------------------------------------------------- /tests/test_imports.py: -------------------------------------------------------------------------------- 1 | def test_regular_imports(): 2 | from vanna.anthropic.anthropic_chat import Anthropic_Chat 3 | from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore 4 | from vanna.base.base import VannaBase 5 | from vanna.bedrock.bedrock_converse import Bedrock_Converse 6 | from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore 7 | from vanna.cohere.cohere_chat import Cohere_Chat 8 | from vanna.cohere.cohere_embeddings import Cohere_Embeddings 9 | from vanna.faiss.faiss import FAISS 10 | from vanna.google.bigquery_vector import BigQuery_VectorStore 11 | from vanna.google.gemini_chat import GoogleGeminiChat 12 | from vanna.hf.hf import Hf 13 | from vanna.local import LocalContext_OpenAI 14 | from vanna.marqo.marqo import Marqo_VectorStore 15 | from vanna.milvus.milvus_vector import Milvus_VectorStore 16 | from vanna.mistral.mistral import Mistral 17 | from vanna.ollama.ollama import Ollama 18 | from vanna.openai.openai_chat import OpenAI_Chat 19 | from vanna.openai.openai_embeddings import OpenAI_Embeddings 20 | from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore 21 | from vanna.opensearch.opensearch_vector_semantic import ( 22 | OpenSearch_Semantic_VectorStore, 23 | ) 24 | from vanna.pgvector.pgvector import PG_VectorStore 25 | from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore 26 | from vanna.qdrant.qdrant import Qdrant_VectorStore 27 | from vanna.qianfan.Qianfan_Chat import Qianfan_Chat 28 | from vanna.qianfan.Qianfan_embeddings import Qianfan_Embeddings 29 | from vanna.qianwen.QianwenAI_chat import QianWenAI_Chat 30 | from vanna.qianwen.QianwenAI_embeddings import QianWenAI_Embeddings 31 | from vanna.remote import VannaDefault 32 | from vanna.vannadb.vannadb_vector import VannaDB_VectorStore 33 | from vanna.weaviate.weaviate_vector import WeaviateDatabase 34 | from vanna.xinference.xinference import Xinference 35 | from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat 36 | from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings 37 | 38 | def test_shortcut_imports(): 39 | from vanna.anthropic import Anthropic_Chat 40 | from vanna.azuresearch import AzureAISearch_VectorStore 41 | from vanna.base import VannaBase 42 | from vanna.chromadb import ChromaDB_VectorStore 43 | from vanna.cohere import Cohere_Chat, Cohere_Embeddings 44 | from vanna.faiss import FAISS 45 | from vanna.hf import Hf 46 | from vanna.marqo import Marqo_VectorStore 47 | from vanna.milvus import Milvus_VectorStore 48 | from vanna.mistral import Mistral 49 | from vanna.ollama import Ollama 50 | from vanna.openai import OpenAI_Chat, OpenAI_Embeddings 51 | from vanna.opensearch import ( 52 | OpenSearch_Semantic_VectorStore, 53 | OpenSearch_VectorStore, 54 | ) 55 | from vanna.pgvector import PG_VectorStore 56 | from vanna.pinecone import PineconeDB_VectorStore 57 | from vanna.qdrant import Qdrant_VectorStore 58 | from vanna.qianfan import Qianfan_Chat, Qianfan_Embeddings 59 | from vanna.qianwen import QianWenAI_Chat, QianWenAI_Embeddings 60 | from vanna.vannadb import VannaDB_VectorStore 61 | from vanna.vllm import Vllm 62 | from vanna.weaviate import WeaviateDatabase 63 | from vanna.xinference import Xinference 64 | from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings 65 | -------------------------------------------------------------------------------- /tests/test_instantiation.py: -------------------------------------------------------------------------------- 1 | from vanna.mock import MockEmbedding, MockLLM, MockVectorDB 2 | -------------------------------------------------------------------------------- /tests/test_pgvector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dotenv import load_dotenv 4 | 5 | # from vanna.pgvector import PG_VectorStore 6 | # from vanna.openai import OpenAI_Chat 7 | 8 | # assume .env file placed next to file with provided env vars 9 | load_dotenv() 10 | 11 | # def get_vanna_connection_string(): 12 | # server = os.environ.get("PG_SERVER") 13 | # driver = "psycopg" 14 | # port = os.environ.get("PG_PORT", 5432) 15 | # database = os.environ.get("PG_DATABASE") 16 | # username = os.environ.get("PG_USERNAME") 17 | # password = os.environ.get("PG_PASSWORD") 18 | 19 | # def test_pgvector_e2e(): 20 | # # configure Vanna to use OpenAI and PGVector 21 | # class VannaCustom(PG_VectorStore, OpenAI_Chat): 22 | # def __init__(self, config=None): 23 | # PG_VectorStore.__init__(self, config=config) 24 | # OpenAI_Chat.__init__(self, config=config) 25 | 26 | # vn = VannaCustom(config={ 27 | # 'api_key': os.environ['OPENAI_API_KEY'], 28 | # 'model': 'gpt-3.5-turbo', 29 | # "connection_string": get_vanna_connection_string(), 30 | # }) 31 | 32 | # # connect to SQLite database 33 | # vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 34 | 35 | # # train Vanna on DDLs 36 | # df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") 37 | # for ddl in df_ddl['sql'].to_list(): 38 | # vn.train(ddl=ddl) 39 | # assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default 40 | 41 | # question = "What are the top 7 customers by sales?" 42 | # sql = vn.generate_sql(question) 43 | # df = vn.run_sql(sql) 44 | # assert len(df) == 7 45 | 46 | # # test if Vanna can generate an answer 47 | # answer = vn.ask(question) 48 | # assert answer is not None 49 | 50 | -------------------------------------------------------------------------------- /tests/test_vanna.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from vanna.anthropic.anthropic_chat import Anthropic_Chat 4 | from vanna.cohere.cohere_chat import Cohere_Chat 5 | from vanna.google import GoogleGeminiChat 6 | from vanna.mistral.mistral import Mistral 7 | from vanna.openai.openai_chat import OpenAI_Chat 8 | from vanna.remote import VannaDefault 9 | from vanna.vannadb.vannadb_vector import VannaDB_VectorStore 10 | 11 | try: 12 | print("Trying to load .env") 13 | from dotenv import load_dotenv 14 | load_dotenv() 15 | except Exception as e: 16 | print(f"Failed to load .env {e}") 17 | pass 18 | 19 | MY_VANNA_MODEL = 'chinook' 20 | ANTHROPIC_Model = 'claude-3-sonnet-20240229' 21 | MY_VANNA_API_KEY = os.environ['VANNA_API_KEY'] 22 | OPENAI_API_KEY = os.environ['OPENAI_API_KEY'] 23 | MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY'] 24 | ANTHROPIC_API_KEY = os.environ['ANTHROPIC_API_KEY'] 25 | SNOWFLAKE_ACCOUNT = os.environ['SNOWFLAKE_ACCOUNT'] 26 | SNOWFLAKE_USERNAME = os.environ['SNOWFLAKE_USERNAME'] 27 | SNOWFLAKE_PASSWORD = os.environ['SNOWFLAKE_PASSWORD'] 28 | # AZURE_SEARCH_API_KEY = os.environ['AZURE_SEARCH_API_KEY'] 29 | 30 | class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat): 31 | def __init__(self, config=None): 32 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) 33 | OpenAI_Chat.__init__(self, config=config) 34 | 35 | vn_openai = VannaOpenAI(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) 36 | vn_openai.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 37 | 38 | def test_vn_openai(): 39 | sql = vn_openai.generate_sql("What are the top 4 customers by sales?") 40 | df = vn_openai.run_sql(sql) 41 | assert len(df) == 4 42 | 43 | class VannaMistral(VannaDB_VectorStore, Mistral): 44 | def __init__(self, config=None): 45 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) 46 | Mistral.__init__(self, config={'api_key': MISTRAL_API_KEY, 'model': 'mistral-tiny'}) 47 | 48 | vn_mistral = VannaMistral() 49 | vn_mistral.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 50 | 51 | def test_vn_mistral(): 52 | sql = vn_mistral.generate_sql("What are the top 5 customers by sales?") 53 | df = vn_mistral.run_sql(sql) 54 | assert len(df) == 5 55 | 56 | vn_default = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY) 57 | vn_default.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 58 | 59 | def test_vn_default(): 60 | sql = vn_default.generate_sql("What are the top 6 customers by sales?") 61 | df = vn_default.run_sql(sql) 62 | assert len(df) == 6 63 | 64 | from vanna.qdrant import Qdrant_VectorStore 65 | 66 | 67 | class VannaQdrant(Qdrant_VectorStore, OpenAI_Chat): 68 | def __init__(self, config=None): 69 | Qdrant_VectorStore.__init__(self, config=config) 70 | OpenAI_Chat.__init__(self, config=config) 71 | 72 | from qdrant_client import QdrantClient 73 | 74 | qdrant_memory_client = QdrantClient(":memory:") 75 | 76 | vn_qdrant = VannaQdrant(config={'client': qdrant_memory_client, 'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) 77 | vn_qdrant.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 78 | 79 | def test_vn_qdrant(): 80 | df_ddl = vn_qdrant.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") 81 | 82 | for ddl in df_ddl['sql'].to_list(): 83 | vn_qdrant.train(ddl=ddl) 84 | 85 | sql = vn_qdrant.generate_sql("What are the top 7 customers by sales?") 86 | df = vn_qdrant.run_sql(sql) 87 | assert len(df) == 7 88 | 89 | from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore 90 | from vanna.openai.openai_chat import OpenAI_Chat 91 | 92 | 93 | class MyVanna(ChromaDB_VectorStore, OpenAI_Chat): 94 | def __init__(self, config=None): 95 | ChromaDB_VectorStore.__init__(self, config=config) 96 | OpenAI_Chat.__init__(self, config=config) 97 | 98 | vn_chroma = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) 99 | vn_chroma.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 100 | 101 | def test_vn_chroma(): 102 | existing_training_data = vn_chroma.get_training_data() 103 | if len(existing_training_data) > 0: 104 | for _, training_data in existing_training_data.iterrows(): 105 | vn_chroma.remove_training_data(training_data['id']) 106 | 107 | df_ddl = vn_chroma.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") 108 | 109 | for ddl in df_ddl['sql'].to_list(): 110 | vn_chroma.train(ddl=ddl) 111 | 112 | sql = vn_chroma.generate_sql("What are the top 7 customers by sales?") 113 | df = vn_chroma.run_sql(sql) 114 | assert len(df) == 7 115 | 116 | # from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore 117 | 118 | 119 | # class VannaAzureSearch(AzureAISearch_VectorStore, OpenAI_Chat): 120 | # def __init__(self, config=None): 121 | # AzureAISearch_VectorStore.__init__(self, config=config) 122 | # OpenAI_Chat.__init__(self, config=config) 123 | 124 | # vn_azure_search = VannaAzureSearch(config={'azure_search_api_key': AZURE_SEARCH_API_KEY,'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) 125 | # vn_azure_search.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 126 | 127 | # def test_vn_azure_search(): 128 | # existing_training_data = vn_azure_search.get_training_data() 129 | # print(existing_training_data) 130 | # if len(existing_training_data) > 0: 131 | # for _, training_data in existing_training_data.iterrows(): 132 | # vn_azure_search.remove_training_data(training_data['id']) 133 | 134 | # df_ddl = vn_azure_search.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") 135 | # for ddl in df_ddl['sql'].to_list(): 136 | # vn_azure_search.train(ddl=ddl) 137 | 138 | # sql = vn_azure_search.generate_sql("What are the top 7 customers by sales?") 139 | # df = vn_azure_search.run_sql(sql) 140 | # assert len(df) == 7 141 | 142 | from vanna.milvus import Milvus_VectorStore 143 | 144 | 145 | class VannaMilvus(Milvus_VectorStore, OpenAI_Chat): 146 | def __init__(self, config=None): 147 | Milvus_VectorStore.__init__(self, config=config) 148 | OpenAI_Chat.__init__(self, config=config) 149 | 150 | vn_milvus = VannaMilvus(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) 151 | vn_milvus.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 152 | 153 | def test_vn_milvus(): 154 | existing_training_data = vn_milvus.get_training_data() 155 | if len(existing_training_data) > 0: 156 | for _, training_data in existing_training_data.iterrows(): 157 | vn_milvus.remove_training_data(training_data['id']) 158 | 159 | df_ddl = vn_milvus.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") 160 | 161 | for ddl in df_ddl['sql'].to_list(): 162 | vn_milvus.train(ddl=ddl) 163 | 164 | sql = vn_milvus.generate_sql("What are the top 7 customers by sales?") 165 | df = vn_milvus.run_sql(sql) 166 | assert len(df) == 7 167 | 168 | 169 | class VannaNumResults(ChromaDB_VectorStore, OpenAI_Chat): 170 | def __init__(self, config=None): 171 | ChromaDB_VectorStore.__init__(self, config=config) 172 | OpenAI_Chat.__init__(self, config=config) 173 | 174 | vn_chroma_n_results = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results': 1}) 175 | vn_chroma_n_results_ddl = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_ddl': 2}) 176 | vn_chroma_n_results_sql = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_sql': 3}) 177 | vn_chroma_n_results_documentation = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_documentation': 4}) 178 | 179 | def test_n_results(): 180 | for i in range(1, 10): 181 | vn_chroma.train(question=f"What are the total sales for customer {i}?", sql=f"SELECT SUM(sales) FROM example_sales WHERE customer_id = {i}") 182 | 183 | for i in range(1, 10): 184 | vn_chroma.train(documentation=f"Sample documentation {i}") 185 | 186 | question = "Whare are the top 5 customers by sales?" 187 | assert len(vn_chroma_n_results.get_related_ddl(question)) == 1 188 | assert len(vn_chroma_n_results.get_related_documentation(question)) == 1 189 | assert len(vn_chroma_n_results.get_similar_question_sql(question)) == 1 190 | 191 | assert len(vn_chroma_n_results_ddl.get_related_ddl(question)) == 2 192 | assert len(vn_chroma_n_results_ddl.get_related_documentation(question)) != 2 193 | assert len(vn_chroma_n_results_ddl.get_similar_question_sql(question)) != 2 194 | 195 | assert len(vn_chroma_n_results_sql.get_related_ddl(question)) != 3 196 | assert len(vn_chroma_n_results_sql.get_related_documentation(question)) != 3 197 | assert len(vn_chroma_n_results_sql.get_similar_question_sql(question)) == 3 198 | 199 | assert len(vn_chroma_n_results_documentation.get_related_ddl(question)) != 4 200 | assert len(vn_chroma_n_results_documentation.get_related_documentation(question)) == 4 201 | assert len(vn_chroma_n_results_documentation.get_similar_question_sql(question)) != 4 202 | 203 | class VannaClaude(VannaDB_VectorStore, Anthropic_Chat): 204 | def __init__(self, config=None): 205 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) 206 | Anthropic_Chat.__init__(self, config={'api_key': ANTHROPIC_API_KEY, 'model': ANTHROPIC_Model}) 207 | 208 | 209 | vn_claude = VannaClaude() 210 | vn_claude.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 211 | 212 | 213 | def test_vn_claude(): 214 | sql = vn_claude.generate_sql("What are the top 8 customers by sales?") 215 | df = vn_claude.run_sql(sql) 216 | assert len(df) == 8 217 | 218 | class VannaGemini(VannaDB_VectorStore, GoogleGeminiChat): 219 | def __init__(self, config=None): 220 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) 221 | GoogleGeminiChat.__init__(self, config=config) 222 | 223 | vn_gemini = VannaGemini(config={'api_key': os.environ['GEMINI_API_KEY']}) 224 | vn_gemini.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 225 | 226 | def test_vn_gemini(): 227 | sql = vn_gemini.generate_sql("What are the top 9 customers by sales?") 228 | df = vn_gemini.run_sql(sql) 229 | assert len(df) == 9 230 | 231 | class VannaCohere(VannaDB_VectorStore, Cohere_Chat): 232 | def __init__(self, config=None): 233 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) 234 | Cohere_Chat.__init__(self, config=config) 235 | 236 | try: 237 | COHERE_API_KEY = os.environ['COHERE_API_KEY'] 238 | vn_cohere = VannaCohere(config={'api_key': COHERE_API_KEY, 'model': 'command-a-03-2025'}) 239 | vn_cohere.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') 240 | 241 | def test_vn_cohere(): 242 | sql = vn_cohere.generate_sql("What are the top 10 customers by sales?") 243 | df = vn_cohere.run_sql(sql) 244 | assert len(df) == 10 245 | except KeyError: 246 | print("Skipping Cohere tests - COHERE_API_KEY not found in environment variables") 247 | 248 | def test_training_plan(): 249 | vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY) 250 | 251 | vn_dummy.connect_to_snowflake( 252 | account=SNOWFLAKE_ACCOUNT, 253 | username=SNOWFLAKE_USERNAME, 254 | password=SNOWFLAKE_PASSWORD, 255 | database='SNOWFLAKE_SAMPLE_DATA', 256 | ) 257 | 258 | df_information_schema = vn_dummy.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = 'TPCH_SF1' ") 259 | 260 | plan = vn_dummy.get_training_plan_generic(df_information_schema) 261 | assert len(plan._plan) == 8 262 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | py310, 4 | mac, 5 | flake8, 6 | 7 | [py] 8 | deps= 9 | pytest-cov 10 | pytest-remove-stale-bytecode 11 | 12 | [testenv:py310] 13 | deps= 14 | {[py]deps} 15 | extras = all 16 | passenv = * 17 | basepython = python3.10 18 | commands = pytest -v --cov=tests/ --cov-report=term --cov-report=html 19 | 20 | [testenv:mac] 21 | deps= 22 | {[py]deps} 23 | python-dotenv 24 | extras = all 25 | basepython = python 26 | commands = 27 | pytest -x -v --cov=tests/ --cov-report=term --cov-report=html 28 | 29 | [testenv:flake8] 30 | exclude = .tox/* 31 | deps = flake8 32 | commands = flake8 src 33 | -------------------------------------------------------------------------------- /training_data/sample-imdb/questions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question":"what are 5 most grossing movies in IMDB top 1000 ", 4 | "answer":"SELECT series_title,\n gross\nFROM imdb.public.movies\nORDER BY gross desc limit 5;" 5 | }, 6 | { 7 | "question":"what are the top 5 movies and their ratings basis IMDB rating ? ", 8 | "answer":"SELECT series_title,\n imdb_rating\nFROM imdb.public.movies\nORDER BY imdb_rating desc limit 5" 9 | }, 10 | { 11 | "question":"which 5 director have the most number of movies in the IMDB top 1000 ? ", 12 | "answer":"SELECT director,\r\n count(*) as num_of_movies\r\nFROM imdb.public.movies\r\nGROUP BY director\r\nORDER BY num_of_movies desc limit 5;" 13 | }, 14 | { 15 | "question":"what are the top 5 movies and their ratings basis IMDB ? ", 16 | "answer":"SELECT series_title,\n imdb_rating\nFROM imdb.public.movies\nORDER BY imdb_rating desc limit 5" 17 | }, 18 | { 19 | "question":"what is the distribution of imdb top 1000 movies across the release year ?", 20 | "answer":"SELECT released_year, count(*) as num_movies\r\nFROM imdb.public.movies\r\nGROUP BY 1\r\nORDER BY 1;" 21 | }, 22 | { 23 | "question":"What are the 5 best rated action movies ? ", 24 | "answer":"SELECT series_title,\n imdb_rating\nFROM imdb.public.movies\nWHERE genre = 'Action'\nORDER BY imdb_rating desc limit 5;" 25 | }, 26 | { 27 | "question":"What are the top 2 rated movies of ingmar bergman in drama genre and what are the lead actors in the movies ?", 28 | "answer":"SELECT series_title,\n star1,\n imdb_rating\nFROM imdb.public.movies\nWHERE director = 'Ingmar Bergman'\n and genre = 'Drama'\nORDER BY imdb_rating desc limit 2;;" 29 | }, 30 | { 31 | "question":"what are the top 5 movies basis IMDB rating ? ", 32 | "answer":"SELECT series_title\r\nFROM IMDB.PUBLIC.MOVIES\r\nORDER BY imdb_rating DESC\r\nLIMIT 5" 33 | }, 34 | { 35 | "question":"which 5 director has the most number of movies in the IMDB top 1000 ? ", 36 | "answer":"```\nSELECT director,\n count(*) as num_movies\nFROM imdb.public.movies\nWHERE imdb_rating <= 1000\nGROUP BY director\nORDER BY num_movies desc limit 5;\n```" 37 | }, 38 | { 39 | "question":"what is the average IMDB rating for each genre ? ", 40 | "answer":"SELECT genre,\n avg(imdb_rating) as avg_imdb_rating\nFROM imdb.public.movies\nGROUP BY genre;" 41 | }, 42 | { 43 | "question":"what is the genre wise, average IMDB rating ? ", 44 | "answer":"SELECT genre,\n avg(imdb_rating) as avg_imdb_rating\nFROM imdb.public.movies\nGROUP BY genre;" 45 | }, 46 | { 47 | "question":"What is the runtime of Forest Gump, Saving Private Ryan and The Green Mile", 48 | "answer":"SELECT series_title,\n runtime\nFROM imdb.public.movies\nWHERE series_title in ('Forrest Gump', 'Saving Private Ryan', 'The Green Mile');" 49 | }, 50 | { 51 | "question":"what is the distribution of titles on IMDB across genres ? ", 52 | "answer":"SELECT genre,\n count(*) as num_titles\nFROM imdb.public.movies\nGROUP BY genre\nORDER BY num_titles desc;" 53 | } 54 | ] -------------------------------------------------------------------------------- /training_data/snowflake-cost/questions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question":"What are the daily costs for the last 30 days?", 4 | "answer":"SELECT usage_date,\n sum(usage_in_currency) as daily_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\n and usage_date >= dateadd(day, -30, current_date())\nGROUP BY usage_date\nORDER BY usage_date" 5 | }, 6 | { 7 | "question":"What are the first 10 rows in the USAGE_IN_CURRENCY_DAILY table?", 8 | "answer":"SELECT ORGANIZATION_NAME, CONTRACT_NUMBER, ACCOUNT_NAME, ACCOUNT_LOCATOR, REGION, SERVICE_LEVEL, USAGE_DATE, USAGE_TYPE, CURRENCY, USAGE, USAGE_IN_CURRENCY, BALANCE_SOURCE\r\nFROM SNOWFLAKE.ORGANIZATION_USAGE.USAGE_IN_CURRENCY_DAILY\r\nLIMIT 10" 9 | }, 10 | { 11 | "question":"Total usage costs in dollars for the organization, broken down by account", 12 | "answer":"SELECT account_name,\n sum(usage_in_currency) as total_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\nGROUP BY account_name" 13 | }, 14 | { 15 | "question":"What is the daily cost by usage type?", 16 | "answer":"SELECT usage_date,\n usage_type,\n sum(usage_in_currency) as daily_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\nGROUP BY usage_date, usage_type\nORDER BY usage_date, daily_cost desc" 17 | }, 18 | { 19 | "question":"What are the first 10 rows in the PREVIEW_USAGE_IN_CURRENCY_DAILY table?", 20 | "answer":"SELECT ORGANIZATION_NAME, CONTRACT_NUMBER, ACCOUNT_NAME, ACCOUNT_LOCATOR, REGION, SERVICE_LEVEL, USAGE_DATE, USAGE_TYPE, CURRENCY, USAGE, USAGE_IN_CURRENCY\r\nFROM SNOWFLAKE.ORGANIZATION_USAGE.PREVIEW_USAGE_IN_CURRENCY_DAILY\r\nLIMIT 10" 21 | }, 22 | { 23 | "question":"Daily usage cost by account", 24 | "answer":"SELECT account_name,\n usage_date,\n sum(usage_in_currency) as daily_usage_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\nGROUP BY account_name, usage_date\nORDER BY usage_date, daily_usage_cost desc" 25 | } 26 | ] --------------------------------------------------------------------------------