├── .gcloudignore ├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── defog ├── __init__.py ├── admin_methods.py ├── async_admin_methods.py ├── async_generate_schema.py ├── async_health_methods.py ├── async_query_methods.py ├── cli.py ├── generate_schema.py ├── health_methods.py ├── llm │ ├── __init__.py │ ├── citations.py │ ├── code_interp.py │ ├── config │ │ ├── __init__.py │ │ ├── constants.py │ │ └── settings.py │ ├── cost │ │ ├── __init__.py │ │ ├── calculator.py │ │ └── models.py │ ├── exceptions │ │ ├── __init__.py │ │ └── llm_exceptions.py │ ├── llm_providers.py │ ├── mcp_readme.md │ ├── memory │ │ ├── __init__.py │ │ ├── compactifier.py │ │ ├── history_manager.py │ │ └── token_counter.py │ ├── models.py │ ├── orchestrator.py │ ├── providers │ │ ├── __init__.py │ │ ├── anthropic_provider.py │ │ ├── base.py │ │ ├── gemini_provider.py │ │ ├── openai_provider.py │ │ └── together_provider.py │ ├── tools │ │ ├── __init__.py │ │ └── handler.py │ ├── utils.py │ ├── utils_function_calling.py │ ├── utils_logging.py │ ├── utils_mcp.py │ ├── utils_memory.py │ ├── web_search.py │ └── youtube_transcript.py ├── query.py ├── query_methods.py ├── serve.py ├── static │ ├── 404.html │ ├── _next │ │ └── static │ │ │ ├── chunks │ │ │ ├── 0c428ae2.7f5ab17ef2110d84.js │ │ │ ├── 238-21e16f207d48d221.js │ │ │ ├── 283.503b46c22e64c702.js │ │ │ ├── 9b8d1757.9dd2c712a441d5f8.js │ │ │ ├── ea88be26.e203a40320569dfc.js │ │ │ ├── faae6fda.7ebb95670d232bb0.js │ │ │ ├── framework-02223fe42ab9321b.js │ │ │ ├── main-d30d248d262e39c4.js │ │ │ ├── pages │ │ │ │ ├── _app-db0976def6406e5e.js │ │ │ │ ├── _error-ee42a9921d95ff81.js │ │ │ │ ├── extract-metadata-2adf74ad3bcc8699.js │ │ │ │ ├── index-b60f249c1d54d3bf.js │ │ │ │ ├── instruct-model-d040be04cf7f21f2.js │ │ │ │ └── query-data-b8197e7950b177eb.js │ │ │ ├── polyfills-c67a75d1b6f99dc8.js │ │ │ └── webpack-1657be5a4830bbb9.js │ │ │ ├── css │ │ │ └── 321c398b2a784143.css │ │ │ └── lQfswjqOtyFH-kQc0_q8t │ │ │ ├── _buildManifest.js │ │ │ └── _ssgManifest.js │ ├── extract-metadata.html │ ├── favicon.ico │ ├── index.html │ ├── instruct-model.html │ ├── next.svg │ ├── query-data.html │ └── vercel.svg └── util.py ├── orchestrator_dynamic_example.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── test_async_defog.py ├── test_citations.py ├── test_code_interp.py ├── test_defog.py ├── test_llm.py ├── test_llm_response.py ├── test_llm_tool_calls.py ├── test_mcp_chat_async.py ├── test_memory.py ├── test_orchestrator_e2e.py ├── test_query.py ├── test_util.py ├── test_web_search.py └── test_youtube.py /.gcloudignore: -------------------------------------------------------------------------------- 1 | # This file specifies files that are *not* uploaded to Google Cloud 2 | # using gcloud. It follows the same syntax as .gitignore, with the addition of 3 | # "#!include" directives (which insert the entries of the given .gitignore-style 4 | # file at that point). 5 | # 6 | # For more information, run: 7 | # $ gcloud topic gcloudignore 8 | # 9 | .gcloudignore 10 | # If you would like to upload your .git directory, .gitignore file or files 11 | # from your .gitignore file, remove the corresponding line 12 | # below: 13 | .git 14 | .gitignore 15 | 16 | node_modules 17 | #!include:.gitignore 18 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | changes: 7 | runs-on: ubuntu-latest 8 | outputs: 9 | defog: ${{ steps.filter.outputs.defog }} 10 | tests: ${{ steps.filter.outputs.tests }} 11 | requirements: ${{ steps.filter.outputs.requirements }} 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: dorny/paths-filter@v2 15 | id: filter 16 | with: 17 | filters: | 18 | defog: 19 | - 'defog/**' 20 | tests: 21 | - 'tests/**' 22 | requirements: 23 | - 'requirements.txt' 24 | - 'setup.py' 25 | - 'setup.cfg' 26 | 27 | test: 28 | runs-on: ubuntu-latest 29 | needs: [changes] 30 | if: needs.changes.outputs.defog == 'true' || needs.changes.outputs.tests == 'true' || needs.changes.outputs.requirements == 'true' 31 | steps: 32 | - uses: actions/checkout@v2 33 | - name: Set up Python 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: '3.10' 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -r requirements.txt 41 | - name: Run affected tests 42 | run: | 43 | if [ "${{ needs.changes.outputs.requirements }}" == "true" ]; then 44 | pytest tests 45 | elif [ "${{ needs.changes.outputs.defog }}" == "true" ] || [ "${{ needs.changes.outputs.tests }}" == "true" ]; then 46 | pytest tests 47 | fi 48 | env: 49 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 50 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 51 | TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} 52 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Jupyter Notebook 55 | .ipynb_checkpoints 56 | 57 | # IPython 58 | profile_default/ 59 | ipython_config.py 60 | 61 | # pyenv 62 | .python-version 63 | 64 | *.DS_Store 65 | local_test.py 66 | 67 | defog_metadata.csv 68 | golden_queries.csv 69 | golden_queries.json 70 | glossary.txt 71 | 72 | # Ignore virtual environment directories 73 | .virtual/ 74 | myenv/ 75 | venv/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Defog.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Defog Python 2 | 3 | [![tests](https://github.com/defog-ai/defog-python/actions/workflows/main.yml/badge.svg)](https://github.com/defog-ai/defog-python/actions/workflows/main.yml) 4 | 5 | # TLDR 6 | 7 | This library used to be an SDK for accessing Defog's cloud hosted text-to-SQL service. It has since transformed into a general purpose library for: 8 | 9 | 1. Making convenient, cross-provider LLM calls (including server-hosted tools) 10 | 2. Easily extracting information from databases to make them easy to use 11 | 12 | If you are looking for text-to-SQL or deep-research like capabilities, check out [introspect](https://github.com/defog-ai/introspect), our open-source, MIT licensed repo that uses this library as a dependency. 13 | 14 | # Using this library 15 | 16 | ## LLM Utilities (`defog.llm`) 17 | 18 | The `defog.llm` module provides cross-provider LLM functionality with support for function calling, structured output, and specialized tools. 19 | 20 | **Note:** As of the latest version, all LLM functions are async-only. Synchronous methods have been removed to improve performance and consistency. 21 | 22 | ### Core Chat Functions 23 | 24 | ```python 25 | from defog.llm.utils import chat_async, chat_async_legacy, LLMResponse 26 | from defog.llm.llm_providers import LLMProvider 27 | 28 | # Unified async interface with explicit provider specification 29 | response: LLMResponse = await chat_async( 30 | provider=LLMProvider.OPENAI, # or "openai", LLMProvider.ANTHROPIC, etc. 31 | model="gpt-4o", 32 | messages=[{"role": "user", "content": "Hello!"}], 33 | max_completion_tokens=1000, 34 | temperature=0.0 35 | ) 36 | 37 | print(response.content) # Response text 38 | print(f"Cost: ${response.cost_in_cents/100:.4f}") 39 | 40 | # Alternative: Legacy model-to-provider inference 41 | 42 | response = await chat_async_legacy( 43 | model="gpt-4o", 44 | messages=[{"role": "user", "content": "Hello!"}] 45 | ) 46 | ``` 47 | 48 | ### Provider-Specific Examples 49 | 50 | ```python 51 | from defog.llm.utils import chat_async 52 | from defog.llm.llm_providers import LLMProvider 53 | 54 | # OpenAI with function calling 55 | response = await chat_async( 56 | provider=LLMProvider.OPENAI, 57 | model="gpt-4o", 58 | messages=[{"role": "user", "content": "What's the weather in Paris?"}], 59 | tools=[my_function], # Optional function calling 60 | tool_choice="auto" 61 | ) 62 | 63 | # Anthropic with structured output 64 | response = await chat_async( 65 | provider=LLMProvider.ANTHROPIC, 66 | model="claude-3-5-sonnet", 67 | messages=[{"role": "user", "content": "Hello!"}], 68 | response_format=MyPydanticModel # Structured output 69 | ) 70 | 71 | # Gemini 72 | response = await chat_async( 73 | provider=LLMProvider.GEMINI, 74 | model="gemini-2.0-flash", 75 | messages=[{"role": "user", "content": "Hello!"}] 76 | ) 77 | ``` 78 | 79 | ### Code Interpreter Tool 80 | 81 | Execute Python code in sandboxed environments across providers: 82 | 83 | ```python 84 | from defog.llm.code_interp import code_interpreter_tool 85 | from defog.llm.llm_providers import LLMProvider 86 | 87 | result = await code_interpreter_tool( 88 | question="Analyze this CSV data and create a visualization", 89 | model="gpt-4o", 90 | provider=LLMProvider.OPENAI, 91 | csv_string="name,age\nAlice,25\nBob,30", 92 | instructions="You are a data analyst. Create clear visualizations." 93 | ) 94 | 95 | print(result["code"]) # Generated Python code 96 | print(result["output"]) # Execution results 97 | ``` 98 | 99 | ### Web Search Tool 100 | 101 | Search the web for current information: 102 | 103 | ```python 104 | from defog.llm.web_search import web_search_tool 105 | from defog.llm.llm_providers import LLMProvider 106 | 107 | result = await web_search_tool( 108 | question="What are the latest developments in AI?", 109 | model="claude-3-5-sonnet", 110 | provider=LLMProvider.ANTHROPIC, 111 | max_tokens=2048 112 | ) 113 | 114 | print(result["search_results"]) # Search results text 115 | print(result["websites_cited"]) # Source citations 116 | ``` 117 | 118 | ### Function Calling 119 | 120 | Define tools for LLMs to call: 121 | 122 | ```python 123 | from pydantic import BaseModel 124 | from defog.llm.utils import chat_async 125 | 126 | class WeatherInput(BaseModel): 127 | location: str 128 | units: str = "celsius" 129 | 130 | def get_weather(input: WeatherInput) -> str: 131 | """Get current weather for a location""" 132 | return f"Weather in {input.location}: 22°{input.units[0].upper()}, sunny" 133 | 134 | response = await chat_async( 135 | model="gpt-4o", 136 | messages=[{"role": "user", "content": "What's the weather in Paris?"}], 137 | tools=[get_weather], 138 | tool_choice="auto" 139 | ) 140 | ``` 141 | 142 | ### Memory Compactification 143 | 144 | Automatically manage long conversations by intelligently summarizing older messages while preserving context: 145 | 146 | ```python 147 | from defog.llm import chat_async_with_memory, create_memory_manager, MemoryConfig 148 | 149 | # Create a memory manager with custom settings 150 | memory_manager = create_memory_manager( 151 | token_threshold=50000, # Compactify when reaching 50k tokens 152 | preserve_last_n_messages=10, # Keep last 10 messages intact 153 | summary_max_tokens=2000, # Max tokens for summary 154 | enabled=True 155 | ) 156 | 157 | # System messages are automatically preserved across compactifications 158 | response1 = await chat_async_with_memory( 159 | provider="openai", 160 | model="gpt-4o", 161 | messages=[ 162 | {"role": "system", "content": "You are a helpful Python tutor."}, 163 | {"role": "user", "content": "Tell me about Python"} 164 | ], 165 | memory_manager=memory_manager 166 | ) 167 | 168 | # Continue the conversation - memory is automatically managed 169 | response2 = await chat_async_with_memory( 170 | provider="openai", 171 | model="gpt-4o", 172 | messages=[{"role": "user", "content": "What about its use in data science?"}], 173 | memory_manager=memory_manager 174 | ) 175 | 176 | # The system message is preserved even after compactification! 177 | # Check current conversation state: 178 | print(f"Total messages: {len(memory_manager.get_current_messages())}") 179 | print(f"Compactifications: {memory_manager.get_stats()['compactification_count']}") 180 | 181 | # Or use memory configuration without explicit manager 182 | response = await chat_async_with_memory( 183 | provider="anthropic", 184 | model="claude-3-5-sonnet", 185 | messages=[{"role": "user", "content": "Hello!"}], 186 | memory_config=MemoryConfig( 187 | enabled=True, 188 | token_threshold=100000, # 100k tokens before compactification 189 | preserve_last_n_messages=10, 190 | summary_max_tokens=4000 191 | ) 192 | ) 193 | ``` 194 | 195 | Key features: 196 | - **System message preservation**: System messages are always kept intact, never summarized 197 | - **Automatic summarization**: When token count exceeds threshold, older messages are intelligently summarized 198 | - **Context preservation**: Recent messages are kept intact for continuity 199 | - **Provider agnostic**: Works with all supported LLM providers 200 | - **Token counting**: Uses tiktoken for accurate OpenAI token counts, with intelligent fallbacks for other providers 201 | - **Flexible configuration**: Customize thresholds, preservation rules, and summary sizes 202 | 203 | How it works: 204 | 1. As conversation grows, token count is tracked 205 | 2. When threshold is exceeded, older messages (except system messages) are summarized 206 | 3. Summary is added as a user message with `[Previous conversation summary]` prefix 207 | 4. Recent messages + system messages are preserved for context 208 | 5. Process repeats as needed for very long conversations 209 | 210 | ### MCP (Model Context Protocol) Support 211 | 212 | Connect to MCP servers for extended tool capabilities: 213 | 214 | ```python 215 | from defog.llm.utils_mcp import initialize_mcp_client 216 | 217 | # Initialize with config file 218 | mcp_client = await initialize_mcp_client( 219 | config="path/to/mcp_config.json", 220 | model="claude-3-5-sonnet" 221 | ) 222 | 223 | # Process queries with MCP tools 224 | response, tool_outputs = await mcp_client.mcp_chat( 225 | "Use the calculator tool to compute 123 * 456" 226 | ) 227 | ``` 228 | 229 | # Testing 230 | For developers who want to test or add tests for this client, you can run: 231 | ``` 232 | pytest tests 233 | ``` 234 | Note that we will transfer the existing .defog/connection.json file over to /tmp (if at all), and transfer the original file back once the tests are done to avoid messing with the original config. 235 | If submitting a PR, please use the `black` linter to lint your code. You can add it as a git hook to your repo by running the command below: 236 | ```bash 237 | echo -e '#!/bin/sh\n#\n# Run linter before commit\nblack $(git rev-parse --show-toplevel)' > .git/hooks/pre-commit 238 | chmod +x .git/hooks/pre-commit 239 | ``` 240 | -------------------------------------------------------------------------------- /defog/async_health_methods.py: -------------------------------------------------------------------------------- 1 | from defog.util import make_async_post_request 2 | 3 | 4 | async def check_golden_queries_coverage(self, dev: bool = False): 5 | """ 6 | Check the number of tables and columns inside the metadata schema that are covered by the golden queries. 7 | """ 8 | url = f"{self.base_url}/get_golden_queries_coverage" 9 | payload = {"api_key": self.api_key, "dev": dev} 10 | return await make_async_post_request(url, payload) 11 | 12 | 13 | async def check_md_valid(self, dev: bool = False): 14 | """ 15 | Check if the metadata schema is valid. 16 | """ 17 | url = f"{self.base_url}/check_md_valid" 18 | payload = {"api_key": self.api_key, "db_type": self.db_type, "dev": dev} 19 | return await make_async_post_request(url, payload) 20 | 21 | 22 | async def check_gold_queries_valid(self, dev: bool = False): 23 | """ 24 | Check if the golden queries are valid and can be executed on a given database without errors. 25 | """ 26 | url = f"{self.base_url}/check_gold_queries_valid" 27 | payload = {"api_key": self.api_key, "db_type": self.db_type, "dev": dev} 28 | return await make_async_post_request(url, payload) 29 | 30 | 31 | async def check_glossary_valid(self, dev: bool = False): 32 | """ 33 | Check if the glossary is valid by verifying if all schema, table, and column names referenced are present in the metadata. 34 | """ 35 | url = f"{self.base_url}/check_glossary_valid" 36 | payload = {"api_key": self.api_key, "dev": dev} 37 | return await make_async_post_request(url, payload) 38 | 39 | 40 | async def check_glossary_consistency(self, dev: bool = False): 41 | """ 42 | Check if all logic in the glossary is consistent and coherent. 43 | """ 44 | url = f"{self.base_url}/check_glossary_consistency" 45 | payload = {"api_key": self.api_key, "dev": dev} 46 | return await make_async_post_request(url, payload) 47 | -------------------------------------------------------------------------------- /defog/async_query_methods.py: -------------------------------------------------------------------------------- 1 | from defog.util import make_async_post_request 2 | from defog.query import async_execute_query 3 | from datetime import datetime 4 | 5 | 6 | async def get_query( 7 | self, 8 | question: str, 9 | hard_filters: str = "", 10 | previous_context: list = [], 11 | glossary: str = "", 12 | debug: bool = False, 13 | dev: bool = False, 14 | temp: bool = False, 15 | profile: bool = False, 16 | ignore_cache: bool = False, 17 | model: str = "", 18 | use_golden_queries: bool = True, 19 | subtable_pruning: bool = False, 20 | glossary_pruning: bool = False, 21 | prune_max_tokens: int = 2000, 22 | prune_bm25_num_columns: int = 10, 23 | prune_glossary_max_tokens: int = 1000, 24 | prune_glossary_num_cos_sim_units: int = 10, 25 | prune_glossary_bm25_units: int = 10, 26 | ): 27 | """ 28 | Asynchronously sends the query to the defog servers, and return the response. 29 | :param question: The question to be asked. 30 | :return: The response from the defog server. 31 | """ 32 | try: 33 | data = { 34 | "question": question, 35 | "api_key": self.api_key, 36 | "previous_context": previous_context, 37 | "db_type": self.db_type if self.db_type != "databricks" else "postgres", 38 | "glossary": glossary, 39 | "hard_filters": hard_filters, 40 | "dev": dev, 41 | "temp": temp, 42 | "ignore_cache": ignore_cache, 43 | "model": model, 44 | "use_golden_queries": use_golden_queries, 45 | "subtable_pruning": subtable_pruning, 46 | "glossary_pruning": glossary_pruning, 47 | "prune_max_tokens": prune_max_tokens, 48 | "prune_bm25_num_columns": prune_bm25_num_columns, 49 | "prune_glossary_max_tokens": prune_glossary_max_tokens, 50 | "prune_glossary_num_cos_sim_units": prune_glossary_num_cos_sim_units, 51 | "prune_glossary_bm25_units": prune_glossary_bm25_units, 52 | } 53 | 54 | t_start = datetime.now() 55 | 56 | resp = await make_async_post_request( 57 | url=self.generate_query_url, payload=data, timeout=300 58 | ) 59 | 60 | t_end = datetime.now() 61 | time_taken = (t_end - t_start).total_seconds() 62 | query_generated = resp.get("sql", resp.get("query_generated")) 63 | ran_successfully = resp.get("ran_successfully") 64 | error_message = resp.get("error_message") 65 | query_db = self.db_type 66 | resp = { 67 | "query_generated": query_generated, 68 | "ran_successfully": ran_successfully, 69 | "error_message": error_message, 70 | "query_db": query_db, 71 | "previous_context": resp.get("previous_context"), 72 | "reason_for_query": resp.get("reason_for_query"), 73 | } 74 | if profile: 75 | resp["time_taken"] = time_taken 76 | 77 | return resp 78 | except Exception as e: 79 | if debug: 80 | print(e) 81 | return { 82 | "ran_successfully": False, 83 | "error_message": "Sorry :( Our server is at capacity right now and we are unable to process your query. Please try again in a few minutes?", 84 | } 85 | 86 | 87 | async def run_query( 88 | self, 89 | question: str, 90 | hard_filters: str = "", 91 | previous_context: list = [], 92 | glossary: str = "", 93 | query: dict = None, 94 | retries: int = 3, 95 | dev: bool = False, 96 | temp: bool = False, 97 | profile: bool = False, 98 | ignore_cache: bool = False, 99 | model: str = "", 100 | use_golden_queries: bool = True, 101 | subtable_pruning: bool = False, 102 | glossary_pruning: bool = False, 103 | prune_max_tokens: int = 2000, 104 | prune_bm25_num_columns: int = 10, 105 | prune_glossary_max_tokens: int = 1000, 106 | prune_glossary_num_cos_sim_units: int = 10, 107 | prune_glossary_bm25_units: int = 10, 108 | ): 109 | """ 110 | Asynchronously sends the question to the defog servers, executes the generated SQL, 111 | and returns the response. 112 | :param question: The question to be asked. 113 | :return: The response from the defog server. 114 | """ 115 | if query is None: 116 | print(f"Generating the query for your question: {question}...") 117 | query = await self.get_query( 118 | question, 119 | hard_filters, 120 | previous_context, 121 | glossary=glossary, 122 | dev=dev, 123 | temp=temp, 124 | profile=profile, 125 | model=model, 126 | ignore_cache=ignore_cache, 127 | use_golden_queries=use_golden_queries, 128 | subtable_pruning=subtable_pruning, 129 | glossary_pruning=glossary_pruning, 130 | prune_max_tokens=prune_max_tokens, 131 | prune_bm25_num_columns=prune_bm25_num_columns, 132 | prune_glossary_max_tokens=prune_glossary_max_tokens, 133 | prune_glossary_num_cos_sim_units=prune_glossary_num_cos_sim_units, 134 | prune_glossary_bm25_units=prune_glossary_bm25_units, 135 | ) 136 | if query["ran_successfully"]: 137 | try: 138 | print("Query generated, now running it on your database...") 139 | tstart = datetime.now() 140 | colnames, result, executed_query = await async_execute_query( 141 | query=query["query_generated"], 142 | api_key=self.api_key, 143 | db_type=self.db_type, 144 | db_creds=self.db_creds, 145 | question=question, 146 | hard_filters=hard_filters, 147 | retries=retries, 148 | dev=dev, 149 | temp=temp, 150 | ) 151 | tend = datetime.now() 152 | time_taken = (tend - tstart).total_seconds() 153 | resp = { 154 | "columns": colnames, 155 | "data": result, 156 | "query_generated": executed_query, 157 | "ran_successfully": True, 158 | "reason_for_query": query.get("reason_for_query"), 159 | "previous_context": query.get("previous_context"), 160 | } 161 | if profile: 162 | resp["execution_time_taken"] = time_taken 163 | resp["generation_time_taken"] = query.get("time_taken") 164 | return resp 165 | except Exception as e: 166 | return { 167 | "ran_successfully": False, 168 | "error_message": str(e), 169 | "query_generated": query["query_generated"], 170 | } 171 | else: 172 | return {"ran_successfully": False, "error_message": query["error_message"]} 173 | -------------------------------------------------------------------------------- /defog/health_methods.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def check_golden_queries_coverage(self, dev: bool = False): 5 | """ 6 | Check the number of tables and columns inside the metadata schema that are covered by the golden queries. 7 | """ 8 | try: 9 | r = requests.post( 10 | f"{self.base_url}/get_golden_queries_coverage", 11 | json={"api_key": self.api_key, "dev": dev}, 12 | verify=False, 13 | ) 14 | resp = r.json() 15 | return resp 16 | except Exception as e: 17 | return {"error": str(e)} 18 | 19 | 20 | def check_md_valid(self, dev: bool = False): 21 | """ 22 | Check if the metadata schema is valid. 23 | """ 24 | try: 25 | r = requests.post( 26 | f"{self.base_url}/check_md_valid", 27 | json={"api_key": self.api_key, "db_type": self.db_type, "dev": dev}, 28 | verify=False, 29 | ) 30 | resp = r.json() 31 | return resp 32 | except Exception as e: 33 | return {"error": str(e)} 34 | 35 | 36 | def check_gold_queries_valid(self, dev: bool = False): 37 | """ 38 | Check if the golden queries are valid, and can actually be executed on a given database without errors 39 | """ 40 | r = requests.post( 41 | f"{self.base_url}/check_gold_queries_valid", 42 | json={"api_key": self.api_key, "db_type": self.db_type, "dev": dev}, 43 | verify=False, 44 | ) 45 | resp = r.json() 46 | return resp 47 | 48 | 49 | def check_glossary_valid(self, dev: bool = False): 50 | """ 51 | Check if glossary is valid by verifying if all schema, table and column names referenced are present in the metadata. 52 | """ 53 | r = requests.post( 54 | f"{self.base_url}/check_glossary_valid", 55 | json={"api_key": self.api_key, "dev": dev}, 56 | verify=False, 57 | ) 58 | resp = r.json() 59 | return resp 60 | 61 | 62 | def check_glossary_consistency(self, dev: bool = False): 63 | """ 64 | Check if all logic in the glossary is consistent and coherent. 65 | """ 66 | r = requests.post( 67 | f"{self.base_url}/check_glossary_consistency", 68 | json={"api_key": self.api_key, "dev": dev}, 69 | verify=False, 70 | ) 71 | resp = r.json() 72 | return resp 73 | -------------------------------------------------------------------------------- /defog/llm/__init__.py: -------------------------------------------------------------------------------- 1 | """LLM module with memory management capabilities.""" 2 | 3 | from .utils import chat_async, LLMResponse 4 | from .utils_memory import ( 5 | chat_async_with_memory, 6 | create_memory_manager, 7 | MemoryConfig, 8 | ) 9 | from .memory import ( 10 | MemoryManager, 11 | ConversationHistory, 12 | compactify_messages, 13 | TokenCounter, 14 | ) 15 | 16 | __all__ = [ 17 | # Core functions 18 | "chat_async", 19 | "chat_async_with_memory", 20 | "LLMResponse", 21 | # Memory management 22 | "MemoryManager", 23 | "ConversationHistory", 24 | "MemoryConfig", 25 | "create_memory_manager", 26 | "compactify_messages", 27 | "TokenCounter", 28 | ] 29 | -------------------------------------------------------------------------------- /defog/llm/citations.py: -------------------------------------------------------------------------------- 1 | from defog.llm.llm_providers import LLMProvider 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | import asyncio 5 | 6 | 7 | async def upload_document_to_openai_vector_store(document, store_id): 8 | from openai import AsyncOpenAI 9 | 10 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 11 | 12 | file_name = document["document_name"] 13 | if not file_name.endswith(".txt"): 14 | file_name = file_name + ".txt" 15 | file_content = document["document_content"] 16 | if isinstance(file_content, str): 17 | # convert to bytes 18 | file_content = file_content.encode("utf-8") 19 | 20 | # first, upload the file to the vector store 21 | file = await client.files.create( 22 | file=(file_name, file_content), purpose="assistants" 23 | ) 24 | 25 | # then add it to the vector store 26 | await client.vector_stores.files.create( 27 | vector_store_id=store_id, 28 | file_id=file.id, 29 | ) 30 | 31 | 32 | async def citations_tool( 33 | question: str, 34 | instructions: str, 35 | documents: list[dict], 36 | model: str, 37 | provider: LLMProvider, 38 | max_tokens: int = 16000, 39 | ): 40 | """ 41 | Use this tool to get an answer to a well-cited answer to a question, 42 | given a list of documents. 43 | """ 44 | async with ToolProgressTracker( 45 | "Citations Tool", f"Generating citations for {len(documents)} documents" 46 | ) as tracker: 47 | subtask_logger = SubTaskLogger() 48 | subtask_logger.log_provider_info( 49 | provider.value if hasattr(provider, "value") else str(provider), model 50 | ) 51 | 52 | if provider in [LLMProvider.OPENAI, LLMProvider.OPENAI.value]: 53 | from openai import AsyncOpenAI 54 | 55 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 56 | 57 | # create an ephemeral vector store 58 | store = await client.vector_stores.create() 59 | store_id = store.id 60 | 61 | # Upload all documents in parallel 62 | tracker.update(10, "Uploading documents to vector store") 63 | subtask_logger.log_subtask("Starting document uploads", "processing") 64 | 65 | upload_tasks = [] 66 | for idx, document in enumerate(documents, 1): 67 | subtask_logger.log_document_upload( 68 | document["document_name"], idx, len(documents) 69 | ) 70 | upload_tasks.append( 71 | upload_document_to_openai_vector_store(document, store_id) 72 | ) 73 | 74 | await asyncio.gather(*upload_tasks) 75 | tracker.update(40, "Documents uploaded") 76 | 77 | # keep polling until the vector store is ready 78 | is_ready = False 79 | while not is_ready: 80 | store = await client.vector_stores.files.list(vector_store_id=store_id) 81 | total_completed = sum( 82 | 1 for file in store.data if file.status == "completed" 83 | ) 84 | is_ready = total_completed == len(documents) 85 | 86 | # Update progress based on indexing status 87 | progress = 40 + (total_completed / len(documents) * 40) # 40-80% range 88 | tracker.update( 89 | progress, f"Indexing {total_completed}/{len(documents)} files" 90 | ) 91 | subtask_logger.log_vector_store_status(total_completed, len(documents)) 92 | 93 | if not is_ready: 94 | await asyncio.sleep(1) 95 | 96 | # get the answer 97 | tracker.update(80, "Generating citations") 98 | subtask_logger.log_subtask("Querying with file search", "processing") 99 | 100 | response = await client.responses.create( 101 | model=model, 102 | input=question, 103 | tools=[ 104 | { 105 | "type": "file_search", 106 | "vector_store_ids": [store_id], 107 | } 108 | ], 109 | tool_choice="required", 110 | instructions=instructions, 111 | max_output_tokens=max_tokens, 112 | ) 113 | 114 | # convert the response to a list of blocks 115 | # similar to a subset of the Anthropic citations API 116 | blocks = [] 117 | for part in response.output: 118 | if part.type == "message": 119 | contents = part.content 120 | for item in contents: 121 | if item.type == "output_text": 122 | blocks.append( 123 | { 124 | "text": item.text, 125 | "type": "text", 126 | "citations": [ 127 | {"document_title": i.filename} 128 | for i in item.annotations 129 | ], 130 | } 131 | ) 132 | tracker.update(95, "Processing results") 133 | subtask_logger.log_result_summary( 134 | "Citations", 135 | {"blocks_generated": len(blocks), "documents_processed": len(documents)}, 136 | ) 137 | 138 | return blocks 139 | 140 | elif provider in [LLMProvider.ANTHROPIC, LLMProvider.ANTHROPIC.value]: 141 | from anthropic import AsyncAnthropic 142 | 143 | client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 144 | 145 | document_contents = [] 146 | for document in documents: 147 | document_contents.append( 148 | { 149 | "type": "document", 150 | "source": { 151 | "type": "text", 152 | "media_type": "text/plain", 153 | "data": document["document_content"], 154 | }, 155 | "title": document["document_name"], 156 | "citations": {"enabled": True}, 157 | } 158 | ) 159 | 160 | # Create content messages with citations enabled for individual tool calls 161 | tracker.update(50, "Preparing document contents") 162 | subtask_logger.log_subtask( 163 | f"Processing {len(documents)} documents for Anthropic", "processing" 164 | ) 165 | 166 | messages = [ 167 | { 168 | "role": "user", 169 | "content": [ 170 | {"type": "text", "text": question}, 171 | # Add all individual document contents 172 | *document_contents, 173 | ], 174 | } 175 | ] 176 | 177 | tracker.update(70, "Generating citations") 178 | subtask_logger.log_subtask("Calling Anthropic API with citations", "processing") 179 | 180 | response = await client.messages.create( 181 | model=model, 182 | messages=messages, 183 | system=instructions, 184 | max_tokens=max_tokens, 185 | ) 186 | 187 | tracker.update(90, "Processing results") 188 | response_with_citations = [item.to_dict() for item in response.content] 189 | 190 | subtask_logger.log_result_summary( 191 | "Citations", 192 | { 193 | "content_blocks": len(response_with_citations), 194 | "documents_processed": len(documents), 195 | }, 196 | ) 197 | 198 | return response_with_citations 199 | 200 | # This else is outside the context manager, move it inside 201 | raise ValueError(f"Provider {provider} not supported for citations tool") 202 | -------------------------------------------------------------------------------- /defog/llm/code_interp.py: -------------------------------------------------------------------------------- 1 | from defog.llm.llm_providers import LLMProvider 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | from io import BytesIO 5 | 6 | 7 | async def code_interpreter_tool( 8 | question: str, 9 | model: str, 10 | provider: LLMProvider, 11 | csv_string: str = "", 12 | instructions: str = "You are a Python programmer. You are given a question and a CSV string of data. You need to answer the question using the data. You are also given a sandboxed server environment where you can run the code.", 13 | ): 14 | """ 15 | Creates a python script to answer the question, where the python script is executed in a sandboxed server environment. 16 | """ 17 | async with ToolProgressTracker( 18 | "Code Interpreter", 19 | f"Executing code to answer: {question[:50]}{'...' if len(question) > 50 else ''}", 20 | ) as tracker: 21 | subtask_logger = SubTaskLogger() 22 | subtask_logger.log_provider_info( 23 | provider.value if hasattr(provider, "value") else str(provider), model 24 | ) 25 | 26 | # create a csv file from the csv_string 27 | tracker.update(10, "Preparing data file") 28 | subtask_logger.log_subtask("Creating CSV file from string", "processing") 29 | csv_file = BytesIO(csv_string.encode("utf-8")) 30 | csv_file.name = "data.csv" 31 | 32 | if provider in [LLMProvider.OPENAI, LLMProvider.OPENAI.value]: 33 | from openai import AsyncOpenAI 34 | from openai.types.responses import ( 35 | ResponseCodeInterpreterToolCall, 36 | ResponseOutputMessage, 37 | ) 38 | 39 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 40 | 41 | tracker.update(20, "Uploading data file") 42 | subtask_logger.log_subtask("Uploading CSV to OpenAI", "processing") 43 | file = await client.files.create( 44 | file=csv_file, 45 | purpose="user_data", 46 | ) 47 | 48 | tracker.update(40, "Running code interpreter") 49 | subtask_logger.log_code_execution("python") 50 | response = await client.responses.create( 51 | model=model, 52 | tools=[ 53 | { 54 | "type": "code_interpreter", 55 | "container": {"type": "auto", "file_ids": [file.id]}, 56 | } 57 | ], 58 | tool_choice="required", 59 | input=[ 60 | { 61 | "role": "user", 62 | "content": [{"type": "input_text", "text": question}], 63 | } 64 | ], 65 | instructions=instructions, 66 | ) 67 | tracker.update(80, "Processing results") 68 | subtask_logger.log_subtask("Extracting code and output", "processing") 69 | 70 | code = "" 71 | output_text = "" 72 | 73 | for chunk in response.output: 74 | if isinstance(chunk, ResponseCodeInterpreterToolCall): 75 | code += chunk.code 76 | elif isinstance(chunk, ResponseOutputMessage): 77 | for content in chunk.content: 78 | output_text += content.text 79 | 80 | subtask_logger.log_result_summary( 81 | "Code Execution", 82 | {"code_length": len(code), "output_length": len(output_text)}, 83 | ) 84 | 85 | return {"code": code, "output": output_text} 86 | elif provider in [LLMProvider.ANTHROPIC, LLMProvider.ANTHROPIC.value]: 87 | from anthropic import AsyncAnthropic 88 | 89 | client = AsyncAnthropic( 90 | api_key=os.getenv("ANTHROPIC_API_KEY"), 91 | default_headers={"anthropic-beta": "code-execution-2025-05-22"}, 92 | ) 93 | 94 | tracker.update(20, "Uploading data file") 95 | subtask_logger.log_subtask("Uploading CSV to Anthropic", "processing") 96 | file_object = await client.beta.files.upload( 97 | file=csv_file, 98 | ) 99 | 100 | tracker.update(40, "Running code execution") 101 | subtask_logger.log_code_execution("python") 102 | response = await client.messages.create( 103 | model=model, 104 | max_tokens=8192, 105 | messages=[ 106 | { 107 | "role": "user", 108 | "content": [ 109 | { 110 | "type": "text", 111 | "text": instructions 112 | + "\n\nThe question you must answer is: " 113 | + question, 114 | }, 115 | {"type": "container_upload", "file_id": file_object.id}, 116 | ], 117 | } 118 | ], 119 | tools=[{"type": "code_execution_20250522", "name": "code_execution"}], 120 | tool_choice={"type": "any"}, 121 | ) 122 | tracker.update(80, "Processing results") 123 | subtask_logger.log_subtask("Extracting code and output", "processing") 124 | 125 | code = "" 126 | output_text = "" 127 | for chunk in response.content: 128 | if chunk.type == "server_tool_use": 129 | code += chunk.input["code"] 130 | elif chunk.type == "text": 131 | output_text += chunk.text 132 | 133 | subtask_logger.log_result_summary( 134 | "Code Execution", 135 | {"code_length": len(code), "output_length": len(output_text)}, 136 | ) 137 | 138 | return {"code": code, "output": output_text} 139 | elif provider in [LLMProvider.GEMINI, LLMProvider.GEMINI.value]: 140 | from google import genai 141 | from google.genai import types 142 | 143 | client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) 144 | 145 | tracker.update(20, "Uploading data file") 146 | subtask_logger.log_subtask("Uploading CSV to Gemini", "processing") 147 | file_csv = await client.aio.files.upload( 148 | file=csv_file, 149 | config=types.UploadFileConfig( 150 | mime_type="text/csv", 151 | display_name="data.csv", 152 | ), 153 | ) 154 | 155 | tracker.update(40, "Running code execution") 156 | subtask_logger.log_code_execution("python") 157 | response = await client.aio.models.generate_content( 158 | model=model, 159 | contents=[file_csv, question], 160 | config=types.GenerateContentConfig( 161 | tools=[types.Tool(code_execution=types.ToolCodeExecution())], 162 | ), 163 | ) 164 | 165 | tracker.update(80, "Processing results") 166 | subtask_logger.log_subtask("Extracting code and output", "processing") 167 | 168 | parts = response.candidates[0].content.parts 169 | 170 | code = "" 171 | output_text = "" 172 | 173 | for part in parts: 174 | if ( 175 | hasattr(part, "executable_code") 176 | and part.executable_code is not None 177 | ): 178 | code += part.executable_code.code 179 | if hasattr(part, "text") and part.text is not None: 180 | output_text += part.text 181 | 182 | subtask_logger.log_result_summary( 183 | "Code Execution", 184 | {"code_length": len(code), "output_length": len(output_text)}, 185 | ) 186 | 187 | return {"code": code, "output": output_text} 188 | -------------------------------------------------------------------------------- /defog/llm/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .settings import LLMConfig 2 | from .constants import DEFAULT_TIMEOUT, MAX_RETRIES, DEFAULT_TEMPERATURE 3 | 4 | __all__ = ["LLMConfig", "DEFAULT_TIMEOUT", "MAX_RETRIES", "DEFAULT_TEMPERATURE"] 5 | -------------------------------------------------------------------------------- /defog/llm/config/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_TIMEOUT = 100 # seconds 2 | MAX_RETRIES = 3 3 | DEFAULT_TEMPERATURE = 0.0 4 | 5 | # Provider-specific constants 6 | DEEPSEEK_BASE_URL = "https://api.deepseek.com" 7 | OPENAI_BASE_URL = "https://api.openai.com/v1/" 8 | 9 | # Model families that require special handling 10 | O_MODELS = ["o1-mini", "o1-preview", "o1", "o3-mini", "o3", "o4-mini"] 11 | 12 | DEEPSEEK_MODELS = ["deepseek-chat", "deepseek-reasoner"] 13 | 14 | MODELS_WITHOUT_TEMPERATURE = ["deepseek-reasoner"] + [model for model in O_MODELS] 15 | 16 | MODELS_WITHOUT_RESPONSE_FORMAT = [ 17 | "o1-mini", 18 | "o1-preview", 19 | "deepseek-chat", 20 | "deepseek-reasoner", 21 | ] 22 | 23 | MODELS_WITHOUT_TOOLS = ["o1-mini", "o1-preview", "deepseek-chat", "deepseek-reasoner"] 24 | 25 | MODELS_WITH_PARALLEL_TOOL_CALLS = ["gpt-4o", "gpt-4o-mini"] 26 | 27 | MODELS_WITH_PREDICTION = ["gpt-4o", "gpt-4o-mini"] 28 | -------------------------------------------------------------------------------- /defog/llm/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Dict, Any 3 | from .constants import ( 4 | DEFAULT_TIMEOUT, 5 | MAX_RETRIES, 6 | DEFAULT_TEMPERATURE, 7 | DEEPSEEK_BASE_URL, 8 | OPENAI_BASE_URL, 9 | ) 10 | 11 | 12 | class LLMConfig: 13 | """Configuration management for LLM providers.""" 14 | 15 | def __init__( 16 | self, 17 | timeout: int = DEFAULT_TIMEOUT, 18 | max_retries: int = MAX_RETRIES, 19 | default_temperature: float = DEFAULT_TEMPERATURE, 20 | api_keys: Optional[Dict[str, str]] = None, 21 | base_urls: Optional[Dict[str, str]] = None, 22 | enable_parallel_tool_calls: bool = True, 23 | ): 24 | self.timeout = timeout 25 | self.max_retries = max_retries 26 | self.default_temperature = default_temperature 27 | self.enable_parallel_tool_calls = enable_parallel_tool_calls 28 | 29 | # API keys with environment fallbacks 30 | self.api_keys = api_keys or {} 31 | self._setup_api_keys() 32 | 33 | # Base URLs with defaults 34 | self.base_urls = base_urls or {} 35 | self._setup_base_urls() 36 | 37 | def _setup_api_keys(self): 38 | """Setup API keys with environment variable fallbacks.""" 39 | key_mappings = { 40 | "openai": "OPENAI_API_KEY", 41 | "anthropic": "ANTHROPIC_API_KEY", 42 | "gemini": "GEMINI_API_KEY", 43 | "deepseek": "DEEPSEEK_API_KEY", 44 | "together": "TOGETHER_API_KEY", 45 | } 46 | 47 | for provider, env_var in key_mappings.items(): 48 | if provider not in self.api_keys: 49 | self.api_keys[provider] = os.getenv(env_var) 50 | 51 | def _setup_base_urls(self): 52 | """Setup base URLs with defaults.""" 53 | default_urls = { 54 | "openai": OPENAI_BASE_URL, 55 | "deepseek": DEEPSEEK_BASE_URL, 56 | } 57 | 58 | for provider, url in default_urls.items(): 59 | if provider not in self.base_urls: 60 | self.base_urls[provider] = url 61 | 62 | def get_api_key(self, provider: str) -> Optional[str]: 63 | """Get API key for a provider.""" 64 | return self.api_keys.get(provider) 65 | 66 | def get_base_url(self, provider: str) -> Optional[str]: 67 | """Get base URL for a provider.""" 68 | return self.base_urls.get(provider) 69 | 70 | def validate_provider_config(self, provider: str) -> bool: 71 | """Validate that a provider has the required configuration.""" 72 | api_key = self.get_api_key(provider) 73 | return api_key is not None and api_key != "" 74 | 75 | def update_config(self, **kwargs): 76 | """Update configuration values.""" 77 | for key, value in kwargs.items(): 78 | if hasattr(self, key): 79 | setattr(self, key, value) 80 | -------------------------------------------------------------------------------- /defog/llm/cost/__init__.py: -------------------------------------------------------------------------------- 1 | from .calculator import CostCalculator 2 | from .models import MODEL_COSTS 3 | 4 | __all__ = ["CostCalculator", "MODEL_COSTS"] 5 | -------------------------------------------------------------------------------- /defog/llm/cost/calculator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from .models import MODEL_COSTS 3 | 4 | 5 | class CostCalculator: 6 | """Handles cost calculation for LLM usage.""" 7 | 8 | @staticmethod 9 | def calculate_cost( 10 | model: str, 11 | input_tokens: int, 12 | output_tokens: int, 13 | cached_input_tokens: Optional[int] = None, 14 | ) -> Optional[float]: 15 | """ 16 | Calculate cost in cents for the given token usage. 17 | 18 | Returns: 19 | Cost in cents, or None if model pricing is not available 20 | """ 21 | # Find exact match first 22 | if model in MODEL_COSTS: 23 | model_name = model 24 | else: 25 | # Attempt partial matches if no exact match 26 | potential_model_names = [] 27 | for mname in MODEL_COSTS.keys(): 28 | if mname in model: 29 | potential_model_names.append(mname) 30 | 31 | if not potential_model_names: 32 | return None 33 | 34 | # Use the longest match 35 | model_name = max(potential_model_names, key=len) 36 | 37 | costs = MODEL_COSTS[model_name] 38 | 39 | # Calculate base cost 40 | cost_in_cents = ( 41 | input_tokens / 1000 * costs["input_cost_per1k"] 42 | + output_tokens / 1000 * costs["output_cost_per1k"] 43 | ) * 100 44 | 45 | # Add cached input cost if available 46 | if cached_input_tokens and "cached_input_cost_per1k" in costs: 47 | cost_in_cents += ( 48 | cached_input_tokens / 1000 * costs["cached_input_cost_per1k"] 49 | ) * 100 50 | 51 | return cost_in_cents 52 | 53 | @staticmethod 54 | def is_model_supported(model: str) -> bool: 55 | """Check if cost calculation is supported for the given model.""" 56 | if model in MODEL_COSTS: 57 | return True 58 | 59 | # Check for partial matches 60 | return any(mname in model for mname in MODEL_COSTS.keys()) 61 | -------------------------------------------------------------------------------- /defog/llm/cost/models.py: -------------------------------------------------------------------------------- 1 | MODEL_COSTS = { 2 | "chatgpt-4o": {"input_cost_per1k": 0.0025, "output_cost_per1k": 0.01}, 3 | "gpt-4o": { 4 | "input_cost_per1k": 0.0025, 5 | "cached_input_cost_per1k": 0.00125, 6 | "output_cost_per1k": 0.01, 7 | }, 8 | "gpt-4o-mini": { 9 | "input_cost_per1k": 0.00015, 10 | "cached_input_cost_per1k": 0.000075, 11 | "output_cost_per1k": 0.0006, 12 | }, 13 | "gpt-4.1": { 14 | "input_cost_per1k": 0.002, 15 | "cached_input_cost_per1k": 0.0005, 16 | "output_cost_per1k": 0.008, 17 | }, 18 | "gpt-4.1-mini": { 19 | "input_cost_per1k": 0.0004, 20 | "cached_input_cost_per1k": 0.0001, 21 | "output_cost_per1k": 0.0016, 22 | }, 23 | "gpt-4.1-nano": { 24 | "input_cost_per1k": 0.0001, 25 | "cached_input_cost_per1k": 0.000025, 26 | "output_cost_per1k": 0.0004, 27 | }, 28 | "o1": { 29 | "input_cost_per1k": 0.015, 30 | "cached_input_cost_per1k": 0.0075, 31 | "output_cost_per1k": 0.06, 32 | }, 33 | "o1-preview": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.06}, 34 | "o1-mini": { 35 | "input_cost_per1k": 0.003, 36 | "cached_input_cost_per1k": 0.00055, 37 | "output_cost_per1k": 0.012, 38 | }, 39 | "o3-mini": { 40 | "input_cost_per1k": 0.0011, 41 | "cached_input_cost_per1k": 0.00055, 42 | "output_cost_per1k": 0.0044, 43 | }, 44 | "o3": { 45 | "input_cost_per1k": 0.01, 46 | "cached_input_cost_per1k": 0.0025, 47 | "output_cost_per1k": 0.04, 48 | }, 49 | "o4-mini": { 50 | "input_cost_per1k": 0.0011, 51 | "cached_input_cost_per1k": 0.000275, 52 | "output_cost_per1k": 0.0044, 53 | }, 54 | "gpt-4-turbo": {"input_cost_per1k": 0.01, "output_cost_per1k": 0.03}, 55 | "gpt-3.5-turbo": {"input_cost_per1k": 0.0005, "output_cost_per1k": 0.0015}, 56 | "claude-3-5-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015}, 57 | "claude-3-5-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125}, 58 | "claude-3-opus": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.075}, 59 | "claude-3-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015}, 60 | "claude-3-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125}, 61 | "gemini-1.5-pro": {"input_cost_per1k": 0.00125, "output_cost_per1k": 0.005}, 62 | "gemini-1.5-flash": {"input_cost_per1k": 0.000075, "output_cost_per1k": 0.0003}, 63 | "gemini-1.5-flash-8b": { 64 | "input_cost_per1k": 0.0000375, 65 | "output_cost_per1k": 0.00015, 66 | }, 67 | "gemini-2.0-flash": { 68 | "input_cost_per1k": 0.00010, 69 | "output_cost_per1k": 0.0004, 70 | }, 71 | "gemini-2.5-flash": { 72 | "input_cost_per1k": 0.00015, 73 | "output_cost_per1k": 0.0035, 74 | }, 75 | "gemini-2.5-pro": { 76 | "input_cost_per1k": 0.00125, 77 | "output_cost_per1k": 0.01, 78 | }, 79 | "deepseek-chat": { 80 | "input_cost_per1k": 0.00027, 81 | "cached_input_cost_per1k": 0.00007, 82 | "output_cost_per1k": 0.0011, 83 | }, 84 | "deepseek-reasoner": { 85 | "input_cost_per1k": 0.00055, 86 | "cached_input_cost_per1k": 0.00014, 87 | "output_cost_per1k": 0.00219, 88 | }, 89 | } 90 | -------------------------------------------------------------------------------- /defog/llm/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_exceptions import ( 2 | LLMError, 3 | ProviderError, 4 | ToolError, 5 | MaxTokensError, 6 | ConfigurationError, 7 | AuthenticationError, 8 | APIError, 9 | ) 10 | 11 | __all__ = [ 12 | "LLMError", 13 | "ProviderError", 14 | "ToolError", 15 | "MaxTokensError", 16 | "ConfigurationError", 17 | "AuthenticationError", 18 | "APIError", 19 | ] 20 | -------------------------------------------------------------------------------- /defog/llm/exceptions/llm_exceptions.py: -------------------------------------------------------------------------------- 1 | class LLMError(Exception): 2 | """Base exception for all LLM-related errors.""" 3 | 4 | pass 5 | 6 | 7 | class ProviderError(LLMError): 8 | """Exception raised when there's an error with a specific LLM provider.""" 9 | 10 | def __init__(self, provider: str, message: str, original_error: Exception = None): 11 | self.provider = provider 12 | self.original_error = original_error 13 | super().__init__(f"Provider '{provider}': {message}") 14 | 15 | 16 | class ToolError(LLMError): 17 | """Exception raised when there's an error with tool calling.""" 18 | 19 | def __init__(self, tool_name: str, message: str, original_error: Exception = None): 20 | self.tool_name = tool_name 21 | self.original_error = original_error 22 | super().__init__(f"Tool '{tool_name}': {message}") 23 | 24 | 25 | class MaxTokensError(LLMError): 26 | """Exception raised when maximum tokens are reached.""" 27 | 28 | pass 29 | 30 | 31 | class ConfigurationError(LLMError): 32 | """Exception raised when there's a configuration error.""" 33 | 34 | pass 35 | 36 | 37 | class AuthenticationError(LLMError): 38 | """Exception raised when there's an authentication error with a provider.""" 39 | 40 | pass 41 | 42 | 43 | class APIError(LLMError): 44 | """Exception raised when there's a general API error.""" 45 | 46 | pass 47 | -------------------------------------------------------------------------------- /defog/llm/llm_providers.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LLMProvider(Enum): 5 | OPENAI = "openai" 6 | ANTHROPIC = "anthropic" 7 | GEMINI = "gemini" 8 | GROK = "grok" 9 | DEEPSEEK = "deepseek" 10 | TOGETHER = "together" 11 | -------------------------------------------------------------------------------- /defog/llm/mcp_readme.md: -------------------------------------------------------------------------------- 1 | # MCP Servers Guide 2 | 3 | This guide outlines the steps for setting up and using Model Context Protocol (MCP) servers, which allow external tools to be used by an LLM. 4 | 5 | ## 1. Set Up Your MCP Servers 6 | Set up your MCP servers by following the official [SDKs](https://github.com/modelcontextprotocol). 7 | If you're hosting them outside of your application, ensure that your servers operate with [SSE transport](https://modelcontextprotocol.io/docs/concepts/transports#server-sent-events-sse). 8 | 9 | ```python 10 | from mcp.server.fastmcp import FastMCP 11 | 12 | # Initialize FastMCP with an available port 13 | mcp = FastMCP( 14 | name="name_of_remote_server", 15 | port=3001, 16 | ) 17 | 18 | # Add your tools and prompts 19 | # @mcp.tool() 20 | # async def example_tool(): 21 | # ... 22 | 23 | # Start the server with SSE transport 24 | if __name__ == "__main__": 25 | mcp.run(transport="sse") 26 | ``` 27 | 28 | ## 2. Configure mcp_config.json 29 | 30 | After your servers are running, create `mcp_config.json` and add the servers to connect to them: 31 | 32 | ```json 33 | { 34 | "mcpServers": { 35 | "name_of_remote_server": { 36 | "command": "sse", 37 | "args": ["http://host.docker.internal:3001"] 38 | }, 39 | "name_of_local_server": { 40 | "command": "npx", 41 | "args": ["-y", "@modelcontextprotocol/package-name"] 42 | } 43 | } 44 | } 45 | ``` 46 | 47 | ### Configuration Options: 48 | 49 | 1. **Remote Servers (SSE):** 50 | - Use `"command": "sse"` 51 | - Use `"args": ["url_to_server"]` for remote servers 52 | - Use `"args": ["http://host.docker.internal:PORT/sse"]` if your application is running in a Docker container and the servers are running on your local machine 53 | - Make sure the port matches your server's configuration 54 | 55 | 2. **Local Servers in your application (stdio):** 56 | - Specify the command to run the server (e.g., `"npx"`, `"python"`) 57 | - Provide arguments in an array (e.g., `["-y", "package-name"]`) 58 | - Optionally provide environment variables with `"env": {}` 59 | 60 | ## 3. Using MCPClient 61 | 62 | Once your servers are configured, use the `MCPClient` class to interact with them: 63 | 64 | ### Step 1: Initialize and Connect 65 | 66 | ```python 67 | from defog.llm.utils_mcp import MCPClient 68 | import json 69 | 70 | # Load config 71 | with open("path/to/mcp_config.json", "r") as f: 72 | config = json.load(f) 73 | 74 | # Initialize client with your preferred model. Only Claude and OpenAI models are supported. 75 | mcp_client = MCPClient(model_name="claude-3-7-sonnet-20250219") 76 | 77 | # Connect to all servers defined in config 78 | await mcp_client.connect_to_server_from_config(config) 79 | ``` 80 | 81 | ### Step 2: Process Queries 82 | 83 | ```python 84 | # Send a query to be processed by the LLM with access to MCP tools 85 | # The response contains the final text output 86 | # tool_outputs is a list of all tool calls made during processing 87 | response, tool_outputs = await mcp_client.mcp_chat(query="What is the average sale in the northeast region?") 88 | ``` 89 | 90 | ### Step 3: Clean Up 91 | 92 | ```python 93 | # Always clean up when done to close connections 94 | await mcp_client.cleanup() 95 | ``` 96 | 97 | ## 4. Using Prompt Templates 98 | 99 | Prompt templates that are defined in the MCP servers can be invoked in queries using the `/command` syntax: 100 | 101 | ```python 102 | # This will apply the "gen_report" prompt template to "sales by region" 103 | response, _ = await mcp_client.mcp_chat(query="/gen_report sales by region") 104 | ``` 105 | If no further text is provided after the command, the prompt template will be applied to the output of the previous message in the message history. 106 | 107 | ## 5. Troubleshooting 108 | 109 | If you encounter the "unhandled errors in a TaskGroup" error: 110 | 111 | 1. **Check Server Status**: Make sure your server is running on the specified port 112 | 2. **Verify Configuration**: Ensure port numbers match in both server and config 113 | 3. **Host Binding**: Server must use `host="0.0.0.0"` (not `127.0.0.1`) 114 | 4. **Network Access**: For Docker, ensure host.docker.internal resolves correctly 115 | 5. **Port Exposure**: Make sure the port is accessible from the Docker container 116 | 117 | If using stdio servers, check for syntax errors in your command or arguments. 118 | 119 | ## 6. MCPClient Reference 120 | 121 | Key methods in `MCPClient`: 122 | 123 | - `connect_to_server_from_config(config)`: Connect to all servers in config json file 124 | - `mcp_chat(query)`: Process a query using LLM and available tools 125 | - Calls tools in a loop until no more tools are called or max tokens are reached 126 | - Stores messages in message history so the LLM can use it as context for following queries 127 | - `call_tool(tool_name, tool_args)`: Directly call a specific tool 128 | - `get_prompt(prompt_name, args)`: Retrieve a specific prompt template 129 | - `cleanup()`: Close all server connections -------------------------------------------------------------------------------- /defog/llm/memory/__init__.py: -------------------------------------------------------------------------------- 1 | """Memory management utilities for LLM conversations.""" 2 | 3 | from .history_manager import MemoryManager, ConversationHistory 4 | from .compactifier import compactify_messages 5 | from .token_counter import TokenCounter 6 | 7 | __all__ = [ 8 | "MemoryManager", 9 | "ConversationHistory", 10 | "compactify_messages", 11 | "TokenCounter", 12 | ] 13 | -------------------------------------------------------------------------------- /defog/llm/memory/compactifier.py: -------------------------------------------------------------------------------- 1 | """Message compactification utilities for conversation memory management.""" 2 | 3 | from typing import List, Dict, Any, Optional, Tuple 4 | from ..utils import chat_async 5 | from .token_counter import TokenCounter 6 | 7 | 8 | async def compactify_messages( 9 | system_messages: List[Dict[str, Any]], 10 | messages_to_summarize: List[Dict[str, Any]], 11 | preserved_messages: List[Dict[str, Any]], 12 | provider: str, 13 | model: str, 14 | max_summary_tokens: int = 2000, 15 | **kwargs, 16 | ) -> Tuple[List[Dict[str, Any]], int]: 17 | """ 18 | Compactify a conversation by summarizing older messages while preserving system messages. 19 | 20 | Args: 21 | system_messages: System messages to preserve at the beginning 22 | messages_to_summarize: Messages to be summarized 23 | preserved_messages: Recent messages to keep as-is 24 | provider: LLM provider to use for summarization 25 | model: Model to use for summarization 26 | max_summary_tokens: Maximum tokens for the summary 27 | **kwargs: Additional arguments for the LLM call 28 | 29 | Returns: 30 | Tuple of (new_messages, total_token_count) 31 | """ 32 | if not messages_to_summarize: 33 | # Nothing to summarize, return system + preserved messages 34 | all_messages = system_messages + preserved_messages 35 | token_counter = TokenCounter() 36 | total_tokens = token_counter.count_tokens(all_messages, model, provider) 37 | return all_messages, total_tokens 38 | 39 | # Create a summary prompt 40 | summary_prompt = _create_summary_prompt(messages_to_summarize, max_summary_tokens) 41 | 42 | # Generate summary using the same provider/model 43 | summary_response = await chat_async( 44 | provider=provider, 45 | model=model, 46 | messages=[{"role": "user", "content": summary_prompt}], 47 | temperature=0.3, # Lower temperature for more consistent summaries 48 | **kwargs, 49 | ) 50 | 51 | # Create summary message 52 | summary_message = { 53 | "role": "user", # Summary as user context 54 | "content": f"[Previous conversation summary]\n{summary_response.content}", 55 | } 56 | 57 | # Combine: system messages + summary + preserved messages 58 | new_messages = system_messages + [summary_message] + preserved_messages 59 | 60 | # Calculate new token count 61 | token_counter = TokenCounter() 62 | total_tokens = token_counter.count_tokens(new_messages, model, provider) 63 | 64 | return new_messages, total_tokens 65 | 66 | 67 | def _create_summary_prompt(messages: List[Dict[str, Any]], max_tokens: int) -> str: 68 | """Create a prompt for summarizing conversation history.""" 69 | 70 | # Format messages for summary 71 | formatted_messages = [] 72 | for msg in messages: 73 | role = msg.get("role", "unknown") 74 | content = msg.get("content", "") 75 | 76 | # Handle different content types 77 | if isinstance(content, dict): 78 | # Tool calls or structured content 79 | content = f"[Structured content: {type(content).__name__}]" 80 | elif isinstance(content, list): 81 | # Multiple content items 82 | content = f"[Multiple content items: {len(content)}]" 83 | 84 | formatted_messages.append(f"{role.upper()}: {content}") 85 | 86 | conversation_text = "\n\n".join(formatted_messages) 87 | 88 | prompt = f"""Please provide a concise summary of the following conversation. 89 | Focus on: 90 | 1. Key topics discussed 91 | 2. Important decisions or conclusions reached 92 | 3. Any unresolved questions or ongoing tasks 93 | 4. Critical context that should be preserved 94 | 95 | Keep the summary under {max_tokens // 4} words (approximately {max_tokens} tokens). 96 | 97 | Conversation: 98 | {conversation_text} 99 | 100 | Summary:""" 101 | 102 | return prompt 103 | 104 | 105 | async def smart_compactify( 106 | memory_manager, provider: str, model: str, **kwargs 107 | ) -> Tuple[List[Dict[str, Any]], int]: 108 | """ 109 | Intelligently compactify messages using the memory manager. 110 | 111 | This is a convenience function that works with a MemoryManager instance. 112 | 113 | Args: 114 | memory_manager: MemoryManager instance 115 | provider: LLM provider 116 | model: Model name 117 | **kwargs: Additional LLM arguments 118 | 119 | Returns: 120 | Tuple of (new_messages, total_token_count) 121 | """ 122 | if not memory_manager.should_compactify(): 123 | return ( 124 | memory_manager.get_current_messages(), 125 | memory_manager.history.total_tokens, 126 | ) 127 | 128 | # Get messages to summarize and preserve 129 | system_messages, messages_to_summarize, preserved_messages = ( 130 | memory_manager.get_messages_for_compactification() 131 | ) 132 | 133 | # Perform compactification 134 | new_messages, new_token_count = await compactify_messages( 135 | system_messages=system_messages, 136 | messages_to_summarize=messages_to_summarize, 137 | preserved_messages=preserved_messages, 138 | provider=provider, 139 | model=model, 140 | max_summary_tokens=memory_manager.summary_max_tokens, 141 | **kwargs, 142 | ) 143 | 144 | # Update memory manager 145 | if new_messages and len(new_messages) > len(system_messages): 146 | # Find the summary message (first non-system message in the result) 147 | summary_idx = len(system_messages) 148 | summary_message = new_messages[summary_idx] 149 | memory_manager.update_after_compactification( 150 | system_messages=system_messages, 151 | summary_message=summary_message, 152 | preserved_messages=preserved_messages, 153 | new_token_count=new_token_count, 154 | ) 155 | 156 | return new_messages, new_token_count 157 | -------------------------------------------------------------------------------- /defog/llm/memory/history_manager.py: -------------------------------------------------------------------------------- 1 | """Manages conversation history and memory for LLM interactions.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import List, Dict, Optional, Any 5 | from datetime import datetime 6 | import copy 7 | 8 | 9 | @dataclass 10 | class ConversationHistory: 11 | """Container for conversation messages with metadata.""" 12 | 13 | messages: List[Dict[str, Any]] = field(default_factory=list) 14 | total_tokens: int = 0 15 | created_at: datetime = field(default_factory=datetime.now) 16 | last_compactified_at: Optional[datetime] = None 17 | compactification_count: int = 0 18 | 19 | def add_message(self, message: Dict[str, Any], tokens: int = 0) -> None: 20 | """Add a message to the history.""" 21 | self.messages.append(copy.deepcopy(message)) 22 | self.total_tokens += tokens 23 | 24 | def get_messages(self) -> List[Dict[str, Any]]: 25 | """Get a copy of all messages.""" 26 | return copy.deepcopy(self.messages) 27 | 28 | def clear(self) -> None: 29 | """Clear all messages and reset token count.""" 30 | self.messages.clear() 31 | self.total_tokens = 0 32 | 33 | def replace_messages( 34 | self, new_messages: List[Dict[str, Any]], new_token_count: int 35 | ) -> None: 36 | """Replace all messages with new ones.""" 37 | self.messages = copy.deepcopy(new_messages) 38 | self.total_tokens = new_token_count 39 | self.last_compactified_at = datetime.now() 40 | self.compactification_count += 1 41 | 42 | 43 | class MemoryManager: 44 | """Manages conversation memory with automatic compactification.""" 45 | 46 | def __init__( 47 | self, 48 | token_threshold: int = 50000, # Default to ~100k tokens before compactifying 49 | preserve_last_n_messages: int = 10, 50 | summary_max_tokens: int = 4000, 51 | enabled: bool = True, 52 | ): 53 | """ 54 | Initialize the memory manager. 55 | 56 | Args: 57 | token_threshold: Number of tokens before triggering compactification 58 | preserve_last_n_messages: Number of recent messages to always preserve 59 | summary_max_tokens: Maximum tokens for the summary 60 | enabled: Whether memory management is enabled 61 | """ 62 | self.token_threshold = token_threshold 63 | self.preserve_last_n_messages = preserve_last_n_messages 64 | self.summary_max_tokens = summary_max_tokens 65 | self.enabled = enabled 66 | self.history = ConversationHistory() 67 | 68 | def should_compactify(self) -> bool: 69 | """Check if memory should be compactified based on token count.""" 70 | if not self.enabled: 71 | return False 72 | return self.history.total_tokens >= self.token_threshold 73 | 74 | def add_messages(self, messages: List[Dict[str, Any]], tokens: int) -> None: 75 | """Add multiple messages to history.""" 76 | for message in messages: 77 | self.history.add_message( 78 | message, tokens // len(messages) if messages else 0 79 | ) 80 | 81 | def get_messages_for_compactification( 82 | self, 83 | ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: 84 | """ 85 | Split messages into system messages, messages to summarize, and messages to preserve. 86 | 87 | System messages are always preserved at the beginning. 88 | Only user/assistant messages are eligible for summarization. 89 | 90 | Returns: 91 | Tuple of (system_messages, messages_to_summarize, messages_to_preserve) 92 | """ 93 | all_messages = self.history.get_messages() 94 | 95 | # Separate system messages from others 96 | system_messages = [] 97 | other_messages = [] 98 | 99 | for msg in all_messages: 100 | if msg.get("role") == "system": 101 | system_messages.append(msg) 102 | else: 103 | other_messages.append(msg) 104 | 105 | # Apply preservation logic only to non-system messages 106 | if len(other_messages) <= self.preserve_last_n_messages: 107 | return system_messages, [], other_messages 108 | 109 | split_index = len(other_messages) - self.preserve_last_n_messages 110 | return ( 111 | system_messages, 112 | other_messages[:split_index], 113 | other_messages[split_index:], 114 | ) 115 | 116 | def update_after_compactification( 117 | self, 118 | system_messages: List[Dict[str, Any]], 119 | summary_message: Dict[str, Any], 120 | preserved_messages: List[Dict[str, Any]], 121 | new_token_count: int, 122 | ) -> None: 123 | """Update history after compactification.""" 124 | # Combine: system messages + summary + preserved messages 125 | new_messages = system_messages + [summary_message] + preserved_messages 126 | self.history.replace_messages(new_messages, new_token_count) 127 | 128 | def get_current_messages(self) -> List[Dict[str, Any]]: 129 | """Get current conversation messages.""" 130 | return self.history.get_messages() 131 | 132 | def get_stats(self) -> Dict[str, Any]: 133 | """Get memory statistics.""" 134 | return { 135 | "total_tokens": self.history.total_tokens, 136 | "message_count": len(self.history.messages), 137 | "compactification_count": self.history.compactification_count, 138 | "last_compactified_at": ( 139 | self.history.last_compactified_at.isoformat() 140 | if self.history.last_compactified_at 141 | else None 142 | ), 143 | "enabled": self.enabled, 144 | "token_threshold": self.token_threshold, 145 | } 146 | -------------------------------------------------------------------------------- /defog/llm/memory/token_counter.py: -------------------------------------------------------------------------------- 1 | """Token counting utilities for different LLM providers.""" 2 | 3 | from typing import List, Dict, Any, Optional, Union 4 | import tiktoken 5 | from functools import lru_cache 6 | import json 7 | 8 | 9 | class TokenCounter: 10 | """ 11 | Accurate token counting for different LLM providers. 12 | 13 | Uses: 14 | - tiktoken for OpenAI models (and as fallback for all other providers) 15 | - API endpoints for Anthropic/Gemini when client is provided 16 | """ 17 | 18 | def __init__(self): 19 | self._encoding_cache = {} 20 | 21 | @lru_cache(maxsize=10) 22 | def _get_openai_encoding(self, model: str): 23 | """Get and cache tiktoken encoding for OpenAI models.""" 24 | try: 25 | # Try to get encoding for specific model 26 | return tiktoken.encoding_for_model(model) 27 | except KeyError: 28 | # Default encodings for different model families 29 | if "gpt-4o" in model: 30 | return tiktoken.get_encoding("o200k_base") 31 | else: 32 | # Default for gpt-4, gpt-3.5-turbo, and others 33 | return tiktoken.get_encoding("cl100k_base") 34 | 35 | def count_openai_tokens( 36 | self, messages: Union[str, List[Dict[str, Any]]], model: str = "gpt-4" 37 | ) -> int: 38 | """ 39 | Count tokens for OpenAI models using tiktoken. 40 | 41 | Args: 42 | messages: Either a string or list of message dicts 43 | model: OpenAI model name 44 | 45 | Returns: 46 | Token count 47 | """ 48 | encoding = self._get_openai_encoding(model) 49 | 50 | if isinstance(messages, str): 51 | return len(encoding.encode(messages)) 52 | 53 | # Handle message list format 54 | # Based on OpenAI's token counting guide 55 | tokens_per_message = ( 56 | 3 # Every message follows <|im_start|>{role}\n{content}<|im_end|>\n 57 | ) 58 | tokens_per_name = 1 # If there's a name, the role is omitted 59 | 60 | total_tokens = 0 61 | for message in messages: 62 | total_tokens += tokens_per_message 63 | 64 | for key, value in message.items(): 65 | if key == "role": 66 | total_tokens += len(encoding.encode(value)) 67 | elif key == "content": 68 | if isinstance(value, str): 69 | total_tokens += len(encoding.encode(value)) 70 | else: 71 | # Handle tool calls or other structured content 72 | total_tokens += len(encoding.encode(json.dumps(value))) 73 | elif key == "name": 74 | total_tokens += tokens_per_name 75 | total_tokens += len(encoding.encode(value)) 76 | 77 | total_tokens += 3 # Every reply is primed with <|im_start|>assistant<|im_sep|> 78 | return total_tokens 79 | 80 | async def count_anthropic_tokens( 81 | self, messages: List[Dict[str, Any]], model: str, client: Optional[Any] = None 82 | ) -> int: 83 | """ 84 | Count tokens for Anthropic models using their API. 85 | 86 | Args: 87 | messages: List of message dicts 88 | model: Anthropic model name 89 | client: Optional Anthropic client instance 90 | 91 | Returns: 92 | Token count 93 | """ 94 | if client is None: 95 | # Use OpenAI tokenizer as approximation 96 | return self.count_openai_tokens(messages, "gpt-4") 97 | 98 | try: 99 | # Use Anthropic's token counting endpoint 100 | response = await client.messages.count_tokens( 101 | model=model, messages=messages 102 | ) 103 | return response.input_tokens 104 | except Exception: 105 | # Fallback to OpenAI tokenizer 106 | return self.count_openai_tokens(messages, "gpt-4") 107 | 108 | def count_gemini_tokens( 109 | self, 110 | content: Union[str, List[Dict[str, Any]]], 111 | model: str, 112 | client: Optional[Any] = None, 113 | ) -> int: 114 | """ 115 | Count tokens for Gemini models. 116 | 117 | Args: 118 | content: Text or message list 119 | model: Gemini model name 120 | client: Optional Gemini client instance 121 | 122 | Returns: 123 | Token count 124 | """ 125 | if client is None: 126 | # Use OpenAI tokenizer as approximation 127 | return self.count_openai_tokens(content, "gpt-4") 128 | 129 | try: 130 | # Extract text content 131 | text = self._extract_text(content) 132 | 133 | # Use Gemini's count_tokens method 134 | response = client.count_tokens(text) 135 | return response.total_tokens 136 | except Exception: 137 | # Fallback to OpenAI tokenizer 138 | return self.count_openai_tokens(content, "gpt-4") 139 | 140 | def count_together_tokens( 141 | self, messages: Union[str, List[Dict[str, Any]]], model: str 142 | ) -> int: 143 | """ 144 | Count tokens for Together models using OpenAI tokenizer as approximation. 145 | 146 | Args: 147 | messages: Text or message list 148 | model: Together model name 149 | 150 | Returns: 151 | Estimated token count 152 | """ 153 | # Use OpenAI tokenizer as approximation 154 | return self.count_openai_tokens(messages, "gpt-4") 155 | 156 | def count_tokens( 157 | self, 158 | messages: Union[str, List[Dict[str, Any]]], 159 | model: str, 160 | provider: str, 161 | client: Optional[Any] = None, 162 | ) -> int: 163 | """ 164 | Universal token counting method. 165 | 166 | Args: 167 | messages: Text or message list 168 | model: Model name 169 | provider: Provider name (openai, anthropic, gemini, together) 170 | client: Optional provider client for API-based counting 171 | 172 | Returns: 173 | Token count 174 | """ 175 | provider_lower = provider.lower() 176 | 177 | if provider_lower == "openai": 178 | return self.count_openai_tokens(messages, model) 179 | elif provider_lower == "anthropic": 180 | # Anthropic count_tokens is async, so for sync context use OpenAI approximation 181 | if ( 182 | client 183 | and hasattr(client, "messages") 184 | and hasattr(client.messages, "count_tokens") 185 | ): 186 | # This would need to be called in an async context 187 | return self.count_openai_tokens(messages, "gpt-4") 188 | return self.count_openai_tokens(messages, "gpt-4") 189 | elif provider_lower == "gemini": 190 | return self.count_gemini_tokens(messages, model, client) 191 | elif provider_lower == "together": 192 | return self.count_together_tokens(messages, model) 193 | else: 194 | # Default to OpenAI tokenizer 195 | return self.count_openai_tokens(messages, "gpt-4") 196 | 197 | def _extract_text(self, content: Union[str, List[Dict[str, Any]]]) -> str: 198 | """Extract text from various content formats.""" 199 | if isinstance(content, str): 200 | return content 201 | 202 | if isinstance(content, list): 203 | texts = [] 204 | for item in content: 205 | if isinstance(item, dict): 206 | if "content" in item: 207 | texts.append(str(item["content"])) 208 | elif "text" in item: 209 | texts.append(str(item["text"])) 210 | else: 211 | texts.append(str(item)) 212 | return " ".join(texts) 213 | 214 | return str(content) 215 | 216 | def estimate_remaining_tokens( 217 | self, 218 | messages: Union[str, List[Dict[str, Any]]], 219 | model: str, 220 | provider: str, 221 | max_context_tokens: int = 128000, 222 | response_buffer: int = 4000, 223 | client: Optional[Any] = None, 224 | ) -> int: 225 | """ 226 | Estimate remaining tokens in context window. 227 | 228 | Args: 229 | messages: Current messages 230 | model: Model name 231 | provider: Provider name 232 | max_context_tokens: Maximum context window size 233 | response_buffer: Tokens to reserve for response 234 | client: Optional provider client 235 | 236 | Returns: 237 | Estimated remaining tokens 238 | """ 239 | used_tokens = self.count_tokens(messages, model, provider, client) 240 | return max(0, max_context_tokens - used_tokens - response_buffer) 241 | -------------------------------------------------------------------------------- /defog/llm/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from enum import Enum 3 | from typing import Optional, Union, Dict, Any, Literal 4 | 5 | 6 | class OpenAIFunctionSpecs(BaseModel): 7 | name: str # name of the function to call 8 | description: Optional[str] = None # description of the function 9 | parameters: Optional[Union[str, Dict[str, Any]]] = ( 10 | None # parameters of the function 11 | ) 12 | 13 | 14 | class AnthropicFunctionSpecs(BaseModel): 15 | name: str # name of the function to call 16 | description: Optional[str] = None # description of the function 17 | input_schema: Optional[Union[str, Dict[str, Any]]] = ( 18 | None # parameters of the function 19 | ) 20 | -------------------------------------------------------------------------------- /defog/llm/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLLMProvider, LLMResponse 2 | from .anthropic_provider import AnthropicProvider 3 | from .openai_provider import OpenAIProvider 4 | from .gemini_provider import GeminiProvider 5 | from .together_provider import TogetherProvider 6 | 7 | __all__ = [ 8 | "BaseLLMProvider", 9 | "LLMResponse", 10 | "AnthropicProvider", 11 | "OpenAIProvider", 12 | "GeminiProvider", 13 | "TogetherProvider", 14 | ] 15 | -------------------------------------------------------------------------------- /defog/llm/providers/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Any, Optional, Callable, Tuple 3 | from dataclasses import dataclass 4 | from ..config.settings import LLMConfig 5 | 6 | 7 | @dataclass 8 | class LLMResponse: 9 | content: Any 10 | model: str 11 | time: float 12 | input_tokens: int 13 | output_tokens: int 14 | cached_input_tokens: Optional[int] = None 15 | output_tokens_details: Optional[Dict[str, int]] = None 16 | cost_in_cents: Optional[float] = None 17 | tool_outputs: Optional[List[Dict[str, Any]]] = None 18 | 19 | 20 | class BaseLLMProvider(ABC): 21 | """Abstract base class for all LLM providers.""" 22 | 23 | def __init__( 24 | self, 25 | api_key: Optional[str] = None, 26 | base_url: Optional[str] = None, 27 | config: Optional[LLMConfig] = None, 28 | ): 29 | self.api_key = api_key 30 | self.base_url = base_url 31 | self.config = config or LLMConfig() 32 | 33 | @abstractmethod 34 | def get_provider_name(self) -> str: 35 | """Return the name of the provider.""" 36 | pass 37 | 38 | @abstractmethod 39 | def build_params( 40 | self, 41 | messages: List[Dict[str, str]], 42 | model: str, 43 | max_completion_tokens: Optional[int] = None, 44 | temperature: float = 0.0, 45 | response_format=None, 46 | seed: int = 0, 47 | tools: Optional[List[Callable]] = None, 48 | tool_choice: Optional[str] = None, 49 | store: bool = True, 50 | metadata: Optional[Dict[str, str]] = None, 51 | timeout: int = 100, 52 | prediction: Optional[Dict[str, str]] = None, 53 | reasoning_effort: Optional[str] = None, 54 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 55 | **kwargs 56 | ) -> Tuple[Dict[str, Any], List[Dict[str, str]]]: 57 | """Build parameters for the provider's API call.""" 58 | pass 59 | 60 | @abstractmethod 61 | async def process_response( 62 | self, 63 | client, 64 | response, 65 | request_params: Dict[str, Any], 66 | tools: Optional[List[Callable]], 67 | tool_dict: Dict[str, Callable], 68 | response_format=None, 69 | post_tool_function: Optional[Callable] = None, 70 | **kwargs 71 | ) -> Tuple[ 72 | Any, List[Dict[str, Any]], int, int, Optional[int], Optional[Dict[str, int]] 73 | ]: 74 | """Process the response from the provider.""" 75 | pass 76 | 77 | @abstractmethod 78 | async def execute_chat( 79 | self, 80 | messages: List[Dict[str, str]], 81 | model: str, 82 | max_completion_tokens: Optional[int] = None, 83 | temperature: float = 0.0, 84 | response_format=None, 85 | seed: int = 0, 86 | tools: Optional[List[Callable]] = None, 87 | tool_choice: Optional[str] = None, 88 | store: bool = True, 89 | metadata: Optional[Dict[str, str]] = None, 90 | timeout: int = 100, 91 | prediction: Optional[Dict[str, str]] = None, 92 | reasoning_effort: Optional[str] = None, 93 | post_tool_function: Optional[Callable] = None, 94 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 95 | **kwargs 96 | ) -> LLMResponse: 97 | """Execute a chat completion with the provider.""" 98 | pass 99 | 100 | @abstractmethod 101 | def supports_tools(self, model: str) -> bool: 102 | """Check if the model supports tool calling.""" 103 | pass 104 | 105 | @abstractmethod 106 | def supports_response_format(self, model: str) -> bool: 107 | """Check if the model supports structured response formats.""" 108 | pass 109 | -------------------------------------------------------------------------------- /defog/llm/providers/together_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Dict, List, Any, Optional, Callable, Tuple 4 | 5 | from .base import BaseLLMProvider, LLMResponse 6 | from ..exceptions import ProviderError, MaxTokensError 7 | from ..cost import CostCalculator 8 | 9 | 10 | class TogetherProvider(BaseLLMProvider): 11 | """Together AI provider implementation.""" 12 | 13 | def __init__(self, api_key: Optional[str] = None, config=None): 14 | super().__init__(api_key or os.getenv("TOGETHER_API_KEY"), config=config) 15 | 16 | def get_provider_name(self) -> str: 17 | return "together" 18 | 19 | def supports_tools(self, model: str) -> bool: 20 | return ( 21 | False # Currently Together models don't support tools in our implementation 22 | ) 23 | 24 | def supports_response_format(self, model: str) -> bool: 25 | return False # Currently Together models don't support structured output in our implementation 26 | 27 | def build_params( 28 | self, 29 | messages: List[Dict[str, str]], 30 | model: str, 31 | max_completion_tokens: Optional[int] = None, 32 | temperature: float = 0.0, 33 | response_format=None, 34 | seed: int = 0, 35 | tools: Optional[List[Callable]] = None, 36 | tool_choice: Optional[str] = None, 37 | store: bool = True, 38 | metadata: Optional[Dict[str, str]] = None, 39 | timeout: int = 100, 40 | prediction: Optional[Dict[str, str]] = None, 41 | reasoning_effort: Optional[str] = None, 42 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 43 | **kwargs, 44 | ) -> Tuple[Dict[str, Any], List[Dict[str, str]]]: 45 | """Build parameters for Together's API call.""" 46 | return { 47 | "messages": messages, 48 | "model": model, 49 | "max_tokens": max_completion_tokens, 50 | "temperature": temperature, 51 | "seed": seed, 52 | }, messages 53 | 54 | async def process_response( 55 | self, 56 | client, 57 | response, 58 | request_params: Dict[str, Any], 59 | tools: Optional[List[Callable]], 60 | tool_dict: Dict[str, Callable], 61 | response_format=None, 62 | post_tool_function: Optional[Callable] = None, 63 | **kwargs, 64 | ) -> Tuple[ 65 | Any, List[Dict[str, Any]], int, int, Optional[int], Optional[Dict[str, int]] 66 | ]: 67 | """Process Together API response.""" 68 | if response.choices[0].finish_reason == "length": 69 | raise MaxTokensError("Max tokens reached") 70 | if len(response.choices) == 0: 71 | raise MaxTokensError("Max tokens reached") 72 | 73 | content = response.choices[0].message.content 74 | input_tokens = response.usage.prompt_tokens 75 | output_tokens = response.usage.completion_tokens 76 | 77 | return content, [], input_tokens, output_tokens, None, None 78 | 79 | async def execute_chat( 80 | self, 81 | messages: List[Dict[str, str]], 82 | model: str, 83 | max_completion_tokens: Optional[int] = None, 84 | temperature: float = 0.0, 85 | response_format=None, 86 | seed: int = 0, 87 | tools: Optional[List[Callable]] = None, 88 | tool_choice: Optional[str] = None, 89 | store: bool = True, 90 | metadata: Optional[Dict[str, str]] = None, 91 | timeout: int = 100, 92 | prediction: Optional[Dict[str, str]] = None, 93 | reasoning_effort: Optional[str] = None, 94 | post_tool_function: Optional[Callable] = None, 95 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 96 | **kwargs, 97 | ) -> LLMResponse: 98 | """Execute a chat completion with Together.""" 99 | from together import AsyncTogether 100 | 101 | t = time.time() 102 | client_together = AsyncTogether(timeout=timeout) 103 | params, _ = self.build_params( 104 | messages=messages, 105 | model=model, 106 | max_completion_tokens=max_completion_tokens, 107 | temperature=temperature, 108 | seed=seed, 109 | ) 110 | 111 | try: 112 | response = await client_together.chat.completions.create(**params) 113 | ( 114 | content, 115 | tool_outputs, 116 | input_toks, 117 | output_toks, 118 | cached_toks, 119 | output_details, 120 | ) = await self.process_response( 121 | client=client_together, 122 | response=response, 123 | request_params=params, 124 | tools=tools, 125 | tool_dict={}, 126 | response_format=response_format, 127 | post_tool_function=post_tool_function, 128 | ) 129 | except Exception as e: 130 | raise ProviderError(self.get_provider_name(), f"API call failed: {e}", e) 131 | 132 | # Calculate cost 133 | cost = CostCalculator.calculate_cost( 134 | model, input_toks, output_toks, cached_toks 135 | ) 136 | 137 | return LLMResponse( 138 | model=model, 139 | content=content, 140 | time=round(time.time() - t, 3), 141 | input_tokens=input_toks, 142 | output_tokens=output_toks, 143 | cached_input_tokens=cached_toks, 144 | output_tokens_details=output_details, 145 | cost_in_cents=cost, 146 | tool_outputs=tool_outputs, 147 | ) 148 | -------------------------------------------------------------------------------- /defog/llm/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .handler import ToolHandler 2 | 3 | __all__ = ["ToolHandler"] 4 | -------------------------------------------------------------------------------- /defog/llm/tools/handler.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import asyncio 3 | from typing import Dict, List, Callable, Any, Optional 4 | from ..exceptions import ToolError 5 | from ..utils_function_calling import ( 6 | execute_tool, 7 | execute_tool_async, 8 | execute_tools_parallel, 9 | verify_post_tool_function, 10 | ) 11 | 12 | 13 | class ToolHandler: 14 | """Handles tool calling logic for LLM providers.""" 15 | 16 | def __init__(self, max_consecutive_errors: int = 3): 17 | self.max_consecutive_errors = max_consecutive_errors 18 | 19 | async def execute_tool_call( 20 | self, 21 | tool_name: str, 22 | args: Dict[str, Any], 23 | tool_dict: Dict[str, Callable], 24 | post_tool_function: Optional[Callable] = None, 25 | ) -> Any: 26 | """Execute a single tool call.""" 27 | tool_to_call = tool_dict.get(tool_name) 28 | if tool_to_call is None: 29 | raise ToolError(tool_name, "Tool not found") 30 | 31 | try: 32 | # Execute tool depending on whether it is async 33 | if inspect.iscoroutinefunction(tool_to_call): 34 | result = await execute_tool_async(tool_to_call, args) 35 | else: 36 | result = execute_tool(tool_to_call, args) 37 | except Exception as e: 38 | raise ToolError(tool_name, f"Error executing tool: {e}", e) 39 | 40 | # Execute post-tool function if provided 41 | if post_tool_function: 42 | try: 43 | if inspect.iscoroutinefunction(post_tool_function): 44 | await post_tool_function( 45 | function_name=tool_name, 46 | input_args=args, 47 | tool_result=result, 48 | ) 49 | else: 50 | post_tool_function( 51 | function_name=tool_name, 52 | input_args=args, 53 | tool_result=result, 54 | ) 55 | except Exception as e: 56 | raise ToolError( 57 | tool_name, f"Error executing post_tool_function: {e}", e 58 | ) 59 | 60 | return result 61 | 62 | async def execute_tool_calls_batch( 63 | self, 64 | tool_calls: List[Dict[str, Any]], 65 | tool_dict: Dict[str, Callable], 66 | enable_parallel: bool = False, 67 | post_tool_function: Optional[Callable] = None, 68 | ) -> List[Any]: 69 | """Execute multiple tool calls either in parallel or sequentially.""" 70 | try: 71 | results = await execute_tools_parallel( 72 | tool_calls, tool_dict, enable_parallel 73 | ) 74 | 75 | # Execute post-tool function for each result if provided 76 | if post_tool_function: 77 | for i, (tool_call, result) in enumerate(zip(tool_calls, results)): 78 | func_name = tool_call.get("function", {}).get( 79 | "name" 80 | ) or tool_call.get("name") 81 | func_args = tool_call.get("function", {}).get( 82 | "arguments" 83 | ) or tool_call.get("arguments", {}) 84 | 85 | try: 86 | if inspect.iscoroutinefunction(post_tool_function): 87 | await post_tool_function( 88 | function_name=func_name, 89 | input_args=func_args, 90 | tool_result=result, 91 | ) 92 | else: 93 | post_tool_function( 94 | function_name=func_name, 95 | input_args=func_args, 96 | tool_result=result, 97 | ) 98 | except Exception as e: 99 | # Don't fail the entire batch for post-tool function errors 100 | print( 101 | f"Warning: Error executing post_tool_function for {func_name}: {e}" 102 | ) 103 | 104 | return results 105 | except Exception as e: 106 | raise ToolError("batch", f"Error executing tool batch: {e}", e) 107 | 108 | def build_tool_dict(self, tools: List[Callable]) -> Dict[str, Callable]: 109 | """Build a dictionary mapping tool names to functions.""" 110 | return {tool.__name__: tool for tool in tools} 111 | 112 | def validate_post_tool_function( 113 | self, post_tool_function: Optional[Callable] 114 | ) -> None: 115 | """Validate the post-tool function signature.""" 116 | if post_tool_function: 117 | verify_post_tool_function(post_tool_function) 118 | -------------------------------------------------------------------------------- /defog/llm/utils_memory.py: -------------------------------------------------------------------------------- 1 | """Chat utilities with memory management support.""" 2 | 3 | from typing import Dict, List, Optional, Any, Union, Callable 4 | from dataclasses import dataclass 5 | 6 | from .utils import chat_async 7 | from .llm_providers import LLMProvider 8 | from .providers.base import LLMResponse 9 | from .config import LLMConfig 10 | from .memory import MemoryManager, compactify_messages, TokenCounter 11 | 12 | 13 | @dataclass 14 | class MemoryConfig: 15 | """Configuration for memory management.""" 16 | 17 | enabled: bool = True 18 | token_threshold: int = 50000 # ~50k tokens before compactifying 19 | preserve_last_n_messages: int = 10 20 | summary_max_tokens: int = 4000 21 | max_context_tokens: int = 128000 # 128k context window 22 | 23 | 24 | async def chat_async_with_memory( 25 | provider: Union[LLMProvider, str], 26 | model: str, 27 | messages: List[Dict[str, str]], 28 | memory_manager: Optional[MemoryManager] = None, 29 | memory_config: Optional[MemoryConfig] = None, 30 | auto_compactify: bool = True, 31 | max_completion_tokens: Optional[int] = None, 32 | temperature: float = 0.0, 33 | response_format=None, 34 | seed: int = 0, 35 | store: bool = True, 36 | metadata: Optional[Dict[str, str]] = None, 37 | timeout: int = 100, 38 | backup_model: Optional[str] = None, 39 | backup_provider: Optional[Union[LLMProvider, str]] = None, 40 | prediction: Optional[Dict[str, str]] = None, 41 | reasoning_effort: Optional[str] = None, 42 | tools: Optional[List[Callable]] = None, 43 | tool_choice: Optional[str] = None, 44 | max_retries: Optional[int] = None, 45 | post_tool_function: Optional[Callable] = None, 46 | config: Optional[LLMConfig] = None, 47 | ) -> LLMResponse: 48 | """ 49 | Execute a chat completion with memory management support. 50 | 51 | This function extends chat_async with automatic conversation memory management 52 | and compactification when approaching token limits. 53 | 54 | Args: 55 | provider: LLM provider to use 56 | model: Model name 57 | messages: List of message dictionaries 58 | memory_manager: Optional MemoryManager instance (created if not provided) 59 | memory_config: Memory configuration settings 60 | auto_compactify: Whether to automatically compactify when needed 61 | ... (all other chat_async parameters) 62 | 63 | Returns: 64 | LLMResponse object with the result 65 | """ 66 | # Initialize memory config if not provided 67 | if memory_config is None: 68 | memory_config = MemoryConfig() 69 | 70 | # Initialize memory manager if not provided and memory is enabled 71 | if memory_manager is None and memory_config.enabled: 72 | memory_manager = MemoryManager( 73 | token_threshold=memory_config.token_threshold, 74 | preserve_last_n_messages=memory_config.preserve_last_n_messages, 75 | summary_max_tokens=memory_config.summary_max_tokens, 76 | enabled=memory_config.enabled, 77 | ) 78 | 79 | # If memory is disabled, just pass through to regular chat_async 80 | if not memory_config.enabled or memory_manager is None: 81 | return await chat_async( 82 | provider=provider, 83 | model=model, 84 | messages=messages, 85 | max_completion_tokens=max_completion_tokens, 86 | temperature=temperature, 87 | response_format=response_format, 88 | seed=seed, 89 | store=store, 90 | metadata=metadata, 91 | timeout=timeout, 92 | backup_model=backup_model, 93 | backup_provider=backup_provider, 94 | prediction=prediction, 95 | reasoning_effort=reasoning_effort, 96 | tools=tools, 97 | tool_choice=tool_choice, 98 | max_retries=max_retries, 99 | post_tool_function=post_tool_function, 100 | config=config, 101 | ) 102 | 103 | # Get current messages from memory manager 104 | current_messages = memory_manager.get_current_messages() 105 | 106 | # Add new messages to memory 107 | token_counter = TokenCounter() 108 | new_tokens = token_counter.count_tokens(messages, model, str(provider)) 109 | memory_manager.add_messages(messages, new_tokens) 110 | 111 | # Check if we should compactify 112 | if auto_compactify and memory_manager.should_compactify(): 113 | # Get messages to summarize and preserve 114 | system_messages, messages_to_summarize, preserved_messages = ( 115 | memory_manager.get_messages_for_compactification() 116 | ) 117 | 118 | # Compactify messages 119 | compactified_messages, new_token_count = await compactify_messages( 120 | system_messages=system_messages, 121 | messages_to_summarize=messages_to_summarize, 122 | preserved_messages=preserved_messages, 123 | provider=str(provider), 124 | model=model, 125 | max_summary_tokens=memory_config.summary_max_tokens, 126 | config=config, # Pass config for API credentials 127 | ) 128 | 129 | # Update memory manager with compactified messages 130 | if compactified_messages and len(compactified_messages) > len(system_messages): 131 | # Find the summary message (first non-system message in the result) 132 | summary_idx = len(system_messages) 133 | summary_message = compactified_messages[summary_idx] 134 | memory_manager.update_after_compactification( 135 | system_messages=system_messages, 136 | summary_message=summary_message, 137 | preserved_messages=preserved_messages, 138 | new_token_count=new_token_count, 139 | ) 140 | 141 | # Use compactified messages for the API call 142 | messages_for_api = compactified_messages 143 | else: 144 | # Use all messages from memory 145 | messages_for_api = memory_manager.get_current_messages() 146 | 147 | # Make the API call with the potentially compactified messages 148 | response = await chat_async( 149 | provider=provider, 150 | model=model, 151 | messages=messages_for_api, 152 | max_completion_tokens=max_completion_tokens, 153 | temperature=temperature, 154 | response_format=response_format, 155 | seed=seed, 156 | store=store, 157 | metadata=metadata, 158 | timeout=timeout, 159 | backup_model=backup_model, 160 | backup_provider=backup_provider, 161 | prediction=prediction, 162 | reasoning_effort=reasoning_effort, 163 | tools=tools, 164 | tool_choice=tool_choice, 165 | max_retries=max_retries, 166 | post_tool_function=post_tool_function, 167 | config=config, 168 | ) 169 | 170 | # Add the assistant's response to memory 171 | assistant_message = {"role": "assistant", "content": response.content} 172 | response_tokens = response.output_tokens or 0 173 | memory_manager.add_messages([assistant_message], response_tokens) 174 | 175 | # Add memory stats to response metadata 176 | if hasattr(response, "_memory_stats"): 177 | response._memory_stats = memory_manager.get_stats() 178 | 179 | return response 180 | 181 | 182 | # Convenience function for creating a memory manager 183 | def create_memory_manager( 184 | token_threshold: int = 100000, 185 | preserve_last_n_messages: int = 10, 186 | summary_max_tokens: int = 2000, 187 | enabled: bool = True, 188 | ) -> MemoryManager: 189 | """ 190 | Create a new MemoryManager instance. 191 | 192 | Args: 193 | token_threshold: Token count threshold for triggering compactification 194 | preserve_last_n_messages: Number of recent messages to always preserve 195 | summary_max_tokens: Maximum tokens for the summary 196 | enabled: Whether memory management is enabled 197 | 198 | Returns: 199 | Configured MemoryManager instance 200 | """ 201 | return MemoryManager( 202 | token_threshold=token_threshold, 203 | preserve_last_n_messages=preserve_last_n_messages, 204 | summary_max_tokens=summary_max_tokens, 205 | enabled=enabled, 206 | ) 207 | -------------------------------------------------------------------------------- /defog/llm/web_search.py: -------------------------------------------------------------------------------- 1 | from defog.llm.llm_providers import LLMProvider 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | 5 | 6 | async def web_search_tool( 7 | question: str, 8 | model: str, 9 | provider: LLMProvider, 10 | max_tokens: int = 2048, 11 | ): 12 | """ 13 | Search the web for the answer to the question. 14 | """ 15 | async with ToolProgressTracker( 16 | "Web Search", 17 | f"Searching for: {question[:50]}{'...' if len(question) > 50 else ''}", 18 | ) as tracker: 19 | subtask_logger = SubTaskLogger() 20 | subtask_logger.log_provider_info( 21 | provider.value if hasattr(provider, "value") else str(provider), model 22 | ) 23 | 24 | if provider in [LLMProvider.OPENAI, LLMProvider.OPENAI.value]: 25 | from openai import AsyncOpenAI 26 | 27 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 28 | 29 | tracker.update(20, "Initiating web search") 30 | subtask_logger.log_search_status(question) 31 | 32 | response = await client.responses.create( 33 | model=model, 34 | tools=[{"type": "web_search_preview"}], 35 | tool_choice="required", 36 | input=question, 37 | # in the responses API, this means both the reasoning and the output tokens 38 | max_output_tokens=max_tokens, 39 | ) 40 | tracker.update(80, "Processing search results") 41 | subtask_logger.log_subtask("Extracting citations and content", "processing") 42 | 43 | usage = { 44 | "input_tokens": response.usage.input_tokens, 45 | "output_tokens": response.usage.output_tokens, 46 | } 47 | output_text = response.output_text 48 | websites_cited = [] 49 | for output in response.output: 50 | if hasattr(output, "content") and output.content: 51 | for content in output.content: 52 | if content.annotations: 53 | for annotation in content.annotations: 54 | websites_cited.append( 55 | { 56 | "url": annotation.url, 57 | "title": annotation.title, 58 | } 59 | ) 60 | 61 | subtask_logger.log_result_summary( 62 | "Web Search", 63 | { 64 | "websites_found": len(websites_cited), 65 | "tokens_used": usage["input_tokens"] + usage["output_tokens"], 66 | }, 67 | ) 68 | 69 | return { 70 | "usage": usage, 71 | "search_results": output_text, 72 | "websites_cited": websites_cited, 73 | } 74 | 75 | elif provider in [LLMProvider.ANTHROPIC, LLMProvider.ANTHROPIC.value]: 76 | from anthropic import AsyncAnthropic 77 | from anthropic.types import TextBlock 78 | 79 | client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 80 | 81 | tracker.update(20, "Initiating web search") 82 | subtask_logger.log_search_status(question, max_results=5) 83 | 84 | response = await client.messages.create( 85 | model=model, 86 | max_tokens=max_tokens, 87 | messages=[{"role": "user", "content": question}], 88 | tools=[ 89 | { 90 | "type": "web_search_20250305", 91 | "name": "web_search", 92 | "max_uses": 5, 93 | # can also use allowed_domains to limit the search to specific domains 94 | # can also use blocked_domains to exclude specific domains 95 | } 96 | ], 97 | tool_choice={"type": "any"}, 98 | ) 99 | 100 | tracker.update(80, "Processing search results") 101 | subtask_logger.log_subtask("Extracting citations and content", "processing") 102 | 103 | usage = { 104 | "input_tokens": response.usage.input_tokens, 105 | "output_tokens": response.usage.output_tokens, 106 | } 107 | search_results = response.content 108 | # we want to use only the TextBlock class in the search results 109 | search_results = [ 110 | block for block in search_results if isinstance(block, TextBlock) 111 | ] 112 | 113 | # convert the search_results into simple text with citations 114 | # (where citations = text + hyperlinks 115 | output_text = [ 116 | ( 117 | f'' + block.text + "" 118 | if block.citations 119 | else block.text 120 | ) 121 | for block in search_results 122 | ] 123 | websites_cited = [ 124 | {"url": block.citations[0].url, "title": block.citations[0].title} 125 | for block in search_results 126 | if block.citations 127 | ] 128 | 129 | subtask_logger.log_result_summary( 130 | "Web Search", 131 | { 132 | "text_blocks": len(search_results), 133 | "websites_cited": len(websites_cited), 134 | "tokens_used": usage["input_tokens"] + usage["output_tokens"], 135 | }, 136 | ) 137 | 138 | return { 139 | "usage": usage, 140 | "search_results": output_text, 141 | "websites_cited": websites_cited, 142 | } 143 | elif provider in [LLMProvider.GEMINI, LLMProvider.GEMINI.value]: 144 | from google import genai 145 | from google.genai.types import ( 146 | Tool, 147 | GenerateContentConfig, 148 | GoogleSearch, 149 | ToolConfig, 150 | FunctionCallingConfig, 151 | ) 152 | 153 | client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) 154 | google_search_tool = Tool(google_search=GoogleSearch()) 155 | 156 | tracker.update(20, "Initiating Google search") 157 | subtask_logger.log_search_status(question) 158 | 159 | response = await client.aio.models.generate_content( 160 | model=model, 161 | contents=question, 162 | config=GenerateContentConfig( 163 | tools=[google_search_tool], 164 | response_modalities=["TEXT"], 165 | tool_config=ToolConfig( 166 | function_calling_config=FunctionCallingConfig( 167 | mode="ANY", 168 | ), 169 | ), 170 | ), 171 | ) 172 | tracker.update(80, "Processing search results") 173 | subtask_logger.log_subtask("Extracting grounding metadata", "processing") 174 | 175 | usage = { 176 | "input_tokens": response.usage_metadata.prompt_token_count, 177 | "thinking_tokens": response.usage_metadata.thoughts_token_count or 0, 178 | "output_tokens": response.usage_metadata.candidates_token_count, 179 | } 180 | 181 | websites_cited = [] 182 | if response.candidates: 183 | for candidate in response.candidates: 184 | if ( 185 | candidate.grounding_metadata 186 | and candidate.grounding_metadata.grounding_chunks 187 | ): 188 | for chunk in candidate.grounding_metadata.grounding_chunks: 189 | websites_cited.append( 190 | {"source": chunk.web.title, "url": chunk.web.uri} 191 | ) 192 | 193 | output_text = response.text 194 | 195 | subtask_logger.log_result_summary( 196 | "Web Search", 197 | { 198 | "websites_found": len(websites_cited), 199 | "total_tokens": usage["input_tokens"] 200 | + usage["thinking_tokens"] 201 | + usage["output_tokens"], 202 | }, 203 | ) 204 | 205 | return { 206 | "usage": usage, 207 | "search_results": output_text, 208 | "websites_cited": websites_cited, 209 | } 210 | 211 | else: 212 | raise ValueError(f"Provider {provider} not supported") 213 | -------------------------------------------------------------------------------- /defog/llm/youtube_transcript.py: -------------------------------------------------------------------------------- 1 | # converts a youtube video to a detailed, ideally diarized transcript 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | 5 | 6 | async def get_transcript( 7 | video_url: str, model: str = "gemini-2.5-pro-preview-06-05" 8 | ) -> str: 9 | """ 10 | Get a detailed, diarized transcript of a YouTube video. 11 | 12 | Args: 13 | video_url: The URL of the YouTube video. 14 | 15 | Returns: 16 | A detailed, ideally diarized transcript of the video. 17 | """ 18 | async with ToolProgressTracker( 19 | "YouTube Transcript", 20 | f"Transcribing video from: {video_url[:50]}{'...' if len(video_url) > 50 else ''}", 21 | ) as tracker: 22 | subtask_logger = SubTaskLogger() 23 | subtask_logger.log_provider_info("Gemini", model) 24 | 25 | if os.getenv("GEMINI_API_KEY") is None: 26 | raise ValueError("GEMINI_API_KEY is not set") 27 | 28 | from google import genai 29 | from google.genai.types import ( 30 | Content, 31 | Part, 32 | FileData, 33 | VideoMetadata, 34 | ) 35 | 36 | client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) 37 | 38 | tracker.update(10, "Processing video") 39 | subtask_logger.log_subtask( 40 | "Using low FPS (0.2) for efficient processing", "info" 41 | ) 42 | 43 | response = await client.aio.models.generate_content( 44 | model=model, 45 | contents=Content( 46 | parts=[ 47 | Part( 48 | file_data=FileData(file_uri=video_url), 49 | video_metadata=VideoMetadata(fps=0.2), 50 | ), 51 | Part( 52 | text="Please provide a detailed, accurate transcript of the video. Please include timestamps for each speaker. Do not describe the video, just create a great transcript." 53 | ), 54 | ] 55 | ), 56 | ) 57 | 58 | tracker.update(90, "Finalizing transcript") 59 | transcript_length = len(response.text) if response.text else 0 60 | 61 | subtask_logger.log_result_summary( 62 | "YouTube Transcript", 63 | { 64 | "transcript_length": f"{transcript_length} characters", 65 | "model_used": model, 66 | }, 67 | ) 68 | 69 | return response.text 70 | -------------------------------------------------------------------------------- /defog/query_methods.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from defog.query import execute_query 3 | from datetime import datetime 4 | 5 | 6 | def get_query( 7 | self, 8 | question: str, 9 | hard_filters: str = "", 10 | previous_context: list = [], 11 | glossary: str = "", 12 | debug: bool = False, 13 | dev: bool = False, 14 | temp: bool = False, 15 | profile: bool = False, 16 | ignore_cache: bool = False, 17 | model: str = "", 18 | use_golden_queries: bool = True, 19 | subtable_pruning: bool = False, 20 | glossary_pruning: bool = False, 21 | prune_max_tokens: int = 2000, 22 | prune_bm25_num_columns: int = 10, 23 | prune_glossary_max_tokens: int = 1000, 24 | prune_glossary_num_cos_sim_units: int = 10, 25 | prune_glossary_bm25_units: int = 10, 26 | ): 27 | """ 28 | Sends the query to the defog servers, and return the response. 29 | :param question: The question to be asked. 30 | :return: The response from the defog server. 31 | """ 32 | try: 33 | data = { 34 | "question": question, 35 | "api_key": self.api_key, 36 | "previous_context": previous_context, 37 | "db_type": self.db_type if self.db_type != "databricks" else "postgres", 38 | "glossary": glossary, 39 | "hard_filters": hard_filters, 40 | "dev": dev, 41 | "temp": temp, 42 | "ignore_cache": ignore_cache, 43 | "model": model, 44 | "use_golden_queries": use_golden_queries, 45 | "subtable_pruning": subtable_pruning, 46 | "glossary_pruning": glossary_pruning, 47 | "prune_max_tokens": prune_max_tokens, 48 | "prune_bm25_num_columns": prune_bm25_num_columns, 49 | "prune_glossary_max_tokens": prune_glossary_max_tokens, 50 | "prune_glossary_num_cos_sim_units": prune_glossary_num_cos_sim_units, 51 | "prune_glossary_bm25_units": prune_glossary_bm25_units, 52 | } 53 | 54 | t_start = datetime.now() 55 | r = requests.post( 56 | self.generate_query_url, 57 | json=data, 58 | timeout=300, 59 | verify=False, 60 | ) 61 | resp = r.json() 62 | t_end = datetime.now() 63 | time_taken = (t_end - t_start).total_seconds() 64 | query_generated = resp.get("sql", resp.get("query_generated")) 65 | ran_successfully = resp.get("ran_successfully") 66 | error_message = resp.get("error_message") 67 | query_db = self.db_type 68 | resp = { 69 | "query_generated": query_generated, 70 | "ran_successfully": ran_successfully, 71 | "error_message": error_message, 72 | "query_db": query_db, 73 | "previous_context": resp.get("previous_context"), 74 | "reason_for_query": resp.get("reason_for_query"), 75 | } 76 | if profile: 77 | resp["time_taken"] = time_taken 78 | 79 | return resp 80 | except Exception as e: 81 | if debug: 82 | print(e) 83 | return { 84 | "ran_successfully": False, 85 | "error_message": "Sorry :( Our server is at capacity right now and we are unable to process your query. Please try again in a few minutes?", 86 | } 87 | 88 | 89 | def run_query( 90 | self, 91 | question: str, 92 | hard_filters: str = "", 93 | previous_context: list = [], 94 | glossary: str = "", 95 | query: dict = None, 96 | retries: int = 3, 97 | dev: bool = False, 98 | temp: bool = False, 99 | profile: bool = False, 100 | ignore_cache: bool = False, 101 | model: str = "", 102 | use_golden_queries: bool = True, 103 | subtable_pruning: bool = False, 104 | glossary_pruning: bool = False, 105 | prune_max_tokens: int = 2000, 106 | prune_bm25_num_columns: int = 10, 107 | prune_glossary_max_tokens: int = 1000, 108 | prune_glossary_num_cos_sim_units: int = 10, 109 | prune_glossary_bm25_units: int = 10, 110 | ): 111 | """ 112 | Sends the question to the defog servers, executes the generated SQL, 113 | and returns the response. 114 | :param question: The question to be asked. 115 | :return: The response from the defog server. 116 | """ 117 | if query is None: 118 | print(f"Generating the query for your question: {question}...") 119 | query = self.get_query( 120 | question, 121 | hard_filters, 122 | previous_context, 123 | glossary=glossary, 124 | dev=dev, 125 | temp=temp, 126 | profile=profile, 127 | model=model, 128 | ignore_cache=ignore_cache, 129 | use_golden_queries=use_golden_queries, 130 | subtable_pruning=subtable_pruning, 131 | glossary_pruning=glossary_pruning, 132 | prune_max_tokens=prune_max_tokens, 133 | prune_bm25_num_columns=prune_bm25_num_columns, 134 | prune_glossary_max_tokens=prune_glossary_max_tokens, 135 | prune_glossary_num_cos_sim_units=prune_glossary_num_cos_sim_units, 136 | prune_glossary_bm25_units=prune_glossary_bm25_units, 137 | ) 138 | if query["ran_successfully"]: 139 | try: 140 | print("Query generated, now running it on your database...") 141 | tstart = datetime.now() 142 | colnames, result, executed_query = execute_query( 143 | query=query["query_generated"], 144 | api_key=self.api_key, 145 | db_type=self.db_type, 146 | db_creds=self.db_creds, 147 | question=question, 148 | hard_filters=hard_filters, 149 | retries=retries, 150 | dev=dev, 151 | temp=temp, 152 | base_url=self.base_url, 153 | ) 154 | tend = datetime.now() 155 | time_taken = (tend - tstart).total_seconds() 156 | resp = { 157 | "columns": colnames, 158 | "data": result, 159 | "query_generated": executed_query, 160 | "ran_successfully": True, 161 | "reason_for_query": query.get("reason_for_query"), 162 | "previous_context": query.get("previous_context"), 163 | } 164 | if profile: 165 | resp["execution_time_taken"] = time_taken 166 | resp["generation_time_taken"] = query.get("time_taken") 167 | return resp 168 | except Exception as e: 169 | return { 170 | "ran_successfully": False, 171 | "error_message": str(e), 172 | "query_generated": query["query_generated"], 173 | } 174 | else: 175 | return {"ran_successfully": False, "error_message": query["error_message"]} 176 | -------------------------------------------------------------------------------- /defog/serve.py: -------------------------------------------------------------------------------- 1 | # create a FastAPI app 2 | from fastapi import FastAPI, Request 3 | from fastapi.middleware.cors import CORSMiddleware 4 | 5 | from defog import Defog 6 | import pandas as pd 7 | import os 8 | import json 9 | from io import StringIO 10 | 11 | app = FastAPI() 12 | 13 | origins = ["*"] 14 | app.add_middleware( 15 | CORSMiddleware, 16 | allow_origins=origins, 17 | allow_credentials=True, 18 | allow_methods=["*"], 19 | allow_headers=["*"], 20 | ) 21 | 22 | home_dir = os.path.expanduser("~") 23 | defog_path = os.path.join(home_dir, ".defog") 24 | 25 | 26 | @app.get("/") 27 | async def root(): 28 | return {"message": "Hello, I am Defog"} 29 | 30 | 31 | @app.post("/generate_query") 32 | async def generate(request: Request): 33 | params = await request.json() 34 | question = params.get("question") 35 | previous_context = params.get("previous_context") 36 | defog = Defog() 37 | resp = defog.run_query(question, previous_context=previous_context) 38 | return resp 39 | 40 | 41 | @app.post("/integration/get_tables_db_creds") 42 | async def get_tables_db_creds(request: Request): 43 | try: 44 | defog = Defog() 45 | except: 46 | return {"error": "no defog instance found"} 47 | 48 | try: 49 | with open(os.path.join(defog_path, "tables.json"), "r") as f: 50 | table_names = json.load(f) 51 | except: 52 | table_names = [] 53 | 54 | try: 55 | with open(os.path.join(defog_path, "selected_tables.json"), "r") as f: 56 | selected_table_names = json.load(f) 57 | except: 58 | selected_table_names = [] 59 | 60 | db_type = defog.db_type 61 | db_creds = defog.db_creds 62 | api_key = defog.api_key 63 | 64 | return { 65 | "tables": table_names, 66 | "db_creds": db_creds, 67 | "db_type": db_type, 68 | "selected_tables": selected_table_names, 69 | "api_key": api_key, 70 | } 71 | 72 | 73 | @app.post("/integration/get_metadata") 74 | async def get_metadata(request: Request): 75 | try: 76 | with open(os.path.join(defog_path, "metadata.json"), "r") as f: 77 | metadata = json.load(f) 78 | 79 | return {"metadata": metadata} 80 | except: 81 | return {"error": "no metadata found"} 82 | 83 | 84 | @app.post("/integration/generate_tables") 85 | async def get_tables(request: Request): 86 | params = await request.json() 87 | api_key = params.get("api_key") 88 | db_type = params.get("db_type") 89 | db_creds = params.get("db_creds") 90 | for k in ["api_key", "db_type"]: 91 | if k in db_creds: 92 | del db_creds[k] 93 | 94 | defog = Defog(api_key, db_type, db_creds) 95 | table_names = defog.generate_db_schema(tables=[], return_tables_only=True) 96 | 97 | with open(os.path.join(defog_path, "tables.json"), "w") as f: 98 | json.dump(table_names, f) 99 | 100 | return {"tables": table_names} 101 | 102 | 103 | @app.post("/integration/generate_metadata") 104 | async def generate_metadata(request: Request): 105 | params = await request.json() 106 | tables = params.get("tables") 107 | 108 | with open(os.path.join(defog_path, "selected_tables.json"), "w") as f: 109 | json.dump(tables, f) 110 | 111 | defog = Defog() 112 | table_metadata = defog.generate_db_schema( 113 | tables=tables, scan=True, upload=True, return_format="csv_string" 114 | ) 115 | metadata = ( 116 | pd.read_csv(StringIO(table_metadata)).fillna("").to_dict(orient="records") 117 | ) 118 | 119 | with open(os.path.join(defog_path, "metadata.json"), "w") as f: 120 | json.dump(metadata, f) 121 | 122 | defog.update_db_schema(StringIO(table_metadata)) 123 | return {"metadata": metadata} 124 | 125 | 126 | @app.post("/integration/update_metadata") 127 | async def update_metadata(request: Request): 128 | params = await request.json() 129 | metadata = params.get("metadata") 130 | defog = Defog() 131 | metadata = pd.DataFrame(metadata).to_csv(index=False) 132 | defog.update_db_schema(StringIO(metadata)) 133 | return {"status": "success"} 134 | 135 | 136 | @app.post("/instruct/get_glossary_golden_queries") 137 | async def update_glossary(request: Request): 138 | defog = Defog() 139 | glossary = defog.get_glossary() 140 | golden_queries = defog.get_golden_queries(format="json") 141 | return {"glossary": glossary, "golden_queries": golden_queries} 142 | 143 | 144 | @app.post("/instruct/update_glossary") 145 | async def update_glossary(request: Request): 146 | params = await request.json() 147 | glossary = params.get("glossary") 148 | defog = Defog() 149 | defog.update_glossary(glossary=glossary) 150 | return {"status": "success"} 151 | 152 | 153 | @app.post("/instruct/update_golden_queries") 154 | async def update_golden_queries(request: Request): 155 | params = await request.json() 156 | golden_queries = params.get("golden_queries") 157 | golden_queries = [ 158 | x 159 | for x in golden_queries 160 | if x["sql"] != "" and x["question"] != "" and x["user_validated"] 161 | ] 162 | defog = Defog() 163 | defog.update_golden_queries(golden_queries=golden_queries) 164 | return {"status": "success"} 165 | -------------------------------------------------------------------------------- /defog/static/404.html: -------------------------------------------------------------------------------- 1 | 404: This page could not be found

