├── .gitattributes
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── python-publish.yaml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── img
├── top-10-customers.png
└── vanna-readme-diagram.png
├── papers
├── ai-sql-accuracy-2023-08-17.md
└── img
│ ├── accuracy-by-llm.png
│ ├── accuracy-using-contextual-examples.png
│ ├── accuracy-using-schema-only.png
│ ├── accuracy-using-static-examples.png
│ ├── chat-gpt-question.png
│ ├── chatgpt-results.png
│ ├── framework-for-sql-generation.png
│ ├── question-flow.png
│ ├── schema-only.png
│ ├── sql-error.png
│ ├── summary-table.png
│ ├── summary.png
│ ├── test-architecture.png
│ ├── test-levers.png
│ ├── using-contextually-relevant-examples.png
│ └── using-sql-examples.png
├── pyproject.toml
├── setup.cfg
├── src
├── .editorconfig
└── vanna
│ ├── ZhipuAI
│ ├── ZhipuAI_Chat.py
│ ├── ZhipuAI_embeddings.py
│ └── __init__.py
│ ├── __init__.py
│ ├── advanced
│ └── __init__.py
│ ├── anthropic
│ ├── __init__.py
│ └── anthropic_chat.py
│ ├── azuresearch
│ ├── __init__.py
│ └── azuresearch_vector.py
│ ├── base
│ ├── __init__.py
│ └── base.py
│ ├── bedrock
│ ├── __init__.py
│ └── bedrock_converse.py
│ ├── chromadb
│ ├── __init__.py
│ └── chromadb_vector.py
│ ├── cohere
│ ├── __init__.py
│ ├── cohere_chat.py
│ └── cohere_embeddings.py
│ ├── deepseek
│ ├── __init__.py
│ └── deepseek_chat.py
│ ├── exceptions
│ └── __init__.py
│ ├── faiss
│ ├── __init__.py
│ └── faiss.py
│ ├── flask
│ ├── __init__.py
│ ├── assets.py
│ └── auth.py
│ ├── google
│ ├── __init__.py
│ ├── bigquery_vector.py
│ └── gemini_chat.py
│ ├── hf
│ ├── __init__.py
│ └── hf.py
│ ├── local.py
│ ├── marqo
│ ├── __init__.py
│ └── marqo.py
│ ├── milvus
│ ├── __init__.py
│ └── milvus_vector.py
│ ├── mistral
│ ├── __init__.py
│ └── mistral.py
│ ├── mock
│ ├── __init__.py
│ ├── embedding.py
│ ├── llm.py
│ └── vectordb.py
│ ├── ollama
│ ├── __init__.py
│ └── ollama.py
│ ├── openai
│ ├── __init__.py
│ ├── openai_chat.py
│ └── openai_embeddings.py
│ ├── opensearch
│ ├── __init__.py
│ ├── opensearch_vector.py
│ └── opensearch_vector_semantic.py
│ ├── oracle
│ ├── __init__.py
│ └── oracle_vector.py
│ ├── pgvector
│ ├── __init__.py
│ └── pgvector.py
│ ├── pinecone
│ ├── __init__.py
│ └── pinecone_vector.py
│ ├── qdrant
│ ├── __init__.py
│ └── qdrant.py
│ ├── qianfan
│ ├── Qianfan_Chat.py
│ ├── Qianfan_embeddings.py
│ └── __init__.py
│ ├── qianwen
│ ├── QianwenAI_chat.py
│ ├── QianwenAI_embeddings.py
│ └── __init__.py
│ ├── remote.py
│ ├── types
│ └── __init__.py
│ ├── utils.py
│ ├── vannadb
│ ├── __init__.py
│ └── vannadb_vector.py
│ ├── vllm
│ ├── __init__.py
│ └── vllm.py
│ ├── weaviate
│ ├── __init__.py
│ └── weaviate_vector.py
│ └── xinference
│ ├── __init__.py
│ └── xinference.py
├── tests
├── test_imports.py
├── test_instantiation.py
├── test_pgvector.py
└── test_vanna.py
├── tox.ini
└── training_data
├── cybersyn-data-commons
└── questions.json
├── cybersyn-financial-data
└── questions.json
├── cybersyn-us-global-public
└── questions.json
├── fivetran-ads-snowflake
└── questions.json
├── sample-fraud
└── questions.json
├── sample-imdb
└── questions.json
├── sample-retention
└── questions.json
├── sample-salaries
└── questions.json
├── similarweb
└── questions.json
├── snowflake-cost
└── questions.json
└── tpc-h
└── questions.json
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-detectable=false
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ["bug"]
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Error logs/Screenshots**
24 | If applicable, add logs/screenshots to give more information about the issue.
25 |
26 | **Desktop (please complete the following information where):**
27 | - OS: [e.g. Ubuntu]
28 | - Version: [e.g. 20.04]
29 | - Python: [3.9]
30 | - Vanna: [2.8.0]
31 |
32 | **Additional context**
33 | Add any other context about the problem here.
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ["enhancements"]
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Basic Integration Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | - name: Set up Python 3.10
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.10"
20 | - name: Install pip
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install tox
24 | - name: Run tests
25 | env:
26 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python
27 | VANNA_API_KEY: ${{ secrets.VANNA_API_KEY }}
28 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
29 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
30 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
31 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
32 | SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }}
33 | SNOWFLAKE_USERNAME: ${{ secrets.SNOWFLAKE_USERNAME }}
34 | SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }}
35 | run: tox -e py310
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | **.egg-info
3 | venv
4 | .DS_Store
5 | notebooks/*
6 | tests/__pycache__
7 | __pycache__/
8 | .idea
9 | .coverage
10 | docs/*.html
11 | .ipynb_checkpoints/
12 | .tox/
13 | notebooks/chroma.sqlite3
14 | dist
15 | .env
16 | *.sqlite
17 | htmlcov
18 | chroma.sqlite3
19 | *.bin
20 | .coverage.*
21 | milvus.db
22 | .milvus.db.lock
23 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: 'docs|node_modules|migrations|.git|.tox|assets.py'
2 | default_stages: [ commit ]
3 | fail_fast: true
4 |
5 | repos:
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v3.2.0
8 | hooks:
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-merge-conflict
12 | - id: debug-statements
13 | - id: mixed-line-ending
14 |
15 | - repo: https://github.com/pycqa/isort
16 | rev: 5.12.0
17 | hooks:
18 | - id: isort
19 | args: [ "--profile", "black", "--filter-files" ]
20 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | ## Setup
4 | ```bash
5 | git clone https://github.com/vanna-ai/vanna.git
6 | cd vanna/
7 |
8 | python3 -m venv venv
9 | source venv/bin/activate
10 |
11 | # install package in editable mode
12 | pip install -e '.[all]' tox pre-commit
13 |
14 | # Setup pre-commit hooks
15 | pre-commit install
16 |
17 | # List dev targets
18 | tox list
19 |
20 | # Run tests
21 | tox -e py310
22 | ```
23 |
24 | ## Running the test on a Mac
25 | ```bash
26 | tox -e mac
27 | ```
28 |
29 | ## Do this before you submit a PR:
30 |
31 | Find the most relevant sample notebook and then replace the install command with:
32 |
33 | ```bash
34 | %pip install 'git+https://github.com/vanna-ai/vanna@your-branch#egg=vanna[chromadb,snowflake,openai]'
35 | ```
36 |
37 | Run the necessary cells and verify that it works as expected in a real-world scenario.
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Vanna.AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | | GitHub | PyPI | Documentation | Gurubase |
4 | | ------ | ---- | ------------- | -------- |
5 | | [](https://github.com/vanna-ai/vanna) | [](https://pypi.org/project/vanna/) | [](https://vanna.ai/docs/) | [](https://gurubase.io/g/vanna) |
6 |
7 | # Vanna
8 | Vanna is an MIT-licensed open-source Python RAG (Retrieval-Augmented Generation) framework for SQL generation and related functionality.
9 |
10 | https://github.com/vanna-ai/vanna/assets/7146154/1901f47a-515d-4982-af50-f12761a3b2ce
11 |
12 | 
13 |
14 | ## How Vanna works
15 |
16 | 
17 |
18 |
19 | Vanna works in two easy steps - train a RAG "model" on your data, and then ask questions which will return SQL queries that can be set up to automatically run on your database.
20 |
21 | 1. **Train a RAG "model" on your data**.
22 | 2. **Ask questions**.
23 |
24 | 
25 |
26 | If you don't know what RAG is, don't worry -- you don't need to know how this works under the hood to use it. You just need to know that you "train" a model, which stores some metadata and then use it to "ask" questions.
27 |
28 | See the [base class](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) for more details on how this works under the hood.
29 |
30 | ## User Interfaces
31 | These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface.
32 |
33 | - [Jupyter Notebook](https://vanna.ai/docs/postgres-openai-vanna-vannadb/)
34 | - [vanna-ai/vanna-streamlit](https://github.com/vanna-ai/vanna-streamlit)
35 | - [vanna-ai/vanna-flask](https://github.com/vanna-ai/vanna-flask)
36 | - [vanna-ai/vanna-slack](https://github.com/vanna-ai/vanna-slack)
37 |
38 | ## Supported LLMs
39 |
40 | - [OpenAI](https://github.com/vanna-ai/vanna/tree/main/src/vanna/openai)
41 | - [Anthropic](https://github.com/vanna-ai/vanna/tree/main/src/vanna/anthropic)
42 | - [Gemini](https://github.com/vanna-ai/vanna/blob/main/src/vanna/google/gemini_chat.py)
43 | - [HuggingFace](https://github.com/vanna-ai/vanna/blob/main/src/vanna/hf/hf.py)
44 | - [AWS Bedrock](https://github.com/vanna-ai/vanna/tree/main/src/vanna/bedrock)
45 | - [Ollama](https://github.com/vanna-ai/vanna/tree/main/src/vanna/ollama)
46 | - [Qianwen](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qianwen)
47 | - [Qianfan](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qianfan)
48 | - [Zhipu](https://github.com/vanna-ai/vanna/tree/main/src/vanna/ZhipuAI)
49 |
50 | ## Supported VectorStores
51 |
52 | - [AzureSearch](https://github.com/vanna-ai/vanna/tree/main/src/vanna/azuresearch)
53 | - [Opensearch](https://github.com/vanna-ai/vanna/tree/main/src/vanna/opensearch)
54 | - [PgVector](https://github.com/vanna-ai/vanna/tree/main/src/vanna/pgvector)
55 | - [PineCone](https://github.com/vanna-ai/vanna/tree/main/src/vanna/pinecone)
56 | - [ChromaDB](https://github.com/vanna-ai/vanna/tree/main/src/vanna/chromadb)
57 | - [FAISS](https://github.com/vanna-ai/vanna/tree/main/src/vanna/faiss)
58 | - [Marqo](https://github.com/vanna-ai/vanna/tree/main/src/vanna/marqo)
59 | - [Milvus](https://github.com/vanna-ai/vanna/tree/main/src/vanna/milvus)
60 | - [Qdrant](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qdrant)
61 | - [Weaviate](https://github.com/vanna-ai/vanna/tree/main/src/vanna/weaviate)
62 | - [Oracle](https://github.com/vanna-ai/vanna/tree/main/src/vanna/oracle)
63 |
64 | ## Supported Databases
65 |
66 | - [PostgreSQL](https://www.postgresql.org/)
67 | - [MySQL](https://www.mysql.com/)
68 | - [PrestoDB](https://prestodb.io/)
69 | - [Apache Hive](https://hive.apache.org/)
70 | - [ClickHouse](https://clickhouse.com/)
71 | - [Snowflake](https://www.snowflake.com/en/)
72 | - [Oracle](https://www.oracle.com/)
73 | - [Microsoft SQL Server](https://www.microsoft.com/en-us/sql-server/sql-server-downloads)
74 | - [BigQuery](https://cloud.google.com/bigquery)
75 | - [SQLite](https://www.sqlite.org/)
76 | - [DuckDB](https://duckdb.org/)
77 |
78 |
79 | ## Getting started
80 | See the [documentation](https://vanna.ai/docs/) for specifics on your desired database, LLM, etc.
81 |
82 | If you want to get a feel for how it works after training, you can try this [Colab notebook](https://vanna.ai/docs/app/).
83 |
84 |
85 | ### Install
86 | ```bash
87 | pip install vanna
88 | ```
89 |
90 | There are a number of optional packages that can be installed so see the [documentation](https://vanna.ai/docs/) for more details.
91 |
92 | ### Import
93 | See the [documentation](https://vanna.ai/docs/) if you're customizing the LLM or vector database.
94 |
95 | ```python
96 | # The import statement will vary depending on your LLM and vector database. This is an example for OpenAI + ChromaDB
97 |
98 | from vanna.openai.openai_chat import OpenAI_Chat
99 | from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
100 |
101 | class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
102 | def __init__(self, config=None):
103 | ChromaDB_VectorStore.__init__(self, config=config)
104 | OpenAI_Chat.__init__(self, config=config)
105 |
106 | vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'})
107 |
108 | # See the documentation for other options
109 |
110 | ```
111 |
112 |
113 | ## Training
114 | You may or may not need to run these `vn.train` commands depending on your use case. See the [documentation](https://vanna.ai/docs/) for more details.
115 |
116 | These statements are shown to give you a feel for how it works.
117 |
118 | ### Train with DDL Statements
119 | DDL statements contain information about the table names, columns, data types, and relationships in your database.
120 |
121 | ```python
122 | vn.train(ddl="""
123 | CREATE TABLE IF NOT EXISTS my-table (
124 | id INT PRIMARY KEY,
125 | name VARCHAR(100),
126 | age INT
127 | )
128 | """)
129 | ```
130 |
131 | ### Train with Documentation
132 | Sometimes you may want to add documentation about your business terminology or definitions.
133 |
134 | ```python
135 | vn.train(documentation="Our business defines XYZ as ...")
136 | ```
137 |
138 | ### Train with SQL
139 | You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
140 |
141 | ```python
142 | vn.train(sql="SELECT name, age FROM my-table WHERE name = 'John Doe'")
143 | ```
144 |
145 |
146 | ## Asking questions
147 | ```python
148 | vn.ask("What are the top 10 customers by sales?")
149 | ```
150 |
151 | You'll get SQL
152 | ```sql
153 | SELECT c.c_name as customer_name,
154 | sum(l.l_extendedprice * (1 - l.l_discount)) as total_sales
155 | FROM snowflake_sample_data.tpch_sf1.lineitem l join snowflake_sample_data.tpch_sf1.orders o
156 | ON l.l_orderkey = o.o_orderkey join snowflake_sample_data.tpch_sf1.customer c
157 | ON o.o_custkey = c.c_custkey
158 | GROUP BY customer_name
159 | ORDER BY total_sales desc limit 10;
160 | ```
161 |
162 | If you've connected to a database, you'll get the table:
163 |
164 |
165 |
166 |
167 | |
168 | CUSTOMER_NAME |
169 | TOTAL_SALES |
170 |
171 |
172 |
173 |
174 | 0 |
175 | Customer#000143500 |
176 | 6757566.0218 |
177 |
178 |
179 | 1 |
180 | Customer#000095257 |
181 | 6294115.3340 |
182 |
183 |
184 | 2 |
185 | Customer#000087115 |
186 | 6184649.5176 |
187 |
188 |
189 | 3 |
190 | Customer#000131113 |
191 | 6080943.8305 |
192 |
193 |
194 | 4 |
195 | Customer#000134380 |
196 | 6075141.9635 |
197 |
198 |
199 | 5 |
200 | Customer#000103834 |
201 | 6059770.3232 |
202 |
203 |
204 | 6 |
205 | Customer#000069682 |
206 | 6057779.0348 |
207 |
208 |
209 | 7 |
210 | Customer#000102022 |
211 | 6039653.6335 |
212 |
213 |
214 | 8 |
215 | Customer#000098587 |
216 | 6027021.5855 |
217 |
218 |
219 | 9 |
220 | Customer#000064660 |
221 | 5905659.6159 |
222 |
223 |
224 |
225 |
226 |
227 | You'll also get an automated Plotly chart:
228 | 
229 |
230 | ## RAG vs. Fine-Tuning
231 | RAG
232 | - Portable across LLMs
233 | - Easy to remove training data if any of it becomes obsolete
234 | - Much cheaper to run than fine-tuning
235 | - More future-proof -- if a better LLM comes out, you can just swap it out
236 |
237 | Fine-Tuning
238 | - Good if you need to minimize tokens in the prompt
239 | - Slow to get started
240 | - Expensive to train and run (generally)
241 |
242 | ## Why Vanna?
243 |
244 | 1. **High accuracy on complex datasets.**
245 | - Vanna’s capabilities are tied to the training data you give it
246 | - More training data means better accuracy for large and complex datasets
247 | 2. **Secure and private.**
248 | - Your database contents are never sent to the LLM or the vector database
249 | - SQL execution happens in your local environment
250 | 3. **Self learning.**
251 | - If using via Jupyter, you can choose to "auto-train" it on the queries that were successfully executed
252 | - If using via other interfaces, you can have the interface prompt the user to provide feedback on the results
253 | - Correct question to SQL pairs are stored for future reference and make the future results more accurate
254 | 4. **Supports any SQL database.**
255 | - The package allows you to connect to any SQL database that you can otherwise connect to with Python
256 | 5. **Choose your front end.**
257 | - Most people start in a Jupyter Notebook.
258 | - Expose to your end users via Slackbot, web app, Streamlit app, or a custom front end.
259 |
260 | ## Extending Vanna
261 | Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
262 |
263 | ## Vanna in 100 Seconds
264 |
265 | https://github.com/vanna-ai/vanna/assets/7146154/eb90ee1e-aa05-4740-891a-4fc10e611cab
266 |
267 | ## More resources
268 | - [Full Documentation](https://vanna.ai/docs/)
269 | - [Website](https://vanna.ai)
270 | - [Discord group for support](https://discord.gg/qUZYKHremx)
271 |
--------------------------------------------------------------------------------
/img/top-10-customers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/img/top-10-customers.png
--------------------------------------------------------------------------------
/img/vanna-readme-diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/img/vanna-readme-diagram.png
--------------------------------------------------------------------------------
/papers/img/accuracy-by-llm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-by-llm.png
--------------------------------------------------------------------------------
/papers/img/accuracy-using-contextual-examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-using-contextual-examples.png
--------------------------------------------------------------------------------
/papers/img/accuracy-using-schema-only.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-using-schema-only.png
--------------------------------------------------------------------------------
/papers/img/accuracy-using-static-examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/accuracy-using-static-examples.png
--------------------------------------------------------------------------------
/papers/img/chat-gpt-question.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/chat-gpt-question.png
--------------------------------------------------------------------------------
/papers/img/chatgpt-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/chatgpt-results.png
--------------------------------------------------------------------------------
/papers/img/framework-for-sql-generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/framework-for-sql-generation.png
--------------------------------------------------------------------------------
/papers/img/question-flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/question-flow.png
--------------------------------------------------------------------------------
/papers/img/schema-only.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/schema-only.png
--------------------------------------------------------------------------------
/papers/img/sql-error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/sql-error.png
--------------------------------------------------------------------------------
/papers/img/summary-table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/summary-table.png
--------------------------------------------------------------------------------
/papers/img/summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/summary.png
--------------------------------------------------------------------------------
/papers/img/test-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/test-architecture.png
--------------------------------------------------------------------------------
/papers/img/test-levers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/test-levers.png
--------------------------------------------------------------------------------
/papers/img/using-contextually-relevant-examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/using-contextually-relevant-examples.png
--------------------------------------------------------------------------------
/papers/img/using-sql-examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vanna-ai/vanna/4da8dea0ce14a0d1db5a0692a7921d873be91c5f/papers/img/using-sql-examples.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.2,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "vanna"
7 | version = "0.7.9"
8 | authors = [
9 | { name="Zain Hoda", email="zain@vanna.ai" },
10 | ]
11 |
12 | description = "Generate SQL queries from natural language"
13 | readme = "README.md"
14 | requires-python = ">=3.9"
15 | classifiers = [
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | ]
20 | dependencies = [
21 | "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "flask-sock", "flasgger", "sqlalchemy"
22 | ]
23 |
24 | [project.urls]
25 | "Homepage" = "https://github.com/vanna-ai/vanna"
26 | "Bug Tracker" = "https://github.com/vanna-ai/vanna/issues"
27 |
28 | [project.optional-dependencies]
29 | postgres = ["psycopg2-binary", "db-dtypes"]
30 | mysql = ["PyMySQL"]
31 | clickhouse = ["clickhouse_connect"]
32 | bigquery = ["google-cloud-bigquery"]
33 | snowflake = ["snowflake-connector-python"]
34 | duckdb = ["duckdb"]
35 | google = ["google-generativeai", "google-cloud-aiplatform"]
36 | all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb<1.0.0", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "langchain-community", "langchain-huggingface", "xinference-client"]
37 | test = ["tox"]
38 | chromadb = ["chromadb<1.0.0"]
39 | openai = ["openai"]
40 | qianfan = ["qianfan"]
41 | mistralai = ["mistralai>=1.0.0"]
42 | anthropic = ["anthropic"]
43 | gemini = ["google-generativeai"]
44 | marqo = ["marqo"]
45 | zhipuai = ["zhipuai"]
46 | ollama = ["ollama", "httpx"]
47 | qdrant = ["qdrant-client", "fastembed"]
48 | vllm = ["vllm"]
49 | pinecone = ["pinecone", "fastembed"]
50 | opensearch = ["opensearch-py", "opensearch-dsl", "langchain-community", "langchain-huggingface"]
51 | hf = ["transformers"]
52 | milvus = ["pymilvus[model]"]
53 | bedrock = ["boto3", "botocore"]
54 | weaviate = ["weaviate-client"]
55 | azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
56 | pgvector = ["langchain-postgres>=0.0.12"]
57 | faiss-cpu = ["faiss-cpu"]
58 | faiss-gpu = ["faiss-gpu"]
59 | xinference-client = ["xinference-client"]
60 | oracle = ["oracledb", "chromadb<1.0.0"]
61 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = BLK100,W503,E203,E722,F821,F841
3 | max-line-length = 100
4 | exclude = .tox,.git,docs,venv,jupyter_notebook_config.py,jupyter_lab_config.py,assets.py
5 |
6 | [tool:brunette]
7 | verbose = true
8 | single-quotes = false
9 | target-version = py39
10 | exclude = .tox,.git,docs,venv,assets.py
11 |
--------------------------------------------------------------------------------
/src/.editorconfig:
--------------------------------------------------------------------------------
1 | # top-most EditorConfig file
2 | root = true
3 |
4 | # Python files
5 | [*.py]
6 | # Indentation style: space
7 | indent_style = space
8 |
9 | # Indentation size: Use 2 spaces
10 | indent_size = 2
11 |
12 | # Newline character at the end of file
13 | insert_final_newline = true
14 |
15 | # Charset: utf-8
16 | charset = utf-8
17 |
18 | # Trim trailing whitespace
19 | trim_trailing_whitespace = true
20 |
21 | # Max line length: 79 characters as per PEP 8 guidelines
22 | max_line_length = 79
23 |
24 | # Set end of line format to LF
25 |
26 | # Exclude specific files or directories
27 | exclude = 'docs|node_modules|migrations|.git|.tox'
28 |
--------------------------------------------------------------------------------
/src/vanna/ZhipuAI/ZhipuAI_Chat.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List
3 |
4 | import pandas as pd
5 | from zhipuai import ZhipuAI
6 |
7 | from ..base import VannaBase
8 |
9 |
10 | class ZhipuAI_Chat(VannaBase):
11 | def __init__(self, config=None):
12 | VannaBase.__init__(self, config=config)
13 | if config is None:
14 | return
15 | if "api_key" not in config:
16 | raise Exception("Missing api_key in config")
17 | self.api_key = config["api_key"]
18 | self.model = config["model"] if "model" in config else "glm-4"
19 | self.api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
20 |
21 | # Static methods similar to those in ZhipuAI_Chat for message formatting and utility
22 | @staticmethod
23 | def system_message(message: str) -> dict:
24 | return {"role": "system", "content": message}
25 |
26 | @staticmethod
27 | def user_message(message: str) -> dict:
28 | return {"role": "user", "content": message}
29 |
30 | @staticmethod
31 | def assistant_message(message: str) -> dict:
32 | return {"role": "assistant", "content": message}
33 |
34 | @staticmethod
35 | def str_to_approx_token_count(string: str) -> int:
36 | return len(string) / 4
37 |
38 | @staticmethod
39 | def add_ddl_to_prompt(
40 | initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
41 | ) -> str:
42 | if len(ddl_list) > 0:
43 | initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
44 |
45 | for ddl in ddl_list:
46 | if (
47 | ZhipuAI_Chat.str_to_approx_token_count(initial_prompt)
48 | + ZhipuAI_Chat.str_to_approx_token_count(ddl)
49 | < max_tokens
50 | ):
51 | initial_prompt += f"{ddl}\n\n"
52 |
53 | return initial_prompt
54 |
55 | @staticmethod
56 | def add_documentation_to_prompt(
57 | initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
58 | ) -> str:
59 | if len(documentation_List) > 0:
60 | initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
61 |
62 | for documentation in documentation_List:
63 | if (
64 | ZhipuAI_Chat.str_to_approx_token_count(initial_prompt)
65 | + ZhipuAI_Chat.str_to_approx_token_count(documentation)
66 | < max_tokens
67 | ):
68 | initial_prompt += f"{documentation}\n\n"
69 |
70 | return initial_prompt
71 |
72 | @staticmethod
73 | def add_sql_to_prompt(
74 | initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
75 | ) -> str:
76 | if len(sql_List) > 0:
77 | initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
78 |
79 | for question in sql_List:
80 | if (
81 | ZhipuAI_Chat.str_to_approx_token_count(initial_prompt)
82 | + ZhipuAI_Chat.str_to_approx_token_count(question["sql"])
83 | < max_tokens
84 | ):
85 | initial_prompt += f"{question['question']}\n{question['sql']}\n\n"
86 |
87 | return initial_prompt
88 |
89 | def get_sql_prompt(
90 | self,
91 | question: str,
92 | question_sql_list: List,
93 | ddl_list: List,
94 | doc_list: List,
95 | **kwargs,
96 | ):
97 | initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
98 |
99 | initial_prompt = ZhipuAI_Chat.add_ddl_to_prompt(
100 | initial_prompt, ddl_list, max_tokens=14000
101 | )
102 |
103 | initial_prompt = ZhipuAI_Chat.add_documentation_to_prompt(
104 | initial_prompt, doc_list, max_tokens=14000
105 | )
106 |
107 | message_log = [ZhipuAI_Chat.system_message(initial_prompt)]
108 |
109 | for example in question_sql_list:
110 | if example is None:
111 | print("example is None")
112 | else:
113 | if example is not None and "question" in example and "sql" in example:
114 | message_log.append(ZhipuAI_Chat.user_message(example["question"]))
115 | message_log.append(ZhipuAI_Chat.assistant_message(example["sql"]))
116 |
117 | message_log.append({"role": "user", "content": question})
118 |
119 | return message_log
120 |
121 | def get_followup_questions_prompt(
122 | self,
123 | question: str,
124 | df: pd.DataFrame,
125 | question_sql_list: List,
126 | ddl_list: List,
127 | doc_list: List,
128 | **kwargs,
129 | ):
130 | initial_prompt = f"The user initially asked the question: '{question}': \n\n"
131 |
132 | initial_prompt = ZhipuAI_Chat.add_ddl_to_prompt(
133 | initial_prompt, ddl_list, max_tokens=14000
134 | )
135 |
136 | initial_prompt = ZhipuAI_Chat.add_documentation_to_prompt(
137 | initial_prompt, doc_list, max_tokens=14000
138 | )
139 |
140 | initial_prompt = ZhipuAI_Chat.add_sql_to_prompt(
141 | initial_prompt, question_sql_list, max_tokens=14000
142 | )
143 |
144 | message_log = [ZhipuAI_Chat.system_message(initial_prompt)]
145 | message_log.append(
146 | ZhipuAI_Chat.user_message(
147 | "Generate a List of followup questions that the user might ask about this data. Respond with a List of questions, one per line. Do not answer with any explanations -- just the questions."
148 | )
149 | )
150 |
151 | return message_log
152 |
153 | def generate_question(self, sql: str, **kwargs) -> str:
154 | response = self.submit_prompt(
155 | [
156 | self.system_message(
157 | "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
158 | ),
159 | self.user_message(sql),
160 | ],
161 | **kwargs,
162 | )
163 |
164 | return response
165 |
166 | def _extract_python_code(self, markdown_string: str) -> str:
167 | # Regex pattern to match Python code blocks
168 | pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"
169 |
170 | # Find all matches in the markdown string
171 | matches = re.findall(pattern, markdown_string, re.IGNORECASE)
172 |
173 | # Extract the Python code from the matches
174 | python_code = []
175 | for match in matches:
176 | python = match[0] if match[0] else match[1]
177 | python_code.append(python.strip())
178 |
179 | if len(python_code) == 0:
180 | return markdown_string
181 |
182 | return python_code[0]
183 |
184 | def _sanitize_plotly_code(self, raw_plotly_code: str) -> str:
185 | # Remove the fig.show() statement from the plotly code
186 | plotly_code = raw_plotly_code.replace("fig.show()", "")
187 |
188 | return plotly_code
189 |
190 | def generate_plotly_code(
191 | self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
192 | ) -> str:
193 | if question is not None:
194 | system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'"
195 | else:
196 | system_msg = "The following is a pandas DataFrame "
197 |
198 | if sql is not None:
199 | system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n"
200 |
201 | system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}"
202 |
203 | message_log = [
204 | self.system_message(system_msg),
205 | self.user_message(
206 | "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
207 | ),
208 | ]
209 |
210 | plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
211 |
212 | return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
213 |
214 | def submit_prompt(
215 | self, prompt, max_tokens=500, temperature=0.7, top_p=0.7, stop=None, **kwargs
216 | ):
217 | if prompt is None:
218 | raise Exception("Prompt is None")
219 |
220 | if len(prompt) == 0:
221 | raise Exception("Prompt is empty")
222 |
223 | client = ZhipuAI(api_key=self.api_key)
224 | response = client.chat.completions.create(
225 | model="glm-4",
226 | max_tokens=max_tokens,
227 | temperature=temperature,
228 | top_p=top_p,
229 | stop=stop,
230 | messages=prompt,
231 | )
232 |
233 | return response.choices[0].message.content
234 |
--------------------------------------------------------------------------------
/src/vanna/ZhipuAI/ZhipuAI_embeddings.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from zhipuai import ZhipuAI
3 | from chromadb import Documents, EmbeddingFunction, Embeddings
4 | from ..base import VannaBase
5 |
6 | class ZhipuAI_Embeddings(VannaBase):
7 | """
8 | [future functionality] This function is used to generate embeddings from ZhipuAI.
9 |
10 | Args:
11 | VannaBase (_type_): _description_
12 | """
13 | def __init__(self, config=None):
14 | VannaBase.__init__(self, config=config)
15 | if "api_key" not in config:
16 | raise Exception("Missing api_key in config")
17 | self.api_key = config["api_key"]
18 | self.client = ZhipuAI(api_key=self.api_key)
19 |
20 | def generate_embedding(self, data: str, **kwargs) -> List[float]:
21 |
22 | embedding = self.client.embeddings.create(
23 | model="embedding-2",
24 | input=data,
25 | )
26 |
27 | return embedding.data[0].embedding
28 |
29 |
30 |
31 | class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
32 | """
33 | A embeddingFunction that uses ZhipuAI to generate embeddings which can use in chromadb.
34 | usage:
35 | class MyVanna(ChromaDB_VectorStore, ZhipuAI_Chat):
36 | def __init__(self, config=None):
37 | ChromaDB_VectorStore.__init__(self, config=config)
38 | ZhipuAI_Chat.__init__(self, config=config)
39 |
40 | config={'api_key': 'xxx'}
41 | zhipu_embedding_function = ZhipuAIEmbeddingFunction(config=config)
42 | config = {"api_key": "xxx", "model": "glm-4","path":"xy","embedding_function":zhipu_embedding_function}
43 |
44 | vn = MyVanna(config)
45 |
46 | """
47 | def __init__(self, config=None):
48 | if config is None or "api_key" not in config:
49 | raise ValueError("Missing 'api_key' in config")
50 |
51 | self.api_key = config["api_key"]
52 | self.model_name = config.get("model_name", "embedding-2")
53 |
54 | try:
55 | self.client = ZhipuAI(api_key=self.api_key)
56 | except Exception as e:
57 | raise ValueError(f"Error initializing ZhipuAI client: {e}")
58 |
59 | def __call__(self, input: Documents) -> Embeddings:
60 | # Replace newlines, which can negatively affect performance.
61 | input = [t.replace("\n", " ") for t in input]
62 | all_embeddings = []
63 | print(f"Generating embeddings for {len(input)} documents")
64 |
65 | # Iterating over each document for individual API calls
66 | for document in input:
67 | try:
68 | response = self.client.embeddings.create(
69 | model=self.model_name,
70 | input=document
71 | )
72 | # print(response)
73 | embedding = response.data[0].embedding
74 | all_embeddings.append(embedding)
75 | # print(f"Cost required: {response.usage.total_tokens}")
76 | except Exception as e:
77 | raise ValueError(f"Error generating embedding for document: {e}")
78 |
79 | return all_embeddings
--------------------------------------------------------------------------------
/src/vanna/ZhipuAI/__init__.py:
--------------------------------------------------------------------------------
1 | from .ZhipuAI_Chat import ZhipuAI_Chat
2 | from .ZhipuAI_embeddings import ZhipuAI_Embeddings, ZhipuAIEmbeddingFunction
3 |
--------------------------------------------------------------------------------
/src/vanna/__init__.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 | import os
4 | from dataclasses import dataclass
5 | from typing import Callable, List, Tuple, Union
6 |
7 | import pandas as pd
8 | import requests
9 | import plotly.graph_objs
10 |
11 | from .exceptions import (
12 | OTPCodeError,
13 | ValidationError,
14 | )
15 | from .types import (
16 | ApiKey,
17 | Status,
18 | TrainingData,
19 | UserEmail,
20 | UserOTP,
21 | )
22 | from .utils import sanitize_model_name, validate_config_path
23 |
24 | api_key: Union[str, None] = None # API key for Vanna.AI
25 |
26 | fig_as_img: bool = False # Whether or not to return Plotly figures as images
27 |
28 | run_sql: Union[
29 | Callable[[str], pd.DataFrame], None
30 | ] = None # Function to convert SQL to a Pandas DataFrame
31 | """
32 | **Example**
33 | ```python
34 | vn.run_sql = lambda sql: pd.read_sql(sql, engine)
35 | ```
36 |
37 | Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function.
38 | Instead of setting this directly you can also use [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] to set this.
39 |
40 | """
41 |
42 | __org: Union[str, None] = None # Organization name for Vanna.AI
43 |
44 | _unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc"
45 |
46 | def error_deprecation():
47 | raise Exception("""
48 | Please switch to the following method for initializing Vanna:
49 |
50 | from vanna.remote import VannaDefault
51 |
52 | api_key = # Your API key from https://vanna.ai/account/profile
53 | vanna_model_name = # Your model name from https://vanna.ai/account/profile
54 |
55 | vn = VannaDefault(model=vanna_model_name, api_key=api_key)
56 | """)
57 |
58 | def __unauthenticated_rpc_call(method, params):
59 | headers = {
60 | "Content-Type": "application/json",
61 | }
62 | data = {"method": method, "params": [__dataclass_to_dict(obj) for obj in params]}
63 |
64 | response = requests.post(
65 | _unauthenticated_endpoint, headers=headers, data=json.dumps(data)
66 | )
67 | return response.json()
68 |
69 |
70 |
71 | def __dataclass_to_dict(obj):
72 | return dataclasses.asdict(obj)
73 |
74 |
75 | def get_api_key(email: str, otp_code: Union[str, None] = None) -> str:
76 | """
77 | **Example:**
78 | ```python
79 | vn.get_api_key(email="my-email@example.com")
80 | ```
81 |
82 | Login to the Vanna.AI API.
83 |
84 | Args:
85 | email (str): The email address to login with.
86 | otp_code (Union[str, None]): The OTP code to login with. If None, an OTP code will be sent to the email address.
87 |
88 | Returns:
89 | str: The API key.
90 | """
91 | vanna_api_key = os.environ.get("VANNA_API_KEY", None)
92 |
93 | if vanna_api_key is not None:
94 | return vanna_api_key
95 |
96 | if email == "my-email@example.com":
97 | raise ValidationError(
98 | "Please replace 'my-email@example.com' with your email address."
99 | )
100 |
101 | if otp_code is None:
102 | params = [UserEmail(email=email)]
103 |
104 | d = __unauthenticated_rpc_call(method="send_otp", params=params)
105 |
106 | if "result" not in d:
107 | raise OTPCodeError("Error sending OTP code.")
108 |
109 | status = Status(**d["result"])
110 |
111 | if not status.success:
112 | raise OTPCodeError(f"Error sending OTP code: {status.message}")
113 |
114 | otp_code = input("Check your email for the code and enter it here: ")
115 |
116 | params = [UserOTP(email=email, otp=otp_code)]
117 |
118 | d = __unauthenticated_rpc_call(method="verify_otp", params=params)
119 |
120 | if "result" not in d:
121 | raise OTPCodeError("Error verifying OTP code.")
122 |
123 | key = ApiKey(**d["result"])
124 |
125 | if key is None:
126 | raise OTPCodeError("Error verifying OTP code.")
127 |
128 | api_key = key.key
129 |
130 | return api_key
131 |
132 |
133 | def set_api_key(key: str) -> None:
134 | error_deprecation()
135 |
136 |
137 | def get_models() -> List[str]:
138 | error_deprecation()
139 |
140 |
141 | def create_model(model: str, db_type: str) -> bool:
142 | error_deprecation()
143 |
144 |
145 | def add_user_to_model(model: str, email: str, is_admin: bool) -> bool:
146 | error_deprecation()
147 |
148 |
149 | def update_model_visibility(public: bool) -> bool:
150 | error_deprecation()
151 |
152 |
153 | def set_model(model: str):
154 | error_deprecation()
155 |
156 |
157 | def add_sql(
158 | question: str, sql: str, tag: Union[str, None] = "Manually Trained"
159 | ) -> bool:
160 | error_deprecation()
161 |
162 |
163 | def add_ddl(ddl: str) -> bool:
164 | error_deprecation()
165 |
166 |
167 | def add_documentation(documentation: str) -> bool:
168 | error_deprecation()
169 |
170 |
171 | @dataclass
172 | class TrainingPlanItem:
173 | item_type: str
174 | item_group: str
175 | item_name: str
176 | item_value: str
177 |
178 | def __str__(self):
179 | if self.item_type == self.ITEM_TYPE_SQL:
180 | return f"Train on SQL: {self.item_group} {self.item_name}"
181 | elif self.item_type == self.ITEM_TYPE_DDL:
182 | return f"Train on DDL: {self.item_group} {self.item_name}"
183 | elif self.item_type == self.ITEM_TYPE_IS:
184 | return f"Train on Information Schema: {self.item_group} {self.item_name}"
185 |
186 | ITEM_TYPE_SQL = "sql"
187 | ITEM_TYPE_DDL = "ddl"
188 | ITEM_TYPE_IS = "is"
189 |
190 |
191 | class TrainingPlan:
192 | """
193 | A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained.
194 |
195 | **Example:**
196 | ```python
197 | plan = vn.get_training_plan()
198 |
199 | plan.get_summary()
200 | ```
201 |
202 | """
203 |
204 | _plan: List[TrainingPlanItem]
205 |
206 | def __init__(self, plan: List[TrainingPlanItem]):
207 | self._plan = plan
208 |
209 | def __str__(self):
210 | return "\n".join(self.get_summary())
211 |
212 | def __repr__(self):
213 | return self.__str__()
214 |
215 | def get_summary(self) -> List[str]:
216 | """
217 | **Example:**
218 | ```python
219 | plan = vn.get_training_plan()
220 |
221 | plan.get_summary()
222 | ```
223 |
224 | Get a summary of the training plan.
225 |
226 | Returns:
227 | List[str]: A list of strings describing the training plan.
228 | """
229 |
230 | return [f"{item}" for item in self._plan]
231 |
232 | def remove_item(self, item: str):
233 | """
234 | **Example:**
235 | ```python
236 | plan = vn.get_training_plan()
237 |
238 | plan.remove_item("Train on SQL: What is the average salary of employees?")
239 | ```
240 |
241 | Remove an item from the training plan.
242 |
243 | Args:
244 | item (str): The item to remove.
245 | """
246 | for plan_item in self._plan:
247 | if str(plan_item) == item:
248 | self._plan.remove(plan_item)
249 | break
250 |
251 |
252 | def get_training_plan_postgres(
253 | filter_databases: Union[List[str], None] = None,
254 | filter_schemas: Union[List[str], None] = None,
255 | include_information_schema: bool = False,
256 | use_historical_queries: bool = True,
257 | ) -> TrainingPlan:
258 | error_deprecation()
259 |
260 |
261 | def get_training_plan_generic(df) -> TrainingPlan:
262 | error_deprecation()
263 |
264 |
265 | def get_training_plan_experimental(
266 | filter_databases: Union[List[str], None] = None,
267 | filter_schemas: Union[List[str], None] = None,
268 | include_information_schema: bool = False,
269 | use_historical_queries: bool = True,
270 | ) -> TrainingPlan:
271 | error_deprecation()
272 |
273 |
274 | def train(
275 | question: str = None,
276 | sql: str = None,
277 | ddl: str = None,
278 | documentation: str = None,
279 | json_file: str = None,
280 | sql_file: str = None,
281 | plan: TrainingPlan = None,
282 | ) -> bool:
283 | error_deprecation()
284 |
285 |
286 | def flag_sql_for_review(
287 | question: str, sql: Union[str, None] = None, error_msg: Union[str, None] = None
288 | ) -> bool:
289 | error_deprecation()
290 |
291 |
292 | def remove_sql(question: str) -> bool:
293 | error_deprecation()
294 |
295 |
296 | def remove_training_data(id: str) -> bool:
297 | error_deprecation()
298 |
299 |
300 | def generate_sql(question: str) -> str:
301 | error_deprecation()
302 |
303 |
304 | def get_related_training_data(question: str) -> TrainingData:
305 | error_deprecation()
306 |
307 |
308 | def generate_meta(question: str) -> str:
309 | error_deprecation()
310 |
311 |
312 | def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]:
313 | error_deprecation()
314 |
315 |
316 | def generate_questions() -> List[str]:
317 | error_deprecation()
318 |
319 |
320 | def ask(
321 | question: Union[str, None] = None,
322 | print_results: bool = True,
323 | auto_train: bool = True,
324 | generate_followups: bool = True,
325 | ) -> Union[
326 | Tuple[
327 | Union[str, None],
328 | Union[pd.DataFrame, None],
329 | Union[plotly.graph_objs.Figure, None],
330 | Union[List[str], None],
331 | ],
332 | None,
333 | ]:
334 | error_deprecation()
335 |
336 | def generate_plotly_code(
337 | question: Union[str, None],
338 | sql: Union[str, None],
339 | df: pd.DataFrame,
340 | chart_instructions: Union[str, None] = None,
341 | ) -> str:
342 | error_deprecation()
343 |
344 |
345 | def get_plotly_figure(
346 | plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
347 | ) -> plotly.graph_objs.Figure:
348 | error_deprecation()
349 |
350 |
351 | def get_results(cs, default_database: str, sql: str) -> pd.DataFrame:
352 | error_deprecation()
353 |
354 |
355 | def generate_explanation(sql: str) -> str:
356 | error_deprecation()
357 |
358 |
359 | def generate_question(sql: str) -> str:
360 | error_deprecation()
361 |
362 |
363 | def get_all_questions() -> pd.DataFrame:
364 | error_deprecation()
365 |
366 |
367 | def get_training_data() -> pd.DataFrame:
368 | error_deprecation()
369 |
370 |
371 | def connect_to_sqlite(url: str):
372 | error_deprecation()
373 |
374 |
375 | def connect_to_snowflake(
376 | account: str,
377 | username: str,
378 | password: str,
379 | database: str,
380 | schema: Union[str, None] = None,
381 | role: Union[str, None] = None,
382 | ):
383 | error_deprecation()
384 |
385 |
386 | def connect_to_postgres(
387 | host: str = None,
388 | dbname: str = None,
389 | user: str = None,
390 | password: str = None,
391 | port: int = None,
392 | ):
393 | error_deprecation()
394 |
395 |
396 | def connect_to_bigquery(cred_file_path: str = None, project_id: str = None):
397 | error_deprecation()
398 |
399 | def connect_to_duckdb(url: str="memory", init_sql: str = None):
400 | error_deprecation()
--------------------------------------------------------------------------------
/src/vanna/advanced/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class VannaAdvanced(ABC):
5 | def __init__(self, config=None):
6 | self.config = config
7 |
8 | @abstractmethod
9 | def get_function(self, question: str, additional_data: dict = {}) -> dict:
10 | pass
11 |
12 | @abstractmethod
13 | def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
14 | pass
15 |
16 | @abstractmethod
17 | def update_function(self, old_function_name: str, updated_function: dict) -> bool:
18 | pass
19 |
20 | @abstractmethod
21 | def delete_function(self, function_name: str) -> bool:
22 | pass
23 |
24 | @abstractmethod
25 | def get_all_functions(self) -> list:
26 | pass
27 |
--------------------------------------------------------------------------------
/src/vanna/anthropic/__init__.py:
--------------------------------------------------------------------------------
1 | from .anthropic_chat import Anthropic_Chat
2 |
--------------------------------------------------------------------------------
/src/vanna/anthropic/anthropic_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import anthropic
4 |
5 | from ..base import VannaBase
6 |
7 |
8 | class Anthropic_Chat(VannaBase):
9 | def __init__(self, client=None, config=None):
10 | VannaBase.__init__(self, config=config)
11 |
12 | # default parameters - can be overrided using config
13 | self.temperature = 0.7
14 | self.max_tokens = 500
15 |
16 | if "temperature" in config:
17 | self.temperature = config["temperature"]
18 |
19 | if "max_tokens" in config:
20 | self.max_tokens = config["max_tokens"]
21 |
22 | if client is not None:
23 | self.client = client
24 | return
25 |
26 | if config is None and client is None:
27 | self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
28 | return
29 |
30 | if "api_key" in config:
31 | self.client = anthropic.Anthropic(api_key=config["api_key"])
32 |
33 | def system_message(self, message: str) -> any:
34 | return {"role": "system", "content": message}
35 |
36 | def user_message(self, message: str) -> any:
37 | return {"role": "user", "content": message}
38 |
39 | def assistant_message(self, message: str) -> any:
40 | return {"role": "assistant", "content": message}
41 |
42 | def submit_prompt(self, prompt, **kwargs) -> str:
43 | if prompt is None:
44 | raise Exception("Prompt is None")
45 |
46 | if len(prompt) == 0:
47 | raise Exception("Prompt is empty")
48 |
49 | # Count the number of tokens in the message log
50 | # Use 4 as an approximation for the number of characters per token
51 | num_tokens = 0
52 | for message in prompt:
53 | num_tokens += len(message["content"]) / 4
54 |
55 | if self.config is not None and "model" in self.config:
56 | print(
57 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
58 | )
59 | # claude required system message is a single filed
60 | # https://docs.anthropic.com/claude/reference/messages_post
61 | system_message = ''
62 | no_system_prompt = []
63 | for prompt_message in prompt:
64 | role = prompt_message['role']
65 | if role == 'system':
66 | system_message = prompt_message['content']
67 | else:
68 | no_system_prompt.append({"role": role, "content": prompt_message['content']})
69 |
70 | response = self.client.messages.create(
71 | model=self.config["model"],
72 | messages=no_system_prompt,
73 | system=system_message,
74 | max_tokens=self.max_tokens,
75 | temperature=self.temperature,
76 | )
77 |
78 | return response.content[0].text
79 |
--------------------------------------------------------------------------------
/src/vanna/azuresearch/__init__.py:
--------------------------------------------------------------------------------
1 | from .azuresearch_vector import AzureAISearch_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/azuresearch/azuresearch_vector.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import json
3 | from typing import List
4 |
5 | import pandas as pd
6 | from azure.core.credentials import AzureKeyCredential
7 | from azure.search.documents import SearchClient
8 | from azure.search.documents.indexes import SearchIndexClient
9 | from azure.search.documents.indexes.models import (
10 | ExhaustiveKnnAlgorithmConfiguration,
11 | ExhaustiveKnnParameters,
12 | SearchableField,
13 | SearchField,
14 | SearchFieldDataType,
15 | SearchIndex,
16 | VectorSearch,
17 | VectorSearchAlgorithmKind,
18 | VectorSearchAlgorithmMetric,
19 | VectorSearchProfile,
20 | )
21 | from azure.search.documents.models import VectorFilterMode, VectorizedQuery
22 | from fastembed import TextEmbedding
23 |
24 | from ..base import VannaBase
25 | from ..utils import deterministic_uuid
26 |
27 |
28 | class AzureAISearch_VectorStore(VannaBase):
29 | """
30 | AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search.
31 |
32 | Args:
33 | config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config.
34 | - azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net".
35 | - azure_search_api_key (str): Azure Search API key.
36 | - dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5.
37 | - fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
38 | - index_name (str, optional): Name of the index. Defaults to "vanna-index".
39 | - n_results (int, optional): Number of results to return. Defaults to 10.
40 | - n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results.
41 | - n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results.
42 | - n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results.
43 |
44 | Raises:
45 | ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config.
46 | """
47 | def __init__(self, config=None):
48 | VannaBase.__init__(self, config=config)
49 |
50 | self.config = config or None
51 |
52 | if config is None:
53 | raise ValueError(
54 | "config is required, pass an API key, 'azure_search_api_key', in the config."
55 | )
56 |
57 | azure_search_endpoint = config.get("azure_search_endpoint", "https://azcognetive.search.windows.net")
58 | azure_search_api_key = config.get("azure_search_api_key")
59 |
60 | self.dimensions = config.get("dimensions", 384)
61 | self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
62 |
63 | self.index_name = config.get("index_name", "vanna-index")
64 |
65 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
66 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
67 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
68 |
69 | if not azure_search_api_key:
70 | raise ValueError(
71 | "'azure_search_api_key' is required in config to use AzureAISearch_VectorStore"
72 | )
73 |
74 | self.index_client = SearchIndexClient(
75 | endpoint=azure_search_endpoint,
76 | credential=AzureKeyCredential(azure_search_api_key)
77 | )
78 |
79 | self.search_client = SearchClient(
80 | endpoint=azure_search_endpoint,
81 | index_name=self.index_name,
82 | credential=AzureKeyCredential(azure_search_api_key)
83 | )
84 |
85 | if self.index_name not in self._get_indexes():
86 | self._create_index()
87 |
88 | def _create_index(self) -> bool:
89 | fields = [
90 | SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
91 | SearchableField(name="document", type=SearchFieldDataType.String, searchable=True, filterable=True),
92 | SearchField(name="type", type=SearchFieldDataType.String, filterable=True, searchable=True),
93 | SearchField(name="document_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=self.dimensions, vector_search_profile_name="ExhaustiveKnnProfile"),
94 | ]
95 |
96 | vector_search = VectorSearch(
97 | algorithms=[
98 | ExhaustiveKnnAlgorithmConfiguration(
99 | name="ExhaustiveKnn",
100 | kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
101 | parameters=ExhaustiveKnnParameters(
102 | metric=VectorSearchAlgorithmMetric.COSINE
103 | )
104 | )
105 | ],
106 | profiles=[
107 | VectorSearchProfile(
108 | name="ExhaustiveKnnProfile",
109 | algorithm_configuration_name="ExhaustiveKnn",
110 | )
111 | ]
112 | )
113 |
114 | index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
115 | result = self.index_client.create_or_update_index(index)
116 | print(f'{result.name} created')
117 |
118 | def _get_indexes(self) -> list:
119 | return [index for index in self.index_client.list_index_names()]
120 |
121 | def add_ddl(self, ddl: str) -> str:
122 | id = deterministic_uuid(ddl) + "-ddl"
123 | document = {
124 | "id": id,
125 | "document": ddl,
126 | "type": "ddl",
127 | "document_vector": self.generate_embedding(ddl)
128 | }
129 | self.search_client.upload_documents(documents=[document])
130 | return id
131 |
132 | def add_documentation(self, doc: str) -> str:
133 | id = deterministic_uuid(doc) + "-doc"
134 | document = {
135 | "id": id,
136 | "document": doc,
137 | "type": "doc",
138 | "document_vector": self.generate_embedding(doc)
139 | }
140 | self.search_client.upload_documents(documents=[document])
141 | return id
142 |
143 | def add_question_sql(self, question: str, sql: str) -> str:
144 | question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
145 | id = deterministic_uuid(question_sql_json) + "-sql"
146 | document = {
147 | "id": id,
148 | "document": question_sql_json,
149 | "type": "sql",
150 | "document_vector": self.generate_embedding(question_sql_json)
151 | }
152 | self.search_client.upload_documents(documents=[document])
153 | return id
154 |
155 | def get_related_ddl(self, text: str) -> List[str]:
156 | result = []
157 | vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
158 | df = pd.DataFrame(
159 | self.search_client.search(
160 | top=self.n_results_ddl,
161 | vector_queries=[vector_query],
162 | select=["id", "document", "type"],
163 | filter=f"type eq 'ddl'"
164 | )
165 | )
166 |
167 | if len(df):
168 | result = df["document"].tolist()
169 | return result
170 |
171 | def get_related_documentation(self, text: str) -> List[str]:
172 | result = []
173 | vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
174 |
175 | df = pd.DataFrame(
176 | self.search_client.search(
177 | top=self.n_results_documentation,
178 | vector_queries=[vector_query],
179 | select=["id", "document", "type"],
180 | filter=f"type eq 'doc'",
181 | vector_filter_mode=VectorFilterMode.PRE_FILTER
182 | )
183 | )
184 |
185 | if len(df):
186 | result = df["document"].tolist()
187 | return result
188 |
189 | def get_similar_question_sql(self, question: str) -> List[str]:
190 | result = []
191 | # Vectorize the text
192 | vector_query = VectorizedQuery(vector=self.generate_embedding(question), fields="document_vector")
193 | df = pd.DataFrame(
194 | self.search_client.search(
195 | top=self.n_results_sql,
196 | vector_queries=[vector_query],
197 | select=["id", "document", "type"],
198 | filter=f"type eq 'sql'"
199 | )
200 | )
201 |
202 | if len(df): # Check if there is similar query and the result is not empty
203 | result = [ast.literal_eval(element) for element in df["document"].tolist()]
204 |
205 | return result
206 |
207 | def get_training_data(self) -> List[str]:
208 |
209 | search = self.search_client.search(
210 | search_text="*",
211 | select=['id', 'document', 'type'],
212 | filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')"
213 | ).by_page()
214 |
215 | df = pd.DataFrame([item for page in search for item in page])
216 |
217 | if len(df):
218 | df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"])
219 | df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"])
220 | df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"]
221 |
222 | return df[["id", "question", "content", "type"]]
223 |
224 | return pd.DataFrame()
225 |
226 | def remove_training_data(self, id: str) -> bool:
227 | result = self.search_client.delete_documents(documents=[{'id':id}])
228 | return result[0].succeeded
229 |
230 | def remove_index(self):
231 | self.index_client.delete_index(self.index_name)
232 |
233 | def generate_embedding(self, data: str, **kwargs) -> List[float]:
234 | embedding_model = TextEmbedding(model_name=self.fastembed_model)
235 | embedding = next(embedding_model.embed(data))
236 | return embedding.tolist()
237 |
--------------------------------------------------------------------------------
/src/vanna/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import VannaBase
2 |
--------------------------------------------------------------------------------
/src/vanna/bedrock/__init__.py:
--------------------------------------------------------------------------------
1 | from .bedrock_converse import Bedrock_Converse
--------------------------------------------------------------------------------
/src/vanna/bedrock/bedrock_converse.py:
--------------------------------------------------------------------------------
1 | from ..base import VannaBase
2 |
3 | try:
4 | import boto3
5 | from botocore.exceptions import ClientError
6 | except ImportError:
7 | raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models")
8 |
9 | class Bedrock_Converse(VannaBase):
10 | def __init__(self, client=None, config=None):
11 | VannaBase.__init__(self, config=config)
12 |
13 | # default parameters
14 | self.temperature = 0.0
15 | self.max_tokens = 500
16 |
17 | if client is None:
18 | raise ValueError(
19 | "A valid Bedrock runtime client must be provided to invoke Bedrock models"
20 | )
21 | else:
22 | self.client = client
23 |
24 | if config is None:
25 | raise ValueError(
26 | "Config is required with model_id and inference parameters"
27 | )
28 |
29 | if "modelId" not in config:
30 | raise ValueError(
31 | "config must contain a modelId to invoke"
32 | )
33 | else:
34 | self.model = config["modelId"]
35 |
36 | if "temperature" in config:
37 | self.temperature = config["temperature"]
38 |
39 | if "max_tokens" in config:
40 | self.max_tokens = config["max_tokens"]
41 |
42 | def system_message(self, message: str) -> dict:
43 | return {"role": "system", "content": message}
44 |
45 | def user_message(self, message: str) -> dict:
46 | return {"role": "user", "content": message}
47 |
48 | def assistant_message(self, message: str) -> dict:
49 | return {"role": "assistant", "content": message}
50 |
51 | def submit_prompt(self, prompt, **kwargs) -> str:
52 | inference_config = {
53 | "temperature": self.temperature,
54 | "maxTokens": self.max_tokens
55 | }
56 | additional_model_fields = {
57 | "top_p": 1, # setting top_p value for nucleus sampling
58 | }
59 |
60 | system_message = None
61 | no_system_prompt = []
62 | for prompt_message in prompt:
63 | role = prompt_message["role"]
64 | if role == "system":
65 | system_message = prompt_message["content"]
66 | else:
67 | no_system_prompt.append({"role": role, "content":[{"text": prompt_message["content"]}]})
68 |
69 | converse_api_params = {
70 | "modelId": self.model,
71 | "messages": no_system_prompt,
72 | "inferenceConfig": inference_config,
73 | "additionalModelRequestFields": additional_model_fields
74 | }
75 |
76 | if system_message:
77 | converse_api_params["system"] = [{"text": system_message}]
78 |
79 | try:
80 | response = self.client.converse(**converse_api_params)
81 | text_content = response["output"]["message"]["content"][0]["text"]
82 | return text_content
83 | except ClientError as err:
84 | message = err.response["Error"]["Message"]
85 | raise Exception(f"A Bedrock client error occurred: {message}")
--------------------------------------------------------------------------------
/src/vanna/chromadb/__init__.py:
--------------------------------------------------------------------------------
1 | from .chromadb_vector import ChromaDB_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/chromadb/chromadb_vector.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import List
3 |
4 | import chromadb
5 | import pandas as pd
6 | from chromadb.config import Settings
7 | from chromadb.utils import embedding_functions
8 |
9 | from ..base import VannaBase
10 | from ..utils import deterministic_uuid
11 |
12 | default_ef = embedding_functions.DefaultEmbeddingFunction()
13 |
14 |
15 | class ChromaDB_VectorStore(VannaBase):
16 | def __init__(self, config=None):
17 | VannaBase.__init__(self, config=config)
18 | if config is None:
19 | config = {}
20 |
21 | path = config.get("path", ".")
22 | self.embedding_function = config.get("embedding_function", default_ef)
23 | curr_client = config.get("client", "persistent")
24 | collection_metadata = config.get("collection_metadata", None)
25 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
26 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
27 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
28 |
29 | if curr_client == "persistent":
30 | self.chroma_client = chromadb.PersistentClient(
31 | path=path, settings=Settings(anonymized_telemetry=False)
32 | )
33 | elif curr_client == "in-memory":
34 | self.chroma_client = chromadb.EphemeralClient(
35 | settings=Settings(anonymized_telemetry=False)
36 | )
37 | elif isinstance(curr_client, chromadb.api.client.Client):
38 | # allow providing client directly
39 | self.chroma_client = curr_client
40 | else:
41 | raise ValueError(f"Unsupported client was set in config: {curr_client}")
42 |
43 | self.documentation_collection = self.chroma_client.get_or_create_collection(
44 | name="documentation",
45 | embedding_function=self.embedding_function,
46 | metadata=collection_metadata,
47 | )
48 | self.ddl_collection = self.chroma_client.get_or_create_collection(
49 | name="ddl",
50 | embedding_function=self.embedding_function,
51 | metadata=collection_metadata,
52 | )
53 | self.sql_collection = self.chroma_client.get_or_create_collection(
54 | name="sql",
55 | embedding_function=self.embedding_function,
56 | metadata=collection_metadata,
57 | )
58 |
59 | def generate_embedding(self, data: str, **kwargs) -> List[float]:
60 | embedding = self.embedding_function([data])
61 | if len(embedding) == 1:
62 | return embedding[0]
63 | return embedding
64 |
65 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
66 | question_sql_json = json.dumps(
67 | {
68 | "question": question,
69 | "sql": sql,
70 | },
71 | ensure_ascii=False,
72 | )
73 | id = deterministic_uuid(question_sql_json) + "-sql"
74 | self.sql_collection.add(
75 | documents=question_sql_json,
76 | embeddings=self.generate_embedding(question_sql_json),
77 | ids=id,
78 | )
79 |
80 | return id
81 |
82 | def add_ddl(self, ddl: str, **kwargs) -> str:
83 | id = deterministic_uuid(ddl) + "-ddl"
84 | self.ddl_collection.add(
85 | documents=ddl,
86 | embeddings=self.generate_embedding(ddl),
87 | ids=id,
88 | )
89 | return id
90 |
91 | def add_documentation(self, documentation: str, **kwargs) -> str:
92 | id = deterministic_uuid(documentation) + "-doc"
93 | self.documentation_collection.add(
94 | documents=documentation,
95 | embeddings=self.generate_embedding(documentation),
96 | ids=id,
97 | )
98 | return id
99 |
100 | def get_training_data(self, **kwargs) -> pd.DataFrame:
101 | sql_data = self.sql_collection.get()
102 |
103 | df = pd.DataFrame()
104 |
105 | if sql_data is not None:
106 | # Extract the documents and ids
107 | documents = [json.loads(doc) for doc in sql_data["documents"]]
108 | ids = sql_data["ids"]
109 |
110 | # Create a DataFrame
111 | df_sql = pd.DataFrame(
112 | {
113 | "id": ids,
114 | "question": [doc["question"] for doc in documents],
115 | "content": [doc["sql"] for doc in documents],
116 | }
117 | )
118 |
119 | df_sql["training_data_type"] = "sql"
120 |
121 | df = pd.concat([df, df_sql])
122 |
123 | ddl_data = self.ddl_collection.get()
124 |
125 | if ddl_data is not None:
126 | # Extract the documents and ids
127 | documents = [doc for doc in ddl_data["documents"]]
128 | ids = ddl_data["ids"]
129 |
130 | # Create a DataFrame
131 | df_ddl = pd.DataFrame(
132 | {
133 | "id": ids,
134 | "question": [None for doc in documents],
135 | "content": [doc for doc in documents],
136 | }
137 | )
138 |
139 | df_ddl["training_data_type"] = "ddl"
140 |
141 | df = pd.concat([df, df_ddl])
142 |
143 | doc_data = self.documentation_collection.get()
144 |
145 | if doc_data is not None:
146 | # Extract the documents and ids
147 | documents = [doc for doc in doc_data["documents"]]
148 | ids = doc_data["ids"]
149 |
150 | # Create a DataFrame
151 | df_doc = pd.DataFrame(
152 | {
153 | "id": ids,
154 | "question": [None for doc in documents],
155 | "content": [doc for doc in documents],
156 | }
157 | )
158 |
159 | df_doc["training_data_type"] = "documentation"
160 |
161 | df = pd.concat([df, df_doc])
162 |
163 | return df
164 |
165 | def remove_training_data(self, id: str, **kwargs) -> bool:
166 | if id.endswith("-sql"):
167 | self.sql_collection.delete(ids=id)
168 | return True
169 | elif id.endswith("-ddl"):
170 | self.ddl_collection.delete(ids=id)
171 | return True
172 | elif id.endswith("-doc"):
173 | self.documentation_collection.delete(ids=id)
174 | return True
175 | else:
176 | return False
177 |
178 | def remove_collection(self, collection_name: str) -> bool:
179 | """
180 | This function can reset the collection to empty state.
181 |
182 | Args:
183 | collection_name (str): sql or ddl or documentation
184 |
185 | Returns:
186 | bool: True if collection is deleted, False otherwise
187 | """
188 | if collection_name == "sql":
189 | self.chroma_client.delete_collection(name="sql")
190 | self.sql_collection = self.chroma_client.get_or_create_collection(
191 | name="sql", embedding_function=self.embedding_function
192 | )
193 | return True
194 | elif collection_name == "ddl":
195 | self.chroma_client.delete_collection(name="ddl")
196 | self.ddl_collection = self.chroma_client.get_or_create_collection(
197 | name="ddl", embedding_function=self.embedding_function
198 | )
199 | return True
200 | elif collection_name == "documentation":
201 | self.chroma_client.delete_collection(name="documentation")
202 | self.documentation_collection = self.chroma_client.get_or_create_collection(
203 | name="documentation", embedding_function=self.embedding_function
204 | )
205 | return True
206 | else:
207 | return False
208 |
209 | @staticmethod
210 | def _extract_documents(query_results) -> list:
211 | """
212 | Static method to extract the documents from the results of a query.
213 |
214 | Args:
215 | query_results (pd.DataFrame): The dataframe to use.
216 |
217 | Returns:
218 | List[str] or None: The extracted documents, or an empty list or
219 | single document if an error occurred.
220 | """
221 | if query_results is None:
222 | return []
223 |
224 | if "documents" in query_results:
225 | documents = query_results["documents"]
226 |
227 | if len(documents) == 1 and isinstance(documents[0], list):
228 | try:
229 | documents = [json.loads(doc) for doc in documents[0]]
230 | except Exception as e:
231 | return documents[0]
232 |
233 | return documents
234 |
235 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
236 | return ChromaDB_VectorStore._extract_documents(
237 | self.sql_collection.query(
238 | query_texts=[question],
239 | n_results=self.n_results_sql,
240 | )
241 | )
242 |
243 | def get_related_ddl(self, question: str, **kwargs) -> list:
244 | return ChromaDB_VectorStore._extract_documents(
245 | self.ddl_collection.query(
246 | query_texts=[question],
247 | n_results=self.n_results_ddl,
248 | )
249 | )
250 |
251 | def get_related_documentation(self, question: str, **kwargs) -> list:
252 | return ChromaDB_VectorStore._extract_documents(
253 | self.documentation_collection.query(
254 | query_texts=[question],
255 | n_results=self.n_results_documentation,
256 | )
257 | )
258 |
--------------------------------------------------------------------------------
/src/vanna/cohere/__init__.py:
--------------------------------------------------------------------------------
1 | from .cohere_chat import Cohere_Chat
2 | from .cohere_embeddings import Cohere_Embeddings
--------------------------------------------------------------------------------
/src/vanna/cohere/cohere_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from openai import OpenAI
4 |
5 | from ..base import VannaBase
6 |
7 |
8 | class Cohere_Chat(VannaBase):
9 | def __init__(self, client=None, config=None):
10 | VannaBase.__init__(self, config=config)
11 |
12 | # default parameters - can be overridden using config
13 | self.temperature = 0.2 # Lower temperature for more precise SQL generation
14 | self.model = "command-a-03-2025" # Cohere's default model
15 |
16 | if config is not None:
17 | if "temperature" in config:
18 | self.temperature = config["temperature"]
19 | if "model" in config:
20 | self.model = config["model"]
21 |
22 | if client is not None:
23 | self.client = client
24 | return
25 |
26 | # Check for API key in environment variable
27 | api_key = os.getenv("COHERE_API_KEY")
28 |
29 | # Check for API key in config
30 | if config is not None and "api_key" in config:
31 | api_key = config["api_key"]
32 |
33 | # Validate API key
34 | if not api_key:
35 | raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.")
36 |
37 | # Initialize client with validated API key
38 | self.client = OpenAI(
39 | base_url="https://api.cohere.ai/compatibility/v1",
40 | api_key=api_key,
41 | )
42 |
43 | def system_message(self, message: str) -> any:
44 | return {"role": "developer", "content": message} # Cohere uses 'developer' for system role
45 |
46 | def user_message(self, message: str) -> any:
47 | return {"role": "user", "content": message}
48 |
49 | def assistant_message(self, message: str) -> any:
50 | return {"role": "assistant", "content": message}
51 |
52 | def submit_prompt(self, prompt, **kwargs) -> str:
53 | if prompt is None:
54 | raise Exception("Prompt is None")
55 |
56 | if len(prompt) == 0:
57 | raise Exception("Prompt is empty")
58 |
59 | # Count the number of tokens in the message log
60 | # Use 4 as an approximation for the number of characters per token
61 | num_tokens = 0
62 | for message in prompt:
63 | num_tokens += len(message["content"]) / 4
64 |
65 | # Use model from kwargs, config, or default
66 | model = kwargs.get("model", self.model)
67 | if self.config is not None and "model" in self.config and model == self.model:
68 | model = self.config["model"]
69 |
70 | print(f"Using model {model} for {num_tokens} tokens (approx)")
71 | try:
72 | response = self.client.chat.completions.create(
73 | model=model,
74 | messages=prompt,
75 | temperature=self.temperature,
76 | )
77 |
78 | # Check if response has expected structure
79 | if not response or not hasattr(response, 'choices') or not response.choices:
80 | raise ValueError("Received empty or malformed response from API")
81 |
82 | if not response.choices[0] or not hasattr(response.choices[0], 'message'):
83 | raise ValueError("Response is missing expected 'message' field")
84 |
85 | if not hasattr(response.choices[0].message, 'content'):
86 | raise ValueError("Response message is missing expected 'content' field")
87 |
88 | return response.choices[0].message.content
89 |
90 | except Exception as e:
91 | # Log the error and raise a more informative exception
92 | error_msg = f"Error processing Cohere chat response: {str(e)}"
93 | print(error_msg)
94 | raise Exception(error_msg)
--------------------------------------------------------------------------------
/src/vanna/cohere/cohere_embeddings.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from openai import OpenAI
4 |
5 | from ..base import VannaBase
6 |
7 |
8 | class Cohere_Embeddings(VannaBase):
9 | def __init__(self, client=None, config=None):
10 | VannaBase.__init__(self, config=config)
11 |
12 | # Default embedding model
13 | self.model = "embed-multilingual-v3.0"
14 |
15 | if config is not None and "model" in config:
16 | self.model = config["model"]
17 |
18 | if client is not None:
19 | self.client = client
20 | return
21 |
22 | # Check for API key in environment variable
23 | api_key = os.getenv("COHERE_API_KEY")
24 |
25 | # Check for API key in config
26 | if config is not None and "api_key" in config:
27 | api_key = config["api_key"]
28 |
29 | # Validate API key
30 | if not api_key:
31 | raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.")
32 |
33 | # Initialize client with validated API key
34 | self.client = OpenAI(
35 | base_url="https://api.cohere.ai/compatibility/v1",
36 | api_key=api_key,
37 | )
38 |
39 | def generate_embedding(self, data: str, **kwargs) -> list[float]:
40 | if not data:
41 | raise ValueError("Cannot generate embedding for empty input data")
42 |
43 | # Use model from kwargs, config, or default
44 | model = kwargs.get("model", self.model)
45 | if self.config is not None and "model" in self.config and model == self.model:
46 | model = self.config["model"]
47 |
48 | try:
49 | embedding = self.client.embeddings.create(
50 | model=model,
51 | input=data,
52 | encoding_format="float", # Ensure we get float values
53 | )
54 |
55 | # Check if response has expected structure
56 | if not embedding or not hasattr(embedding, 'data') or not embedding.data:
57 | raise ValueError("Received empty or malformed embedding response from API")
58 |
59 | if not embedding.data[0] or not hasattr(embedding.data[0], 'embedding'):
60 | raise ValueError("Embedding response is missing expected 'embedding' field")
61 |
62 | if not embedding.data[0].embedding:
63 | raise ValueError("Received empty embedding vector")
64 |
65 | return embedding.data[0].embedding
66 |
67 | except Exception as e:
68 | # Log the error and raise a more informative exception
69 | error_msg = f"Error generating embedding with Cohere: {str(e)}"
70 | print(error_msg)
71 | raise Exception(error_msg)
--------------------------------------------------------------------------------
/src/vanna/deepseek/__init__.py:
--------------------------------------------------------------------------------
1 | from .deepseek_chat import DeepSeekChat
2 |
--------------------------------------------------------------------------------
/src/vanna/deepseek/deepseek_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from openai import OpenAI
4 |
5 | from ..base import VannaBase
6 |
7 |
8 |
9 | # from vanna.chromadb import ChromaDB_VectorStore
10 |
11 | # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
12 | # def __init__(self, config=None):
13 | # ChromaDB_VectorStore.__init__(self, config=config)
14 | # DeepSeekChat.__init__(self, config=config)
15 |
16 | # vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
17 |
18 |
19 | class DeepSeekChat(VannaBase):
20 | def __init__(self, config=None):
21 | if config is None:
22 | raise ValueError(
23 | "For DeepSeek, config must be provided with an api_key and model"
24 | )
25 | if "api_key" not in config:
26 | raise ValueError("config must contain a DeepSeek api_key")
27 |
28 | if "model" not in config:
29 | raise ValueError("config must contain a DeepSeek model")
30 |
31 | api_key = config["api_key"]
32 | model = config["model"]
33 | self.model = model
34 | self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
35 |
36 | def system_message(self, message: str) -> any:
37 | return {"role": "system", "content": message}
38 |
39 | def user_message(self, message: str) -> any:
40 | return {"role": "user", "content": message}
41 |
42 | def assistant_message(self, message: str) -> any:
43 | return {"role": "assistant", "content": message}
44 |
45 | def generate_sql(self, question: str, **kwargs) -> str:
46 | # 使用父类的 generate_sql
47 | sql = super().generate_sql(question, **kwargs)
48 |
49 | # 替换 "\_" 为 "_"
50 | sql = sql.replace("\\_", "_")
51 |
52 | return sql
53 |
54 | def submit_prompt(self, prompt, **kwargs) -> str:
55 | chat_response = self.client.chat.completions.create(
56 | model=self.model,
57 | messages=prompt,
58 | )
59 |
60 | return chat_response.choices[0].message.content
61 |
--------------------------------------------------------------------------------
/src/vanna/exceptions/__init__.py:
--------------------------------------------------------------------------------
1 | class ImproperlyConfigured(Exception):
2 | """Raise for incorrect configuration."""
3 |
4 | pass
5 |
6 |
7 | class DependencyError(Exception):
8 | """Raise for missing dependencies."""
9 |
10 | pass
11 |
12 |
13 | class ConnectionError(Exception):
14 | """Raise for connection"""
15 |
16 | pass
17 |
18 |
19 | class OTPCodeError(Exception):
20 | """Raise for invalid otp or not able to send it"""
21 |
22 | pass
23 |
24 |
25 | class SQLRemoveError(Exception):
26 | """Raise when not able to remove SQL"""
27 |
28 | pass
29 |
30 |
31 | class ExecutionError(Exception):
32 | """Raise when not able to execute Code"""
33 |
34 | pass
35 |
36 |
37 | class ValidationError(Exception):
38 | """Raise for validations"""
39 |
40 | pass
41 |
42 |
43 | class APIError(Exception):
44 | """Raise for API errors"""
45 |
46 | pass
47 |
--------------------------------------------------------------------------------
/src/vanna/faiss/__init__.py:
--------------------------------------------------------------------------------
1 | from .faiss import FAISS
--------------------------------------------------------------------------------
/src/vanna/faiss/faiss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import uuid
4 | from typing import List, Dict, Any
5 |
6 | import faiss
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from ..base import VannaBase
11 | from ..exceptions import DependencyError
12 |
13 | class FAISS(VannaBase):
14 | def __init__(self, config=None):
15 | if config is None:
16 | config = {}
17 |
18 | VannaBase.__init__(self, config=config)
19 |
20 | try:
21 | import faiss
22 | except ImportError:
23 | raise DependencyError(
24 | "FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'"
25 | )
26 |
27 | try:
28 | from sentence_transformers import SentenceTransformer
29 | except ImportError:
30 | raise DependencyError(
31 | "SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
32 | )
33 |
34 | self.path = config.get("path", ".")
35 | self.embedding_dim = config.get('embedding_dim', 384)
36 | self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10))
37 | self.n_results_ddl = config.get('n_results_ddl', config.get("n_results", 10))
38 | self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10))
39 | self.curr_client = config.get("client", "persistent")
40 |
41 | if self.curr_client == 'persistent':
42 | self.sql_index = self._load_or_create_index('sql_index.faiss')
43 | self.ddl_index = self._load_or_create_index('ddl_index.faiss')
44 | self.doc_index = self._load_or_create_index('doc_index.faiss')
45 | elif self.curr_client == 'in-memory':
46 | self.sql_index = faiss.IndexFlatL2(self.embedding_dim)
47 | self.ddl_index = faiss.IndexFlatL2(self.embedding_dim)
48 | self.doc_index = faiss.IndexFlatL2(self.embedding_dim)
49 | elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client):
50 | self.sql_index = self.curr_client[0]
51 | self.ddl_index = self.curr_client[1]
52 | self.doc_index = self.curr_client[2]
53 | else:
54 | raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}")
55 |
56 | self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json')
57 | self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json')
58 | self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json')
59 |
60 | model_name = config.get('embedding_model', 'all-MiniLM-L6-v2')
61 | self.embedding_model = SentenceTransformer(model_name)
62 |
63 | def _load_or_create_index(self, filename):
64 | filepath = os.path.join(self.path, filename)
65 | if os.path.exists(filepath):
66 | return faiss.read_index(filepath)
67 | return faiss.IndexFlatL2(self.embedding_dim)
68 |
69 | def _load_or_create_metadata(self, filename):
70 | filepath = os.path.join(self.path, filename)
71 | if os.path.exists(filepath):
72 | with open(filepath, 'r') as f:
73 | return json.load(f)
74 | return []
75 |
76 | def _save_index(self, index, filename):
77 | if self.curr_client == 'persistent':
78 | filepath = os.path.join(self.path, filename)
79 | faiss.write_index(index, filepath)
80 |
81 | def _save_metadata(self, metadata, filename):
82 | if self.curr_client == 'persistent':
83 | filepath = os.path.join(self.path, filename)
84 | with open(filepath, 'w') as f:
85 | json.dump(metadata, f)
86 |
87 | def generate_embedding(self, data: str, **kwargs) -> List[float]:
88 | embedding = self.embedding_model.encode(data)
89 | assert embedding.shape[0] == self.embedding_dim, \
90 | f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
91 | return embedding.tolist()
92 |
93 | def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
94 | embedding = self.generate_embedding(text)
95 | index.add(np.array([embedding], dtype=np.float32))
96 | entry_id = str(uuid.uuid4())
97 | metadata_list.append({"id": entry_id, **(extra_metadata or {})})
98 | return entry_id
99 |
100 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
101 | entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql})
102 | self._save_index(self.sql_index, 'sql_index.faiss')
103 | self._save_metadata(self.sql_metadata, 'sql_metadata.json')
104 | return entry_id
105 |
106 | def add_ddl(self, ddl: str, **kwargs) -> str:
107 | entry_id = self._add_to_index(self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl})
108 | self._save_index(self.ddl_index, 'ddl_index.faiss')
109 | self._save_metadata(self.ddl_metadata, 'ddl_metadata.json')
110 | return entry_id
111 |
112 | def add_documentation(self, documentation: str, **kwargs) -> str:
113 | entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation})
114 | self._save_index(self.doc_index, 'doc_index.faiss')
115 | self._save_metadata(self.doc_metadata, 'doc_metadata.json')
116 | return entry_id
117 |
118 | def _get_similar(self, index, metadata_list, text, n_results) -> list:
119 | embedding = self.generate_embedding(text)
120 | D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
121 | return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
122 |
123 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
124 | return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql)
125 |
126 | def get_related_ddl(self, question: str, **kwargs) -> list:
127 | return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)]
128 |
129 | def get_related_documentation(self, question: str, **kwargs) -> list:
130 | return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)]
131 |
132 | def get_training_data(self, **kwargs) -> pd.DataFrame:
133 | sql_data = pd.DataFrame(self.sql_metadata)
134 | sql_data['training_data_type'] = 'sql'
135 |
136 | ddl_data = pd.DataFrame(self.ddl_metadata)
137 | ddl_data['training_data_type'] = 'ddl'
138 |
139 | doc_data = pd.DataFrame(self.doc_metadata)
140 | doc_data['training_data_type'] = 'documentation'
141 |
142 | return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True)
143 |
144 | def remove_training_data(self, id: str, **kwargs) -> bool:
145 | for metadata_list, index, index_name in [
146 | (self.sql_metadata, self.sql_index, 'sql_index.faiss'),
147 | (self.ddl_metadata, self.ddl_index, 'ddl_index.faiss'),
148 | (self.doc_metadata, self.doc_index, 'doc_index.faiss')
149 | ]:
150 | for i, item in enumerate(metadata_list):
151 | if item['id'] == id:
152 | del metadata_list[i]
153 | new_index = faiss.IndexFlatL2(self.embedding_dim)
154 | embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list]
155 | if embeddings:
156 | new_index.add(np.array(embeddings, dtype=np.float32))
157 | setattr(self, index_name.split('.')[0], new_index)
158 |
159 | if self.curr_client == 'persistent':
160 | self._save_index(new_index, index_name)
161 | self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json")
162 |
163 | return True
164 | return False
165 |
166 | def remove_collection(self, collection_name: str) -> bool:
167 | if collection_name in ["sql", "ddl", "documentation"]:
168 | setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim))
169 | setattr(self, f"{collection_name}_metadata", [])
170 |
171 | if self.curr_client == 'persistent':
172 | self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss")
173 | self._save_metadata([], f"{collection_name}_metadata.json")
174 |
175 | return True
176 | return False
--------------------------------------------------------------------------------
/src/vanna/flask/auth.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import flask
4 |
5 |
6 | class AuthInterface(ABC):
7 | @abstractmethod
8 | def get_user(self, flask_request) -> any:
9 | pass
10 |
11 | @abstractmethod
12 | def is_logged_in(self, user: any) -> bool:
13 | pass
14 |
15 | @abstractmethod
16 | def override_config_for_user(self, user: any, config: dict) -> dict:
17 | pass
18 |
19 | @abstractmethod
20 | def login_form(self) -> str:
21 | pass
22 |
23 | @abstractmethod
24 | def login_handler(self, flask_request) -> str:
25 | pass
26 |
27 | @abstractmethod
28 | def callback_handler(self, flask_request) -> str:
29 | pass
30 |
31 | @abstractmethod
32 | def logout_handler(self, flask_request) -> str:
33 | pass
34 |
35 | class NoAuth(AuthInterface):
36 | def get_user(self, flask_request) -> any:
37 | return {}
38 |
39 | def is_logged_in(self, user: any) -> bool:
40 | return True
41 |
42 | def override_config_for_user(self, user: any, config: dict) -> dict:
43 | return config
44 |
45 | def login_form(self) -> str:
46 | return ''
47 |
48 | def login_handler(self, flask_request) -> str:
49 | return 'No login required'
50 |
51 | def callback_handler(self, flask_request) -> str:
52 | return 'No login required'
53 |
54 | def logout_handler(self, flask_request) -> str:
55 | return 'No login required'
56 |
--------------------------------------------------------------------------------
/src/vanna/google/__init__.py:
--------------------------------------------------------------------------------
1 | from .bigquery_vector import BigQuery_VectorStore
2 | from .gemini_chat import GoogleGeminiChat
3 |
--------------------------------------------------------------------------------
/src/vanna/google/bigquery_vector.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import uuid
4 | from typing import List, Optional
5 | from vertexai.language_models import (
6 | TextEmbeddingInput,
7 | TextEmbeddingModel
8 | )
9 |
10 | import pandas as pd
11 | from google.cloud import bigquery
12 |
13 | from ..base import VannaBase
14 |
15 |
16 | class BigQuery_VectorStore(VannaBase):
17 | def __init__(self, config: dict, **kwargs):
18 | self.config = config
19 |
20 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
21 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
22 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
23 |
24 | if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
25 | """
26 | If Google api_key is provided through config
27 | or set as an environment variable, assign it.
28 | """
29 | print("Configuring genai")
30 | self.type = "GEMINI"
31 | import google.generativeai as genai
32 |
33 | genai.configure(api_key=config["api_key"])
34 |
35 | self.genai = genai
36 | else:
37 | self.type = "VERTEX_AI"
38 | # Authenticate using VertexAI
39 |
40 | if self.config.get("project_id"):
41 | self.project_id = self.config.get("project_id")
42 | else:
43 | self.project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
44 |
45 | if self.project_id is None:
46 | raise ValueError("Project ID is not set")
47 |
48 | self.conn = bigquery.Client(project=self.project_id)
49 |
50 | dataset_name = self.config.get('bigquery_dataset_name', 'vanna_managed')
51 | self.dataset_id = f"{self.project_id}.{dataset_name}"
52 | dataset = bigquery.Dataset(self.dataset_id)
53 |
54 | try:
55 | self.conn.get_dataset(self.dataset_id) # Make an API request.
56 | print(f"Dataset {self.dataset_id} already exists")
57 | except Exception:
58 | # Dataset does not exist, create it
59 | dataset.location = "US"
60 | self.conn.create_dataset(dataset, timeout=30) # Make an API request.
61 | print(f"Created dataset {self.dataset_id}")
62 |
63 | # Create a table called training_data in the dataset that contains the columns:
64 | # id, training_data_type, question, content, embedding, created_at
65 |
66 | self.table_id = f"{self.dataset_id}.training_data"
67 | schema = [
68 | bigquery.SchemaField("id", "STRING", mode="REQUIRED"),
69 | bigquery.SchemaField("training_data_type", "STRING", mode="REQUIRED"),
70 | bigquery.SchemaField("question", "STRING", mode="REQUIRED"),
71 | bigquery.SchemaField("content", "STRING", mode="REQUIRED"),
72 | bigquery.SchemaField("embedding", "FLOAT64", mode="REPEATED"),
73 | bigquery.SchemaField("created_at", "TIMESTAMP", mode="REQUIRED"),
74 | ]
75 |
76 | table = bigquery.Table(self.table_id, schema=schema)
77 |
78 | try:
79 | self.conn.get_table(self.table_id) # Make an API request.
80 | print(f"Table {self.table_id} already exists")
81 | except Exception:
82 | # Table does not exist, create it
83 | self.conn.create_table(table, timeout=30) # Make an API request.
84 | print(f"Created table {self.table_id}")
85 |
86 | # Create VECTOR INDEX IF NOT EXISTS
87 | # TODO: This requires 5000 rows before it can be created
88 | # vector_index_query = f"""
89 | # CREATE VECTOR INDEX IF NOT EXISTS my_index
90 | # ON `{self.table_id}`(embedding)
91 | # OPTIONS(
92 | # distance_type='COSINE',
93 | # index_type='IVF',
94 | # ivf_options='{{"num_lists": 1000}}'
95 | # )
96 | # """
97 |
98 | # try:
99 | # self.conn.query(vector_index_query).result() # Make an API request.
100 | # print(f"Vector index on {self.table_id} created or already exists")
101 | # except Exception as e:
102 | # print(f"Failed to create vector index: {e}")
103 |
104 | def store_training_data(self, training_data_type: str, question: str, content: str, embedding: List[float], **kwargs) -> str:
105 | id = str(uuid.uuid4())
106 | created_at = datetime.datetime.now()
107 | self.conn.insert_rows_json(self.table_id, [{
108 | "id": id,
109 | "training_data_type": training_data_type,
110 | "question": question,
111 | "content": content,
112 | "embedding": embedding,
113 | "created_at": created_at.isoformat()
114 | }])
115 |
116 | return id
117 |
118 | def fetch_similar_training_data(self, training_data_type: str, question: str, n_results, **kwargs) -> pd.DataFrame:
119 | question_embedding = self.generate_question_embedding(question)
120 |
121 | query = f"""
122 | SELECT
123 | base.id as id,
124 | base.question as question,
125 | base.training_data_type as training_data_type,
126 | base.content as content,
127 | distance
128 | FROM
129 | VECTOR_SEARCH(
130 | TABLE `{self.table_id}`,
131 | 'embedding',
132 | (SELECT * FROM UNNEST([STRUCT({question_embedding})])),
133 | top_k => 5,
134 | distance_type => 'COSINE',
135 | options => '{{"use_brute_force":true}}'
136 | )
137 | WHERE
138 | base.training_data_type = '{training_data_type}'
139 | """
140 |
141 | results = self.conn.query(query).result().to_dataframe()
142 | return results
143 |
144 | def get_embeddings(self, data: str, task: str) -> List[float]:
145 | embeddings = None
146 |
147 | if self.type == "VERTEX_AI":
148 | input = [TextEmbeddingInput(data, task)]
149 | model = TextEmbeddingModel.from_pretrained("text-embedding-004")
150 |
151 | result = model.get_embeddings(input)
152 |
153 | if len(result) > 0:
154 | embeddings = result[0].values
155 | else:
156 | # Use Gemini Consumer API
157 | result = self.genai.embed_content(
158 | model="models/text-embedding-004",
159 | content=data,
160 | task_type=task)
161 |
162 | if 'embedding' in result:
163 | embeddings = result['embedding']
164 |
165 | return embeddings
166 |
167 | def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
168 | result = self.get_embeddings(data, "RETRIEVAL_QUERY")
169 |
170 | if result != None:
171 | return result
172 | else:
173 | raise ValueError("No embeddings returned")
174 |
175 | def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
176 | result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
177 |
178 | if result != None:
179 | return result
180 | else:
181 | raise ValueError("No embeddings returned")
182 |
183 | # task = "RETRIEVAL_DOCUMENT"
184 | # inputs = [TextEmbeddingInput(data, task)]
185 | # embeddings = self.vertex_embedding_model.get_embeddings(inputs)
186 |
187 | # if len(embeddings) == 0:
188 | # raise ValueError("No embeddings returned")
189 |
190 | # return embeddings[0].values
191 |
192 | return result
193 |
194 | def generate_embedding(self, data: str, **kwargs) -> List[float]:
195 | return self.generate_storage_embedding(data, **kwargs)
196 |
197 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
198 | df = self.fetch_similar_training_data(training_data_type="sql", question=question, n_results=self.n_results_sql)
199 |
200 | # Return a list of dictionaries with only question, sql fields. The content field needs to be renamed to sql
201 | return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(orient="records")
202 |
203 | def get_related_ddl(self, question: str, **kwargs) -> list:
204 | df = self.fetch_similar_training_data(training_data_type="ddl", question=question, n_results=self.n_results_ddl)
205 |
206 | # Return a list of strings of the content
207 | return df["content"].tolist()
208 |
209 | def get_related_documentation(self, question: str, **kwargs) -> list:
210 | df = self.fetch_similar_training_data(training_data_type="documentation", question=question, n_results=self.n_results_documentation)
211 |
212 | # Return a list of strings of the content
213 | return df["content"].tolist()
214 |
215 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
216 | doc = {
217 | "question": question,
218 | "sql": sql
219 | }
220 |
221 | embedding = self.generate_embedding(str(doc))
222 |
223 | return self.store_training_data(training_data_type="sql", question=question, content=sql, embedding=embedding)
224 |
225 | def add_ddl(self, ddl: str, **kwargs) -> str:
226 | embedding = self.generate_embedding(ddl)
227 |
228 | return self.store_training_data(training_data_type="ddl", question="", content=ddl, embedding=embedding)
229 |
230 | def add_documentation(self, documentation: str, **kwargs) -> str:
231 | embedding = self.generate_embedding(documentation)
232 |
233 | return self.store_training_data(training_data_type="documentation", question="", content=documentation, embedding=embedding)
234 |
235 | def get_training_data(self, **kwargs) -> pd.DataFrame:
236 | query = f"SELECT id, training_data_type, question, content FROM `{self.table_id}`"
237 |
238 | return self.conn.query(query).result().to_dataframe()
239 |
240 | def remove_training_data(self, id: str, **kwargs) -> bool:
241 | query = f"DELETE FROM `{self.table_id}` WHERE id = '{id}'"
242 |
243 | try:
244 | self.conn.query(query).result()
245 | return True
246 |
247 | except Exception as e:
248 | print(f"Failed to remove training data: {e}")
249 | return False
250 |
--------------------------------------------------------------------------------
/src/vanna/google/gemini_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class GoogleGeminiChat(VannaBase):
7 | def __init__(self, config=None):
8 | VannaBase.__init__(self, config=config)
9 |
10 | # default temperature - can be overrided using config
11 | self.temperature = 0.7
12 |
13 | if "temperature" in config:
14 | self.temperature = config["temperature"]
15 |
16 | if "model_name" in config:
17 | model_name = config["model_name"]
18 | else:
19 | model_name = "gemini-1.5-pro"
20 |
21 | self.google_api_key = None
22 |
23 | if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
24 | """
25 | If Google api_key is provided through config
26 | or set as an environment variable, assign it.
27 | """
28 | import google.generativeai as genai
29 |
30 | genai.configure(api_key=config["api_key"])
31 | self.chat_model = genai.GenerativeModel(model_name)
32 | else:
33 | # Authenticate using VertexAI
34 | import google.auth
35 | import vertexai
36 | from vertexai.generative_models import GenerativeModel
37 |
38 | json_file_path = config.get("google_credentials") # Assuming the JSON file path is provided in the config
39 |
40 | if not json_file_path or not os.path.exists(json_file_path):
41 | raise FileNotFoundError(f"JSON credentials file not found at: {json_file_path}")
42 |
43 | try:
44 | # Validate and set the JSON file path for GOOGLE_APPLICATION_CREDENTIALS
45 | os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = json_file_path
46 |
47 | # Initialize VertexAI with the credentials
48 | credentials, _ = google.auth.default()
49 | vertexai.init(credentials=credentials)
50 | self.chat_model = GenerativeModel(model_name)
51 | except google.auth.exceptions.DefaultCredentialsError as e:
52 | raise RuntimeError(f"Default credentials error: {e}")
53 | except google.auth.exceptions.TransportError as e:
54 | raise RuntimeError(f"Transport error during authentication: {e}")
55 | except Exception as e:
56 | raise RuntimeError(f"Failed to authenticate using JSON file: {e}")
57 |
58 | def system_message(self, message: str) -> any:
59 | return message
60 |
61 | def user_message(self, message: str) -> any:
62 | return message
63 |
64 | def assistant_message(self, message: str) -> any:
65 | return message
66 |
67 | def submit_prompt(self, prompt, **kwargs) -> str:
68 | response = self.chat_model.generate_content(
69 | prompt,
70 | generation_config={
71 | "temperature": self.temperature,
72 | },
73 | )
74 | return response.text
75 |
--------------------------------------------------------------------------------
/src/vanna/hf/__init__.py:
--------------------------------------------------------------------------------
1 | from .hf import Hf
2 |
--------------------------------------------------------------------------------
/src/vanna/hf/hf.py:
--------------------------------------------------------------------------------
1 | import re
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 |
4 | from ..base import VannaBase
5 |
6 |
7 | class Hf(VannaBase):
8 | def __init__(self, config=None):
9 | model_name_or_path = self.config.get(
10 | "model_name_or_path", None
11 | ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files
12 | # list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview
13 | quantization_config = self.config.get("quantization_config", None)
14 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
15 | self.model = AutoModelForCausalLM.from_pretrained(
16 | model_name_or_path,
17 | quantization_config=quantization_config,
18 | device_map="auto",
19 | )
20 |
21 | def system_message(self, message: str) -> any:
22 | return {"role": "system", "content": message}
23 |
24 | def user_message(self, message: str) -> any:
25 | return {"role": "user", "content": message}
26 |
27 | def assistant_message(self, message: str) -> any:
28 | return {"role": "assistant", "content": message}
29 |
30 | def extract_sql_query(self, text):
31 | """
32 | Extracts the first SQL statement after the word 'select', ignoring case,
33 | matches until the first semicolon, three backticks, or the end of the string,
34 | and removes three backticks if they exist in the extracted string.
35 |
36 | Args:
37 | - text (str): The string to search within for an SQL statement.
38 |
39 | Returns:
40 | - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
41 | """
42 | # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
43 | pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
44 |
45 | match = pattern.search(text)
46 | if match:
47 | # Remove three backticks from the matched string if they exist
48 | return match.group(0).replace("```", "")
49 | else:
50 | return text
51 |
52 | def generate_sql(self, question: str, **kwargs) -> str:
53 | # Use the super generate_sql
54 | sql = super().generate_sql(question, **kwargs)
55 |
56 | # Replace "\_" with "_"
57 | sql = sql.replace("\\_", "_")
58 |
59 | sql = sql.replace("\\", "")
60 |
61 | return self.extract_sql_query(sql)
62 |
63 | def submit_prompt(self, prompt, **kwargs) -> str:
64 |
65 | input_ids = self.tokenizer.apply_chat_template(
66 | prompt, add_generation_prompt=True, return_tensors="pt"
67 | ).to(self.model.device)
68 |
69 | outputs = self.model.generate(
70 | input_ids,
71 | max_new_tokens=512,
72 | eos_token_id=self.tokenizer.eos_token_id,
73 | do_sample=True,
74 | temperature=1,
75 | top_p=0.9,
76 | )
77 | response = outputs[0][input_ids.shape[-1] :]
78 | response = self.tokenizer.decode(response, skip_special_tokens=True)
79 | self.log(response)
80 |
81 | return response
82 |
--------------------------------------------------------------------------------
/src/vanna/local.py:
--------------------------------------------------------------------------------
1 | from .chromadb.chromadb_vector import ChromaDB_VectorStore
2 | from .openai.openai_chat import OpenAI_Chat
3 |
4 |
5 | class LocalContext_OpenAI(ChromaDB_VectorStore, OpenAI_Chat):
6 | def __init__(self, config=None):
7 | ChromaDB_VectorStore.__init__(self, config=config)
8 | OpenAI_Chat.__init__(self, config=config)
9 |
--------------------------------------------------------------------------------
/src/vanna/marqo/__init__.py:
--------------------------------------------------------------------------------
1 | from .marqo import Marqo_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/marqo/marqo.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | import marqo
4 | import pandas as pd
5 |
6 | from ..base import VannaBase
7 |
8 |
9 | class Marqo_VectorStore(VannaBase):
10 | def __init__(self, config=None):
11 | VannaBase.__init__(self, config=config)
12 |
13 | if config is not None and "marqo_url" in config:
14 | marqo_url = config["marqo_url"]
15 | else:
16 | marqo_url = "http://localhost:8882"
17 |
18 | if config is not None and "marqo_model" in config:
19 | marqo_model = config["marqo_model"]
20 | else:
21 | marqo_model = "hf/all_datasets_v4_MiniLM-L6"
22 |
23 | self.mq = marqo.Client(url=marqo_url)
24 |
25 | for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]:
26 | try:
27 | self.mq.create_index(index, model=marqo_model)
28 | except Exception as e:
29 | print(e)
30 | print(f"Marqo index {index} already exists")
31 | pass
32 |
33 | def generate_embedding(self, data: str, **kwargs) -> list[float]:
34 | # Marqo doesn't need to generate embeddings
35 | pass
36 |
37 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
38 | id = str(uuid.uuid4()) + "-sql"
39 | question_sql_dict = {
40 | "question": question,
41 | "sql": sql,
42 | "_id": id,
43 | }
44 |
45 | self.mq.index("vanna-sql").add_documents(
46 | [question_sql_dict],
47 | tensor_fields=["question", "sql"],
48 | )
49 |
50 | return id
51 |
52 | def add_ddl(self, ddl: str, **kwargs) -> str:
53 | id = str(uuid.uuid4()) + "-ddl"
54 | ddl_dict = {
55 | "ddl": ddl,
56 | "_id": id,
57 | }
58 |
59 | self.mq.index("vanna-ddl").add_documents(
60 | [ddl_dict],
61 | tensor_fields=["ddl"],
62 | )
63 | return id
64 |
65 | def add_documentation(self, documentation: str, **kwargs) -> str:
66 | id = str(uuid.uuid4()) + "-doc"
67 | doc_dict = {
68 | "doc": documentation,
69 | "_id": id,
70 | }
71 |
72 | self.mq.index("vanna-doc").add_documents(
73 | [doc_dict],
74 | tensor_fields=["doc"],
75 | )
76 | return id
77 |
78 | def get_training_data(self, **kwargs) -> pd.DataFrame:
79 | data = []
80 |
81 | for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]:
82 | data.append(
83 | {
84 | "id": hit["_id"],
85 | "training_data_type": "documentation",
86 | "question": "",
87 | "content": hit["doc"],
88 | }
89 | )
90 |
91 | for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]:
92 | data.append(
93 | {
94 | "id": hit["_id"],
95 | "training_data_type": "ddl",
96 | "question": "",
97 | "content": hit["ddl"],
98 | }
99 | )
100 |
101 | for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]:
102 | data.append(
103 | {
104 | "id": hit["_id"],
105 | "training_data_type": "sql",
106 | "question": hit["question"],
107 | "content": hit["sql"],
108 | }
109 | )
110 |
111 | df = pd.DataFrame(data)
112 |
113 | return df
114 |
115 | def remove_training_data(self, id: str, **kwargs) -> bool:
116 | if id.endswith("-sql"):
117 | self.mq.index("vanna-sql").delete_documents(ids=[id])
118 | return True
119 | elif id.endswith("-ddl"):
120 | self.mq.index("vanna-ddl").delete_documents(ids=[id])
121 | return True
122 | elif id.endswith("-doc"):
123 | self.mq.index("vanna-doc").delete_documents(ids=[id])
124 | return True
125 | else:
126 | return False
127 |
128 | # Static method to extract the documents from the results of a query
129 | @staticmethod
130 | def _extract_documents(data) -> list:
131 | # Check if 'hits' key is in the dictionary and if it's a list
132 | if "hits" in data and isinstance(data["hits"], list):
133 | # Iterate over each item in 'hits'
134 |
135 | if len(data["hits"]) == 0:
136 | return []
137 |
138 | # If there is a "doc" key, return the value of that key
139 | if "doc" in data["hits"][0]:
140 | return [hit["doc"] for hit in data["hits"]]
141 |
142 | # If there is a "ddl" key, return the value of that key
143 | if "ddl" in data["hits"][0]:
144 | return [hit["ddl"] for hit in data["hits"]]
145 |
146 | # Otherwise return the entire hit
147 | return [
148 | {key: value for key, value in hit.items() if not key.startswith("_")}
149 | for hit in data["hits"]
150 | ]
151 | else:
152 | # Return an empty list if 'hits' is not found or not a list
153 | return []
154 |
155 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
156 | return Marqo_VectorStore._extract_documents(
157 | self.mq.index("vanna-sql").search(question)
158 | )
159 |
160 | def get_related_ddl(self, question: str, **kwargs) -> list:
161 | return Marqo_VectorStore._extract_documents(
162 | self.mq.index("vanna-ddl").search(question)
163 | )
164 |
165 | def get_related_documentation(self, question: str, **kwargs) -> list:
166 | return Marqo_VectorStore._extract_documents(
167 | self.mq.index("vanna-doc").search(question)
168 | )
169 |
--------------------------------------------------------------------------------
/src/vanna/milvus/__init__.py:
--------------------------------------------------------------------------------
1 | from .milvus_vector import Milvus_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/mistral/__init__.py:
--------------------------------------------------------------------------------
1 | from .mistral import Mistral
2 |
--------------------------------------------------------------------------------
/src/vanna/mistral/mistral.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from mistralai import Mistral as MistralClient
4 | from mistralai import UserMessage
5 |
6 | from ..base import VannaBase
7 |
8 |
9 | class Mistral(VannaBase):
10 | def __init__(self, config=None):
11 | if config is None:
12 | raise ValueError(
13 | "For Mistral, config must be provided with an api_key and model"
14 | )
15 |
16 | if "api_key" not in config:
17 | raise ValueError("config must contain a Mistral api_key")
18 |
19 | if "model" not in config:
20 | raise ValueError("config must contain a Mistral model")
21 |
22 | api_key = config["api_key"]
23 | model = config["model"]
24 | self.client = MistralClient(api_key=api_key)
25 | self.model = model
26 |
27 | def system_message(self, message: str) -> any:
28 | return {"role": "system", "content": message}
29 |
30 | def user_message(self, message: str) -> any:
31 | return {"role": "user", "content": message}
32 |
33 | def assistant_message(self, message: str) -> any:
34 | return {"role": "assistant", "content": message}
35 |
36 | def generate_sql(self, question: str, **kwargs) -> str:
37 | # Use the super generate_sql
38 | sql = super().generate_sql(question, **kwargs)
39 |
40 | # Replace "\_" with "_"
41 | sql = sql.replace("\\_", "_")
42 |
43 | return sql
44 |
45 | def submit_prompt(self, prompt, **kwargs) -> str:
46 | chat_response = self.client.chat.complete(
47 | model=self.model,
48 | messages=prompt,
49 | )
50 |
51 | return chat_response.choices[0].message.content
52 |
--------------------------------------------------------------------------------
/src/vanna/mock/__init__.py:
--------------------------------------------------------------------------------
1 | from .embedding import MockEmbedding
2 | from .llm import MockLLM
3 | from .vectordb import MockVectorDB
4 |
--------------------------------------------------------------------------------
/src/vanna/mock/embedding.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class MockEmbedding(VannaBase):
7 | def __init__(self, config=None):
8 | pass
9 |
10 | def generate_embedding(self, data: str, **kwargs) -> List[float]:
11 | return [1.0, 2.0, 3.0, 4.0, 5.0]
12 |
--------------------------------------------------------------------------------
/src/vanna/mock/llm.py:
--------------------------------------------------------------------------------
1 |
2 | from ..base import VannaBase
3 |
4 |
5 | class MockLLM(VannaBase):
6 | def __init__(self, config=None):
7 | pass
8 |
9 | def system_message(self, message: str) -> any:
10 | return {"role": "system", "content": message}
11 |
12 | def user_message(self, message: str) -> any:
13 | return {"role": "user", "content": message}
14 |
15 | def assistant_message(self, message: str) -> any:
16 | return {"role": "assistant", "content": message}
17 |
18 | def submit_prompt(self, prompt, **kwargs) -> str:
19 | return "Mock LLM response"
20 |
--------------------------------------------------------------------------------
/src/vanna/mock/vectordb.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class MockVectorDB(VannaBase):
7 | def __init__(self, config=None):
8 | pass
9 |
10 | def _get_id(self, value: str, **kwargs) -> str:
11 | # Hash the value and return the ID
12 | return str(hash(value))
13 |
14 | def add_ddl(self, ddl: str, **kwargs) -> str:
15 | return self._get_id(ddl)
16 |
17 | def add_documentation(self, doc: str, **kwargs) -> str:
18 | return self._get_id(doc)
19 |
20 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
21 | return self._get_id(question)
22 |
23 | def get_related_ddl(self, question: str, **kwargs) -> list:
24 | return []
25 |
26 | def get_related_documentation(self, question: str, **kwargs) -> list:
27 | return []
28 |
29 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
30 | return []
31 |
32 | def get_training_data(self, **kwargs) -> pd.DataFrame:
33 | return pd.DataFrame({'id': {0: '19546-ddl',
34 | 1: '91597-sql',
35 | 2: '133976-sql',
36 | 3: '59851-doc',
37 | 4: '73046-sql'},
38 | 'training_data_type': {0: 'ddl',
39 | 1: 'sql',
40 | 2: 'sql',
41 | 3: 'documentation',
42 | 4: 'sql'},
43 | 'question': {0: None,
44 | 1: 'What are the top selling genres?',
45 | 2: 'What are the low 7 artists by sales?',
46 | 3: None,
47 | 4: 'What is the total sales for each customer?'},
48 | 'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
49 | 1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
50 | 2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
51 | 3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
52 | 4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})
53 |
54 | def remove_training_data(id: str, **kwargs) -> bool:
55 | return True
56 |
--------------------------------------------------------------------------------
/src/vanna/ollama/__init__.py:
--------------------------------------------------------------------------------
1 | from .ollama import Ollama
2 |
--------------------------------------------------------------------------------
/src/vanna/ollama/ollama.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 |
4 | from httpx import Timeout
5 |
6 | from ..base import VannaBase
7 | from ..exceptions import DependencyError
8 |
9 |
10 | class Ollama(VannaBase):
11 | def __init__(self, config=None):
12 |
13 | try:
14 | ollama = __import__("ollama")
15 | except ImportError:
16 | raise DependencyError(
17 | "You need to install required dependencies to execute this method, run command:"
18 | " \npip install ollama"
19 | )
20 |
21 | if not config:
22 | raise ValueError("config must contain at least Ollama model")
23 | if 'model' not in config.keys():
24 | raise ValueError("config must contain at least Ollama model")
25 | self.host = config.get("ollama_host", "http://localhost:11434")
26 | self.model = config["model"]
27 | if ":" not in self.model:
28 | self.model += ":latest"
29 |
30 | self.ollama_timeout = config.get("ollama_timeout", 240.0)
31 |
32 | self.ollama_client = ollama.Client(self.host, timeout=Timeout(self.ollama_timeout))
33 | self.keep_alive = config.get('keep_alive', None)
34 | self.ollama_options = config.get('options', {})
35 | self.num_ctx = self.ollama_options.get('num_ctx', 2048)
36 | self.__pull_model_if_ne(self.ollama_client, self.model)
37 |
38 | @staticmethod
39 | def __pull_model_if_ne(ollama_client, model):
40 | model_response = ollama_client.list()
41 | model_lists = [model_element['model'] for model_element in
42 | model_response.get('models', [])]
43 | if model not in model_lists:
44 | ollama_client.pull(model)
45 |
46 | def system_message(self, message: str) -> any:
47 | return {"role": "system", "content": message}
48 |
49 | def user_message(self, message: str) -> any:
50 | return {"role": "user", "content": message}
51 |
52 | def assistant_message(self, message: str) -> any:
53 | return {"role": "assistant", "content": message}
54 |
55 | def extract_sql(self, llm_response):
56 | """
57 | Extracts the first SQL statement after the word 'select', ignoring case,
58 | matches until the first semicolon, three backticks, or the end of the string,
59 | and removes three backticks if they exist in the extracted string.
60 |
61 | Args:
62 | - llm_response (str): The string to search within for an SQL statement.
63 |
64 | Returns:
65 | - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
66 | """
67 | # Remove ollama-generated extra characters
68 | llm_response = llm_response.replace("\\_", "_")
69 | llm_response = llm_response.replace("\\", "")
70 |
71 | # Regular expression to find ```sql' and capture until '```'
72 | sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
73 | # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string
74 | select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)',
75 | llm_response,
76 | re.IGNORECASE | re.DOTALL)
77 | if sql:
78 | self.log(
79 | f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
80 | return sql.group(1).replace("```", "")
81 | elif select_with:
82 | self.log(
83 | f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}")
84 | return select_with.group(0)
85 | else:
86 | return llm_response
87 |
88 | def submit_prompt(self, prompt, **kwargs) -> str:
89 | self.log(
90 | f"Ollama parameters:\n"
91 | f"model={self.model},\n"
92 | f"options={self.ollama_options},\n"
93 | f"keep_alive={self.keep_alive}")
94 | self.log(f"Prompt Content:\n{json.dumps(prompt, ensure_ascii=False)}")
95 | response_dict = self.ollama_client.chat(model=self.model,
96 | messages=prompt,
97 | stream=False,
98 | options=self.ollama_options,
99 | keep_alive=self.keep_alive)
100 |
101 | self.log(f"Ollama Response:\n{str(response_dict)}")
102 |
103 | return response_dict['message']['content']
104 |
--------------------------------------------------------------------------------
/src/vanna/openai/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai_chat import OpenAI_Chat
2 | from .openai_embeddings import OpenAI_Embeddings
3 |
--------------------------------------------------------------------------------
/src/vanna/openai/openai_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from openai import OpenAI
4 |
5 | from ..base import VannaBase
6 |
7 |
8 | class OpenAI_Chat(VannaBase):
9 | def __init__(self, client=None, config=None):
10 | VannaBase.__init__(self, config=config)
11 |
12 | # default parameters - can be overrided using config
13 | self.temperature = 0.7
14 |
15 | if "temperature" in config:
16 | self.temperature = config["temperature"]
17 |
18 | if "api_type" in config:
19 | raise Exception(
20 | "Passing api_type is now deprecated. Please pass an OpenAI client instead."
21 | )
22 |
23 | if "api_base" in config:
24 | raise Exception(
25 | "Passing api_base is now deprecated. Please pass an OpenAI client instead."
26 | )
27 |
28 | if "api_version" in config:
29 | raise Exception(
30 | "Passing api_version is now deprecated. Please pass an OpenAI client instead."
31 | )
32 |
33 | if client is not None:
34 | self.client = client
35 | return
36 |
37 | if config is None and client is None:
38 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
39 | return
40 |
41 | if "api_key" in config:
42 | self.client = OpenAI(api_key=config["api_key"])
43 |
44 | def system_message(self, message: str) -> any:
45 | return {"role": "system", "content": message}
46 |
47 | def user_message(self, message: str) -> any:
48 | return {"role": "user", "content": message}
49 |
50 | def assistant_message(self, message: str) -> any:
51 | return {"role": "assistant", "content": message}
52 |
53 | def submit_prompt(self, prompt, **kwargs) -> str:
54 | if prompt is None:
55 | raise Exception("Prompt is None")
56 |
57 | if len(prompt) == 0:
58 | raise Exception("Prompt is empty")
59 |
60 | # Count the number of tokens in the message log
61 | # Use 4 as an approximation for the number of characters per token
62 | num_tokens = 0
63 | for message in prompt:
64 | num_tokens += len(message["content"]) / 4
65 |
66 | if kwargs.get("model", None) is not None:
67 | model = kwargs.get("model", None)
68 | print(
69 | f"Using model {model} for {num_tokens} tokens (approx)"
70 | )
71 | response = self.client.chat.completions.create(
72 | model=model,
73 | messages=prompt,
74 | stop=None,
75 | temperature=self.temperature,
76 | )
77 | elif kwargs.get("engine", None) is not None:
78 | engine = kwargs.get("engine", None)
79 | print(
80 | f"Using model {engine} for {num_tokens} tokens (approx)"
81 | )
82 | response = self.client.chat.completions.create(
83 | engine=engine,
84 | messages=prompt,
85 | stop=None,
86 | temperature=self.temperature,
87 | )
88 | elif self.config is not None and "engine" in self.config:
89 | print(
90 | f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
91 | )
92 | response = self.client.chat.completions.create(
93 | engine=self.config["engine"],
94 | messages=prompt,
95 | stop=None,
96 | temperature=self.temperature,
97 | )
98 | elif self.config is not None and "model" in self.config:
99 | print(
100 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
101 | )
102 | response = self.client.chat.completions.create(
103 | model=self.config["model"],
104 | messages=prompt,
105 | stop=None,
106 | temperature=self.temperature,
107 | )
108 | else:
109 | if num_tokens > 3500:
110 | model = "gpt-3.5-turbo-16k"
111 | else:
112 | model = "gpt-3.5-turbo"
113 |
114 | print(f"Using model {model} for {num_tokens} tokens (approx)")
115 | response = self.client.chat.completions.create(
116 | model=model,
117 | messages=prompt,
118 | stop=None,
119 | temperature=self.temperature,
120 | )
121 |
122 | # Find the first response from the chatbot that has text in it (some responses may not have text)
123 | for choice in response.choices:
124 | if "text" in choice:
125 | return choice.text
126 |
127 | # If no response with text is found, return the first response's content (which may be empty)
128 | return response.choices[0].message.content
129 |
--------------------------------------------------------------------------------
/src/vanna/openai/openai_embeddings.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class OpenAI_Embeddings(VannaBase):
7 | def __init__(self, client=None, config=None):
8 | VannaBase.__init__(self, config=config)
9 |
10 | if client is not None:
11 | self.client = client
12 | return
13 |
14 | if self.client is not None:
15 | return
16 |
17 | self.client = OpenAI()
18 |
19 | if config is None:
20 | return
21 |
22 | if "api_type" in config:
23 | self.client.api_type = config["api_type"]
24 |
25 | if "api_base" in config:
26 | self.client.api_base = config["api_base"]
27 |
28 | if "api_version" in config:
29 | self.client.api_version = config["api_version"]
30 |
31 | if "api_key" in config:
32 | self.client.api_key = config["api_key"]
33 |
34 | def generate_embedding(self, data: str, **kwargs) -> list[float]:
35 | if self.config is not None and "engine" in self.config:
36 | embedding = self.client.embeddings.create(
37 | engine=self.config["engine"],
38 | input=data,
39 | )
40 | else:
41 | embedding = self.client.embeddings.create(
42 | model="text-embedding-ada-002",
43 | input=data,
44 | )
45 |
46 | return embedding.get("data")[0]["embedding"]
47 |
--------------------------------------------------------------------------------
/src/vanna/opensearch/__init__.py:
--------------------------------------------------------------------------------
1 | from .opensearch_vector import OpenSearch_VectorStore
2 | from .opensearch_vector_semantic import OpenSearch_Semantic_VectorStore
3 |
--------------------------------------------------------------------------------
/src/vanna/opensearch/opensearch_vector_semantic.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pandas as pd
4 | from langchain_community.vectorstores import OpenSearchVectorSearch
5 |
6 | from ..base import VannaBase
7 | from ..utils import deterministic_uuid
8 |
9 |
10 | class OpenSearch_Semantic_VectorStore(VannaBase):
11 | def __init__(self, config=None):
12 | VannaBase.__init__(self, config=config)
13 | if config is None:
14 | config = {}
15 |
16 | if "embedding_function" in config:
17 | self.embedding_function = config.get("embedding_function")
18 | else:
19 | from langchain_huggingface import HuggingFaceEmbeddings
20 | self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
21 |
22 | self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
23 | self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
24 | self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
25 |
26 | self.document_index = config.get("es_document_index", "vanna_document_index")
27 | self.ddl_index = config.get("es_ddl_index", "vanna_ddl_index")
28 | self.question_sql_index = config.get("es_question_sql_index", "vanna_questions_sql_index")
29 |
30 | self.log(f"OpenSearch_Semantic_VectorStore initialized with document_index: {self.document_index}, ddl_index: {self.ddl_index}, question_sql_index: {self.question_sql_index}")
31 |
32 | es_urls = config.get("es_urls", "https://localhost:9200")
33 | ssl = config.get("es_ssl", True)
34 | verify_certs = config.get("es_verify_certs", True)
35 |
36 | if "es_user" in config:
37 | auth = (config["es_user"], config["es_password"])
38 | else:
39 | auth = None
40 |
41 | headers = config.get("es_headers", None)
42 | timeout = config.get("es_timeout", 60)
43 | max_retries = config.get("es_max_retries", 10)
44 |
45 | common_args = {
46 | "opensearch_url": es_urls,
47 | "embedding_function": self.embedding_function,
48 | "engine": "faiss",
49 | "http_auth": auth,
50 | "use_ssl": ssl,
51 | "verify_certs": verify_certs,
52 | "timeout": timeout,
53 | "max_retries": max_retries,
54 | "retry_on_timeout": True,
55 | "headers": headers,
56 | }
57 |
58 | self.documentation_store = OpenSearchVectorSearch(index_name=self.document_index, **common_args)
59 | self.ddl_store = OpenSearchVectorSearch(index_name=self.ddl_index, **common_args)
60 | self.sql_store = OpenSearchVectorSearch(index_name=self.question_sql_index, **common_args)
61 |
62 | def add_ddl(self, ddl: str, **kwargs) -> str:
63 | _id = deterministic_uuid(ddl) + "-ddl"
64 | self.ddl_store.add_texts(texts=[ddl], ids=[_id], **kwargs)
65 | return _id
66 |
67 | def add_documentation(self, documentation: str, **kwargs) -> str:
68 | _id = deterministic_uuid(documentation) + "-doc"
69 | self.documentation_store.add_texts(texts=[documentation], ids=[_id], **kwargs)
70 | return _id
71 |
72 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
73 | question_sql_json = json.dumps(
74 | {
75 | "question": question,
76 | "sql": sql,
77 | },
78 | ensure_ascii=False,
79 | )
80 |
81 | _id = deterministic_uuid(question_sql_json) + "-sql"
82 | self.sql_store.add_texts(texts=[question_sql_json], ids=[_id], **kwargs)
83 | return _id
84 |
85 | def get_related_ddl(self, question: str, **kwargs) -> list:
86 | documents = self.ddl_store.similarity_search(query=question, k=self.n_results_ddl)
87 | return [document.page_content for document in documents]
88 |
89 | def get_related_documentation(self, question: str, **kwargs) -> list:
90 | documents = self.documentation_store.similarity_search(query=question, k=self.n_results_documentation)
91 | return [document.page_content for document in documents]
92 |
93 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
94 | documents = self.sql_store.similarity_search(query=question, k=self.n_results_sql)
95 | return [json.loads(document.page_content) for document in documents]
96 |
97 | def get_training_data(self, **kwargs) -> pd.DataFrame:
98 | data = []
99 | query = {
100 | "query": {
101 | "match_all": {}
102 | }
103 | }
104 |
105 | indices = [
106 | {"index": self.document_index, "type": "documentation"},
107 | {"index": self.question_sql_index, "type": "sql"},
108 | {"index": self.ddl_index, "type": "ddl"},
109 | ]
110 |
111 | # Use documentation_store.client consistently for search on all indices
112 | opensearch_client = self.documentation_store.client
113 |
114 | for index_info in indices:
115 | index_name = index_info["index"]
116 | training_data_type = index_info["type"]
117 | scroll = '1m' # keep scroll context for 1 minute
118 | response = opensearch_client.search(
119 | index=index_name,
120 | ignore_unavailable=True,
121 | body=query,
122 | scroll=scroll,
123 | size=1000
124 | )
125 |
126 | scroll_id = response.get('_scroll_id')
127 |
128 | while scroll_id:
129 | hits = response['hits']['hits']
130 | if not hits:
131 | break # No more hits, exit loop
132 |
133 | for hit in hits:
134 | source = hit['_source']
135 | if training_data_type == "sql":
136 | try:
137 | doc_dict = json.loads(source['text'])
138 | content = doc_dict.get("sql")
139 | question = doc_dict.get("question")
140 | except json.JSONDecodeError as e:
141 | self.log(f"Skipping row with custom_id {hit['_id']} due to JSON parsing error: {e}","Error")
142 | continue
143 | else: # documentation or ddl
144 | content = source['text']
145 | question = None
146 |
147 | data.append({
148 | "id": hit["_id"],
149 | "training_data_type": training_data_type,
150 | "question": question,
151 | "content": content,
152 | })
153 |
154 | # Get next batch of results, using documentation_store.client.scroll
155 | response = opensearch_client.scroll(scroll_id=scroll_id, scroll=scroll)
156 | scroll_id = response.get('_scroll_id')
157 |
158 | return pd.DataFrame(data)
159 |
160 | def remove_training_data(self, id: str, **kwargs) -> bool:
161 | try:
162 | if id.endswith("-sql"):
163 | return self.sql_store.delete(ids=[id], **kwargs)
164 | elif id.endswith("-ddl"):
165 | return self.ddl_store.delete(ids=[id], **kwargs)
166 | elif id.endswith("-doc"):
167 | return self.documentation_store.delete(ids=[id], **kwargs)
168 | else:
169 | return False
170 | except Exception as e:
171 | self.log(f"Error deleting training dataError deleting training data: {e}", "Error")
172 | return False
173 |
174 | def generate_embedding(self, data: str, **kwargs) -> list[float]:
175 | pass
176 |
--------------------------------------------------------------------------------
/src/vanna/oracle/__init__.py:
--------------------------------------------------------------------------------
1 | from .oracle_vector import Oracle_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/pgvector/__init__.py:
--------------------------------------------------------------------------------
1 | from .pgvector import PG_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/pgvector/pgvector.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import json
3 | import logging
4 | import uuid
5 |
6 | import pandas as pd
7 | from langchain_core.documents import Document
8 | from langchain_postgres.vectorstores import PGVector
9 | from sqlalchemy import create_engine, text
10 |
11 | from .. import ValidationError
12 | from ..base import VannaBase
13 | from ..types import TrainingPlan, TrainingPlanItem
14 |
15 |
16 | class PG_VectorStore(VannaBase):
17 | def __init__(self, config=None):
18 | if not config or "connection_string" not in config:
19 | raise ValueError(
20 | "A valid 'config' dictionary with a 'connection_string' is required.")
21 |
22 | VannaBase.__init__(self, config=config)
23 |
24 | if config and "connection_string" in config:
25 | self.connection_string = config.get("connection_string")
26 | self.n_results = config.get("n_results", 10)
27 |
28 | if config and "embedding_function" in config:
29 | self.embedding_function = config.get("embedding_function")
30 | else:
31 | from langchain_huggingface import HuggingFaceEmbeddings
32 | self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
33 |
34 | self.sql_collection = PGVector(
35 | embeddings=self.embedding_function,
36 | collection_name="sql",
37 | connection=self.connection_string,
38 | )
39 | self.ddl_collection = PGVector(
40 | embeddings=self.embedding_function,
41 | collection_name="ddl",
42 | connection=self.connection_string,
43 | )
44 | self.documentation_collection = PGVector(
45 | embeddings=self.embedding_function,
46 | collection_name="documentation",
47 | connection=self.connection_string,
48 | )
49 |
50 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
51 | question_sql_json = json.dumps(
52 | {
53 | "question": question,
54 | "sql": sql,
55 | },
56 | ensure_ascii=False,
57 | )
58 | id = str(uuid.uuid4()) + "-sql"
59 | createdat = kwargs.get("createdat")
60 | doc = Document(
61 | page_content=question_sql_json,
62 | metadata={"id": id, "createdat": createdat},
63 | )
64 | self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
65 |
66 | return id
67 |
68 | def add_ddl(self, ddl: str, **kwargs) -> str:
69 | _id = str(uuid.uuid4()) + "-ddl"
70 | doc = Document(
71 | page_content=ddl,
72 | metadata={"id": _id},
73 | )
74 | self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
75 | return _id
76 |
77 | def add_documentation(self, documentation: str, **kwargs) -> str:
78 | _id = str(uuid.uuid4()) + "-doc"
79 | doc = Document(
80 | page_content=documentation,
81 | metadata={"id": _id},
82 | )
83 | self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
84 | return _id
85 |
86 | def get_collection(self, collection_name):
87 | match collection_name:
88 | case "sql":
89 | return self.sql_collection
90 | case "ddl":
91 | return self.ddl_collection
92 | case "documentation":
93 | return self.documentation_collection
94 | case _:
95 | raise ValueError("Specified collection does not exist.")
96 |
97 | def get_similar_question_sql(self, question: str) -> list:
98 | documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
99 | return [ast.literal_eval(document.page_content) for document in documents]
100 |
101 | def get_related_ddl(self, question: str, **kwargs) -> list:
102 | documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
103 | return [document.page_content for document in documents]
104 |
105 | def get_related_documentation(self, question: str, **kwargs) -> list:
106 | documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
107 | return [document.page_content for document in documents]
108 |
109 | def train(
110 | self,
111 | question: str | None = None,
112 | sql: str | None = None,
113 | ddl: str | None = None,
114 | documentation: str | None = None,
115 | plan: TrainingPlan | None = None,
116 | createdat: str | None = None,
117 | ):
118 | if question and not sql:
119 | raise ValidationError("Please provide a SQL query.")
120 |
121 | if documentation:
122 | logging.info(f"Adding documentation: {documentation}")
123 | return self.add_documentation(documentation)
124 |
125 | if sql and question:
126 | return self.add_question_sql(question=question, sql=sql, createdat=createdat)
127 |
128 | if ddl:
129 | logging.info(f"Adding ddl: {ddl}")
130 | return self.add_ddl(ddl)
131 |
132 | if plan:
133 | for item in plan._plan:
134 | if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
135 | self.add_ddl(item.item_value)
136 | elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
137 | self.add_documentation(item.item_value)
138 | elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
139 | self.add_question_sql(question=item.item_name, sql=item.item_value)
140 |
141 | def get_training_data(self, **kwargs) -> pd.DataFrame:
142 | # Establishing the connection
143 | engine = create_engine(self.connection_string)
144 |
145 | # Querying the 'langchain_pg_embedding' table
146 | query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
147 | df_embedding = pd.read_sql(query_embedding, engine)
148 |
149 | # List to accumulate the processed rows
150 | processed_rows = []
151 |
152 | # Process each row in the DataFrame
153 | for _, row in df_embedding.iterrows():
154 | custom_id = row["cmetadata"]["id"]
155 | document = row["document"]
156 | training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
157 |
158 | if training_data_type == "sql":
159 | # Convert the document string to a dictionary
160 | try:
161 | doc_dict = ast.literal_eval(document)
162 | question = doc_dict.get("question")
163 | content = doc_dict.get("sql")
164 | except (ValueError, SyntaxError):
165 | logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
166 | continue
167 | elif training_data_type in ["documentation", "ddl"]:
168 | question = None # Default value for question
169 | content = document
170 | else:
171 | # If the suffix is not recognized, skip this row
172 | logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
173 | continue
174 |
175 | # Append the processed data to the list
176 | processed_rows.append(
177 | {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
178 | )
179 |
180 | # Create a DataFrame from the list of processed rows
181 | df_processed = pd.DataFrame(processed_rows)
182 |
183 | return df_processed
184 |
185 | def remove_training_data(self, id: str, **kwargs) -> bool:
186 | # Create the database engine
187 | engine = create_engine(self.connection_string)
188 |
189 | # SQL DELETE statement
190 | delete_statement = text(
191 | """
192 | DELETE FROM langchain_pg_embedding
193 | WHERE cmetadata ->> 'id' = :id
194 | """
195 | )
196 |
197 | # Connect to the database and execute the delete statement
198 | with engine.connect() as connection:
199 | # Start a transaction
200 | with connection.begin() as transaction:
201 | try:
202 | result = connection.execute(delete_statement, {"id": id})
203 | # Commit the transaction if the delete was successful
204 | transaction.commit()
205 | # Check if any row was deleted and return True or False accordingly
206 | return result.rowcount > 0
207 | except Exception as e:
208 | # Rollback the transaction in case of error
209 | logging.error(f"An error occurred: {e}")
210 | transaction.rollback()
211 | return False
212 |
213 | def remove_collection(self, collection_name: str) -> bool:
214 | engine = create_engine(self.connection_string)
215 |
216 | # Determine the suffix to look for based on the collection name
217 | suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
218 | suffix = suffix_map.get(collection_name)
219 |
220 | if not suffix:
221 | logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
222 | return False
223 |
224 | # SQL query to delete rows based on the condition
225 | query = text(
226 | f"""
227 | DELETE FROM langchain_pg_embedding
228 | WHERE cmetadata->>'id' LIKE '%{suffix}'
229 | """
230 | )
231 |
232 | # Execute the deletion within a transaction block
233 | with engine.connect() as connection:
234 | with connection.begin() as transaction:
235 | try:
236 | result = connection.execute(query)
237 | transaction.commit() # Explicitly commit the transaction
238 | if result.rowcount > 0:
239 | logging.info(
240 | f"Deleted {result.rowcount} rows from "
241 | f"langchain_pg_embedding where collection is {collection_name}."
242 | )
243 | return True
244 | else:
245 | logging.info(f"No rows deleted for collection {collection_name}.")
246 | return False
247 | except Exception as e:
248 | logging.error(f"An error occurred: {e}")
249 | transaction.rollback() # Rollback in case of error
250 | return False
251 |
252 | def generate_embedding(self, *args, **kwargs):
253 | pass
254 |
--------------------------------------------------------------------------------
/src/vanna/pinecone/__init__.py:
--------------------------------------------------------------------------------
1 | from .pinecone_vector import PineconeDB_VectorStore
2 |
3 | __all__ = ["PineconeDB_VectorStore"]
4 |
--------------------------------------------------------------------------------
/src/vanna/qdrant/__init__.py:
--------------------------------------------------------------------------------
1 | from .qdrant import Qdrant_VectorStore
2 |
3 | __all__ = ["Qdrant_VectorStore"]
4 |
--------------------------------------------------------------------------------
/src/vanna/qianfan/Qianfan_Chat.py:
--------------------------------------------------------------------------------
1 | import qianfan
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class Qianfan_Chat(VannaBase):
7 | def __init__(self, client=None, config=None):
8 | VannaBase.__init__(self, config=config)
9 |
10 | if "api_key" not in config:
11 | raise Exception("Missing api_key in config")
12 | self.api_key = config["api_key"]
13 |
14 | if "secret_key" not in config:
15 | raise Exception("Missing secret_key in config")
16 | self.secret_key = config["secret_key"]
17 |
18 | # default parameters - can be overrided using config
19 | self.temperature = 0.9
20 | self.max_tokens = 1024
21 |
22 | if "temperature" in config:
23 | self.temperature = config["temperature"]
24 |
25 | if "max_tokens" in config:
26 | self.max_tokens = config["max_tokens"]
27 |
28 | self.model = config["model"] if "model" in config else "ERNIE-Speed"
29 |
30 | if client is not None:
31 | self.client = client
32 | return
33 |
34 | self.client = qianfan.ChatCompletion(ak=self.api_key,
35 | sk=self.secret_key)
36 |
37 | def system_message(self, message: str) -> any:
38 | return {"role": "system", "content": message}
39 |
40 | def user_message(self, message: str) -> any:
41 | return {"role": "user", "content": message}
42 |
43 | def assistant_message(self, message: str) -> any:
44 | return {"role": "assistant", "content": message}
45 |
46 | def get_sql_prompt(
47 | self,
48 | initial_prompt: str,
49 | question: str,
50 | question_sql_list: list,
51 | ddl_list: list,
52 | doc_list: list,
53 | **kwargs,
54 | ):
55 | """
56 | Example:
57 | ```python
58 | vn.get_sql_prompt(
59 | question="What are the top 10 customers by sales?",
60 | question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
61 | ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
62 | doc_list=["The customers table contains information about customers and their sales."],
63 | )
64 |
65 | ```
66 |
67 | This method is used to generate a prompt for the LLM to generate SQL.
68 |
69 | Args:
70 | question (str): The question to generate SQL for.
71 | question_sql_list (list): A list of questions and their corresponding SQL statements.
72 | ddl_list (list): A list of DDL statements.
73 | doc_list (list): A list of documentation.
74 |
75 | Returns:
76 | any: The prompt for the LLM to generate SQL.
77 | """
78 |
79 | if initial_prompt is None:
80 | initial_prompt = f"You are a {self.dialect} expert. " + \
81 | "Please help to generate a SQL to answer the question based on some context.Please don't give any explanation for your answer. Just only generate a SQL \n"
82 |
83 | initial_prompt = self.add_ddl_to_prompt(
84 | initial_prompt, ddl_list, max_tokens=self.max_tokens
85 | )
86 |
87 | if self.static_documentation != "":
88 | doc_list.append(self.static_documentation)
89 |
90 | initial_prompt = self.add_documentation_to_prompt(
91 | initial_prompt, doc_list, max_tokens=self.max_tokens
92 | )
93 | message_log = []
94 |
95 | if question_sql_list is None or len(question_sql_list) == 0:
96 | initial_prompt = initial_prompt + f"question: {question}"
97 | message_log.append(self.user_message(initial_prompt))
98 | else:
99 | for i, example in question_sql_list:
100 | if example is None:
101 | print("example is None")
102 | else:
103 | if example is not None and "question" in example and "sql" in example:
104 | if i == 0:
105 | initial_prompt = initial_prompt + f"question: {example['question']}"
106 | message_log.append(self.user_message(initial_prompt))
107 | else:
108 | message_log.append(self.user_message(example["question"]))
109 | message_log.append(self.assistant_message(example["sql"]))
110 |
111 | message_log.append(self.user_message(question))
112 | return message_log
113 |
114 | def submit_prompt(self, prompt, **kwargs) -> str:
115 | if prompt is None:
116 | raise Exception("Prompt is None")
117 |
118 | if len(prompt) == 0:
119 | raise Exception("Prompt is empty")
120 |
121 | # Count the number of tokens in the message log
122 | # Use 4 as an approximation for the number of characters per token
123 | num_tokens = 0
124 | for message in prompt:
125 | num_tokens += len(message["content"]) / 4
126 |
127 | if kwargs.get("model", None) is not None:
128 | model = kwargs.get("model", None)
129 | print(
130 | f"Using model {model} for {num_tokens} tokens (approx)"
131 | )
132 | response = self.client.do(
133 | model=self.model,
134 | messages=prompt,
135 | max_output_tokens=self.max_tokens,
136 | stop=None,
137 | temperature=self.temperature,
138 | )
139 | elif self.config is not None and "model" in self.config:
140 | print(
141 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
142 | )
143 | response = self.client.do(
144 | model=self.config.get("model"),
145 | messages=prompt,
146 | max_output_tokens=self.max_tokens,
147 | stop=None,
148 | temperature=self.temperature,
149 | )
150 | else:
151 | if num_tokens > 3500:
152 | model = "ERNIE-Speed-128K"
153 | else:
154 | model = "ERNIE-Speed-8K"
155 |
156 | print(f"Using model {model} for {num_tokens} tokens (approx)")
157 | response = self.client.do(
158 | model=model,
159 | messages=prompt,
160 | max_output_tokens=self.max_tokens,
161 | stop=None,
162 | temperature=self.temperature,
163 | )
164 |
165 | return response.body.get("result")
166 |
--------------------------------------------------------------------------------
/src/vanna/qianfan/Qianfan_embeddings.py:
--------------------------------------------------------------------------------
1 | import qianfan
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class Qianfan_Embeddings(VannaBase):
7 | def __init__(self, client=None, config=None):
8 | VannaBase.__init__(self, config=config)
9 |
10 | if client is not None:
11 | self.client = client
12 | return
13 |
14 | if "api_key" not in config:
15 | raise Exception("Missing api_key in config")
16 | self.api_key = config["api_key"]
17 |
18 | if "secret_key" not in config:
19 | raise Exception("Missing secret_key in config")
20 | self.secret_key = config["secret_key"]
21 |
22 | self.client = qianfan.Embedding(ak=self.api_key, sk=self.secret_key)
23 |
24 | def generate_embedding(self, data: str, **kwargs) -> list[float]:
25 | if self.config is not None and "model" in self.config:
26 | embedding = self.client.do(
27 | model=self.config["model"],
28 | input=[data],
29 | )
30 | else:
31 | embedding = self.client.do(
32 | model="bge-large-zh",
33 | input=[data],
34 | )
35 |
36 | return embedding.get("data")[0]["embedding"]
37 |
--------------------------------------------------------------------------------
/src/vanna/qianfan/__init__.py:
--------------------------------------------------------------------------------
1 | from .Qianfan_Chat import Qianfan_Chat
2 | from .Qianfan_embeddings import Qianfan_Embeddings
3 |
--------------------------------------------------------------------------------
/src/vanna/qianwen/QianwenAI_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from openai import OpenAI
4 |
5 | from ..base import VannaBase
6 |
7 |
8 | class QianWenAI_Chat(VannaBase):
9 | def __init__(self, client=None, config=None):
10 | VannaBase.__init__(self, config=config)
11 |
12 | # default parameters - can be overrided using config
13 | self.temperature = 0.7
14 |
15 | if "temperature" in config:
16 | self.temperature = config["temperature"]
17 |
18 | if "api_type" in config:
19 | raise Exception(
20 | "Passing api_type is now deprecated. Please pass an OpenAI client instead."
21 | )
22 |
23 | if "api_base" in config:
24 | raise Exception(
25 | "Passing api_base is now deprecated. Please pass an OpenAI client instead."
26 | )
27 |
28 | if "api_version" in config:
29 | raise Exception(
30 | "Passing api_version is now deprecated. Please pass an OpenAI client instead."
31 | )
32 |
33 | if client is not None:
34 | self.client = client
35 | return
36 |
37 | if config is None and client is None:
38 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
39 | return
40 |
41 | if "api_key" in config:
42 | if "base_url" not in config:
43 | self.client = OpenAI(api_key=config["api_key"],
44 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
45 | else:
46 | self.client = OpenAI(api_key=config["api_key"],
47 | base_url=config["base_url"])
48 |
49 | def system_message(self, message: str) -> any:
50 | return {"role": "system", "content": message}
51 |
52 | def user_message(self, message: str) -> any:
53 | return {"role": "user", "content": message}
54 |
55 | def assistant_message(self, message: str) -> any:
56 | return {"role": "assistant", "content": message}
57 |
58 | def submit_prompt(self, prompt, **kwargs) -> str:
59 | if prompt is None:
60 | raise Exception("Prompt is None")
61 |
62 | if len(prompt) == 0:
63 | raise Exception("Prompt is empty")
64 |
65 | # Count the number of tokens in the message log
66 | # Use 4 as an approximation for the number of characters per token
67 | num_tokens = 0
68 | for message in prompt:
69 | num_tokens += len(message["content"]) / 4
70 |
71 | if kwargs.get("model", None) is not None:
72 | model = kwargs.get("model", None)
73 | print(
74 | f"Using model {model} for {num_tokens} tokens (approx)"
75 | )
76 | response = self.client.chat.completions.create(
77 | model=model,
78 | messages=prompt,
79 | stop=None,
80 | temperature=self.temperature,
81 | )
82 | elif kwargs.get("engine", None) is not None:
83 | engine = kwargs.get("engine", None)
84 | print(
85 | f"Using model {engine} for {num_tokens} tokens (approx)"
86 | )
87 | response = self.client.chat.completions.create(
88 | engine=engine,
89 | messages=prompt,
90 | stop=None,
91 | temperature=self.temperature,
92 | )
93 | elif self.config is not None and "engine" in self.config:
94 | print(
95 | f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
96 | )
97 | response = self.client.chat.completions.create(
98 | engine=self.config["engine"],
99 | messages=prompt,
100 | stop=None,
101 | temperature=self.temperature,
102 | )
103 | elif self.config is not None and "model" in self.config:
104 | print(
105 | f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
106 | )
107 | response = self.client.chat.completions.create(
108 | model=self.config["model"],
109 | messages=prompt,
110 | stop=None,
111 | temperature=self.temperature,
112 | )
113 | else:
114 | if num_tokens > 3500:
115 | model = "qwen-long"
116 | else:
117 | model = "qwen-plus"
118 |
119 | print(f"Using model {model} for {num_tokens} tokens (approx)")
120 | response = self.client.chat.completions.create(
121 | model=model,
122 | messages=prompt,
123 | stop=None,
124 | temperature=self.temperature,
125 | )
126 |
127 | # Find the first response from the chatbot that has text in it (some responses may not have text)
128 | for choice in response.choices:
129 | if "text" in choice:
130 | return choice.text
131 |
132 | # If no response with text is found, return the first response's content (which may be empty)
133 | return response.choices[0].message.content
134 |
--------------------------------------------------------------------------------
/src/vanna/qianwen/QianwenAI_embeddings.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | from ..base import VannaBase
4 |
5 |
6 | class QianWenAI_Embeddings(VannaBase):
7 | def __init__(self, client=None, config=None):
8 | VannaBase.__init__(self, config=config)
9 |
10 | if client is not None:
11 | self.client = client
12 | return
13 |
14 | if self.client is not None:
15 | return
16 |
17 | self.client = OpenAI()
18 |
19 | if config is None:
20 | return
21 |
22 | if "api_type" in config:
23 | self.client.api_type = config["api_type"]
24 |
25 | if "api_base" in config:
26 | self.client.api_base = config["api_base"]
27 |
28 | if "api_version" in config:
29 | self.client.api_version = config["api_version"]
30 |
31 | if "api_key" in config:
32 | self.client.api_key = config["api_key"]
33 |
34 | def generate_embedding(self, data: str, **kwargs) -> list[float]:
35 | if self.config is not None and "engine" in self.config:
36 | embedding = self.client.embeddings.create(
37 | engine=self.config["engine"],
38 | input=data,
39 | )
40 | else:
41 | embedding = self.client.embeddings.create(
42 | model="bge-large-zh",
43 | input=data,
44 | )
45 |
46 | return embedding.get("data")[0]["embedding"]
47 |
--------------------------------------------------------------------------------
/src/vanna/qianwen/__init__.py:
--------------------------------------------------------------------------------
1 | from .QianwenAI_chat import QianWenAI_Chat
2 | from .QianwenAI_embeddings import QianWenAI_Embeddings
3 |
--------------------------------------------------------------------------------
/src/vanna/remote.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 | from io import StringIO
4 | from typing import Callable, List, Tuple, Union
5 |
6 | import pandas as pd
7 | import requests
8 |
9 | from .base import VannaBase
10 | from .types import (
11 | AccuracyStats,
12 | ApiKey,
13 | DataFrameJSON,
14 | DataResult,
15 | Explanation,
16 | FullQuestionDocument,
17 | NewOrganization,
18 | NewOrganizationMember,
19 | Organization,
20 | OrganizationList,
21 | PlotlyResult,
22 | Question,
23 | QuestionCategory,
24 | QuestionId,
25 | QuestionList,
26 | QuestionSQLPair,
27 | QuestionStringList,
28 | SQLAnswer,
29 | Status,
30 | StatusWithId,
31 | StringData,
32 | TrainingData,
33 | UserEmail,
34 | UserOTP,
35 | Visibility,
36 | )
37 | from .vannadb import VannaDB_VectorStore
38 |
39 |
40 | class VannaDefault(VannaDB_VectorStore):
41 | def __init__(self, model: str, api_key: str, config=None):
42 | VannaBase.__init__(self, config=config)
43 | VannaDB_VectorStore.__init__(self, vanna_model=model, vanna_api_key=api_key, config=config)
44 |
45 | self._model = model
46 | self._api_key = api_key
47 |
48 | self._endpoint = (
49 | "https://ask.vanna.ai/rpc"
50 | if config is None or "endpoint" not in config
51 | else config["endpoint"]
52 | )
53 |
54 | def system_message(self, message: str) -> any:
55 | return {"role": "system", "content": message}
56 |
57 | def user_message(self, message: str) -> any:
58 | return {"role": "user", "content": message}
59 |
60 | def assistant_message(self, message: str) -> any:
61 | return {"role": "assistant", "content": message}
62 |
63 | def submit_prompt(self, prompt, **kwargs) -> str:
64 | # JSON-ify the prompt
65 | json_prompt = json.dumps(prompt, ensure_ascii=False)
66 |
67 | params = [StringData(data=json_prompt)]
68 |
69 | d = self._rpc_call(method="submit_prompt", params=params)
70 |
71 | if "result" not in d:
72 | return None
73 |
74 | # Load the result into a dataclass
75 | results = StringData(**d["result"])
76 |
77 | return results.data
78 |
--------------------------------------------------------------------------------
/src/vanna/types/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import Dict, List, Union
5 |
6 |
7 | @dataclass
8 | class Status:
9 | success: bool
10 | message: str
11 |
12 |
13 | @dataclass
14 | class StatusWithId:
15 | success: bool
16 | message: str
17 | id: str
18 |
19 |
20 | @dataclass
21 | class QuestionList:
22 | questions: List[FullQuestionDocument]
23 |
24 |
25 | @dataclass
26 | class FullQuestionDocument:
27 | id: QuestionId
28 | question: Question
29 | answer: SQLAnswer | None
30 | data: DataResult | None
31 | plotly: PlotlyResult | None
32 |
33 |
34 | @dataclass
35 | class QuestionSQLPair:
36 | question: str
37 | sql: str
38 | tag: Union[str, None]
39 |
40 |
41 | @dataclass
42 | class Organization:
43 | name: str
44 | user: str | None
45 | connection: Connection | None
46 |
47 |
48 | @dataclass
49 | class OrganizationList:
50 | organizations: List[str]
51 |
52 |
53 | @dataclass
54 | class QuestionStringList:
55 | questions: List[str]
56 |
57 |
58 | @dataclass
59 | class Visibility:
60 | visibility: bool
61 |
62 |
63 | @dataclass
64 | class UserEmail:
65 | email: str
66 |
67 |
68 | @dataclass
69 | class NewOrganization:
70 | org_name: str
71 | db_type: str
72 |
73 |
74 | @dataclass
75 | class NewOrganizationMember:
76 | org_name: str
77 | email: str
78 | is_admin: bool
79 |
80 |
81 | @dataclass
82 | class UserOTP:
83 | email: str
84 | otp: str
85 |
86 |
87 | @dataclass
88 | class ApiKey:
89 | key: str
90 |
91 |
92 | @dataclass
93 | class QuestionId:
94 | id: str
95 |
96 |
97 | @dataclass
98 | class Question:
99 | question: str
100 |
101 |
102 | @dataclass
103 | class QuestionCategory:
104 | question: str
105 | category: str
106 |
107 | NO_SQL_GENERATED = "No SQL Generated"
108 | SQL_UNABLE_TO_RUN = "SQL Unable to Run"
109 | BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query"
110 | SQL_RAN = "SQL Ran Successfully"
111 | FLAGGED_FOR_REVIEW = "Flagged for Review"
112 | REVIEWED_AND_APPROVED = "Reviewed and Approved"
113 | REVIEWED_AND_REJECTED = "Reviewed and Rejected"
114 | REVIEWED_AND_UPDATED = "Reviewed and Updated"
115 |
116 |
117 | @dataclass
118 | class AccuracyStats:
119 | num_questions: int
120 | data: Dict[str, int]
121 |
122 |
123 | @dataclass
124 | class Followup:
125 | followup: str
126 |
127 |
128 | @dataclass
129 | class QuestionEmbedding:
130 | question: Question
131 | embedding: List[float]
132 |
133 |
134 | @dataclass
135 | class Connection:
136 | # TODO: implement
137 | pass
138 |
139 |
140 | @dataclass
141 | class SQLAnswer:
142 | raw_answer: str
143 | prefix: str
144 | postfix: str
145 | sql: str
146 |
147 |
148 | @dataclass
149 | class Explanation:
150 | explanation: str
151 |
152 |
153 | @dataclass
154 | class DataResult:
155 | question: str | None
156 | sql: str | None
157 | table_markdown: str
158 | error: str | None
159 | correction_attempts: int
160 |
161 |
162 | @dataclass
163 | class PlotlyResult:
164 | plotly_code: str
165 |
166 |
167 | @dataclass
168 | class WarehouseDefinition:
169 | name: str
170 | tables: List[TableDefinition]
171 |
172 |
173 | @dataclass
174 | class TableDefinition:
175 | schema_name: str
176 | table_name: str
177 | ddl: str | None
178 | columns: List[ColumnDefinition]
179 |
180 |
181 | @dataclass
182 | class ColumnDefinition:
183 | name: str
184 | type: str
185 | is_primary_key: bool
186 | is_foreign_key: bool
187 | foreign_key_table: str
188 | foreign_key_column: str
189 |
190 |
191 | @dataclass
192 | class Diagram:
193 | raw: str
194 | mermaid_code: str
195 |
196 |
197 | @dataclass
198 | class StringData:
199 | data: str
200 |
201 |
202 | @dataclass
203 | class DataFrameJSON:
204 | data: str
205 |
206 |
207 | @dataclass
208 | class TrainingData:
209 | questions: List[dict]
210 | ddl: List[str]
211 | documentation: List[str]
212 |
213 |
214 | @dataclass
215 | class TrainingPlanItem:
216 | item_type: str
217 | item_group: str
218 | item_name: str
219 | item_value: str
220 |
221 | def __str__(self):
222 | if self.item_type == self.ITEM_TYPE_SQL:
223 | return f"Train on SQL: {self.item_group} {self.item_name}"
224 | elif self.item_type == self.ITEM_TYPE_DDL:
225 | return f"Train on DDL: {self.item_group} {self.item_name}"
226 | elif self.item_type == self.ITEM_TYPE_IS:
227 | return f"Train on Information Schema: {self.item_group} {self.item_name}"
228 |
229 | ITEM_TYPE_SQL = "sql"
230 | ITEM_TYPE_DDL = "ddl"
231 | ITEM_TYPE_IS = "is"
232 |
233 |
234 | class TrainingPlan:
235 | """
236 | A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained.
237 |
238 | **Example:**
239 | ```python
240 | plan = vn.get_training_plan()
241 |
242 | plan.get_summary()
243 | ```
244 |
245 | """
246 |
247 | _plan: List[TrainingPlanItem]
248 |
249 | def __init__(self, plan: List[TrainingPlanItem]):
250 | self._plan = plan
251 |
252 | def __str__(self):
253 | return "\n".join(self.get_summary())
254 |
255 | def __repr__(self):
256 | return self.__str__()
257 |
258 | def get_summary(self) -> List[str]:
259 | """
260 | **Example:**
261 | ```python
262 | plan = vn.get_training_plan()
263 |
264 | plan.get_summary()
265 | ```
266 |
267 | Get a summary of the training plan.
268 |
269 | Returns:
270 | List[str]: A list of strings describing the training plan.
271 | """
272 |
273 | return [f"{item}" for item in self._plan]
274 |
275 | def remove_item(self, item: str):
276 | """
277 | **Example:**
278 | ```python
279 | plan = vn.get_training_plan()
280 |
281 | plan.remove_item("Train on SQL: What is the average salary of employees?")
282 | ```
283 |
284 | Remove an item from the training plan.
285 |
286 | Args:
287 | item (str): The item to remove.
288 | """
289 | for plan_item in self._plan:
290 | if str(plan_item) == item:
291 | self._plan.remove(plan_item)
292 | break
293 |
--------------------------------------------------------------------------------
/src/vanna/utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import re
4 | import uuid
5 | from typing import Union
6 |
7 | from .exceptions import ImproperlyConfigured, ValidationError
8 |
9 |
10 | def validate_config_path(path):
11 | if not os.path.exists(path):
12 | raise ImproperlyConfigured(
13 | f'No such configuration file: {path}'
14 | )
15 |
16 | if not os.path.isfile(path):
17 | raise ImproperlyConfigured(
18 | f'Config should be a file: {path}'
19 | )
20 |
21 | if not os.access(path, os.R_OK):
22 | raise ImproperlyConfigured(
23 | f'Cannot read the config file. Please grant read privileges: {path}'
24 | )
25 |
26 |
27 | def sanitize_model_name(model_name):
28 | try:
29 | model_name = model_name.lower()
30 |
31 | # Replace spaces with a hyphen
32 | model_name = model_name.replace(" ", "-")
33 |
34 | if '-' in model_name:
35 |
36 | # remove double hyphones
37 | model_name = re.sub(r"-+", "-", model_name)
38 | if '_' in model_name:
39 | # If name contains both underscores and hyphen replace all underscores with hyphens
40 | model_name = re.sub(r'_', '-', model_name)
41 |
42 | # Remove special characters only allow underscore
43 | model_name = re.sub(r"[^a-zA-Z0-9-_]", "", model_name)
44 |
45 | # Remove hyphen or underscore if any at the last or first
46 | if model_name[-1] in ("-", "_"):
47 | model_name = model_name[:-1]
48 | if model_name[0] in ("-", "_"):
49 | model_name = model_name[1:]
50 |
51 | return model_name
52 | except Exception as e:
53 | raise ValidationError(e)
54 |
55 |
56 | def deterministic_uuid(content: Union[str, bytes]) -> str:
57 | """Creates deterministic UUID on hash value of string or byte content.
58 |
59 | Args:
60 | content: String or byte representation of data.
61 |
62 | Returns:
63 | UUID of the content.
64 | """
65 | if isinstance(content, str):
66 | content_bytes = content.encode("utf-8")
67 | elif isinstance(content, bytes):
68 | content_bytes = content
69 | else:
70 | raise ValueError(f"Content type {type(content)} not supported !")
71 |
72 | hash_object = hashlib.sha256(content_bytes)
73 | hash_hex = hash_object.hexdigest()
74 | namespace = uuid.UUID("00000000-0000-0000-0000-000000000000")
75 | content_uuid = str(uuid.uuid5(namespace, hash_hex))
76 |
77 | return content_uuid
78 |
--------------------------------------------------------------------------------
/src/vanna/vannadb/__init__.py:
--------------------------------------------------------------------------------
1 | from .vannadb_vector import VannaDB_VectorStore
2 |
--------------------------------------------------------------------------------
/src/vanna/vllm/__init__.py:
--------------------------------------------------------------------------------
1 | from .vllm import Vllm
2 |
--------------------------------------------------------------------------------
/src/vanna/vllm/vllm.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import requests
4 |
5 | from ..base import VannaBase
6 |
7 |
8 | class Vllm(VannaBase):
9 | def __init__(self, config=None):
10 | if config is None or "vllm_host" not in config:
11 | self.host = "http://localhost:8000"
12 | else:
13 | self.host = config["vllm_host"]
14 |
15 | if config is None or "model" not in config:
16 | raise ValueError("check the config for vllm")
17 | else:
18 | self.model = config["model"]
19 |
20 | if "auth-key" in config:
21 | self.auth_key = config["auth-key"]
22 | else:
23 | self.auth_key = None
24 |
25 | if "temperature" in config:
26 | self.temperature = config["temperature"]
27 | else:
28 | # default temperature - can be overrided using config
29 | self.temperature = 0.7
30 |
31 | def system_message(self, message: str) -> any:
32 | return {"role": "system", "content": message}
33 |
34 | def user_message(self, message: str) -> any:
35 | return {"role": "user", "content": message}
36 |
37 | def assistant_message(self, message: str) -> any:
38 | return {"role": "assistant", "content": message}
39 |
40 | def extract_sql_query(self, text):
41 | """
42 | Extracts the first SQL statement after the word 'select', ignoring case,
43 | matches until the first semicolon, three backticks, or the end of the string,
44 | and removes three backticks if they exist in the extracted string.
45 |
46 | Args:
47 | - text (str): The string to search within for an SQL statement.
48 |
49 | Returns:
50 | - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
51 | """
52 | # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
53 | pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
54 |
55 | match = pattern.search(text)
56 | if match:
57 | # Remove three backticks from the matched string if they exist
58 | return match.group(0).replace("```", "")
59 | else:
60 | return text
61 |
62 | def generate_sql(self, question: str, **kwargs) -> str:
63 | # Use the super generate_sql
64 | sql = super().generate_sql(question, **kwargs)
65 |
66 | # Replace "\_" with "_"
67 | sql = sql.replace("\\_", "_")
68 |
69 | sql = sql.replace("\\", "")
70 |
71 | return self.extract_sql_query(sql)
72 |
73 | def submit_prompt(self, prompt, **kwargs) -> str:
74 | url = f"{self.host}/v1/chat/completions"
75 | data = {
76 | "model": self.model,
77 | "temperature": self.temperature,
78 | "stream": False,
79 | "messages": prompt,
80 | }
81 |
82 | if self.auth_key is not None:
83 | headers = {
84 | 'Content-Type': 'application/json',
85 | 'Authorization': f'Bearer {self.auth_key}'
86 | }
87 |
88 | response = requests.post(url, headers=headers,json=data)
89 |
90 |
91 | else:
92 | response = requests.post(url, json=data)
93 |
94 | response_dict = response.json()
95 |
96 | self.log(response.text)
97 |
98 | return response_dict['choices'][0]['message']['content']
99 |
--------------------------------------------------------------------------------
/src/vanna/weaviate/__init__.py:
--------------------------------------------------------------------------------
1 | from .weaviate_vector import WeaviateDatabase
2 |
--------------------------------------------------------------------------------
/src/vanna/weaviate/weaviate_vector.py:
--------------------------------------------------------------------------------
1 | import weaviate
2 | import weaviate.classes as wvc
3 | from fastembed import TextEmbedding
4 |
5 | from vanna.base import VannaBase
6 |
7 |
8 | class WeaviateDatabase(VannaBase):
9 |
10 | def __init__(self, config=None):
11 | """
12 | Initialize the VannaEnhanced class with the provided configuration.
13 |
14 | :param config: Dictionary containing configuration parameters.
15 |
16 | params:
17 | weaviate_url (str): Weaviate cluster URL while using weaviate cloud,
18 | weaviate_api_key (str): Weaviate API key while using weaviate cloud,
19 | weaviate_port (num): Weaviate port while using local weaviate,
20 | weaviate_grpc (num): Weaviate gRPC port while using local weaviate,
21 | fastembed_model (str): Fastembed model name for text embeddings. BAAI/bge-small-en-v1.5 by default.
22 |
23 | """
24 | super().__init__(config=config)
25 |
26 | if config is None:
27 | raise ValueError("config is required")
28 |
29 | self.n_results = config.get("n_results", 3)
30 | self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
31 | self.weaviate_api_key = config.get("weaviate_api_key")
32 | self.weaviate_url = config.get("weaviate_url")
33 | self.weaviate_port = config.get("weaviate_port")
34 | self.weaviate_grpc_port = config.get("weaviate_grpc", 50051)
35 |
36 | if not self.weaviate_api_key and not self.weaviate_port:
37 | raise ValueError("Add proper credentials to connect to weaviate")
38 |
39 | self.weaviate_client = self._initialize_weaviate_client()
40 | self.embeddings = TextEmbedding(model_name=self.fastembed_model)
41 |
42 | self.training_data_cluster = {
43 | "sql": "SQLTrainingDataEntry",
44 | "ddl": "DDLEntry",
45 | "doc": "DocumentationEntry"
46 | }
47 |
48 | self._create_collections_if_not_exist()
49 |
50 | def _create_collections_if_not_exist(self):
51 | properties_dict = {
52 | self.training_data_cluster['ddl']: [
53 | wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
54 | ],
55 | self.training_data_cluster['doc']: [
56 | wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
57 | ],
58 | self.training_data_cluster['sql']: [
59 | wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT),
60 | wvc.config.Property(name="natural_language_question", data_type=wvc.config.DataType.TEXT),
61 | ]
62 | }
63 |
64 | for cluster, properties in properties_dict.items():
65 | if not self.weaviate_client.collections.exists(cluster):
66 | self.weaviate_client.collections.create(
67 | name=cluster,
68 | properties=properties
69 | )
70 |
71 | def _initialize_weaviate_client(self):
72 | if self.weaviate_api_key:
73 | return weaviate.connect_to_wcs(
74 | cluster_url=self.weaviate_url,
75 | auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key),
76 | additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
77 | skip_init_checks=True
78 | )
79 | else:
80 | return weaviate.connect_to_local(
81 | port=self.weaviate_port,
82 | grpc_port=self.weaviate_grpc_port,
83 | additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
84 | skip_init_checks=True
85 | )
86 |
87 | def generate_embedding(self, data: str, **kwargs):
88 | embedding_model = TextEmbedding(model_name=self.fastembed_model)
89 | embedding = next(embedding_model.embed(data))
90 | return embedding.tolist()
91 |
92 |
93 | def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str:
94 | self.weaviate_client.connect()
95 | response = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]).data.insert(
96 | properties=data_object,
97 | vector=vector
98 | )
99 | self.weaviate_client.close()
100 | return response
101 |
102 | def add_ddl(self, ddl: str, **kwargs) -> str:
103 | data_object = {
104 | "description": ddl,
105 | }
106 | response = self._insert_data('ddl', data_object, self.generate_embedding(ddl))
107 | return f'{response}-ddl'
108 |
109 | def add_documentation(self, doc: str, **kwargs) -> str:
110 | data_object = {
111 | "description": doc,
112 | }
113 | response = self._insert_data('doc', data_object, self.generate_embedding(doc))
114 | return f'{response}-doc'
115 |
116 | def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
117 | data_object = {
118 | "sql": sql,
119 | "natural_language_question": question,
120 | }
121 | response = self._insert_data('sql', data_object, self.generate_embedding(question))
122 | return f'{response}-sql'
123 |
124 | def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list) -> list:
125 | self.weaviate_client.connect()
126 | collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key])
127 | response = collection.query.near_vector(
128 | near_vector=vector_input,
129 | limit=self.n_results,
130 | return_properties=return_properties
131 | )
132 | response_list = [item.properties for item in response.objects]
133 | self.weaviate_client.close()
134 | return response_list
135 |
136 | def get_related_ddl(self, question: str, **kwargs) -> list:
137 | vector_input = self.generate_embedding(question)
138 | response_list = self._query_collection('ddl', vector_input, ["description"])
139 | return [item["description"] for item in response_list]
140 |
141 | def get_related_documentation(self, question: str, **kwargs) -> list:
142 | vector_input = self.generate_embedding(question)
143 | response_list = self._query_collection('doc', vector_input, ["description"])
144 | return [item["description"] for item in response_list]
145 |
146 | def get_similar_question_sql(self, question: str, **kwargs) -> list:
147 | vector_input = self.generate_embedding(question)
148 | response_list = self._query_collection('sql', vector_input, ["sql", "natural_language_question"])
149 | return [{"question": item["natural_language_question"], "sql": item["sql"]} for item in response_list]
150 |
151 | def get_training_data(self, **kwargs) -> list:
152 | self.weaviate_client.connect()
153 | combined_response_list = []
154 | for collection_name in self.training_data_cluster.values():
155 | if self.weaviate_client.collections.exists(collection_name):
156 | collection = self.weaviate_client.collections.get(collection_name)
157 | response_list = [item.properties for item in collection.iterator()]
158 | combined_response_list.extend(response_list)
159 | self.weaviate_client.close()
160 | return combined_response_list
161 |
162 | def remove_training_data(self, id: str, **kwargs) -> bool:
163 | self.weaviate_client.connect()
164 | success = False
165 | if id.endswith("-sql"):
166 | id = id.replace('-sql', '')
167 | success = self.weaviate_client.collections.get(self.training_data_cluster['sql']).data.delete_by_id(id)
168 | elif id.endswith("-ddl"):
169 | id = id.replace('-ddl', '')
170 | success = self.weaviate_client.collections.get(self.training_data_cluster['ddl']).data.delete_by_id(id)
171 | elif id.endswith("-doc"):
172 | id = id.replace('-doc', '')
173 | success = self.weaviate_client.collections.get(self.training_data_cluster['doc']).data.delete_by_id(id)
174 | self.weaviate_client.close()
175 | return success
176 |
--------------------------------------------------------------------------------
/src/vanna/xinference/__init__.py:
--------------------------------------------------------------------------------
1 | from .xinference import Xinference
2 |
--------------------------------------------------------------------------------
/src/vanna/xinference/xinference.py:
--------------------------------------------------------------------------------
1 | from xinference_client.client.restful.restful_client import (
2 | Client,
3 | RESTfulChatModelHandle,
4 | )
5 |
6 | from ..base import VannaBase
7 |
8 |
9 | class Xinference(VannaBase):
10 | def __init__(self, config=None):
11 | VannaBase.__init__(self, config=config)
12 |
13 | if not config or "base_url" not in config:
14 | raise ValueError("config must contain at least Xinference base_url")
15 |
16 | base_url = config["base_url"]
17 | api_key = config.get("api_key", "not empty")
18 | self.xinference_client = Client(base_url=base_url, api_key=api_key)
19 |
20 | def system_message(self, message: str) -> any:
21 | return {"role": "system", "content": message}
22 |
23 | def user_message(self, message: str) -> any:
24 | return {"role": "user", "content": message}
25 |
26 | def assistant_message(self, message: str) -> any:
27 | return {"role": "assistant", "content": message}
28 |
29 | def submit_prompt(self, prompt, **kwargs) -> str:
30 | if prompt is None:
31 | raise Exception("Prompt is None")
32 |
33 | if len(prompt) == 0:
34 | raise Exception("Prompt is empty")
35 |
36 | num_tokens = 0
37 | for message in prompt:
38 | num_tokens += len(message["content"]) / 4
39 |
40 | model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None)
41 | if model_uid is None:
42 | raise ValueError("model_uid is required")
43 |
44 | xinference_model = self.xinference_client.get_model(model_uid)
45 | if isinstance(xinference_model, RESTfulChatModelHandle):
46 | print(
47 | f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
48 | )
49 |
50 | response = xinference_model.chat(prompt)
51 | return response["choices"][0]["message"]["content"]
52 | else:
53 | raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle")
54 |
--------------------------------------------------------------------------------
/tests/test_imports.py:
--------------------------------------------------------------------------------
1 | def test_regular_imports():
2 | from vanna.anthropic.anthropic_chat import Anthropic_Chat
3 | from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore
4 | from vanna.base.base import VannaBase
5 | from vanna.bedrock.bedrock_converse import Bedrock_Converse
6 | from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
7 | from vanna.cohere.cohere_chat import Cohere_Chat
8 | from vanna.cohere.cohere_embeddings import Cohere_Embeddings
9 | from vanna.faiss.faiss import FAISS
10 | from vanna.google.bigquery_vector import BigQuery_VectorStore
11 | from vanna.google.gemini_chat import GoogleGeminiChat
12 | from vanna.hf.hf import Hf
13 | from vanna.local import LocalContext_OpenAI
14 | from vanna.marqo.marqo import Marqo_VectorStore
15 | from vanna.milvus.milvus_vector import Milvus_VectorStore
16 | from vanna.mistral.mistral import Mistral
17 | from vanna.ollama.ollama import Ollama
18 | from vanna.openai.openai_chat import OpenAI_Chat
19 | from vanna.openai.openai_embeddings import OpenAI_Embeddings
20 | from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore
21 | from vanna.opensearch.opensearch_vector_semantic import (
22 | OpenSearch_Semantic_VectorStore,
23 | )
24 | from vanna.pgvector.pgvector import PG_VectorStore
25 | from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore
26 | from vanna.qdrant.qdrant import Qdrant_VectorStore
27 | from vanna.qianfan.Qianfan_Chat import Qianfan_Chat
28 | from vanna.qianfan.Qianfan_embeddings import Qianfan_Embeddings
29 | from vanna.qianwen.QianwenAI_chat import QianWenAI_Chat
30 | from vanna.qianwen.QianwenAI_embeddings import QianWenAI_Embeddings
31 | from vanna.remote import VannaDefault
32 | from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
33 | from vanna.weaviate.weaviate_vector import WeaviateDatabase
34 | from vanna.xinference.xinference import Xinference
35 | from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
36 | from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings
37 |
38 | def test_shortcut_imports():
39 | from vanna.anthropic import Anthropic_Chat
40 | from vanna.azuresearch import AzureAISearch_VectorStore
41 | from vanna.base import VannaBase
42 | from vanna.chromadb import ChromaDB_VectorStore
43 | from vanna.cohere import Cohere_Chat, Cohere_Embeddings
44 | from vanna.faiss import FAISS
45 | from vanna.hf import Hf
46 | from vanna.marqo import Marqo_VectorStore
47 | from vanna.milvus import Milvus_VectorStore
48 | from vanna.mistral import Mistral
49 | from vanna.ollama import Ollama
50 | from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
51 | from vanna.opensearch import (
52 | OpenSearch_Semantic_VectorStore,
53 | OpenSearch_VectorStore,
54 | )
55 | from vanna.pgvector import PG_VectorStore
56 | from vanna.pinecone import PineconeDB_VectorStore
57 | from vanna.qdrant import Qdrant_VectorStore
58 | from vanna.qianfan import Qianfan_Chat, Qianfan_Embeddings
59 | from vanna.qianwen import QianWenAI_Chat, QianWenAI_Embeddings
60 | from vanna.vannadb import VannaDB_VectorStore
61 | from vanna.vllm import Vllm
62 | from vanna.weaviate import WeaviateDatabase
63 | from vanna.xinference import Xinference
64 | from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings
65 |
--------------------------------------------------------------------------------
/tests/test_instantiation.py:
--------------------------------------------------------------------------------
1 | from vanna.mock import MockEmbedding, MockLLM, MockVectorDB
2 |
--------------------------------------------------------------------------------
/tests/test_pgvector.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dotenv import load_dotenv
4 |
5 | # from vanna.pgvector import PG_VectorStore
6 | # from vanna.openai import OpenAI_Chat
7 |
8 | # assume .env file placed next to file with provided env vars
9 | load_dotenv()
10 |
11 | # def get_vanna_connection_string():
12 | # server = os.environ.get("PG_SERVER")
13 | # driver = "psycopg"
14 | # port = os.environ.get("PG_PORT", 5432)
15 | # database = os.environ.get("PG_DATABASE")
16 | # username = os.environ.get("PG_USERNAME")
17 | # password = os.environ.get("PG_PASSWORD")
18 |
19 | # def test_pgvector_e2e():
20 | # # configure Vanna to use OpenAI and PGVector
21 | # class VannaCustom(PG_VectorStore, OpenAI_Chat):
22 | # def __init__(self, config=None):
23 | # PG_VectorStore.__init__(self, config=config)
24 | # OpenAI_Chat.__init__(self, config=config)
25 |
26 | # vn = VannaCustom(config={
27 | # 'api_key': os.environ['OPENAI_API_KEY'],
28 | # 'model': 'gpt-3.5-turbo',
29 | # "connection_string": get_vanna_connection_string(),
30 | # })
31 |
32 | # # connect to SQLite database
33 | # vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
34 |
35 | # # train Vanna on DDLs
36 | # df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
37 | # for ddl in df_ddl['sql'].to_list():
38 | # vn.train(ddl=ddl)
39 | # assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default
40 |
41 | # question = "What are the top 7 customers by sales?"
42 | # sql = vn.generate_sql(question)
43 | # df = vn.run_sql(sql)
44 | # assert len(df) == 7
45 |
46 | # # test if Vanna can generate an answer
47 | # answer = vn.ask(question)
48 | # assert answer is not None
49 |
50 |
--------------------------------------------------------------------------------
/tests/test_vanna.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from vanna.anthropic.anthropic_chat import Anthropic_Chat
4 | from vanna.cohere.cohere_chat import Cohere_Chat
5 | from vanna.google import GoogleGeminiChat
6 | from vanna.mistral.mistral import Mistral
7 | from vanna.openai.openai_chat import OpenAI_Chat
8 | from vanna.remote import VannaDefault
9 | from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
10 |
11 | try:
12 | print("Trying to load .env")
13 | from dotenv import load_dotenv
14 | load_dotenv()
15 | except Exception as e:
16 | print(f"Failed to load .env {e}")
17 | pass
18 |
19 | MY_VANNA_MODEL = 'chinook'
20 | ANTHROPIC_Model = 'claude-3-sonnet-20240229'
21 | MY_VANNA_API_KEY = os.environ['VANNA_API_KEY']
22 | OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
23 | MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY']
24 | ANTHROPIC_API_KEY = os.environ['ANTHROPIC_API_KEY']
25 | SNOWFLAKE_ACCOUNT = os.environ['SNOWFLAKE_ACCOUNT']
26 | SNOWFLAKE_USERNAME = os.environ['SNOWFLAKE_USERNAME']
27 | SNOWFLAKE_PASSWORD = os.environ['SNOWFLAKE_PASSWORD']
28 | # AZURE_SEARCH_API_KEY = os.environ['AZURE_SEARCH_API_KEY']
29 |
30 | class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat):
31 | def __init__(self, config=None):
32 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
33 | OpenAI_Chat.__init__(self, config=config)
34 |
35 | vn_openai = VannaOpenAI(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
36 | vn_openai.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
37 |
38 | def test_vn_openai():
39 | sql = vn_openai.generate_sql("What are the top 4 customers by sales?")
40 | df = vn_openai.run_sql(sql)
41 | assert len(df) == 4
42 |
43 | class VannaMistral(VannaDB_VectorStore, Mistral):
44 | def __init__(self, config=None):
45 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
46 | Mistral.__init__(self, config={'api_key': MISTRAL_API_KEY, 'model': 'mistral-tiny'})
47 |
48 | vn_mistral = VannaMistral()
49 | vn_mistral.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
50 |
51 | def test_vn_mistral():
52 | sql = vn_mistral.generate_sql("What are the top 5 customers by sales?")
53 | df = vn_mistral.run_sql(sql)
54 | assert len(df) == 5
55 |
56 | vn_default = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)
57 | vn_default.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
58 |
59 | def test_vn_default():
60 | sql = vn_default.generate_sql("What are the top 6 customers by sales?")
61 | df = vn_default.run_sql(sql)
62 | assert len(df) == 6
63 |
64 | from vanna.qdrant import Qdrant_VectorStore
65 |
66 |
67 | class VannaQdrant(Qdrant_VectorStore, OpenAI_Chat):
68 | def __init__(self, config=None):
69 | Qdrant_VectorStore.__init__(self, config=config)
70 | OpenAI_Chat.__init__(self, config=config)
71 |
72 | from qdrant_client import QdrantClient
73 |
74 | qdrant_memory_client = QdrantClient(":memory:")
75 |
76 | vn_qdrant = VannaQdrant(config={'client': qdrant_memory_client, 'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
77 | vn_qdrant.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
78 |
79 | def test_vn_qdrant():
80 | df_ddl = vn_qdrant.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
81 |
82 | for ddl in df_ddl['sql'].to_list():
83 | vn_qdrant.train(ddl=ddl)
84 |
85 | sql = vn_qdrant.generate_sql("What are the top 7 customers by sales?")
86 | df = vn_qdrant.run_sql(sql)
87 | assert len(df) == 7
88 |
89 | from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
90 | from vanna.openai.openai_chat import OpenAI_Chat
91 |
92 |
93 | class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
94 | def __init__(self, config=None):
95 | ChromaDB_VectorStore.__init__(self, config=config)
96 | OpenAI_Chat.__init__(self, config=config)
97 |
98 | vn_chroma = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
99 | vn_chroma.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
100 |
101 | def test_vn_chroma():
102 | existing_training_data = vn_chroma.get_training_data()
103 | if len(existing_training_data) > 0:
104 | for _, training_data in existing_training_data.iterrows():
105 | vn_chroma.remove_training_data(training_data['id'])
106 |
107 | df_ddl = vn_chroma.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
108 |
109 | for ddl in df_ddl['sql'].to_list():
110 | vn_chroma.train(ddl=ddl)
111 |
112 | sql = vn_chroma.generate_sql("What are the top 7 customers by sales?")
113 | df = vn_chroma.run_sql(sql)
114 | assert len(df) == 7
115 |
116 | # from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore
117 |
118 |
119 | # class VannaAzureSearch(AzureAISearch_VectorStore, OpenAI_Chat):
120 | # def __init__(self, config=None):
121 | # AzureAISearch_VectorStore.__init__(self, config=config)
122 | # OpenAI_Chat.__init__(self, config=config)
123 |
124 | # vn_azure_search = VannaAzureSearch(config={'azure_search_api_key': AZURE_SEARCH_API_KEY,'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
125 | # vn_azure_search.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
126 |
127 | # def test_vn_azure_search():
128 | # existing_training_data = vn_azure_search.get_training_data()
129 | # print(existing_training_data)
130 | # if len(existing_training_data) > 0:
131 | # for _, training_data in existing_training_data.iterrows():
132 | # vn_azure_search.remove_training_data(training_data['id'])
133 |
134 | # df_ddl = vn_azure_search.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
135 | # for ddl in df_ddl['sql'].to_list():
136 | # vn_azure_search.train(ddl=ddl)
137 |
138 | # sql = vn_azure_search.generate_sql("What are the top 7 customers by sales?")
139 | # df = vn_azure_search.run_sql(sql)
140 | # assert len(df) == 7
141 |
142 | from vanna.milvus import Milvus_VectorStore
143 |
144 |
145 | class VannaMilvus(Milvus_VectorStore, OpenAI_Chat):
146 | def __init__(self, config=None):
147 | Milvus_VectorStore.__init__(self, config=config)
148 | OpenAI_Chat.__init__(self, config=config)
149 |
150 | vn_milvus = VannaMilvus(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
151 | vn_milvus.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
152 |
153 | def test_vn_milvus():
154 | existing_training_data = vn_milvus.get_training_data()
155 | if len(existing_training_data) > 0:
156 | for _, training_data in existing_training_data.iterrows():
157 | vn_milvus.remove_training_data(training_data['id'])
158 |
159 | df_ddl = vn_milvus.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
160 |
161 | for ddl in df_ddl['sql'].to_list():
162 | vn_milvus.train(ddl=ddl)
163 |
164 | sql = vn_milvus.generate_sql("What are the top 7 customers by sales?")
165 | df = vn_milvus.run_sql(sql)
166 | assert len(df) == 7
167 |
168 |
169 | class VannaNumResults(ChromaDB_VectorStore, OpenAI_Chat):
170 | def __init__(self, config=None):
171 | ChromaDB_VectorStore.__init__(self, config=config)
172 | OpenAI_Chat.__init__(self, config=config)
173 |
174 | vn_chroma_n_results = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results': 1})
175 | vn_chroma_n_results_ddl = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_ddl': 2})
176 | vn_chroma_n_results_sql = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_sql': 3})
177 | vn_chroma_n_results_documentation = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_documentation': 4})
178 |
179 | def test_n_results():
180 | for i in range(1, 10):
181 | vn_chroma.train(question=f"What are the total sales for customer {i}?", sql=f"SELECT SUM(sales) FROM example_sales WHERE customer_id = {i}")
182 |
183 | for i in range(1, 10):
184 | vn_chroma.train(documentation=f"Sample documentation {i}")
185 |
186 | question = "Whare are the top 5 customers by sales?"
187 | assert len(vn_chroma_n_results.get_related_ddl(question)) == 1
188 | assert len(vn_chroma_n_results.get_related_documentation(question)) == 1
189 | assert len(vn_chroma_n_results.get_similar_question_sql(question)) == 1
190 |
191 | assert len(vn_chroma_n_results_ddl.get_related_ddl(question)) == 2
192 | assert len(vn_chroma_n_results_ddl.get_related_documentation(question)) != 2
193 | assert len(vn_chroma_n_results_ddl.get_similar_question_sql(question)) != 2
194 |
195 | assert len(vn_chroma_n_results_sql.get_related_ddl(question)) != 3
196 | assert len(vn_chroma_n_results_sql.get_related_documentation(question)) != 3
197 | assert len(vn_chroma_n_results_sql.get_similar_question_sql(question)) == 3
198 |
199 | assert len(vn_chroma_n_results_documentation.get_related_ddl(question)) != 4
200 | assert len(vn_chroma_n_results_documentation.get_related_documentation(question)) == 4
201 | assert len(vn_chroma_n_results_documentation.get_similar_question_sql(question)) != 4
202 |
203 | class VannaClaude(VannaDB_VectorStore, Anthropic_Chat):
204 | def __init__(self, config=None):
205 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
206 | Anthropic_Chat.__init__(self, config={'api_key': ANTHROPIC_API_KEY, 'model': ANTHROPIC_Model})
207 |
208 |
209 | vn_claude = VannaClaude()
210 | vn_claude.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
211 |
212 |
213 | def test_vn_claude():
214 | sql = vn_claude.generate_sql("What are the top 8 customers by sales?")
215 | df = vn_claude.run_sql(sql)
216 | assert len(df) == 8
217 |
218 | class VannaGemini(VannaDB_VectorStore, GoogleGeminiChat):
219 | def __init__(self, config=None):
220 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
221 | GoogleGeminiChat.__init__(self, config=config)
222 |
223 | vn_gemini = VannaGemini(config={'api_key': os.environ['GEMINI_API_KEY']})
224 | vn_gemini.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
225 |
226 | def test_vn_gemini():
227 | sql = vn_gemini.generate_sql("What are the top 9 customers by sales?")
228 | df = vn_gemini.run_sql(sql)
229 | assert len(df) == 9
230 |
231 | class VannaCohere(VannaDB_VectorStore, Cohere_Chat):
232 | def __init__(self, config=None):
233 | VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
234 | Cohere_Chat.__init__(self, config=config)
235 |
236 | try:
237 | COHERE_API_KEY = os.environ['COHERE_API_KEY']
238 | vn_cohere = VannaCohere(config={'api_key': COHERE_API_KEY, 'model': 'command-a-03-2025'})
239 | vn_cohere.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
240 |
241 | def test_vn_cohere():
242 | sql = vn_cohere.generate_sql("What are the top 10 customers by sales?")
243 | df = vn_cohere.run_sql(sql)
244 | assert len(df) == 10
245 | except KeyError:
246 | print("Skipping Cohere tests - COHERE_API_KEY not found in environment variables")
247 |
248 | def test_training_plan():
249 | vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)
250 |
251 | vn_dummy.connect_to_snowflake(
252 | account=SNOWFLAKE_ACCOUNT,
253 | username=SNOWFLAKE_USERNAME,
254 | password=SNOWFLAKE_PASSWORD,
255 | database='SNOWFLAKE_SAMPLE_DATA',
256 | )
257 |
258 | df_information_schema = vn_dummy.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = 'TPCH_SF1' ")
259 |
260 | plan = vn_dummy.get_training_plan_generic(df_information_schema)
261 | assert len(plan._plan) == 8
262 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist =
3 | py310,
4 | mac,
5 | flake8,
6 |
7 | [py]
8 | deps=
9 | pytest-cov
10 | pytest-remove-stale-bytecode
11 |
12 | [testenv:py310]
13 | deps=
14 | {[py]deps}
15 | extras = all
16 | passenv = *
17 | basepython = python3.10
18 | commands = pytest -v --cov=tests/ --cov-report=term --cov-report=html
19 |
20 | [testenv:mac]
21 | deps=
22 | {[py]deps}
23 | python-dotenv
24 | extras = all
25 | basepython = python
26 | commands =
27 | pytest -x -v --cov=tests/ --cov-report=term --cov-report=html
28 |
29 | [testenv:flake8]
30 | exclude = .tox/*
31 | deps = flake8
32 | commands = flake8 src
33 |
--------------------------------------------------------------------------------
/training_data/sample-imdb/questions.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "question":"what are 5 most grossing movies in IMDB top 1000 ",
4 | "answer":"SELECT series_title,\n gross\nFROM imdb.public.movies\nORDER BY gross desc limit 5;"
5 | },
6 | {
7 | "question":"what are the top 5 movies and their ratings basis IMDB rating ? ",
8 | "answer":"SELECT series_title,\n imdb_rating\nFROM imdb.public.movies\nORDER BY imdb_rating desc limit 5"
9 | },
10 | {
11 | "question":"which 5 director have the most number of movies in the IMDB top 1000 ? ",
12 | "answer":"SELECT director,\r\n count(*) as num_of_movies\r\nFROM imdb.public.movies\r\nGROUP BY director\r\nORDER BY num_of_movies desc limit 5;"
13 | },
14 | {
15 | "question":"what are the top 5 movies and their ratings basis IMDB ? ",
16 | "answer":"SELECT series_title,\n imdb_rating\nFROM imdb.public.movies\nORDER BY imdb_rating desc limit 5"
17 | },
18 | {
19 | "question":"what is the distribution of imdb top 1000 movies across the release year ?",
20 | "answer":"SELECT released_year, count(*) as num_movies\r\nFROM imdb.public.movies\r\nGROUP BY 1\r\nORDER BY 1;"
21 | },
22 | {
23 | "question":"What are the 5 best rated action movies ? ",
24 | "answer":"SELECT series_title,\n imdb_rating\nFROM imdb.public.movies\nWHERE genre = 'Action'\nORDER BY imdb_rating desc limit 5;"
25 | },
26 | {
27 | "question":"What are the top 2 rated movies of ingmar bergman in drama genre and what are the lead actors in the movies ?",
28 | "answer":"SELECT series_title,\n star1,\n imdb_rating\nFROM imdb.public.movies\nWHERE director = 'Ingmar Bergman'\n and genre = 'Drama'\nORDER BY imdb_rating desc limit 2;;"
29 | },
30 | {
31 | "question":"what are the top 5 movies basis IMDB rating ? ",
32 | "answer":"SELECT series_title\r\nFROM IMDB.PUBLIC.MOVIES\r\nORDER BY imdb_rating DESC\r\nLIMIT 5"
33 | },
34 | {
35 | "question":"which 5 director has the most number of movies in the IMDB top 1000 ? ",
36 | "answer":"```\nSELECT director,\n count(*) as num_movies\nFROM imdb.public.movies\nWHERE imdb_rating <= 1000\nGROUP BY director\nORDER BY num_movies desc limit 5;\n```"
37 | },
38 | {
39 | "question":"what is the average IMDB rating for each genre ? ",
40 | "answer":"SELECT genre,\n avg(imdb_rating) as avg_imdb_rating\nFROM imdb.public.movies\nGROUP BY genre;"
41 | },
42 | {
43 | "question":"what is the genre wise, average IMDB rating ? ",
44 | "answer":"SELECT genre,\n avg(imdb_rating) as avg_imdb_rating\nFROM imdb.public.movies\nGROUP BY genre;"
45 | },
46 | {
47 | "question":"What is the runtime of Forest Gump, Saving Private Ryan and The Green Mile",
48 | "answer":"SELECT series_title,\n runtime\nFROM imdb.public.movies\nWHERE series_title in ('Forrest Gump', 'Saving Private Ryan', 'The Green Mile');"
49 | },
50 | {
51 | "question":"what is the distribution of titles on IMDB across genres ? ",
52 | "answer":"SELECT genre,\n count(*) as num_titles\nFROM imdb.public.movies\nGROUP BY genre\nORDER BY num_titles desc;"
53 | }
54 | ]
--------------------------------------------------------------------------------
/training_data/snowflake-cost/questions.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "question":"What are the daily costs for the last 30 days?",
4 | "answer":"SELECT usage_date,\n sum(usage_in_currency) as daily_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\n and usage_date >= dateadd(day, -30, current_date())\nGROUP BY usage_date\nORDER BY usage_date"
5 | },
6 | {
7 | "question":"What are the first 10 rows in the USAGE_IN_CURRENCY_DAILY table?",
8 | "answer":"SELECT ORGANIZATION_NAME, CONTRACT_NUMBER, ACCOUNT_NAME, ACCOUNT_LOCATOR, REGION, SERVICE_LEVEL, USAGE_DATE, USAGE_TYPE, CURRENCY, USAGE, USAGE_IN_CURRENCY, BALANCE_SOURCE\r\nFROM SNOWFLAKE.ORGANIZATION_USAGE.USAGE_IN_CURRENCY_DAILY\r\nLIMIT 10"
9 | },
10 | {
11 | "question":"Total usage costs in dollars for the organization, broken down by account",
12 | "answer":"SELECT account_name,\n sum(usage_in_currency) as total_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\nGROUP BY account_name"
13 | },
14 | {
15 | "question":"What is the daily cost by usage type?",
16 | "answer":"SELECT usage_date,\n usage_type,\n sum(usage_in_currency) as daily_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\nGROUP BY usage_date, usage_type\nORDER BY usage_date, daily_cost desc"
17 | },
18 | {
19 | "question":"What are the first 10 rows in the PREVIEW_USAGE_IN_CURRENCY_DAILY table?",
20 | "answer":"SELECT ORGANIZATION_NAME, CONTRACT_NUMBER, ACCOUNT_NAME, ACCOUNT_LOCATOR, REGION, SERVICE_LEVEL, USAGE_DATE, USAGE_TYPE, CURRENCY, USAGE, USAGE_IN_CURRENCY\r\nFROM SNOWFLAKE.ORGANIZATION_USAGE.PREVIEW_USAGE_IN_CURRENCY_DAILY\r\nLIMIT 10"
21 | },
22 | {
23 | "question":"Daily usage cost by account",
24 | "answer":"SELECT account_name,\n usage_date,\n sum(usage_in_currency) as daily_usage_cost\nFROM snowflake.organization_usage.preview_usage_in_currency_daily\nWHERE currency = 'USD'\nGROUP BY account_name, usage_date\nORDER BY usage_date, daily_usage_cost desc"
25 | }
26 | ]
--------------------------------------------------------------------------------