404

This page could not be found.

-------------------------------------------------------------------------------- /defog/static/_next/static/chunks/pages/_app-db0976def6406e5e.js: -------------------------------------------------------------------------------- 1 | (self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[888],{91118:function(n,e,_){(window.__NEXT_P=window.__NEXT_P||[]).push(["/_app",function(){return _(36290)}])},36290:function(n,e,_){"use strict";_.r(e),_.d(e,{default:function(){return App}});var t=_(28598);_(50558);var u=_(82684);let c=u.createContext();function App(n){let{Component:e,pageProps:_}=n,[r,i]=(0,u.useState)({});return(0,t.jsx)(c.Provider,{value:[r,i],children:(0,t.jsx)(e,{..._})})}_(46731)},50558:function(){},46731:function(){}},function(n){var __webpack_exec__=function(e){return n(n.s=e)};n.O(0,[774,179],function(){return __webpack_exec__(91118),__webpack_exec__(84142)}),_N_E=n.O()}]); -------------------------------------------------------------------------------- /defog/static/_next/static/chunks/pages/_error-ee42a9921d95ff81.js: -------------------------------------------------------------------------------- 1 | (self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[820],{81981:function(n,_,u){(window.__NEXT_P=window.__NEXT_P||[]).push(["/_error",function(){return u(30730)}])}},function(n){n.O(0,[774,888,179],function(){return n(n.s=81981)}),_N_E=n.O()}]); -------------------------------------------------------------------------------- /defog/static/_next/static/chunks/pages/extract-metadata-2adf74ad3bcc8699.js: -------------------------------------------------------------------------------- 1 | (self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[99],{97362:function(e,t,a){(window.__NEXT_P=window.__NEXT_P||[]).push(["/extract-metadata",function(){return a(16334)}])},31143:function(e,t,a){"use strict";var s=a(28598),n=a(1887),l=a.n(n);t.Z=()=>(0,s.jsxs)(l(),{children:[(0,s.jsx)("title",{children:"Defog.ai - AI Assistant for Data Analysis"}),(0,s.jsx)("meta",{name:"description",content:"Train your AI data assistant on your own device"}),(0,s.jsx)("meta",{name:"viewport",content:"width=device-width, initial-scale=1"}),(0,s.jsx)("link",{rel:"icon",href:"/favicon.ico"})]})},62429:function(e,t,a){"use strict";var s=a(28598);a(82684);var n=a(79869),l=a(12691),d=a.n(l);t.Z=e=>{let{id:t,children:a}=e,{Content:l,Sider:i}=n.Layout,r=[{key:"manage-database",title:"Manage Database",icon:(0,s.jsx)(d(),{href:"/extract-metadata",children:"\uD83D\uDCBE Manage DB"})},{key:"instruct-model",title:"Instruct Model",icon:(0,s.jsx)(d(),{href:"/instruct-model",children:"\uD83D\uDC68‍\uD83C\uDFEB Instruct Model"})},{key:"query-data",title:"Query Data",icon:(0,s.jsx)(d(),{href:"/query-data",children:"\uD83D\uDD0D Query Data"})}];return(0,s.jsx)(n.Layout,{style:{height:"100vh"},children:(0,s.jsxs)(l,{children:[(0,s.jsx)(i,{style:{height:"100vh",position:"fixed"},children:(0,s.jsx)(n.v2,{style:{width:200,paddingTop:"2em",paddingBottom:"2em"},mode:"inline",selectedKeys:[t],items:r})}),(0,s.jsx)("div",{style:{paddingLeft:240,paddingTop:30,backgroundColor:"#f5f5f5"},children:a})]})})}},16334:function(e,t,a){"use strict";a.r(t);var s=a(28598),n=a(82684),l=a(31143),d=a(62429),i=a(79869);t.default=()=>{let{Option:e}=i.Select,[t,a]=(0,n.useState)(""),[r,o]=(0,n.useState)("postgres"),[c,h]=(0,n.useState)({}),[p,m]=(0,n.useState)([]),[u,x]=(0,n.useState)(!1),[y,b]=(0,n.useState)([]),[g,j]=(0,n.useState)([]),[w]=i.l0.useForm(),getTables=async()=>{let e=await fetch("http://localhost:1235/integration/get_tables_db_creds",{method:"POST"}),t=await e.json();t.error||(console.log(t),o(t.db_type),h(t.db_creds),m(t.tables),j(t.selected_tables),w.setFieldsValue({db_type:t.db_type,api_key:t.api_key,...t.db_creds}))},getMetadata=async()=>{let e=await fetch("http://localhost:1235/integration/get_metadata",{method:"POST"}),t=await e.json();t.error||b((null==t?void 0:t.metadata)||[])};(0,n.useEffect)(()=>{x(!0),getTables().then(()=>{getMetadata().then(()=>{x(!1)})})},[]);let f={postgres:["host","port","user","password","database"],mysql:["host","port","user","password","database"],redshift:["host","port","user","password","database"],snowflake:["account","warehouse","user","password"],databricks:["server_hostname","access_token","http_path","schema"]};return(0,s.jsxs)(s.Fragment,{children:[(0,s.jsx)(l.Z,{}),(0,s.jsxs)(d.Z,{id:"manage-database",userType:"admin",children:[(0,s.jsx)("h1",{style:{paddingBottom:"1em"},children:"Extract Metadata"}),(0,s.jsxs)(i.X2,{type:"flex",height:"100vh",children:[(0,s.jsx)(i.JX,{md:{span:8},xs:{span:24},children:(0,s.jsxs)("div",{children:[(0,s.jsxs)(i.l0,{name:"db_creds",form:w,labelCol:{span:6},wrapperCol:{span:18},style:{maxWidth:400},disabled:u,onFinish:async e=>{x(!0),e={db_creds:e,db_type:e.db_type||r,api_key:e.api_key||t};let a=await fetch("http://localhost:1235/integration/generate_tables",{method:"POST",body:JSON.stringify(e)}),s=await a.json();m(s.tables),x(!1)},children:[(0,s.jsx)(i.l0.Item,{name:"api_key",label:"API Key",children:(0,s.jsx)(i.II,{style:{width:"100%"},onChange:e=>{a(e.target.value)}})}),(0,s.jsx)(i.l0.Item,{name:"db_type",label:"DB Type",children:(0,s.jsx)(i.Select,{style:{width:"100%"},onChange:e=>{o(e)},options:["databricks","mysql","postgres","redshift","snowflake"].map(e=>({value:e,key:e,label:e}))})}),void 0!==f[r]&&f[r].map(e=>(0,s.jsx)(i.l0.Item,{label:e,name:e,children:(0,s.jsx)(i.II,{style:{width:"100%"}})},r+"_"+e)),(0,s.jsx)(i.l0.Item,{wrapperCol:{span:24},children:(0,s.jsx)(i.zx,{type:"primary",style:{width:"100%"},htmlType:"submit",children:"Get Tables"})})]}),p.length>0&&(0,s.jsxs)(i.l0,{name:"db_tables",labelCol:{span:8},wrapperCol:{span:16},style:{maxWidth:400},disabled:u,onFinish:async e=>{x(!0);let t=await fetch("http://localhost:1235/integration/generate_metadata",{method:"POST",body:JSON.stringify({tables:e.tables})}),a=await t.json();x(!1),b((null==a?void 0:a.metadata)||[])},children:[(0,s.jsx)(i.l0.Item,{name:"tables",label:"Tables to index",value:g,children:(0,s.jsx)(i.Select,{mode:"multiple",style:{width:"100%",maxWidth:400},placeholder:"Select tables to index",defaultValue:g,onChange:e=>{console.log(e),j(e)},children:p.map(t=>(0,s.jsx)(e,{value:t,children:t},t))})}),(0,s.jsx)(i.l0.Item,{wrapperCol:{span:24},children:(0,s.jsx)(i.zx,{type:"primary",style:{width:"100%",maxWidth:535},htmlType:"submit",children:"Extract Metadata"})})]})]})}),(0,s.jsxs)(i.JX,{md:{span:16},xs:{span:24},style:{paddingRight:"2em",height:600,overflowY:"scroll"},children:[y.length>0&&(0,s.jsx)(i.zx,{type:"primary",style:{width:"100%",maxWidth:535},disabled:u,loading:u,onClick:async()=>{x(!0);let e=await fetch("http://localhost:1235/integration/update_metadata",{method:"POST",body:JSON.stringify({metadata:y})}),t=await e.json();console.log(t),x(!1),void 0!==t.suggested_joins&&null!==t.suggested_joins&&""!==t.suggested_joins&&(document.getElementById("allowed-joins").value=t.suggested_joins),i.yw.success("Metadata updated successfully!")},children:"Update metadata on server"}),y.length>0?(0,s.jsxs)(i.X2,{style:{marginTop:"1em",position:"sticky",top:0,paddingBottom:"1em",paddingTop:"1em",backgroundColor:"white",zIndex:100},children:[(0,s.jsx)(i.JX,{xs:{span:24},md:{span:4},style:{overflowWrap:"break-word"},children:(0,s.jsx)("b",{children:"Table Name"})}),(0,s.jsx)(i.JX,{xs:{span:24},md:{span:4},style:{overflowWrap:"break-word"},children:(0,s.jsx)("b",{children:"Column Name"})}),(0,s.jsx)(i.JX,{xs:{span:24},md:{span:4},style:{overflowWrap:"break-word"},children:(0,s.jsx)("b",{children:"Data Type"})}),(0,s.jsx)(i.JX,{xs:{span:24},md:{span:12},children:(0,s.jsx)("b",{children:"Description (Optional)"})})]}):null,y.length>0&&y.map((e,t)=>(0,s.jsxs)(i.X2,{style:{marginTop:"1em"},children:[(0,s.jsx)(i.JX,{xs:{span:24},md:{span:4},style:{overflowWrap:"break-word"},children:e.table_name}),(0,s.jsx)(i.JX,{xs:{span:24},md:{span:4},style:{overflowWrap:"break-word"},children:e.column_name}),(0,s.jsx)(i.JX,{xs:{span:24},md:{span:4},style:{overflowWrap:"break-word"},children:e.data_type}),(0,s.jsx)(i.JX,{xs:{span:24},md:{span:12},children:(0,s.jsx)(i.II.TextArea,{placeholder:"Description of what this column does",defaultValue:e.column_description,autoSize:{minRows:2},onChange:e=>{let a=[...y];a[t].column_description=e.target.value,b(a)}},t)})]},t))]})]})]})]})}}},function(e){e.O(0,[238,774,888,179],function(){return e(e.s=97362)}),_N_E=e.O()}]); -------------------------------------------------------------------------------- /defog/static/_next/static/chunks/pages/index-b60f249c1d54d3bf.js: -------------------------------------------------------------------------------- 1 | (self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[405],{34376:function(e,t,n){e.exports=n(84142)},75557:function(e,t,n){(window.__NEXT_P=window.__NEXT_P||[]).push(["/",function(){return n(48591)}])},31143:function(e,t,n){"use strict";var i=n(28598),a=n(1887),s=n.n(a);t.Z=()=>(0,i.jsxs)(s(),{children:[(0,i.jsx)("title",{children:"Defog.ai - AI Assistant for Data Analysis"}),(0,i.jsx)("meta",{name:"description",content:"Train your AI data assistant on your own device"}),(0,i.jsx)("meta",{name:"viewport",content:"width=device-width, initial-scale=1"}),(0,i.jsx)("link",{rel:"icon",href:"/favicon.ico"})]})},62429:function(e,t,n){"use strict";var i=n(28598);n(82684);var a=n(79869),s=n(12691),r=n.n(s);t.Z=e=>{let{id:t,children:n}=e,{Content:s,Sider:o}=a.Layout,d=[{key:"manage-database",title:"Manage Database",icon:(0,i.jsx)(r(),{href:"/extract-metadata",children:"\uD83D\uDCBE Manage DB"})},{key:"instruct-model",title:"Instruct Model",icon:(0,i.jsx)(r(),{href:"/instruct-model",children:"\uD83D\uDC68‍\uD83C\uDFEB Instruct Model"})},{key:"query-data",title:"Query Data",icon:(0,i.jsx)(r(),{href:"/query-data",children:"\uD83D\uDD0D Query Data"})}];return(0,i.jsx)(a.Layout,{style:{height:"100vh"},children:(0,i.jsxs)(s,{children:[(0,i.jsx)(o,{style:{height:"100vh",position:"fixed"},children:(0,i.jsx)(a.v2,{style:{width:200,paddingTop:"2em",paddingBottom:"2em"},mode:"inline",selectedKeys:[t],items:d})}),(0,i.jsx)("div",{style:{paddingLeft:240,paddingTop:30,backgroundColor:"#f5f5f5"},children:n})]})})}},48591:function(e,t,n){"use strict";n.r(t);var i=n(28598),a=n(82684),s=n(34376),r=n(31143),o=n(62429),d=n(79869);t.default=()=>{let e=(0,s.useRouter)();return(0,a.useEffect)(()=>{fetch("http://localhost:1235/integration/get_tables_db_creds",{method:"POST",headers:{"Content-Type":"application/json"}}).then(e=>e.json()).then(t=>{t.tables&&t.db_creds?e.push("/query-data"):e.push("/extract-metadata")})},[]),(0,i.jsxs)(i.Fragment,{children:[(0,i.jsx)(r.Z,{}),(0,i.jsxs)(o.Z,{children:[(0,i.jsx)("h1",{style:{paddingBottom:"1em"},children:"Welcome to Defog!"}),(0,i.jsxs)("h3",{children:["Please wait while we log you in and redirect you to the right page... ",(0,i.jsx)(d.yC,{})]})]})]})}}},function(e){e.O(0,[238,774,888,179],function(){return e(e.s=75557)}),_N_E=e.O()}]); -------------------------------------------------------------------------------- /defog/static/_next/static/chunks/pages/instruct-model-d040be04cf7f21f2.js: -------------------------------------------------------------------------------- 1 | (self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[673],{2252:function(e,s,t){(window.__NEXT_P=window.__NEXT_P||[]).push(["/instruct-model",function(){return t(14373)}])},31143:function(e,s,t){"use strict";var a=t(28598),n=t(1887),r=t.n(n);s.Z=()=>(0,a.jsxs)(r(),{children:[(0,a.jsx)("title",{children:"Defog.ai - AI Assistant for Data Analysis"}),(0,a.jsx)("meta",{name:"description",content:"Train your AI data assistant on your own device"}),(0,a.jsx)("meta",{name:"viewport",content:"width=device-width, initial-scale=1"}),(0,a.jsx)("link",{rel:"icon",href:"/favicon.ico"})]})},62429:function(e,s,t){"use strict";var a=t(28598);t(82684);var n=t(79869),r=t(12691),l=t.n(r);s.Z=e=>{let{id:s,children:t}=e,{Content:r,Sider:i}=n.Layout,o=[{key:"manage-database",title:"Manage Database",icon:(0,a.jsx)(l(),{href:"/extract-metadata",children:"\uD83D\uDCBE Manage DB"})},{key:"instruct-model",title:"Instruct Model",icon:(0,a.jsx)(l(),{href:"/instruct-model",children:"\uD83D\uDC68‍\uD83C\uDFEB Instruct Model"})},{key:"query-data",title:"Query Data",icon:(0,a.jsx)(l(),{href:"/query-data",children:"\uD83D\uDD0D Query Data"})}];return(0,a.jsx)(n.Layout,{style:{height:"100vh"},children:(0,a.jsxs)(r,{children:[(0,a.jsx)(i,{style:{height:"100vh",position:"fixed"},children:(0,a.jsx)(n.v2,{style:{width:200,paddingTop:"2em",paddingBottom:"2em"},mode:"inline",selectedKeys:[s],items:o})}),(0,a.jsx)("div",{style:{paddingLeft:240,paddingTop:30,backgroundColor:"#f5f5f5"},children:t})]})})}},14373:function(e,s,t){"use strict";t.r(s);var a=t(28598),n=t(82684),r=t(31143),l=t(62429),i=t(79869);s.default=()=>{let[e,s]=(0,n.useState)(!0),[t]=i.l0.useForm(),[o,d]=(0,n.useState)([]),[c,h]=(0,n.useState)([]);return(0,n.useEffect)(async()=>{let e=await fetch("http://localhost:1235/instruct/get_glossary_golden_queries",{method:"POST",headers:{"Content-Type":"application/json"}}),a=await e.json();t.setFieldsValue({glossary:a.glossary}),h(a.golden_queries),s(!1)},[]),(0,a.jsxs)(a.Fragment,{children:[(0,a.jsx)(r.Z,{}),(0,a.jsxs)(l.Z,{id:"instruct-model",children:[(0,a.jsx)("h1",{children:"Instruct Model"}),(0,a.jsx)("h3",{children:"Update glossary"}),(0,a.jsxs)("div",{style:{maxWidth:800},children:[(0,a.jsxs)("div",{children:["The glossary is a list of instructions that the model will use to generate queries.",(0,a.jsx)("br",{}),(0,a.jsx)("br",{}),"An example of a glossary entry is:",(0,a.jsx)("br",{}),(0,a.jsx)("br",{}),(0,a.jsx)("span",{style:{fontFamily:"monospace"},children:"- When a user asks for APR time, they are asking for Average Payment Reimbursement. You can calculate this by dividing the total amount reimbursed by the number of reimbursement requests."}),"."]}),(0,a.jsxs)(i.l0,{name:"glossary",layout:"vertical",form:t,initialValues:{remember:!0},style:{paddingTop:"1em"},defaultValue:{glossary:o},loading:e,disabled:e,onFinish:async e=>{let s=await fetch("http://localhost:1235/instruct/update_glossary",{method:"POST",headers:{"Content-Type":"application/json"},body:JSON.stringify({glossary:e.glossary})}),t=await s.json();console.log(t)},children:[(0,a.jsx)(i.l0.Item,{label:"Glossary",name:"glossary",children:(0,a.jsx)(i.II.TextArea,{rows:5})}),(0,a.jsx)(i.l0.Item,{children:(0,a.jsx)(i.zx,{type:"primary",htmlType:"submit",children:"Update Glossary"})})]})]}),(0,a.jsx)("h3",{children:"Golden Queries"}),c.length>0?(0,a.jsxs)(i.X2,{style:{marginTop:"1em",position:"sticky",top:0,paddingBottom:"1em",paddingTop:"1em",backgroundColor:"white",zIndex:100},children:[(0,a.jsx)(i.JX,{xs:{span:24},md:{span:1},style:{overflowWrap:"break-word"},children:(0,a.jsx)("b",{children:" "})}),(0,a.jsx)(i.JX,{xs:{span:24},md:{span:6},style:{overflowWrap:"break-word"},children:(0,a.jsx)("b",{children:"Question"})}),(0,a.jsx)(i.JX,{xs:{span:24},md:{span:16},style:{overflowWrap:"break-word"},children:(0,a.jsx)("b",{children:"SQL"})}),(0,a.jsx)(i.JX,{xs:{span:24},md:{span:1},style:{overflowWrap:"break-word"},children:(0,a.jsx)("b",{children:"Good?"})})]}):null,c.length>0&&c.map((e,s)=>(0,a.jsxs)(i.X2,{style:{marginTop:"1em"},gutter:{xs:8,sm:16,md:24,lg:32},children:[(0,a.jsx)(i.JX,{xs:{span:24},md:{span:1},style:{overflowWrap:"break-word"},children:(0,a.jsx)(i.zx,{ghost:!0,onClick:()=>{let e=[...c];e.splice(s,1),h(e)},children:"⛔"})}),(0,a.jsx)(i.JX,{xs:{span:24},md:{span:6},style:{overflowWrap:"break-word"},children:(0,a.jsx)(i.II.TextArea,{placeholder:"Question",defaultValue:e.question,autoSize:{minRows:2},onChange:e=>{let t=[...c];t[s].question=e.target.value,h(t)}},s)}),(0,a.jsx)(i.JX,{xs:{span:24},md:{span:16},style:{overflowWrap:"break-word"},children:(0,a.jsx)(i.II.TextArea,{placeholder:"SQL",defaultValue:e.sql,autoSize:{minRows:2},onChange:e=>{let t=[...c];t[s].sql=e.target.value,h(t)}},s)}),(0,a.jsx)(i.JX,{xs:{span:24},md:{span:1},style:{overflowWrap:"break-word"},children:(0,a.jsx)(i.XZ,{checked:e.user_validated,onChange:e=>{let t=[...c];t[s].user_validated=e.target.checked,h(t)}})})]},s)),(0,a.jsxs)(i.X2,{children:[(0,a.jsx)(i.JX,{span:24,style:{marginTop:"1em"},children:(0,a.jsx)(i.zx,{type:"primary",onClick:()=>{let e=[...c];e.push({question:"",sql:"",user_validated:!1}),h(e)},children:"Add new row"})}),(0,a.jsx)(i.JX,{span:24,style:{marginTop:"1em"},children:(0,a.jsx)(i.zx,{type:"primary",onClick:async()=>{let e=await fetch("http://localhost:1235/instruct/update_golden_queries",{method:"POST",headers:{"Content-Type":"application/json"},body:JSON.stringify({golden_queries:c})}),s=await e.json();console.log(s)},children:"Update Golden Queries on Server"})})]})]})]})}}},function(e){e.O(0,[238,774,888,179],function(){return e(e.s=2252)}),_N_E=e.O()}]); -------------------------------------------------------------------------------- /defog/static/_next/static/chunks/pages/query-data-b8197e7950b177eb.js: -------------------------------------------------------------------------------- 1 | (self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[137],{81743:function(e,t,n){"use strict";Object.defineProperty(t,"__esModule",{value:!0}),function(e,t){for(var n in t)Object.defineProperty(e,n,{enumerable:!0,get:t[n]})}(t,{noSSR:function(){return noSSR},default:function(){return dynamic}});let a=n(43709),l=(n(82684),a._(n(36586)));function convertModule(e){return{default:(null==e?void 0:e.default)||e}}function noSSR(e,t){return delete t.webpack,delete t.modules,e(t)}function dynamic(e,t){let n=l.default,a={loading:e=>{let{error:t,isLoading:n,pastDelay:a}=e;return null}};e instanceof Promise?a.loader=()=>e:"function"==typeof e?a.loader=e:"object"==typeof e&&(a={...a,...e}),a={...a,...t};let r=a.loader;return(a.loadableGenerated&&(a={...a,...a.loadableGenerated},delete a.loadableGenerated),"boolean"!=typeof a.ssr||a.ssr)?n({...a,loader:()=>null!=r?r().then(convertModule):Promise.resolve(convertModule(()=>null))}):(delete a.webpack,delete a.modules,noSSR(n,a))}("function"==typeof t.default||"object"==typeof t.default&&null!==t.default)&&void 0===t.default.__esModule&&(Object.defineProperty(t.default,"__esModule",{value:!0}),Object.assign(t.default,t),e.exports=t.default)},80805:function(e,t,n){"use strict";Object.defineProperty(t,"__esModule",{value:!0}),Object.defineProperty(t,"LoadableContext",{enumerable:!0,get:function(){return r}});let a=n(43709),l=a._(n(82684)),r=l.default.createContext(null)},36586:function(e,t,n){"use strict";/** 2 | @copyright (c) 2017-present James Kyle 3 | MIT License 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 16 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 17 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 18 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 19 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE 20 | */Object.defineProperty(t,"__esModule",{value:!0}),Object.defineProperty(t,"default",{enumerable:!0,get:function(){return u}});let a=n(43709),l=a._(n(82684)),r=n(80805),o=[],i=[],s=!1;function load(e){let t=e(),n={loading:!0,loaded:null,error:null};return n.promise=t.then(e=>(n.loading=!1,n.loaded=e,e)).catch(e=>{throw n.loading=!1,n.error=e,e}),n}let LoadableSubscription=class LoadableSubscription{promise(){return this._res.promise}retry(){this._clearTimeouts(),this._res=this._loadFn(this._opts.loader),this._state={pastDelay:!1,timedOut:!1};let{_res:e,_opts:t}=this;e.loading&&("number"==typeof t.delay&&(0===t.delay?this._state.pastDelay=!0:this._delay=setTimeout(()=>{this._update({pastDelay:!0})},t.delay)),"number"==typeof t.timeout&&(this._timeout=setTimeout(()=>{this._update({timedOut:!0})},t.timeout))),this._res.promise.then(()=>{this._update({}),this._clearTimeouts()}).catch(e=>{this._update({}),this._clearTimeouts()}),this._update({})}_update(e){this._state={...this._state,error:this._res.error,loaded:this._res.loaded,loading:this._res.loading,...e},this._callbacks.forEach(e=>e())}_clearTimeouts(){clearTimeout(this._delay),clearTimeout(this._timeout)}getCurrentValue(){return this._state}subscribe(e){return this._callbacks.add(e),()=>{this._callbacks.delete(e)}}constructor(e,t){this._loadFn=e,this._opts=t,this._callbacks=new Set,this._delay=null,this._timeout=null,this.retry()}};function Loadable(e){return function(e,t){let n=Object.assign({loader:null,loading:null,delay:200,timeout:null,webpack:null,modules:null},t),a=null;function init(){if(!a){let t=new LoadableSubscription(e,n);a={getCurrentValue:t.getCurrentValue.bind(t),subscribe:t.subscribe.bind(t),retry:t.retry.bind(t),promise:t.promise.bind(t)}}return a.promise()}if(!s){let e=n.webpack?n.webpack():n.modules;e&&i.push(t=>{for(let n of e)if(t.includes(n))return init()})}function LoadableComponent(e,t){!function(){init();let e=l.default.useContext(r.LoadableContext);e&&Array.isArray(n.modules)&&n.modules.forEach(t=>{e(t)})}();let o=l.default.useSyncExternalStore(a.subscribe,a.getCurrentValue,a.getCurrentValue);return l.default.useImperativeHandle(t,()=>({retry:a.retry}),[]),l.default.useMemo(()=>{var t;return o.loading||o.error?l.default.createElement(n.loading,{isLoading:o.loading,pastDelay:o.pastDelay,timedOut:o.timedOut,error:o.error,retry:a.retry}):o.loaded?l.default.createElement((t=o.loaded)&&t.default?t.default:t,e):null},[e,o])}return LoadableComponent.preload=()=>init(),LoadableComponent.displayName="LoadableComponent",l.default.forwardRef(LoadableComponent)}(load,e)}function flushInitializers(e,t){let n=[];for(;e.length;){let a=e.pop();n.push(a(t))}return Promise.all(n).then(()=>{if(e.length)return flushInitializers(e,t)})}Loadable.preloadAll=()=>new Promise((e,t)=>{flushInitializers(o).then(e,t)}),Loadable.preloadReady=e=>(void 0===e&&(e=[]),new Promise(t=>{let res=()=>(s=!0,t());flushInitializers(i,e).then(res,res)})),window.__NEXT_PRELOADREADY=Loadable.preloadReady;let u=Loadable},51774:function(e,t,n){e.exports=n(81743)},20328:function(e,t,n){(window.__NEXT_P=window.__NEXT_P||[]).push(["/query-data",function(){return n(81800)}])},31143:function(e,t,n){"use strict";var a=n(28598),l=n(1887),r=n.n(l);t.Z=()=>(0,a.jsxs)(r(),{children:[(0,a.jsx)("title",{children:"Defog.ai - AI Assistant for Data Analysis"}),(0,a.jsx)("meta",{name:"description",content:"Train your AI data assistant on your own device"}),(0,a.jsx)("meta",{name:"viewport",content:"width=device-width, initial-scale=1"}),(0,a.jsx)("link",{rel:"icon",href:"/favicon.ico"})]})},62429:function(e,t,n){"use strict";var a=n(28598);n(82684);var l=n(79869),r=n(12691),o=n.n(r);t.Z=e=>{let{id:t,children:n}=e,{Content:r,Sider:i}=l.Layout,s=[{key:"manage-database",title:"Manage Database",icon:(0,a.jsx)(o(),{href:"/extract-metadata",children:"\uD83D\uDCBE Manage DB"})},{key:"instruct-model",title:"Instruct Model",icon:(0,a.jsx)(o(),{href:"/instruct-model",children:"\uD83D\uDC68‍\uD83C\uDFEB Instruct Model"})},{key:"query-data",title:"Query Data",icon:(0,a.jsx)(o(),{href:"/query-data",children:"\uD83D\uDD0D Query Data"})}];return(0,a.jsx)(l.Layout,{style:{height:"100vh"},children:(0,a.jsxs)(r,{children:[(0,a.jsx)(i,{style:{height:"100vh",position:"fixed"},children:(0,a.jsx)(l.v2,{style:{width:200,paddingTop:"2em",paddingBottom:"2em"},mode:"inline",selectedKeys:[t],items:s})}),(0,a.jsx)("div",{style:{paddingLeft:240,paddingTop:30,backgroundColor:"#f5f5f5"},children:n})]})})}},81800:function(e,t,n){"use strict";n.r(t);var a=n(28598),l=n(82684),r=n(31143),o=n(62429),i=n(51774),s=n.n(i);let u=s()(()=>Promise.all([n.e(235),n.e(13),n.e(296),n.e(354),n.e(283)]).then(n.t.bind(n,14979,19)).then(e=>e.AskDefogChat),{loadableGenerated:{webpack:()=>[14979]},ssr:!1});t.default=()=>{let[e,t]=(0,l.useState)([]);return(0,l.useEffect)(()=>{fetch("http://localhost:1235/integration/get_tables_db_creds",{method:"POST",headers:{"Content-Type":"application/json"}}).then(e=>e.json()).then(e=>{e.selected_tables&&t(e.selected_tables)})},[]),(0,a.jsxs)(a.Fragment,{children:[(0,a.jsx)(r.Z,{}),(0,a.jsxs)(o.Z,{id:"query-data",children:[(0,a.jsx)("h1",{children:"Query your database"}),(0,a.jsxs)("p",{children:["Ask Defog questions about your data. You have selected the following tables: ",(0,a.jsx)("code",{children:e.join(", ")})]}),(0,a.jsx)(u,{maxWidth:"100%",height:"80vh",apiEndpoint:"http://localhost:1235/generate_query",buttonText:"Ask Defog",darkMode:!1,debugMode:!0})]})]})}}},function(e){e.O(0,[238,774,888,179],function(){return e(e.s=20328)}),_N_E=e.O()}]); -------------------------------------------------------------------------------- /defog/static/_next/static/chunks/webpack-1657be5a4830bbb9.js: -------------------------------------------------------------------------------- 1 | !function(){"use strict";var e,r,_,t,n,u,i,c,o,a={},p={};function __webpack_require__(e){var r=p[e];if(void 0!==r)return r.exports;var _=p[e]={exports:{}},t=!0;try{a[e].call(_.exports,_,_.exports,__webpack_require__),t=!1}finally{t&&delete p[e]}return _.exports}__webpack_require__.m=a,__webpack_require__.amdO={},e=[],__webpack_require__.O=function(r,_,t,n){if(_){n=n||0;for(var u=e.length;u>0&&e[u-1][2]>n;u--)e[u]=e[u-1];e[u]=[_,t,n];return}for(var i=1/0,u=0;u=n&&Object.keys(__webpack_require__.O).every(function(e){return __webpack_require__.O[e](_[o])})?_.splice(o--,1):(c=!1,nDefog.ai - AI Assistant for Data Analysis

Extract Metadata

-------------------------------------------------------------------------------- /defog/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/defog-ai/defog-python/86b2141d875259ba0cd2128d3883cd2af60f0121/defog/static/favicon.ico -------------------------------------------------------------------------------- /defog/static/index.html: -------------------------------------------------------------------------------- 1 | Defog.ai - AI Assistant for Data Analysis

Welcome to Defog!

Please wait while we log you in and redirect you to the right page...

-------------------------------------------------------------------------------- /defog/static/instruct-model.html: -------------------------------------------------------------------------------- 1 | Defog.ai - AI Assistant for Data Analysis

Instruct Model

Update glossary

The glossary is a list of instructions that the model will use to generate queries.

An example of a glossary entry is:

- When a user asks for APR time, they are asking for Average Payment Reimbursement. You can calculate this by dividing the total amount reimbursed by the number of reimbursement requests..

Golden Queries

-------------------------------------------------------------------------------- /defog/static/next.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /defog/static/query-data.html: -------------------------------------------------------------------------------- 1 | Defog.ai - AI Assistant for Data Analysis

Query your database

Ask Defog questions about your data. You have selected the following tables:

-------------------------------------------------------------------------------- /defog/static/vercel.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | prompt_toolkit 3 | psycopg2-binary>=2.9.5 4 | asyncpg 5 | aiomysql 6 | aioodbc 7 | pwinput 8 | requests>=2.28.2 9 | aiohttp 10 | tabulate 11 | uvicorn 12 | tqdm 13 | setuptools 14 | pydantic 15 | anthropic==0.52.2 16 | google-genai==1.16.1 17 | openai==1.84.0 18 | together==1.3.11 19 | tiktoken==0.9.0 20 | pytest 21 | pytest-asyncio 22 | mcp 23 | rich -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages, setup 3 | 4 | extras = { 5 | "postgres": ["psycopg2-binary"], 6 | "mysql": ["mysql-connector-python"], 7 | "snowflake": ["snowflake-connector-python"], 8 | "bigquery": ["google-cloud-bigquery"], 9 | "redshift": ["psycopg2-binary"], 10 | "databricks": ["databricks-sql-connector"], 11 | "sqlserver": ["pyodbc"], 12 | } 13 | 14 | 15 | def package_files(directory): 16 | paths = [] 17 | for path, directories, filenames in os.walk(directory): 18 | for filename in filenames: 19 | paths.append(os.path.join("..", path, filename)) 20 | return paths 21 | 22 | 23 | next_static_files = package_files("defog/static") 24 | 25 | setup( 26 | name="defog", 27 | packages=find_packages(), 28 | package_data={"defog": next_static_files}, 29 | version="0.71.0", 30 | description="Defog is a Python library that helps you generate data queries from natural language questions.", 31 | author="Full Stack Data Pte. Ltd.", 32 | license="MIT", 33 | install_requires=[ 34 | "requests>=2.28.2", 35 | "psycopg2-binary>=2.9.5", 36 | "prompt-toolkit>=3.0.38", 37 | "fastapi", 38 | "uvicorn", 39 | "tqdm", 40 | "pwinput", 41 | "aiohttp", 42 | "pydantic", 43 | "tabulate", 44 | ], 45 | entry_points={ 46 | "console_scripts": [ 47 | "defog=defog.cli:main", 48 | ], 49 | }, 50 | author_email="founders@defog.ai", 51 | url="https://github.com/defog-ai/defog-python", 52 | long_description="Defog is a Python library that helps you generate data queries from natural language questions.", 53 | long_description_content_type="text/markdown", 54 | extras_require=extras, 55 | ) 56 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/defog-ai/defog-python/86b2141d875259ba0cd2128d3883cd2af60f0121/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_citations.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | import os 4 | from defog.llm.citations import citations_tool 5 | from defog.llm.llm_providers import LLMProvider 6 | 7 | 8 | class TestCitations(unittest.IsolatedAsyncioTestCase): 9 | def setUp(self): 10 | # Skip tests if API keys are not available 11 | self.skip_openai = not os.getenv("OPENAI_API_KEY") 12 | self.skip_anthropic = not os.getenv("ANTHROPIC_API_KEY") 13 | 14 | @pytest.mark.asyncio 15 | async def test_simple_anthropic_citations(self): 16 | if self.skip_anthropic: 17 | self.skipTest("ANTHROPIC_API_KEY not set") 18 | 19 | question = "What is Rishabh's favourite food?" 20 | instructions = "Answer the question with high quality citations. If you don't know the answer, say 'I don't know'." 21 | documents = [ 22 | { 23 | "document_name": "Rishabh's favourite food.txt", 24 | "document_content": "Rishabh's favourite food is pizza.", 25 | }, 26 | { 27 | "document_name": "Medha's favourite food.txt", 28 | "document_content": "Medha's favourite food is pineapple.", 29 | }, 30 | ] 31 | 32 | response = await citations_tool( 33 | question, 34 | instructions, 35 | documents, 36 | "claude-3-7-sonnet-latest", 37 | LLMProvider.ANTHROPIC, 38 | ) 39 | 40 | # Check that response is a list of blocks 41 | self.assertIsInstance(response, list) 42 | self.assertGreater(len(response), 0) 43 | 44 | # Check that at least one block contains text mentioning pizza 45 | found_pizza = False 46 | for block in response: 47 | if block.get("type") == "text" and "pizza" in block.get("text", "").lower(): 48 | found_pizza = True 49 | break 50 | self.assertTrue( 51 | found_pizza, "Response should mention pizza as Rishabh's favourite food" 52 | ) 53 | 54 | @pytest.mark.asyncio 55 | async def test_openai_citations(self): 56 | if self.skip_openai: 57 | self.skipTest("OPENAI_API_KEY not set") 58 | 59 | question = "What are the main benefits of renewable energy?" 60 | instructions = "Provide a detailed answer with proper citations from the provided documents." 61 | documents = [ 62 | { 63 | "document_name": "Solar Energy Benefits.txt", 64 | "document_content": """Solar energy offers numerous environmental and economic benefits. It produces clean electricity 65 | without greenhouse gas emissions during operation, helping combat climate change. Solar installations can reduce 66 | electricity bills and provide energy independence. The technology has become increasingly cost-effective with 67 | falling panel prices and improved efficiency.""", 68 | }, 69 | { 70 | "document_name": "Wind Power Advantages.txt", 71 | "document_content": """Wind power is one of the fastest-growing renewable energy sources globally. It generates 72 | electricity without air pollution or water consumption during operation. Wind farms can be built on land or 73 | offshore, providing flexibility in deployment. The technology creates jobs in manufacturing, installation, 74 | and maintenance sectors.""", 75 | }, 76 | { 77 | "document_name": "Renewable Energy Economics.txt", 78 | "document_content": """Renewable energy sources have become increasingly competitive with fossil fuels in terms of cost. 79 | The levelized cost of electricity from renewables has decreased significantly over the past decade. Investment 80 | in renewable energy infrastructure stimulates economic growth and reduces dependence on volatile fossil fuel markets.""", 81 | }, 82 | ] 83 | 84 | response = await citations_tool( 85 | question, instructions, documents, "gpt-4o", LLMProvider.OPENAI 86 | ) 87 | 88 | # Check response structure 89 | self.assertIsInstance(response, list) 90 | self.assertGreater(len(response), 0) 91 | 92 | # Check that all blocks have required structure 93 | for block in response: 94 | self.assertIn("type", block) 95 | self.assertEqual(block["type"], "text") 96 | self.assertIn("text", block) 97 | self.assertIn("citations", block) 98 | self.assertIsInstance(block["citations"], list) 99 | 100 | @pytest.mark.asyncio 101 | async def test_no_relevant_documents_anthropic(self): 102 | if self.skip_anthropic: 103 | self.skipTest("ANTHROPIC_API_KEY not set") 104 | 105 | question = "What is the capital of Mars?" 106 | instructions = "Answer the question based only on the provided documents. If the information is not available, say 'I don't know'." 107 | documents = [ 108 | { 109 | "document_name": "Earth Geography.txt", 110 | "document_content": "London is the capital of England. Paris is the capital of France. Berlin is the capital of Germany.", 111 | }, 112 | { 113 | "document_name": "Ocean Facts.txt", 114 | "document_content": "The Pacific Ocean is the largest ocean on Earth. The Atlantic Ocean separates Europe and Africa from the Americas.", 115 | }, 116 | ] 117 | 118 | response = await citations_tool( 119 | question, 120 | instructions, 121 | documents, 122 | "claude-3-7-sonnet-latest", 123 | LLMProvider.ANTHROPIC, 124 | ) 125 | 126 | # Check that response indicates lack of information 127 | response_text = " ".join( 128 | [block.get("text", "") for block in response if block.get("type") == "text"] 129 | ).lower() 130 | self.assertTrue( 131 | "don't know" in response_text 132 | or "not" in response_text 133 | or "no information" in response_text, 134 | "Response should indicate lack of information about Mars capital", 135 | ) 136 | 137 | def test_unsupported_provider(self): 138 | question = "Test question" 139 | instructions = "Test instructions" 140 | documents = [{"document_name": "test.txt", "document_content": "test content"}] 141 | 142 | with self.assertRaises(ValueError) as context: 143 | # Using asyncio.run since this should fail immediately 144 | import asyncio 145 | 146 | asyncio.run( 147 | citations_tool( 148 | question, 149 | instructions, 150 | documents, 151 | "test-model", 152 | "unsupported_provider", 153 | ) 154 | ) 155 | 156 | self.assertIn("not supported", str(context.exception)) 157 | 158 | 159 | if __name__ == "__main__": 160 | unittest.main() 161 | -------------------------------------------------------------------------------- /tests/test_code_interp.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import unittest 6 | import pytest 7 | import os 8 | from defog.llm.code_interp import code_interpreter_tool 9 | from defog.llm.llm_providers import LLMProvider 10 | 11 | 12 | class TestCodeInterp(unittest.IsolatedAsyncioTestCase): 13 | def setUp(self): 14 | """Set up test fixtures with sample CSV data for all tests.""" 15 | self.sample_csv = """name,age,salary,department 16 | John Doe,30,50000,Engineering 17 | Jane Smith,25,45000,Marketing 18 | Bob Johnson,35,60000,Engineering 19 | Alice Brown,28,55000,Sales 20 | Charlie Wilson,32,48000,Marketing""" 21 | 22 | self.complex_question = "What is the average salary by department? Show the results in a table format." 23 | 24 | @pytest.mark.asyncio 25 | async def test_openai_complex_analysis(self): 26 | """Test OpenAI provider with complex aggregation question.""" 27 | if not os.getenv("OPENAI_API_KEY"): 28 | self.skipTest("OPENAI_API_KEY not set") 29 | 30 | response = await code_interpreter_tool( 31 | question=self.complex_question, 32 | model="gpt-4o", 33 | provider=LLMProvider.OPENAI, 34 | csv_string=self.sample_csv, 35 | ) 36 | 37 | self.assertIsInstance(response, dict) 38 | self.assertIn("code", response) 39 | self.assertIn("output", response) 40 | self.assertGreater(len(response["output"]), 0) 41 | # Should contain department names from our data in output 42 | output_text = response["output"] 43 | self.assertIn("Engineering", output_text) 44 | self.assertIn("Marketing", output_text) 45 | self.assertIn("Sales", output_text) 46 | 47 | @pytest.mark.asyncio 48 | async def test_anthropic_complex_analysis(self): 49 | """Test Anthropic provider with complex aggregation question.""" 50 | if not os.getenv("ANTHROPIC_API_KEY"): 51 | self.skipTest("ANTHROPIC_API_KEY not set") 52 | 53 | response = await code_interpreter_tool( 54 | question=self.complex_question, 55 | model="claude-3-7-sonnet-latest", 56 | provider=LLMProvider.ANTHROPIC, 57 | csv_string=self.sample_csv, 58 | ) 59 | 60 | self.assertIsInstance(response, dict) 61 | self.assertIn("code", response) 62 | self.assertIn("output", response) 63 | self.assertGreater(len(response["output"]), 0) 64 | # Should contain department names from our data in output 65 | output_text = response["output"] 66 | self.assertIn("Engineering", output_text) 67 | self.assertIn("Marketing", output_text) 68 | self.assertIn("Sales", output_text) 69 | 70 | @pytest.mark.asyncio 71 | async def test_gemini_complex_analysis(self): 72 | """Test Gemini provider with complex aggregation question.""" 73 | if not os.getenv("GEMINI_API_KEY"): 74 | self.skipTest("GEMINI_API_KEY not set") 75 | 76 | response = await code_interpreter_tool( 77 | question=self.complex_question, 78 | model="gemini-2.0-flash", 79 | provider=LLMProvider.GEMINI, 80 | csv_string=self.sample_csv, 81 | ) 82 | 83 | self.assertIsInstance(response, dict) 84 | self.assertIn("code", response) 85 | self.assertIn("output", response) 86 | # Should have generated some code and output 87 | self.assertGreater(len(response["code"] + response["output"]), 0) 88 | 89 | @pytest.mark.asyncio 90 | async def test_large_dataset_analysis(self): 91 | """Test analysis with larger dataset.""" 92 | if not os.getenv("OPENAI_API_KEY"): 93 | self.skipTest("OPENAI_API_KEY not set") 94 | 95 | # Generate larger CSV with 100 rows 96 | large_csv_header = "id,name,age,salary,department,years_experience\n" 97 | large_csv_rows = [] 98 | departments = ["Engineering", "Marketing", "Sales", "HR", "Finance"] 99 | 100 | for i in range(100): 101 | dept = departments[i % len(departments)] 102 | age = 22 + (i % 40) 103 | salary = 40000 + (i * 500) 104 | experience = max(0, age - 22) 105 | large_csv_rows.append( 106 | f"{i+1},Employee{i+1},{age},{salary},{dept},{experience}" 107 | ) 108 | 109 | large_csv = large_csv_header + "\n".join(large_csv_rows) 110 | 111 | response = await code_interpreter_tool( 112 | question="What are the key insights about salary distribution across departments? Include summary statistics.", 113 | model="gpt-4o", 114 | provider=LLMProvider.OPENAI, 115 | csv_string=large_csv, 116 | ) 117 | 118 | self.assertIsInstance(response, dict) 119 | self.assertIn("code", response) 120 | self.assertIn("output", response) 121 | self.assertGreater(len(response["output"]), 0) 122 | # Should contain department names and statistical terms in output 123 | output_lower = response["output"].lower() 124 | self.assertTrue(any(dept.lower() in output_lower for dept in departments)) 125 | self.assertTrue( 126 | any( 127 | word in output_lower 128 | for word in ["mean", "median", "average", "distribution"] 129 | ) 130 | ) 131 | 132 | @pytest.mark.asyncio 133 | async def test_mathematical_calculations(self): 134 | """Test mathematical calculations and formulas.""" 135 | if not os.getenv("ANTHROPIC_API_KEY"): 136 | self.skipTest("ANTHROPIC_API_KEY not set") 137 | 138 | math_csv = """product,price,quantity,discount_rate 139 | Laptop,1000,5,0.1 140 | Mouse,25,100,0.05 141 | Keyboard,75,50,0.15 142 | Monitor,300,20,0.08""" 143 | 144 | response = await code_interpreter_tool( 145 | question="Calculate the total revenue after applying discounts for each product, and find which product generates the most revenue.", 146 | model="claude-3-7-sonnet-latest", 147 | provider=LLMProvider.ANTHROPIC, 148 | csv_string=math_csv, 149 | ) 150 | 151 | self.assertIsInstance(response, dict) 152 | self.assertIn("code", response) 153 | self.assertIn("output", response) 154 | self.assertGreater(len(response["output"]), 0) 155 | # Should contain product names and revenue calculations in output 156 | output_text = response["output"] 157 | self.assertIn("Laptop", output_text) 158 | self.assertTrue(any(char.isdigit() for char in output_text)) 159 | 160 | @pytest.mark.asyncio 161 | async def test_time_series_data(self): 162 | """Test analysis with time-based data.""" 163 | if not os.getenv("OPENAI_API_KEY"): 164 | self.skipTest("OPENAI_API_KEY not set") 165 | 166 | time_series_csv = """date,sales,customers 167 | 2024-01-01,1000,50 168 | 2024-01-02,1200,60 169 | 2024-01-03,800,40 170 | 2024-01-04,1500,75 171 | 2024-01-05,1100,55 172 | 2024-01-06,1300,65 173 | 2024-01-07,900,45""" 174 | 175 | response = await code_interpreter_tool( 176 | question="Analyze the sales trend over time and calculate the daily growth rate.", 177 | model="gpt-4o", 178 | provider=LLMProvider.OPENAI, 179 | csv_string=time_series_csv, 180 | ) 181 | 182 | self.assertIsInstance(response, dict) 183 | self.assertIn("code", response) 184 | self.assertIn("output", response) 185 | self.assertGreater(len(response["output"]), 0) 186 | # Should contain time-related analysis terms in output 187 | output_lower = response["output"].lower() 188 | self.assertTrue( 189 | any(word in output_lower for word in ["trend", "growth", "time", "daily"]) 190 | ) 191 | 192 | 193 | if __name__ == "__main__": 194 | unittest.main() 195 | -------------------------------------------------------------------------------- /tests/test_llm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | from defog.llm.utils import ( 4 | map_model_to_provider, 5 | chat_async, 6 | ) 7 | from defog.llm.llm_providers import LLMProvider 8 | import re 9 | 10 | from pydantic import BaseModel 11 | 12 | messages_sql = [ 13 | { 14 | "role": "system", 15 | "content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.", 16 | }, 17 | { 18 | "role": "user", 19 | "content": f"""Question: What is the total number of orders? 20 | Schema: 21 | ```sql 22 | CREATE TABLE orders ( 23 | order_id int, 24 | customer_id int, 25 | employee_id int, 26 | order_date date 27 | ); 28 | ``` 29 | """, 30 | }, 31 | ] 32 | 33 | acceptable_sql = [ 34 | "select count(*) from orders", 35 | "select count(order_id) from orders", 36 | "select count(*) as total_orders from orders", 37 | "select count(order_id) as total_orders from orders", 38 | ] 39 | 40 | 41 | class ResponseFormat(BaseModel): 42 | reasoning: str 43 | sql: str 44 | 45 | 46 | messages_sql_structured = [ 47 | { 48 | "role": "system", 49 | "content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases.", 50 | }, 51 | { 52 | "role": "user", 53 | "content": f"""Question: What is the total number of orders? 54 | Schema: 55 | ```sql 56 | CREATE TABLE orders ( 57 | order_id int, 58 | customer_id int, 59 | employee_id int, 60 | order_date date 61 | ); 62 | ``` 63 | """, 64 | }, 65 | ] 66 | 67 | 68 | class TestChatClients(unittest.IsolatedAsyncioTestCase): 69 | def check_sql(self, sql: str): 70 | sql = sql.replace("```sql", "").replace("```", "").strip(";\n").lower() 71 | sql = re.sub(r"(\s+)", " ", sql) 72 | self.assertIn(sql, acceptable_sql) 73 | 74 | def test_map_model_to_provider(self): 75 | self.assertEqual( 76 | map_model_to_provider("claude-3-5-sonnet-20241022"), 77 | LLMProvider.ANTHROPIC, 78 | ) 79 | 80 | self.assertEqual( 81 | map_model_to_provider("gemini-1.5-flash-002"), 82 | LLMProvider.GEMINI, 83 | ) 84 | 85 | self.assertEqual(map_model_to_provider("gpt-4o-mini"), LLMProvider.OPENAI) 86 | 87 | self.assertEqual(map_model_to_provider("deepseek-chat"), LLMProvider.DEEPSEEK) 88 | 89 | self.assertEqual( 90 | map_model_to_provider("meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"), 91 | LLMProvider.TOGETHER, 92 | ) 93 | 94 | with self.assertRaises(Exception): 95 | map_model_to_provider("unknown-model") 96 | 97 | @pytest.mark.asyncio(loop_scope="session") 98 | async def test_simple_chat_async(self): 99 | models = [ 100 | "claude-3-7-sonnet-latest", 101 | "gpt-4.1-mini", 102 | "o4-mini", 103 | "o3", 104 | "gemini-2.0-flash", 105 | "gemini-2.5-pro-preview-03-25", 106 | ] 107 | messages = [ 108 | {"role": "user", "content": "Return a greeting in not more than 2 words\n"} 109 | ] 110 | for model in models: 111 | provider = map_model_to_provider(model) 112 | response = await chat_async( 113 | provider=provider, 114 | model=model, 115 | messages=messages, 116 | max_completion_tokens=4000, 117 | temperature=0.0, 118 | seed=0, 119 | max_retries=1, 120 | ) 121 | self.assertIsInstance(response.content, str) 122 | self.assertIsInstance(response.time, float) 123 | 124 | @pytest.mark.asyncio(loop_scope="session") 125 | async def test_sql_chat_async(self): 126 | models = [ 127 | "gpt-4o-mini", 128 | "o1", 129 | "gemini-2.0-flash", 130 | "gemini-2.5-pro-preview-03-25", 131 | "o3", 132 | "o4-mini", 133 | "gpt-4.1-mini", 134 | "gpt-4.1-nano", 135 | ] 136 | for model in models: 137 | provider = map_model_to_provider(model) 138 | response = await chat_async( 139 | provider=provider, 140 | model=model, 141 | messages=messages_sql, 142 | max_completion_tokens=4000, 143 | temperature=0.0, 144 | seed=0, 145 | max_retries=1, 146 | ) 147 | self.check_sql(response.content) 148 | self.assertIsInstance(response.time, float) 149 | 150 | @pytest.mark.asyncio(loop_scope="session") 151 | async def test_sql_chat_structured_reasoning_effort_async(self): 152 | reasoning_effort = ["low", "medium", "high", None] 153 | for effort in reasoning_effort: 154 | for model in ["o4-mini", "claude-3-7-sonnet-latest"]: 155 | provider = map_model_to_provider(model) 156 | response = await chat_async( 157 | provider=provider, 158 | model=model, 159 | messages=messages_sql_structured, 160 | max_completion_tokens=32000, 161 | temperature=0.0, 162 | seed=0, 163 | response_format=ResponseFormat, 164 | reasoning_effort=effort, 165 | max_retries=1, 166 | ) 167 | self.check_sql(response.content.sql) 168 | self.assertIsInstance(response.content.reasoning, str) 169 | 170 | @pytest.mark.asyncio(loop_scope="session") 171 | async def test_sql_chat_structured_async(self): 172 | models = [ 173 | "gpt-4o", 174 | "o1", 175 | "gemini-2.0-flash", 176 | "gemini-2.5-pro-preview-03-25", 177 | "claude-3-7-sonnet-latest", # Added Anthropic model to test structured output 178 | "o3", 179 | "o4-mini", 180 | "gpt-4.1-mini", 181 | "gpt-4.1-nano", 182 | ] 183 | for model in models: 184 | provider = map_model_to_provider(model) 185 | response = await chat_async( 186 | provider=provider, 187 | model=model, 188 | messages=messages_sql_structured, 189 | max_completion_tokens=4000, 190 | temperature=0.0, 191 | seed=0, 192 | response_format=ResponseFormat, 193 | max_retries=1, 194 | ) 195 | self.check_sql(response.content.sql) 196 | self.assertIsInstance(response.content.reasoning, str) 197 | 198 | 199 | if __name__ == "__main__": 200 | unittest.main() 201 | -------------------------------------------------------------------------------- /tests/test_llm_response.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from defog.llm.utils import LLMResponse 3 | from defog.llm.cost.calculator import CostCalculator 4 | 5 | 6 | class TestLLMResponse(unittest.TestCase): 7 | def test_cost_calculator(self): 8 | # Test CostCalculator directly 9 | self.assertAlmostEqual( 10 | CostCalculator.calculate_cost("gpt-4o", 1000, 1000, 500), 11 | (0.0025 * 1 + 0.00125 * 0.5 + 0.01 * 1) * 100, 12 | places=10, 13 | ) 14 | 15 | self.assertAlmostEqual( 16 | CostCalculator.calculate_cost("claude-3-5-sonnet", 1000, 1000), 17 | (0.003 * 1 + 0.015 * 1) * 100, 18 | places=10, 19 | ) 20 | 21 | # Test unsupported model 22 | self.assertIsNone(CostCalculator.calculate_cost("unknown-model", 1000, 1000)) 23 | 24 | def test_token_costs(self): 25 | test_cases = [ 26 | { 27 | "model_name": "gpt-4o", 28 | "input_tokens": 1000, 29 | "cached_input_tokens": 500, 30 | "output_tokens": 1000, 31 | }, 32 | { 33 | "model_name": "gpt-4o-mini", 34 | "input_tokens": 1000, 35 | "cached_input_tokens": 500, 36 | "output_tokens": 1000, 37 | }, 38 | { 39 | "model_name": "o1", 40 | "input_tokens": 1000, 41 | "cached_input_tokens": 500, 42 | "output_tokens": 1000, 43 | }, 44 | { 45 | "model_name": "o1-mini", 46 | "input_tokens": 1000, 47 | "cached_input_tokens": 500, 48 | "output_tokens": 1000, 49 | }, 50 | { 51 | "model_name": "o3-mini", 52 | "input_tokens": 1000, 53 | "cached_input_tokens": 500, 54 | "output_tokens": 1000, 55 | }, 56 | { 57 | "model_name": "deepseek-chat", 58 | "input_tokens": 1000, 59 | "cached_input_tokens": 500, 60 | "output_tokens": 1000, 61 | }, 62 | { 63 | "model_name": "deepseek-reasoner", 64 | "input_tokens": 1000, 65 | "cached_input_tokens": 500, 66 | "output_tokens": 1000, 67 | }, 68 | { 69 | "model_name": "claude-3-5-sonnet", 70 | "input_tokens": 1000, 71 | "cached_input_tokens": 0, 72 | "output_tokens": 1000, 73 | }, 74 | ] 75 | 76 | for case in test_cases: 77 | with self.subTest(case=case): 78 | # Calculate expected cost using CostCalculator 79 | expected_cost_in_cents = CostCalculator.calculate_cost( 80 | model=case["model_name"], 81 | input_tokens=case["input_tokens"], 82 | output_tokens=case["output_tokens"], 83 | cached_input_tokens=case["cached_input_tokens"], 84 | ) 85 | 86 | # Create LLMResponse and check if cost is calculated correctly 87 | response = LLMResponse( 88 | content="", 89 | model=case["model_name"], 90 | time=0.0, 91 | input_tokens=case["input_tokens"], 92 | output_tokens=case["output_tokens"], 93 | cached_input_tokens=case["cached_input_tokens"], 94 | cost_in_cents=expected_cost_in_cents, 95 | ) 96 | 97 | self.assertAlmostEqual( 98 | response.cost_in_cents, expected_cost_in_cents, places=10 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import Mock, MagicMock, patch 3 | from defog.util import identify_categorical_columns, parse_update, get_feedback 4 | 5 | 6 | class TestIdentifyCategoricalColumns(unittest.TestCase): 7 | def setUp(self): 8 | self.cur = Mock() 9 | self.cur.execute = MagicMock() 10 | self.cur.fetchone = MagicMock() 11 | self.cur.fetchall = MagicMock() 12 | 13 | def test_identify_categorical_columns_succeed(self): 14 | # Mock 2 distinct values each with their respective occurence counts of 3, 30 15 | self.cur.fetchone.return_value = (2,) 16 | self.cur.fetchall.return_value = [("value1", 3), ("value2", 30)] 17 | rows = [ 18 | {"column_name": "test_column", "data_type": "varchar"}, 19 | {"column_name": "test_column_int", "data_type": "bigint"}, 20 | ] 21 | 22 | # Call the function 23 | result = identify_categorical_columns(self.cur, "test_table", rows) 24 | 25 | # Assert the results 26 | self.assertEqual(len(result), len(rows)) 27 | self.assertEqual(result[0]["top_values"], "value1,value2") 28 | 29 | def test_identify_categorical_columns_exceed_threshold(self): 30 | # Mock 20 distinct values each with their respective occurence counts of 1 31 | self.cur.fetchone.return_value = (20,) 32 | self.cur.fetchall.return_value = [(f"value{i}", 1) for i in range(20)] 33 | rows = [{"column_name": "test_column", "data_type": "varchar"}] 34 | 35 | # Call the function 36 | result = identify_categorical_columns( 37 | self.cur, "test_table", rows, distinct_threshold=10 38 | ) 39 | 40 | # Assert results is still the same as rows (not modified) 41 | self.assertEqual(result, rows) 42 | self.assertIn("column_name", result[0]) 43 | self.assertIn("data_type", result[0]) 44 | self.assertNotIn("top_values", result[0]) 45 | 46 | def test_identify_categorical_columns_within_modified_threshold(self): 47 | # Mock 20 distinct values each with their respective occurence counts of 1 48 | self.cur.fetchone.return_value = (20,) 49 | self.cur.fetchall.return_value = [(f"value{i}", 1) for i in range(20)] 50 | rows = [{"column_name": "test_column", "data_type": "varchar"}] 51 | 52 | # Call the function 53 | result = identify_categorical_columns( 54 | self.cur, "test_table", rows, distinct_threshold=20 55 | ) 56 | 57 | # Assert that we get 20 distinct values as required 58 | self.assertEqual(len(result), 1) 59 | self.assertEqual( 60 | result[0]["top_values"], 61 | "value0,value1,value10,value11,value12,value13,value14,value15,value16,value17,value18,value19,value2,value3,value4,value5,value6,value7,value8,value9", 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | unittest.main() 67 | 68 | 69 | class TestGetFeedback(unittest.TestCase): 70 | @patch("defog.util.prompt", return_value="y") 71 | @patch("requests.post") 72 | def test_positive_feedback(self, mock_post, mock_prompt): 73 | get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") 74 | assert mock_post.call_count == 1 75 | self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"]) 76 | self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) 77 | 78 | @patch("defog.util.prompt", side_effect=["n", "bad query", "", "", ""]) 79 | @patch("requests.post") 80 | def test_negative_feedback_with_text(self, mock_post, mock_prompt): 81 | get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") 82 | # 2 calls: 1 to /feedback, 1 to /reflect_on_error 83 | assert mock_post.call_count == 2 84 | self.assertIn("api_key", mock_post.call_args.kwargs["json"]["api_key"]) 85 | self.assertIn("user_question", mock_post.call_args.kwargs["json"]["question"]) 86 | self.assertIn( 87 | "sql_generated", mock_post.call_args.kwargs["json"]["sql_generated"] 88 | ) 89 | self.assertIn("bad query", mock_post.call_args.kwargs["json"]["error"]) 90 | 91 | @patch("defog.util.prompt", side_effect=["n", "", "", "", ""]) 92 | @patch("requests.post") 93 | def test_negative_feedback_without_text(self, mock_post, mock_prompt): 94 | get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") 95 | # 2 calls: 1 to /feedback, 1 to /reflect_on_error 96 | assert mock_post.call_count == 2 97 | self.assertIn("api_key", mock_post.call_args.kwargs["json"]["api_key"]) 98 | self.assertIn("user_question", mock_post.call_args.kwargs["json"]["question"]) 99 | self.assertIn( 100 | "sql_generated", mock_post.call_args.kwargs["json"]["sql_generated"] 101 | ) 102 | self.assertIn("", mock_post.call_args.kwargs["json"]["error"]) 103 | 104 | @patch("defog.util.prompt", side_effect=["invalid", "y"]) 105 | @patch("requests.post") 106 | def test_invalid_then_valid_input(self, mock_post, mock_prompt): 107 | get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") 108 | assert mock_post.call_count == 1 109 | self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"]) 110 | self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) 111 | 112 | @patch("defog.util.prompt", side_effect=[""]) 113 | @patch("requests.post") 114 | def test_skip_input(self, mock_post, mock_prompt): 115 | get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") 116 | mock_post.assert_not_called() 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /tests/test_web_search.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | import os 4 | from defog.llm.web_search import web_search_tool 5 | from defog.llm.llm_providers import LLMProvider 6 | 7 | 8 | class TestWebSearchTool(unittest.IsolatedAsyncioTestCase): 9 | 10 | def setUp(self): 11 | self.test_question = "What is the capital of France?" 12 | self.max_tokens = 1024 13 | 14 | def _validate_basic_structure(self, result): 15 | """Validate basic structure common to all providers""" 16 | self.assertIsInstance(result, dict) 17 | self.assertIn("usage", result) 18 | self.assertIn("search_results", result) 19 | self.assertIn("websites_cited", result) 20 | 21 | def _validate_usage_structure(self, usage, provider): 22 | """Validate usage structure with provider-specific fields""" 23 | self.assertIsInstance(usage, dict) 24 | self.assertIn("input_tokens", usage) 25 | self.assertIn("output_tokens", usage) 26 | self.assertIsInstance(usage["input_tokens"], int) 27 | self.assertIsInstance(usage["output_tokens"], int) 28 | self.assertGreater(usage["input_tokens"], 0) 29 | self.assertGreater(usage["output_tokens"], 0) 30 | 31 | if provider == LLMProvider.GEMINI: 32 | self.assertIn("thinking_tokens", usage) 33 | self.assertIsInstance(usage["thinking_tokens"], int) 34 | 35 | def _validate_citations(self, citations, provider): 36 | """Validate citations structure with provider-specific fields""" 37 | self.assertIsInstance(citations, list) 38 | for citation in citations: 39 | self.assertIsInstance(citation, dict) 40 | if provider == LLMProvider.GEMINI: 41 | self.assertIn("source", citation) 42 | self.assertIn("url", citation) 43 | self.assertIsInstance(citation["source"], str) 44 | else: 45 | self.assertIn("url", citation) 46 | self.assertIn("title", citation) 47 | self.assertIsInstance(citation["title"], str) 48 | 49 | self.assertIsInstance(citation["url"], str) 50 | self.assertTrue(citation["url"].startswith(("http://", "https://"))) 51 | 52 | def _validate_search_results(self, search_results, provider): 53 | """Validate search results structure with provider-specific types""" 54 | if provider == LLMProvider.ANTHROPIC: 55 | self.assertIsInstance(search_results, list) 56 | self.assertGreater(len(search_results), 0) 57 | else: 58 | self.assertIsInstance(search_results, str) 59 | self.assertGreater(len(search_results), 0) 60 | 61 | async def _test_provider_structure(self, provider, model, api_key_env): 62 | """Generic test for provider structure""" 63 | if not os.getenv(api_key_env): 64 | self.skipTest(f"{api_key_env} not set") 65 | 66 | result = await web_search_tool( 67 | question=self.test_question, 68 | model=model, 69 | provider=provider, 70 | max_tokens=self.max_tokens, 71 | ) 72 | 73 | self._validate_basic_structure(result) 74 | self._validate_usage_structure(result["usage"], provider) 75 | self._validate_search_results(result["search_results"], provider) 76 | self._validate_citations(result["websites_cited"], provider) 77 | 78 | return result 79 | 80 | @pytest.mark.asyncio 81 | async def test_web_search_openai_structure(self): 82 | await self._test_provider_structure( 83 | LLMProvider.OPENAI, "gpt-4.1-mini", "OPENAI_API_KEY" 84 | ) 85 | 86 | @pytest.mark.asyncio 87 | async def test_web_search_anthropic_structure(self): 88 | await self._test_provider_structure( 89 | LLMProvider.ANTHROPIC, "claude-3-7-sonnet-latest", "ANTHROPIC_API_KEY" 90 | ) 91 | 92 | @pytest.mark.asyncio 93 | async def test_web_search_gemini_structure(self): 94 | await self._test_provider_structure( 95 | LLMProvider.GEMINI, "gemini-2.0-flash", "GEMINI_API_KEY" 96 | ) 97 | 98 | @pytest.mark.asyncio 99 | async def test_web_search_unsupported_provider(self): 100 | with self.assertRaises(ValueError) as context: 101 | await web_search_tool( 102 | question=self.test_question, 103 | model="test-model", 104 | provider=LLMProvider.GROK, 105 | ) 106 | 107 | self.assertIn("Provider LLMProvider.GROK not supported", str(context.exception)) 108 | 109 | @pytest.mark.asyncio 110 | async def test_web_search_different_questions(self): 111 | questions = [ 112 | "What is machine learning?", 113 | "Current weather in Tokyo", 114 | "Latest news about artificial intelligence", 115 | ] 116 | 117 | providers_config = [ 118 | (LLMProvider.OPENAI, "gpt-4.1-mini", "OPENAI_API_KEY"), 119 | (LLMProvider.ANTHROPIC, "claude-3-7-sonnet-latest", "ANTHROPIC_API_KEY"), 120 | (LLMProvider.GEMINI, "gemini-2.0-flash", "GEMINI_API_KEY"), 121 | ] 122 | 123 | available_providers = [ 124 | (provider, model) 125 | for provider, model, env_key in providers_config 126 | if os.getenv(env_key) 127 | ] 128 | 129 | if not available_providers: 130 | self.skipTest("No API keys set for testing") 131 | 132 | for provider, model in available_providers: 133 | for question in questions: 134 | with self.subTest(provider=provider.value, question=question): 135 | result = await web_search_tool( 136 | question=question, 137 | model=model, 138 | provider=provider, 139 | max_tokens=1024, 140 | ) 141 | 142 | self._validate_basic_structure(result) 143 | self._validate_search_results(result["search_results"], provider) 144 | 145 | @pytest.mark.asyncio 146 | async def test_web_search_custom_max_tokens(self): 147 | providers_config = [ 148 | (LLMProvider.OPENAI, "gpt-4.1-mini", "OPENAI_API_KEY"), 149 | (LLMProvider.ANTHROPIC, "claude-3-7-sonnet-latest", "ANTHROPIC_API_KEY"), 150 | (LLMProvider.GEMINI, "gemini-2.0-flash", "GEMINI_API_KEY"), 151 | ] 152 | 153 | available_providers = [ 154 | (provider, model) 155 | for provider, model, env_key in providers_config 156 | if os.getenv(env_key) 157 | ] 158 | 159 | if not available_providers: 160 | self.skipTest("No API keys set for testing") 161 | 162 | for provider, model in available_providers: 163 | with self.subTest(provider=provider.value): 164 | result = await web_search_tool( 165 | question="Brief summary of Python programming language", 166 | model=model, 167 | provider=provider, 168 | max_tokens=1024, 169 | ) 170 | 171 | self._validate_basic_structure(result) 172 | 173 | 174 | if __name__ == "__main__": 175 | unittest.main() 176 | -------------------------------------------------------------------------------- /tests/test_youtube.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import pytest 4 | from defog.llm.youtube_transcript import get_transcript 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_youtube_transcript_end_to_end(): 9 | """End-to-end test for YouTube transcript generation.""" 10 | # Skip test if GEMINI_API_KEY is not set 11 | if not os.getenv("GEMINI_API_KEY"): 12 | pytest.skip("GEMINI_API_KEY not set") 13 | 14 | # Use a short, public YouTube video for testing 15 | video_url = "https://www.youtube.com/watch?v=EysJTNLQVZw" 16 | 17 | # Get transcript 18 | transcript = await get_transcript(video_url) 19 | 20 | # Basic assertions 21 | assert transcript is not None 22 | assert isinstance(transcript, str) 23 | assert len(transcript) > 0 24 | 25 | # Check that transcript contains some expected content 26 | # (This will vary by video, but should contain some words) 27 | assert len(transcript.split()) > 10 28 | 29 | print(f"Generated transcript ({len(transcript)} characters):") 30 | print(transcript[:200] + "..." if len(transcript) > 200 else transcript) 31 | 32 | 33 | if __name__ == "__main__": 34 | asyncio.run(test_youtube_transcript_end_to_end()) 35 | --------------------------------------------------------------------------------