├── .gcloudignore ├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── defog ├── __init__.py ├── admin_methods.py ├── async_admin_methods.py ├── async_generate_schema.py ├── async_health_methods.py ├── async_query_methods.py ├── cli.py ├── generate_schema.py ├── health_methods.py ├── llm │ ├── __init__.py │ ├── citations.py │ ├── code_interp.py │ ├── config │ │ ├── __init__.py │ │ ├── constants.py │ │ └── settings.py │ ├── cost │ │ ├── __init__.py │ │ ├── calculator.py │ │ └── models.py │ ├── exceptions │ │ ├── __init__.py │ │ └── llm_exceptions.py │ ├── llm_providers.py │ ├── mcp_readme.md │ ├── memory │ │ ├── __init__.py │ │ ├── compactifier.py │ │ ├── history_manager.py │ │ └── token_counter.py │ ├── models.py │ ├── orchestrator.py │ ├── providers │ │ ├── __init__.py │ │ ├── anthropic_provider.py │ │ ├── base.py │ │ ├── gemini_provider.py │ │ ├── openai_provider.py │ │ └── together_provider.py │ ├── tools │ │ ├── __init__.py │ │ └── handler.py │ ├── utils.py │ ├── utils_function_calling.py │ ├── utils_logging.py │ ├── utils_mcp.py │ ├── utils_memory.py │ ├── web_search.py │ └── youtube_transcript.py ├── query.py ├── query_methods.py ├── serve.py ├── static │ ├── 404.html │ ├── _next │ │ └── static │ │ │ ├── chunks │ │ │ ├── 0c428ae2.7f5ab17ef2110d84.js │ │ │ ├── 238-21e16f207d48d221.js │ │ │ ├── 283.503b46c22e64c702.js │ │ │ ├── 9b8d1757.9dd2c712a441d5f8.js │ │ │ ├── ea88be26.e203a40320569dfc.js │ │ │ ├── faae6fda.7ebb95670d232bb0.js │ │ │ ├── framework-02223fe42ab9321b.js │ │ │ ├── main-d30d248d262e39c4.js │ │ │ ├── pages │ │ │ │ ├── _app-db0976def6406e5e.js │ │ │ │ ├── _error-ee42a9921d95ff81.js │ │ │ │ ├── extract-metadata-2adf74ad3bcc8699.js │ │ │ │ ├── index-b60f249c1d54d3bf.js │ │ │ │ ├── instruct-model-d040be04cf7f21f2.js │ │ │ │ └── query-data-b8197e7950b177eb.js │ │ │ ├── polyfills-c67a75d1b6f99dc8.js │ │ │ └── webpack-1657be5a4830bbb9.js │ │ │ ├── css │ │ │ └── 321c398b2a784143.css │ │ │ └── lQfswjqOtyFH-kQc0_q8t │ │ │ ├── _buildManifest.js │ │ │ └── _ssgManifest.js │ ├── extract-metadata.html │ ├── favicon.ico │ ├── index.html │ ├── instruct-model.html │ ├── next.svg │ ├── query-data.html │ └── vercel.svg └── util.py ├── orchestrator_dynamic_example.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── test_async_defog.py ├── test_citations.py ├── test_code_interp.py ├── test_defog.py ├── test_llm.py ├── test_llm_response.py ├── test_llm_tool_calls.py ├── test_mcp_chat_async.py ├── test_memory.py ├── test_orchestrator_e2e.py ├── test_query.py ├── test_util.py ├── test_web_search.py └── test_youtube.py /.gcloudignore: -------------------------------------------------------------------------------- 1 | # This file specifies files that are *not* uploaded to Google Cloud 2 | # using gcloud. It follows the same syntax as .gitignore, with the addition of 3 | # "#!include" directives (which insert the entries of the given .gitignore-style 4 | # file at that point). 5 | # 6 | # For more information, run: 7 | # $ gcloud topic gcloudignore 8 | # 9 | .gcloudignore 10 | # If you would like to upload your .git directory, .gitignore file or files 11 | # from your .gitignore file, remove the corresponding line 12 | # below: 13 | .git 14 | .gitignore 15 | 16 | node_modules 17 | #!include:.gitignore 18 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | changes: 7 | runs-on: ubuntu-latest 8 | outputs: 9 | defog: ${{ steps.filter.outputs.defog }} 10 | tests: ${{ steps.filter.outputs.tests }} 11 | requirements: ${{ steps.filter.outputs.requirements }} 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: dorny/paths-filter@v2 15 | id: filter 16 | with: 17 | filters: | 18 | defog: 19 | - 'defog/**' 20 | tests: 21 | - 'tests/**' 22 | requirements: 23 | - 'requirements.txt' 24 | - 'setup.py' 25 | - 'setup.cfg' 26 | 27 | test: 28 | runs-on: ubuntu-latest 29 | needs: [changes] 30 | if: needs.changes.outputs.defog == 'true' || needs.changes.outputs.tests == 'true' || needs.changes.outputs.requirements == 'true' 31 | steps: 32 | - uses: actions/checkout@v2 33 | - name: Set up Python 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: '3.10' 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -r requirements.txt 41 | - name: Run affected tests 42 | run: | 43 | if [ "${{ needs.changes.outputs.requirements }}" == "true" ]; then 44 | pytest tests 45 | elif [ "${{ needs.changes.outputs.defog }}" == "true" ] || [ "${{ needs.changes.outputs.tests }}" == "true" ]; then 46 | pytest tests 47 | fi 48 | env: 49 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 50 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 51 | TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} 52 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Jupyter Notebook 55 | .ipynb_checkpoints 56 | 57 | # IPython 58 | profile_default/ 59 | ipython_config.py 60 | 61 | # pyenv 62 | .python-version 63 | 64 | *.DS_Store 65 | local_test.py 66 | 67 | defog_metadata.csv 68 | golden_queries.csv 69 | golden_queries.json 70 | glossary.txt 71 | 72 | # Ignore virtual environment directories 73 | .virtual/ 74 | myenv/ 75 | venv/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Defog.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Defog Python 2 | 3 | [](https://github.com/defog-ai/defog-python/actions/workflows/main.yml) 4 | 5 | # TLDR 6 | 7 | This library used to be an SDK for accessing Defog's cloud hosted text-to-SQL service. It has since transformed into a general purpose library for: 8 | 9 | 1. Making convenient, cross-provider LLM calls (including server-hosted tools) 10 | 2. Easily extracting information from databases to make them easy to use 11 | 12 | If you are looking for text-to-SQL or deep-research like capabilities, check out [introspect](https://github.com/defog-ai/introspect), our open-source, MIT licensed repo that uses this library as a dependency. 13 | 14 | # Using this library 15 | 16 | ## LLM Utilities (`defog.llm`) 17 | 18 | The `defog.llm` module provides cross-provider LLM functionality with support for function calling, structured output, and specialized tools. 19 | 20 | **Note:** As of the latest version, all LLM functions are async-only. Synchronous methods have been removed to improve performance and consistency. 21 | 22 | ### Core Chat Functions 23 | 24 | ```python 25 | from defog.llm.utils import chat_async, chat_async_legacy, LLMResponse 26 | from defog.llm.llm_providers import LLMProvider 27 | 28 | # Unified async interface with explicit provider specification 29 | response: LLMResponse = await chat_async( 30 | provider=LLMProvider.OPENAI, # or "openai", LLMProvider.ANTHROPIC, etc. 31 | model="gpt-4o", 32 | messages=[{"role": "user", "content": "Hello!"}], 33 | max_completion_tokens=1000, 34 | temperature=0.0 35 | ) 36 | 37 | print(response.content) # Response text 38 | print(f"Cost: ${response.cost_in_cents/100:.4f}") 39 | 40 | # Alternative: Legacy model-to-provider inference 41 | 42 | response = await chat_async_legacy( 43 | model="gpt-4o", 44 | messages=[{"role": "user", "content": "Hello!"}] 45 | ) 46 | ``` 47 | 48 | ### Provider-Specific Examples 49 | 50 | ```python 51 | from defog.llm.utils import chat_async 52 | from defog.llm.llm_providers import LLMProvider 53 | 54 | # OpenAI with function calling 55 | response = await chat_async( 56 | provider=LLMProvider.OPENAI, 57 | model="gpt-4o", 58 | messages=[{"role": "user", "content": "What's the weather in Paris?"}], 59 | tools=[my_function], # Optional function calling 60 | tool_choice="auto" 61 | ) 62 | 63 | # Anthropic with structured output 64 | response = await chat_async( 65 | provider=LLMProvider.ANTHROPIC, 66 | model="claude-3-5-sonnet", 67 | messages=[{"role": "user", "content": "Hello!"}], 68 | response_format=MyPydanticModel # Structured output 69 | ) 70 | 71 | # Gemini 72 | response = await chat_async( 73 | provider=LLMProvider.GEMINI, 74 | model="gemini-2.0-flash", 75 | messages=[{"role": "user", "content": "Hello!"}] 76 | ) 77 | ``` 78 | 79 | ### Code Interpreter Tool 80 | 81 | Execute Python code in sandboxed environments across providers: 82 | 83 | ```python 84 | from defog.llm.code_interp import code_interpreter_tool 85 | from defog.llm.llm_providers import LLMProvider 86 | 87 | result = await code_interpreter_tool( 88 | question="Analyze this CSV data and create a visualization", 89 | model="gpt-4o", 90 | provider=LLMProvider.OPENAI, 91 | csv_string="name,age\nAlice,25\nBob,30", 92 | instructions="You are a data analyst. Create clear visualizations." 93 | ) 94 | 95 | print(result["code"]) # Generated Python code 96 | print(result["output"]) # Execution results 97 | ``` 98 | 99 | ### Web Search Tool 100 | 101 | Search the web for current information: 102 | 103 | ```python 104 | from defog.llm.web_search import web_search_tool 105 | from defog.llm.llm_providers import LLMProvider 106 | 107 | result = await web_search_tool( 108 | question="What are the latest developments in AI?", 109 | model="claude-3-5-sonnet", 110 | provider=LLMProvider.ANTHROPIC, 111 | max_tokens=2048 112 | ) 113 | 114 | print(result["search_results"]) # Search results text 115 | print(result["websites_cited"]) # Source citations 116 | ``` 117 | 118 | ### Function Calling 119 | 120 | Define tools for LLMs to call: 121 | 122 | ```python 123 | from pydantic import BaseModel 124 | from defog.llm.utils import chat_async 125 | 126 | class WeatherInput(BaseModel): 127 | location: str 128 | units: str = "celsius" 129 | 130 | def get_weather(input: WeatherInput) -> str: 131 | """Get current weather for a location""" 132 | return f"Weather in {input.location}: 22°{input.units[0].upper()}, sunny" 133 | 134 | response = await chat_async( 135 | model="gpt-4o", 136 | messages=[{"role": "user", "content": "What's the weather in Paris?"}], 137 | tools=[get_weather], 138 | tool_choice="auto" 139 | ) 140 | ``` 141 | 142 | ### Memory Compactification 143 | 144 | Automatically manage long conversations by intelligently summarizing older messages while preserving context: 145 | 146 | ```python 147 | from defog.llm import chat_async_with_memory, create_memory_manager, MemoryConfig 148 | 149 | # Create a memory manager with custom settings 150 | memory_manager = create_memory_manager( 151 | token_threshold=50000, # Compactify when reaching 50k tokens 152 | preserve_last_n_messages=10, # Keep last 10 messages intact 153 | summary_max_tokens=2000, # Max tokens for summary 154 | enabled=True 155 | ) 156 | 157 | # System messages are automatically preserved across compactifications 158 | response1 = await chat_async_with_memory( 159 | provider="openai", 160 | model="gpt-4o", 161 | messages=[ 162 | {"role": "system", "content": "You are a helpful Python tutor."}, 163 | {"role": "user", "content": "Tell me about Python"} 164 | ], 165 | memory_manager=memory_manager 166 | ) 167 | 168 | # Continue the conversation - memory is automatically managed 169 | response2 = await chat_async_with_memory( 170 | provider="openai", 171 | model="gpt-4o", 172 | messages=[{"role": "user", "content": "What about its use in data science?"}], 173 | memory_manager=memory_manager 174 | ) 175 | 176 | # The system message is preserved even after compactification! 177 | # Check current conversation state: 178 | print(f"Total messages: {len(memory_manager.get_current_messages())}") 179 | print(f"Compactifications: {memory_manager.get_stats()['compactification_count']}") 180 | 181 | # Or use memory configuration without explicit manager 182 | response = await chat_async_with_memory( 183 | provider="anthropic", 184 | model="claude-3-5-sonnet", 185 | messages=[{"role": "user", "content": "Hello!"}], 186 | memory_config=MemoryConfig( 187 | enabled=True, 188 | token_threshold=100000, # 100k tokens before compactification 189 | preserve_last_n_messages=10, 190 | summary_max_tokens=4000 191 | ) 192 | ) 193 | ``` 194 | 195 | Key features: 196 | - **System message preservation**: System messages are always kept intact, never summarized 197 | - **Automatic summarization**: When token count exceeds threshold, older messages are intelligently summarized 198 | - **Context preservation**: Recent messages are kept intact for continuity 199 | - **Provider agnostic**: Works with all supported LLM providers 200 | - **Token counting**: Uses tiktoken for accurate OpenAI token counts, with intelligent fallbacks for other providers 201 | - **Flexible configuration**: Customize thresholds, preservation rules, and summary sizes 202 | 203 | How it works: 204 | 1. As conversation grows, token count is tracked 205 | 2. When threshold is exceeded, older messages (except system messages) are summarized 206 | 3. Summary is added as a user message with `[Previous conversation summary]` prefix 207 | 4. Recent messages + system messages are preserved for context 208 | 5. Process repeats as needed for very long conversations 209 | 210 | ### MCP (Model Context Protocol) Support 211 | 212 | Connect to MCP servers for extended tool capabilities: 213 | 214 | ```python 215 | from defog.llm.utils_mcp import initialize_mcp_client 216 | 217 | # Initialize with config file 218 | mcp_client = await initialize_mcp_client( 219 | config="path/to/mcp_config.json", 220 | model="claude-3-5-sonnet" 221 | ) 222 | 223 | # Process queries with MCP tools 224 | response, tool_outputs = await mcp_client.mcp_chat( 225 | "Use the calculator tool to compute 123 * 456" 226 | ) 227 | ``` 228 | 229 | # Testing 230 | For developers who want to test or add tests for this client, you can run: 231 | ``` 232 | pytest tests 233 | ``` 234 | Note that we will transfer the existing .defog/connection.json file over to /tmp (if at all), and transfer the original file back once the tests are done to avoid messing with the original config. 235 | If submitting a PR, please use the `black` linter to lint your code. You can add it as a git hook to your repo by running the command below: 236 | ```bash 237 | echo -e '#!/bin/sh\n#\n# Run linter before commit\nblack $(git rev-parse --show-toplevel)' > .git/hooks/pre-commit 238 | chmod +x .git/hooks/pre-commit 239 | ``` 240 | -------------------------------------------------------------------------------- /defog/async_health_methods.py: -------------------------------------------------------------------------------- 1 | from defog.util import make_async_post_request 2 | 3 | 4 | async def check_golden_queries_coverage(self, dev: bool = False): 5 | """ 6 | Check the number of tables and columns inside the metadata schema that are covered by the golden queries. 7 | """ 8 | url = f"{self.base_url}/get_golden_queries_coverage" 9 | payload = {"api_key": self.api_key, "dev": dev} 10 | return await make_async_post_request(url, payload) 11 | 12 | 13 | async def check_md_valid(self, dev: bool = False): 14 | """ 15 | Check if the metadata schema is valid. 16 | """ 17 | url = f"{self.base_url}/check_md_valid" 18 | payload = {"api_key": self.api_key, "db_type": self.db_type, "dev": dev} 19 | return await make_async_post_request(url, payload) 20 | 21 | 22 | async def check_gold_queries_valid(self, dev: bool = False): 23 | """ 24 | Check if the golden queries are valid and can be executed on a given database without errors. 25 | """ 26 | url = f"{self.base_url}/check_gold_queries_valid" 27 | payload = {"api_key": self.api_key, "db_type": self.db_type, "dev": dev} 28 | return await make_async_post_request(url, payload) 29 | 30 | 31 | async def check_glossary_valid(self, dev: bool = False): 32 | """ 33 | Check if the glossary is valid by verifying if all schema, table, and column names referenced are present in the metadata. 34 | """ 35 | url = f"{self.base_url}/check_glossary_valid" 36 | payload = {"api_key": self.api_key, "dev": dev} 37 | return await make_async_post_request(url, payload) 38 | 39 | 40 | async def check_glossary_consistency(self, dev: bool = False): 41 | """ 42 | Check if all logic in the glossary is consistent and coherent. 43 | """ 44 | url = f"{self.base_url}/check_glossary_consistency" 45 | payload = {"api_key": self.api_key, "dev": dev} 46 | return await make_async_post_request(url, payload) 47 | -------------------------------------------------------------------------------- /defog/async_query_methods.py: -------------------------------------------------------------------------------- 1 | from defog.util import make_async_post_request 2 | from defog.query import async_execute_query 3 | from datetime import datetime 4 | 5 | 6 | async def get_query( 7 | self, 8 | question: str, 9 | hard_filters: str = "", 10 | previous_context: list = [], 11 | glossary: str = "", 12 | debug: bool = False, 13 | dev: bool = False, 14 | temp: bool = False, 15 | profile: bool = False, 16 | ignore_cache: bool = False, 17 | model: str = "", 18 | use_golden_queries: bool = True, 19 | subtable_pruning: bool = False, 20 | glossary_pruning: bool = False, 21 | prune_max_tokens: int = 2000, 22 | prune_bm25_num_columns: int = 10, 23 | prune_glossary_max_tokens: int = 1000, 24 | prune_glossary_num_cos_sim_units: int = 10, 25 | prune_glossary_bm25_units: int = 10, 26 | ): 27 | """ 28 | Asynchronously sends the query to the defog servers, and return the response. 29 | :param question: The question to be asked. 30 | :return: The response from the defog server. 31 | """ 32 | try: 33 | data = { 34 | "question": question, 35 | "api_key": self.api_key, 36 | "previous_context": previous_context, 37 | "db_type": self.db_type if self.db_type != "databricks" else "postgres", 38 | "glossary": glossary, 39 | "hard_filters": hard_filters, 40 | "dev": dev, 41 | "temp": temp, 42 | "ignore_cache": ignore_cache, 43 | "model": model, 44 | "use_golden_queries": use_golden_queries, 45 | "subtable_pruning": subtable_pruning, 46 | "glossary_pruning": glossary_pruning, 47 | "prune_max_tokens": prune_max_tokens, 48 | "prune_bm25_num_columns": prune_bm25_num_columns, 49 | "prune_glossary_max_tokens": prune_glossary_max_tokens, 50 | "prune_glossary_num_cos_sim_units": prune_glossary_num_cos_sim_units, 51 | "prune_glossary_bm25_units": prune_glossary_bm25_units, 52 | } 53 | 54 | t_start = datetime.now() 55 | 56 | resp = await make_async_post_request( 57 | url=self.generate_query_url, payload=data, timeout=300 58 | ) 59 | 60 | t_end = datetime.now() 61 | time_taken = (t_end - t_start).total_seconds() 62 | query_generated = resp.get("sql", resp.get("query_generated")) 63 | ran_successfully = resp.get("ran_successfully") 64 | error_message = resp.get("error_message") 65 | query_db = self.db_type 66 | resp = { 67 | "query_generated": query_generated, 68 | "ran_successfully": ran_successfully, 69 | "error_message": error_message, 70 | "query_db": query_db, 71 | "previous_context": resp.get("previous_context"), 72 | "reason_for_query": resp.get("reason_for_query"), 73 | } 74 | if profile: 75 | resp["time_taken"] = time_taken 76 | 77 | return resp 78 | except Exception as e: 79 | if debug: 80 | print(e) 81 | return { 82 | "ran_successfully": False, 83 | "error_message": "Sorry :( Our server is at capacity right now and we are unable to process your query. Please try again in a few minutes?", 84 | } 85 | 86 | 87 | async def run_query( 88 | self, 89 | question: str, 90 | hard_filters: str = "", 91 | previous_context: list = [], 92 | glossary: str = "", 93 | query: dict = None, 94 | retries: int = 3, 95 | dev: bool = False, 96 | temp: bool = False, 97 | profile: bool = False, 98 | ignore_cache: bool = False, 99 | model: str = "", 100 | use_golden_queries: bool = True, 101 | subtable_pruning: bool = False, 102 | glossary_pruning: bool = False, 103 | prune_max_tokens: int = 2000, 104 | prune_bm25_num_columns: int = 10, 105 | prune_glossary_max_tokens: int = 1000, 106 | prune_glossary_num_cos_sim_units: int = 10, 107 | prune_glossary_bm25_units: int = 10, 108 | ): 109 | """ 110 | Asynchronously sends the question to the defog servers, executes the generated SQL, 111 | and returns the response. 112 | :param question: The question to be asked. 113 | :return: The response from the defog server. 114 | """ 115 | if query is None: 116 | print(f"Generating the query for your question: {question}...") 117 | query = await self.get_query( 118 | question, 119 | hard_filters, 120 | previous_context, 121 | glossary=glossary, 122 | dev=dev, 123 | temp=temp, 124 | profile=profile, 125 | model=model, 126 | ignore_cache=ignore_cache, 127 | use_golden_queries=use_golden_queries, 128 | subtable_pruning=subtable_pruning, 129 | glossary_pruning=glossary_pruning, 130 | prune_max_tokens=prune_max_tokens, 131 | prune_bm25_num_columns=prune_bm25_num_columns, 132 | prune_glossary_max_tokens=prune_glossary_max_tokens, 133 | prune_glossary_num_cos_sim_units=prune_glossary_num_cos_sim_units, 134 | prune_glossary_bm25_units=prune_glossary_bm25_units, 135 | ) 136 | if query["ran_successfully"]: 137 | try: 138 | print("Query generated, now running it on your database...") 139 | tstart = datetime.now() 140 | colnames, result, executed_query = await async_execute_query( 141 | query=query["query_generated"], 142 | api_key=self.api_key, 143 | db_type=self.db_type, 144 | db_creds=self.db_creds, 145 | question=question, 146 | hard_filters=hard_filters, 147 | retries=retries, 148 | dev=dev, 149 | temp=temp, 150 | ) 151 | tend = datetime.now() 152 | time_taken = (tend - tstart).total_seconds() 153 | resp = { 154 | "columns": colnames, 155 | "data": result, 156 | "query_generated": executed_query, 157 | "ran_successfully": True, 158 | "reason_for_query": query.get("reason_for_query"), 159 | "previous_context": query.get("previous_context"), 160 | } 161 | if profile: 162 | resp["execution_time_taken"] = time_taken 163 | resp["generation_time_taken"] = query.get("time_taken") 164 | return resp 165 | except Exception as e: 166 | return { 167 | "ran_successfully": False, 168 | "error_message": str(e), 169 | "query_generated": query["query_generated"], 170 | } 171 | else: 172 | return {"ran_successfully": False, "error_message": query["error_message"]} 173 | -------------------------------------------------------------------------------- /defog/health_methods.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def check_golden_queries_coverage(self, dev: bool = False): 5 | """ 6 | Check the number of tables and columns inside the metadata schema that are covered by the golden queries. 7 | """ 8 | try: 9 | r = requests.post( 10 | f"{self.base_url}/get_golden_queries_coverage", 11 | json={"api_key": self.api_key, "dev": dev}, 12 | verify=False, 13 | ) 14 | resp = r.json() 15 | return resp 16 | except Exception as e: 17 | return {"error": str(e)} 18 | 19 | 20 | def check_md_valid(self, dev: bool = False): 21 | """ 22 | Check if the metadata schema is valid. 23 | """ 24 | try: 25 | r = requests.post( 26 | f"{self.base_url}/check_md_valid", 27 | json={"api_key": self.api_key, "db_type": self.db_type, "dev": dev}, 28 | verify=False, 29 | ) 30 | resp = r.json() 31 | return resp 32 | except Exception as e: 33 | return {"error": str(e)} 34 | 35 | 36 | def check_gold_queries_valid(self, dev: bool = False): 37 | """ 38 | Check if the golden queries are valid, and can actually be executed on a given database without errors 39 | """ 40 | r = requests.post( 41 | f"{self.base_url}/check_gold_queries_valid", 42 | json={"api_key": self.api_key, "db_type": self.db_type, "dev": dev}, 43 | verify=False, 44 | ) 45 | resp = r.json() 46 | return resp 47 | 48 | 49 | def check_glossary_valid(self, dev: bool = False): 50 | """ 51 | Check if glossary is valid by verifying if all schema, table and column names referenced are present in the metadata. 52 | """ 53 | r = requests.post( 54 | f"{self.base_url}/check_glossary_valid", 55 | json={"api_key": self.api_key, "dev": dev}, 56 | verify=False, 57 | ) 58 | resp = r.json() 59 | return resp 60 | 61 | 62 | def check_glossary_consistency(self, dev: bool = False): 63 | """ 64 | Check if all logic in the glossary is consistent and coherent. 65 | """ 66 | r = requests.post( 67 | f"{self.base_url}/check_glossary_consistency", 68 | json={"api_key": self.api_key, "dev": dev}, 69 | verify=False, 70 | ) 71 | resp = r.json() 72 | return resp 73 | -------------------------------------------------------------------------------- /defog/llm/__init__.py: -------------------------------------------------------------------------------- 1 | """LLM module with memory management capabilities.""" 2 | 3 | from .utils import chat_async, LLMResponse 4 | from .utils_memory import ( 5 | chat_async_with_memory, 6 | create_memory_manager, 7 | MemoryConfig, 8 | ) 9 | from .memory import ( 10 | MemoryManager, 11 | ConversationHistory, 12 | compactify_messages, 13 | TokenCounter, 14 | ) 15 | 16 | __all__ = [ 17 | # Core functions 18 | "chat_async", 19 | "chat_async_with_memory", 20 | "LLMResponse", 21 | # Memory management 22 | "MemoryManager", 23 | "ConversationHistory", 24 | "MemoryConfig", 25 | "create_memory_manager", 26 | "compactify_messages", 27 | "TokenCounter", 28 | ] 29 | -------------------------------------------------------------------------------- /defog/llm/citations.py: -------------------------------------------------------------------------------- 1 | from defog.llm.llm_providers import LLMProvider 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | import asyncio 5 | 6 | 7 | async def upload_document_to_openai_vector_store(document, store_id): 8 | from openai import AsyncOpenAI 9 | 10 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 11 | 12 | file_name = document["document_name"] 13 | if not file_name.endswith(".txt"): 14 | file_name = file_name + ".txt" 15 | file_content = document["document_content"] 16 | if isinstance(file_content, str): 17 | # convert to bytes 18 | file_content = file_content.encode("utf-8") 19 | 20 | # first, upload the file to the vector store 21 | file = await client.files.create( 22 | file=(file_name, file_content), purpose="assistants" 23 | ) 24 | 25 | # then add it to the vector store 26 | await client.vector_stores.files.create( 27 | vector_store_id=store_id, 28 | file_id=file.id, 29 | ) 30 | 31 | 32 | async def citations_tool( 33 | question: str, 34 | instructions: str, 35 | documents: list[dict], 36 | model: str, 37 | provider: LLMProvider, 38 | max_tokens: int = 16000, 39 | ): 40 | """ 41 | Use this tool to get an answer to a well-cited answer to a question, 42 | given a list of documents. 43 | """ 44 | async with ToolProgressTracker( 45 | "Citations Tool", f"Generating citations for {len(documents)} documents" 46 | ) as tracker: 47 | subtask_logger = SubTaskLogger() 48 | subtask_logger.log_provider_info( 49 | provider.value if hasattr(provider, "value") else str(provider), model 50 | ) 51 | 52 | if provider in [LLMProvider.OPENAI, LLMProvider.OPENAI.value]: 53 | from openai import AsyncOpenAI 54 | 55 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 56 | 57 | # create an ephemeral vector store 58 | store = await client.vector_stores.create() 59 | store_id = store.id 60 | 61 | # Upload all documents in parallel 62 | tracker.update(10, "Uploading documents to vector store") 63 | subtask_logger.log_subtask("Starting document uploads", "processing") 64 | 65 | upload_tasks = [] 66 | for idx, document in enumerate(documents, 1): 67 | subtask_logger.log_document_upload( 68 | document["document_name"], idx, len(documents) 69 | ) 70 | upload_tasks.append( 71 | upload_document_to_openai_vector_store(document, store_id) 72 | ) 73 | 74 | await asyncio.gather(*upload_tasks) 75 | tracker.update(40, "Documents uploaded") 76 | 77 | # keep polling until the vector store is ready 78 | is_ready = False 79 | while not is_ready: 80 | store = await client.vector_stores.files.list(vector_store_id=store_id) 81 | total_completed = sum( 82 | 1 for file in store.data if file.status == "completed" 83 | ) 84 | is_ready = total_completed == len(documents) 85 | 86 | # Update progress based on indexing status 87 | progress = 40 + (total_completed / len(documents) * 40) # 40-80% range 88 | tracker.update( 89 | progress, f"Indexing {total_completed}/{len(documents)} files" 90 | ) 91 | subtask_logger.log_vector_store_status(total_completed, len(documents)) 92 | 93 | if not is_ready: 94 | await asyncio.sleep(1) 95 | 96 | # get the answer 97 | tracker.update(80, "Generating citations") 98 | subtask_logger.log_subtask("Querying with file search", "processing") 99 | 100 | response = await client.responses.create( 101 | model=model, 102 | input=question, 103 | tools=[ 104 | { 105 | "type": "file_search", 106 | "vector_store_ids": [store_id], 107 | } 108 | ], 109 | tool_choice="required", 110 | instructions=instructions, 111 | max_output_tokens=max_tokens, 112 | ) 113 | 114 | # convert the response to a list of blocks 115 | # similar to a subset of the Anthropic citations API 116 | blocks = [] 117 | for part in response.output: 118 | if part.type == "message": 119 | contents = part.content 120 | for item in contents: 121 | if item.type == "output_text": 122 | blocks.append( 123 | { 124 | "text": item.text, 125 | "type": "text", 126 | "citations": [ 127 | {"document_title": i.filename} 128 | for i in item.annotations 129 | ], 130 | } 131 | ) 132 | tracker.update(95, "Processing results") 133 | subtask_logger.log_result_summary( 134 | "Citations", 135 | {"blocks_generated": len(blocks), "documents_processed": len(documents)}, 136 | ) 137 | 138 | return blocks 139 | 140 | elif provider in [LLMProvider.ANTHROPIC, LLMProvider.ANTHROPIC.value]: 141 | from anthropic import AsyncAnthropic 142 | 143 | client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 144 | 145 | document_contents = [] 146 | for document in documents: 147 | document_contents.append( 148 | { 149 | "type": "document", 150 | "source": { 151 | "type": "text", 152 | "media_type": "text/plain", 153 | "data": document["document_content"], 154 | }, 155 | "title": document["document_name"], 156 | "citations": {"enabled": True}, 157 | } 158 | ) 159 | 160 | # Create content messages with citations enabled for individual tool calls 161 | tracker.update(50, "Preparing document contents") 162 | subtask_logger.log_subtask( 163 | f"Processing {len(documents)} documents for Anthropic", "processing" 164 | ) 165 | 166 | messages = [ 167 | { 168 | "role": "user", 169 | "content": [ 170 | {"type": "text", "text": question}, 171 | # Add all individual document contents 172 | *document_contents, 173 | ], 174 | } 175 | ] 176 | 177 | tracker.update(70, "Generating citations") 178 | subtask_logger.log_subtask("Calling Anthropic API with citations", "processing") 179 | 180 | response = await client.messages.create( 181 | model=model, 182 | messages=messages, 183 | system=instructions, 184 | max_tokens=max_tokens, 185 | ) 186 | 187 | tracker.update(90, "Processing results") 188 | response_with_citations = [item.to_dict() for item in response.content] 189 | 190 | subtask_logger.log_result_summary( 191 | "Citations", 192 | { 193 | "content_blocks": len(response_with_citations), 194 | "documents_processed": len(documents), 195 | }, 196 | ) 197 | 198 | return response_with_citations 199 | 200 | # This else is outside the context manager, move it inside 201 | raise ValueError(f"Provider {provider} not supported for citations tool") 202 | -------------------------------------------------------------------------------- /defog/llm/code_interp.py: -------------------------------------------------------------------------------- 1 | from defog.llm.llm_providers import LLMProvider 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | from io import BytesIO 5 | 6 | 7 | async def code_interpreter_tool( 8 | question: str, 9 | model: str, 10 | provider: LLMProvider, 11 | csv_string: str = "", 12 | instructions: str = "You are a Python programmer. You are given a question and a CSV string of data. You need to answer the question using the data. You are also given a sandboxed server environment where you can run the code.", 13 | ): 14 | """ 15 | Creates a python script to answer the question, where the python script is executed in a sandboxed server environment. 16 | """ 17 | async with ToolProgressTracker( 18 | "Code Interpreter", 19 | f"Executing code to answer: {question[:50]}{'...' if len(question) > 50 else ''}", 20 | ) as tracker: 21 | subtask_logger = SubTaskLogger() 22 | subtask_logger.log_provider_info( 23 | provider.value if hasattr(provider, "value") else str(provider), model 24 | ) 25 | 26 | # create a csv file from the csv_string 27 | tracker.update(10, "Preparing data file") 28 | subtask_logger.log_subtask("Creating CSV file from string", "processing") 29 | csv_file = BytesIO(csv_string.encode("utf-8")) 30 | csv_file.name = "data.csv" 31 | 32 | if provider in [LLMProvider.OPENAI, LLMProvider.OPENAI.value]: 33 | from openai import AsyncOpenAI 34 | from openai.types.responses import ( 35 | ResponseCodeInterpreterToolCall, 36 | ResponseOutputMessage, 37 | ) 38 | 39 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 40 | 41 | tracker.update(20, "Uploading data file") 42 | subtask_logger.log_subtask("Uploading CSV to OpenAI", "processing") 43 | file = await client.files.create( 44 | file=csv_file, 45 | purpose="user_data", 46 | ) 47 | 48 | tracker.update(40, "Running code interpreter") 49 | subtask_logger.log_code_execution("python") 50 | response = await client.responses.create( 51 | model=model, 52 | tools=[ 53 | { 54 | "type": "code_interpreter", 55 | "container": {"type": "auto", "file_ids": [file.id]}, 56 | } 57 | ], 58 | tool_choice="required", 59 | input=[ 60 | { 61 | "role": "user", 62 | "content": [{"type": "input_text", "text": question}], 63 | } 64 | ], 65 | instructions=instructions, 66 | ) 67 | tracker.update(80, "Processing results") 68 | subtask_logger.log_subtask("Extracting code and output", "processing") 69 | 70 | code = "" 71 | output_text = "" 72 | 73 | for chunk in response.output: 74 | if isinstance(chunk, ResponseCodeInterpreterToolCall): 75 | code += chunk.code 76 | elif isinstance(chunk, ResponseOutputMessage): 77 | for content in chunk.content: 78 | output_text += content.text 79 | 80 | subtask_logger.log_result_summary( 81 | "Code Execution", 82 | {"code_length": len(code), "output_length": len(output_text)}, 83 | ) 84 | 85 | return {"code": code, "output": output_text} 86 | elif provider in [LLMProvider.ANTHROPIC, LLMProvider.ANTHROPIC.value]: 87 | from anthropic import AsyncAnthropic 88 | 89 | client = AsyncAnthropic( 90 | api_key=os.getenv("ANTHROPIC_API_KEY"), 91 | default_headers={"anthropic-beta": "code-execution-2025-05-22"}, 92 | ) 93 | 94 | tracker.update(20, "Uploading data file") 95 | subtask_logger.log_subtask("Uploading CSV to Anthropic", "processing") 96 | file_object = await client.beta.files.upload( 97 | file=csv_file, 98 | ) 99 | 100 | tracker.update(40, "Running code execution") 101 | subtask_logger.log_code_execution("python") 102 | response = await client.messages.create( 103 | model=model, 104 | max_tokens=8192, 105 | messages=[ 106 | { 107 | "role": "user", 108 | "content": [ 109 | { 110 | "type": "text", 111 | "text": instructions 112 | + "\n\nThe question you must answer is: " 113 | + question, 114 | }, 115 | {"type": "container_upload", "file_id": file_object.id}, 116 | ], 117 | } 118 | ], 119 | tools=[{"type": "code_execution_20250522", "name": "code_execution"}], 120 | tool_choice={"type": "any"}, 121 | ) 122 | tracker.update(80, "Processing results") 123 | subtask_logger.log_subtask("Extracting code and output", "processing") 124 | 125 | code = "" 126 | output_text = "" 127 | for chunk in response.content: 128 | if chunk.type == "server_tool_use": 129 | code += chunk.input["code"] 130 | elif chunk.type == "text": 131 | output_text += chunk.text 132 | 133 | subtask_logger.log_result_summary( 134 | "Code Execution", 135 | {"code_length": len(code), "output_length": len(output_text)}, 136 | ) 137 | 138 | return {"code": code, "output": output_text} 139 | elif provider in [LLMProvider.GEMINI, LLMProvider.GEMINI.value]: 140 | from google import genai 141 | from google.genai import types 142 | 143 | client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) 144 | 145 | tracker.update(20, "Uploading data file") 146 | subtask_logger.log_subtask("Uploading CSV to Gemini", "processing") 147 | file_csv = await client.aio.files.upload( 148 | file=csv_file, 149 | config=types.UploadFileConfig( 150 | mime_type="text/csv", 151 | display_name="data.csv", 152 | ), 153 | ) 154 | 155 | tracker.update(40, "Running code execution") 156 | subtask_logger.log_code_execution("python") 157 | response = await client.aio.models.generate_content( 158 | model=model, 159 | contents=[file_csv, question], 160 | config=types.GenerateContentConfig( 161 | tools=[types.Tool(code_execution=types.ToolCodeExecution())], 162 | ), 163 | ) 164 | 165 | tracker.update(80, "Processing results") 166 | subtask_logger.log_subtask("Extracting code and output", "processing") 167 | 168 | parts = response.candidates[0].content.parts 169 | 170 | code = "" 171 | output_text = "" 172 | 173 | for part in parts: 174 | if ( 175 | hasattr(part, "executable_code") 176 | and part.executable_code is not None 177 | ): 178 | code += part.executable_code.code 179 | if hasattr(part, "text") and part.text is not None: 180 | output_text += part.text 181 | 182 | subtask_logger.log_result_summary( 183 | "Code Execution", 184 | {"code_length": len(code), "output_length": len(output_text)}, 185 | ) 186 | 187 | return {"code": code, "output": output_text} 188 | -------------------------------------------------------------------------------- /defog/llm/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .settings import LLMConfig 2 | from .constants import DEFAULT_TIMEOUT, MAX_RETRIES, DEFAULT_TEMPERATURE 3 | 4 | __all__ = ["LLMConfig", "DEFAULT_TIMEOUT", "MAX_RETRIES", "DEFAULT_TEMPERATURE"] 5 | -------------------------------------------------------------------------------- /defog/llm/config/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_TIMEOUT = 100 # seconds 2 | MAX_RETRIES = 3 3 | DEFAULT_TEMPERATURE = 0.0 4 | 5 | # Provider-specific constants 6 | DEEPSEEK_BASE_URL = "https://api.deepseek.com" 7 | OPENAI_BASE_URL = "https://api.openai.com/v1/" 8 | 9 | # Model families that require special handling 10 | O_MODELS = ["o1-mini", "o1-preview", "o1", "o3-mini", "o3", "o4-mini"] 11 | 12 | DEEPSEEK_MODELS = ["deepseek-chat", "deepseek-reasoner"] 13 | 14 | MODELS_WITHOUT_TEMPERATURE = ["deepseek-reasoner"] + [model for model in O_MODELS] 15 | 16 | MODELS_WITHOUT_RESPONSE_FORMAT = [ 17 | "o1-mini", 18 | "o1-preview", 19 | "deepseek-chat", 20 | "deepseek-reasoner", 21 | ] 22 | 23 | MODELS_WITHOUT_TOOLS = ["o1-mini", "o1-preview", "deepseek-chat", "deepseek-reasoner"] 24 | 25 | MODELS_WITH_PARALLEL_TOOL_CALLS = ["gpt-4o", "gpt-4o-mini"] 26 | 27 | MODELS_WITH_PREDICTION = ["gpt-4o", "gpt-4o-mini"] 28 | -------------------------------------------------------------------------------- /defog/llm/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Dict, Any 3 | from .constants import ( 4 | DEFAULT_TIMEOUT, 5 | MAX_RETRIES, 6 | DEFAULT_TEMPERATURE, 7 | DEEPSEEK_BASE_URL, 8 | OPENAI_BASE_URL, 9 | ) 10 | 11 | 12 | class LLMConfig: 13 | """Configuration management for LLM providers.""" 14 | 15 | def __init__( 16 | self, 17 | timeout: int = DEFAULT_TIMEOUT, 18 | max_retries: int = MAX_RETRIES, 19 | default_temperature: float = DEFAULT_TEMPERATURE, 20 | api_keys: Optional[Dict[str, str]] = None, 21 | base_urls: Optional[Dict[str, str]] = None, 22 | enable_parallel_tool_calls: bool = True, 23 | ): 24 | self.timeout = timeout 25 | self.max_retries = max_retries 26 | self.default_temperature = default_temperature 27 | self.enable_parallel_tool_calls = enable_parallel_tool_calls 28 | 29 | # API keys with environment fallbacks 30 | self.api_keys = api_keys or {} 31 | self._setup_api_keys() 32 | 33 | # Base URLs with defaults 34 | self.base_urls = base_urls or {} 35 | self._setup_base_urls() 36 | 37 | def _setup_api_keys(self): 38 | """Setup API keys with environment variable fallbacks.""" 39 | key_mappings = { 40 | "openai": "OPENAI_API_KEY", 41 | "anthropic": "ANTHROPIC_API_KEY", 42 | "gemini": "GEMINI_API_KEY", 43 | "deepseek": "DEEPSEEK_API_KEY", 44 | "together": "TOGETHER_API_KEY", 45 | } 46 | 47 | for provider, env_var in key_mappings.items(): 48 | if provider not in self.api_keys: 49 | self.api_keys[provider] = os.getenv(env_var) 50 | 51 | def _setup_base_urls(self): 52 | """Setup base URLs with defaults.""" 53 | default_urls = { 54 | "openai": OPENAI_BASE_URL, 55 | "deepseek": DEEPSEEK_BASE_URL, 56 | } 57 | 58 | for provider, url in default_urls.items(): 59 | if provider not in self.base_urls: 60 | self.base_urls[provider] = url 61 | 62 | def get_api_key(self, provider: str) -> Optional[str]: 63 | """Get API key for a provider.""" 64 | return self.api_keys.get(provider) 65 | 66 | def get_base_url(self, provider: str) -> Optional[str]: 67 | """Get base URL for a provider.""" 68 | return self.base_urls.get(provider) 69 | 70 | def validate_provider_config(self, provider: str) -> bool: 71 | """Validate that a provider has the required configuration.""" 72 | api_key = self.get_api_key(provider) 73 | return api_key is not None and api_key != "" 74 | 75 | def update_config(self, **kwargs): 76 | """Update configuration values.""" 77 | for key, value in kwargs.items(): 78 | if hasattr(self, key): 79 | setattr(self, key, value) 80 | -------------------------------------------------------------------------------- /defog/llm/cost/__init__.py: -------------------------------------------------------------------------------- 1 | from .calculator import CostCalculator 2 | from .models import MODEL_COSTS 3 | 4 | __all__ = ["CostCalculator", "MODEL_COSTS"] 5 | -------------------------------------------------------------------------------- /defog/llm/cost/calculator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from .models import MODEL_COSTS 3 | 4 | 5 | class CostCalculator: 6 | """Handles cost calculation for LLM usage.""" 7 | 8 | @staticmethod 9 | def calculate_cost( 10 | model: str, 11 | input_tokens: int, 12 | output_tokens: int, 13 | cached_input_tokens: Optional[int] = None, 14 | ) -> Optional[float]: 15 | """ 16 | Calculate cost in cents for the given token usage. 17 | 18 | Returns: 19 | Cost in cents, or None if model pricing is not available 20 | """ 21 | # Find exact match first 22 | if model in MODEL_COSTS: 23 | model_name = model 24 | else: 25 | # Attempt partial matches if no exact match 26 | potential_model_names = [] 27 | for mname in MODEL_COSTS.keys(): 28 | if mname in model: 29 | potential_model_names.append(mname) 30 | 31 | if not potential_model_names: 32 | return None 33 | 34 | # Use the longest match 35 | model_name = max(potential_model_names, key=len) 36 | 37 | costs = MODEL_COSTS[model_name] 38 | 39 | # Calculate base cost 40 | cost_in_cents = ( 41 | input_tokens / 1000 * costs["input_cost_per1k"] 42 | + output_tokens / 1000 * costs["output_cost_per1k"] 43 | ) * 100 44 | 45 | # Add cached input cost if available 46 | if cached_input_tokens and "cached_input_cost_per1k" in costs: 47 | cost_in_cents += ( 48 | cached_input_tokens / 1000 * costs["cached_input_cost_per1k"] 49 | ) * 100 50 | 51 | return cost_in_cents 52 | 53 | @staticmethod 54 | def is_model_supported(model: str) -> bool: 55 | """Check if cost calculation is supported for the given model.""" 56 | if model in MODEL_COSTS: 57 | return True 58 | 59 | # Check for partial matches 60 | return any(mname in model for mname in MODEL_COSTS.keys()) 61 | -------------------------------------------------------------------------------- /defog/llm/cost/models.py: -------------------------------------------------------------------------------- 1 | MODEL_COSTS = { 2 | "chatgpt-4o": {"input_cost_per1k": 0.0025, "output_cost_per1k": 0.01}, 3 | "gpt-4o": { 4 | "input_cost_per1k": 0.0025, 5 | "cached_input_cost_per1k": 0.00125, 6 | "output_cost_per1k": 0.01, 7 | }, 8 | "gpt-4o-mini": { 9 | "input_cost_per1k": 0.00015, 10 | "cached_input_cost_per1k": 0.000075, 11 | "output_cost_per1k": 0.0006, 12 | }, 13 | "gpt-4.1": { 14 | "input_cost_per1k": 0.002, 15 | "cached_input_cost_per1k": 0.0005, 16 | "output_cost_per1k": 0.008, 17 | }, 18 | "gpt-4.1-mini": { 19 | "input_cost_per1k": 0.0004, 20 | "cached_input_cost_per1k": 0.0001, 21 | "output_cost_per1k": 0.0016, 22 | }, 23 | "gpt-4.1-nano": { 24 | "input_cost_per1k": 0.0001, 25 | "cached_input_cost_per1k": 0.000025, 26 | "output_cost_per1k": 0.0004, 27 | }, 28 | "o1": { 29 | "input_cost_per1k": 0.015, 30 | "cached_input_cost_per1k": 0.0075, 31 | "output_cost_per1k": 0.06, 32 | }, 33 | "o1-preview": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.06}, 34 | "o1-mini": { 35 | "input_cost_per1k": 0.003, 36 | "cached_input_cost_per1k": 0.00055, 37 | "output_cost_per1k": 0.012, 38 | }, 39 | "o3-mini": { 40 | "input_cost_per1k": 0.0011, 41 | "cached_input_cost_per1k": 0.00055, 42 | "output_cost_per1k": 0.0044, 43 | }, 44 | "o3": { 45 | "input_cost_per1k": 0.01, 46 | "cached_input_cost_per1k": 0.0025, 47 | "output_cost_per1k": 0.04, 48 | }, 49 | "o4-mini": { 50 | "input_cost_per1k": 0.0011, 51 | "cached_input_cost_per1k": 0.000275, 52 | "output_cost_per1k": 0.0044, 53 | }, 54 | "gpt-4-turbo": {"input_cost_per1k": 0.01, "output_cost_per1k": 0.03}, 55 | "gpt-3.5-turbo": {"input_cost_per1k": 0.0005, "output_cost_per1k": 0.0015}, 56 | "claude-3-5-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015}, 57 | "claude-3-5-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125}, 58 | "claude-3-opus": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.075}, 59 | "claude-3-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015}, 60 | "claude-3-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125}, 61 | "gemini-1.5-pro": {"input_cost_per1k": 0.00125, "output_cost_per1k": 0.005}, 62 | "gemini-1.5-flash": {"input_cost_per1k": 0.000075, "output_cost_per1k": 0.0003}, 63 | "gemini-1.5-flash-8b": { 64 | "input_cost_per1k": 0.0000375, 65 | "output_cost_per1k": 0.00015, 66 | }, 67 | "gemini-2.0-flash": { 68 | "input_cost_per1k": 0.00010, 69 | "output_cost_per1k": 0.0004, 70 | }, 71 | "gemini-2.5-flash": { 72 | "input_cost_per1k": 0.00015, 73 | "output_cost_per1k": 0.0035, 74 | }, 75 | "gemini-2.5-pro": { 76 | "input_cost_per1k": 0.00125, 77 | "output_cost_per1k": 0.01, 78 | }, 79 | "deepseek-chat": { 80 | "input_cost_per1k": 0.00027, 81 | "cached_input_cost_per1k": 0.00007, 82 | "output_cost_per1k": 0.0011, 83 | }, 84 | "deepseek-reasoner": { 85 | "input_cost_per1k": 0.00055, 86 | "cached_input_cost_per1k": 0.00014, 87 | "output_cost_per1k": 0.00219, 88 | }, 89 | } 90 | -------------------------------------------------------------------------------- /defog/llm/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_exceptions import ( 2 | LLMError, 3 | ProviderError, 4 | ToolError, 5 | MaxTokensError, 6 | ConfigurationError, 7 | AuthenticationError, 8 | APIError, 9 | ) 10 | 11 | __all__ = [ 12 | "LLMError", 13 | "ProviderError", 14 | "ToolError", 15 | "MaxTokensError", 16 | "ConfigurationError", 17 | "AuthenticationError", 18 | "APIError", 19 | ] 20 | -------------------------------------------------------------------------------- /defog/llm/exceptions/llm_exceptions.py: -------------------------------------------------------------------------------- 1 | class LLMError(Exception): 2 | """Base exception for all LLM-related errors.""" 3 | 4 | pass 5 | 6 | 7 | class ProviderError(LLMError): 8 | """Exception raised when there's an error with a specific LLM provider.""" 9 | 10 | def __init__(self, provider: str, message: str, original_error: Exception = None): 11 | self.provider = provider 12 | self.original_error = original_error 13 | super().__init__(f"Provider '{provider}': {message}") 14 | 15 | 16 | class ToolError(LLMError): 17 | """Exception raised when there's an error with tool calling.""" 18 | 19 | def __init__(self, tool_name: str, message: str, original_error: Exception = None): 20 | self.tool_name = tool_name 21 | self.original_error = original_error 22 | super().__init__(f"Tool '{tool_name}': {message}") 23 | 24 | 25 | class MaxTokensError(LLMError): 26 | """Exception raised when maximum tokens are reached.""" 27 | 28 | pass 29 | 30 | 31 | class ConfigurationError(LLMError): 32 | """Exception raised when there's a configuration error.""" 33 | 34 | pass 35 | 36 | 37 | class AuthenticationError(LLMError): 38 | """Exception raised when there's an authentication error with a provider.""" 39 | 40 | pass 41 | 42 | 43 | class APIError(LLMError): 44 | """Exception raised when there's a general API error.""" 45 | 46 | pass 47 | -------------------------------------------------------------------------------- /defog/llm/llm_providers.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LLMProvider(Enum): 5 | OPENAI = "openai" 6 | ANTHROPIC = "anthropic" 7 | GEMINI = "gemini" 8 | GROK = "grok" 9 | DEEPSEEK = "deepseek" 10 | TOGETHER = "together" 11 | -------------------------------------------------------------------------------- /defog/llm/mcp_readme.md: -------------------------------------------------------------------------------- 1 | # MCP Servers Guide 2 | 3 | This guide outlines the steps for setting up and using Model Context Protocol (MCP) servers, which allow external tools to be used by an LLM. 4 | 5 | ## 1. Set Up Your MCP Servers 6 | Set up your MCP servers by following the official [SDKs](https://github.com/modelcontextprotocol). 7 | If you're hosting them outside of your application, ensure that your servers operate with [SSE transport](https://modelcontextprotocol.io/docs/concepts/transports#server-sent-events-sse). 8 | 9 | ```python 10 | from mcp.server.fastmcp import FastMCP 11 | 12 | # Initialize FastMCP with an available port 13 | mcp = FastMCP( 14 | name="name_of_remote_server", 15 | port=3001, 16 | ) 17 | 18 | # Add your tools and prompts 19 | # @mcp.tool() 20 | # async def example_tool(): 21 | # ... 22 | 23 | # Start the server with SSE transport 24 | if __name__ == "__main__": 25 | mcp.run(transport="sse") 26 | ``` 27 | 28 | ## 2. Configure mcp_config.json 29 | 30 | After your servers are running, create `mcp_config.json` and add the servers to connect to them: 31 | 32 | ```json 33 | { 34 | "mcpServers": { 35 | "name_of_remote_server": { 36 | "command": "sse", 37 | "args": ["http://host.docker.internal:3001"] 38 | }, 39 | "name_of_local_server": { 40 | "command": "npx", 41 | "args": ["-y", "@modelcontextprotocol/package-name"] 42 | } 43 | } 44 | } 45 | ``` 46 | 47 | ### Configuration Options: 48 | 49 | 1. **Remote Servers (SSE):** 50 | - Use `"command": "sse"` 51 | - Use `"args": ["url_to_server"]` for remote servers 52 | - Use `"args": ["http://host.docker.internal:PORT/sse"]` if your application is running in a Docker container and the servers are running on your local machine 53 | - Make sure the port matches your server's configuration 54 | 55 | 2. **Local Servers in your application (stdio):** 56 | - Specify the command to run the server (e.g., `"npx"`, `"python"`) 57 | - Provide arguments in an array (e.g., `["-y", "package-name"]`) 58 | - Optionally provide environment variables with `"env": {}` 59 | 60 | ## 3. Using MCPClient 61 | 62 | Once your servers are configured, use the `MCPClient` class to interact with them: 63 | 64 | ### Step 1: Initialize and Connect 65 | 66 | ```python 67 | from defog.llm.utils_mcp import MCPClient 68 | import json 69 | 70 | # Load config 71 | with open("path/to/mcp_config.json", "r") as f: 72 | config = json.load(f) 73 | 74 | # Initialize client with your preferred model. Only Claude and OpenAI models are supported. 75 | mcp_client = MCPClient(model_name="claude-3-7-sonnet-20250219") 76 | 77 | # Connect to all servers defined in config 78 | await mcp_client.connect_to_server_from_config(config) 79 | ``` 80 | 81 | ### Step 2: Process Queries 82 | 83 | ```python 84 | # Send a query to be processed by the LLM with access to MCP tools 85 | # The response contains the final text output 86 | # tool_outputs is a list of all tool calls made during processing 87 | response, tool_outputs = await mcp_client.mcp_chat(query="What is the average sale in the northeast region?") 88 | ``` 89 | 90 | ### Step 3: Clean Up 91 | 92 | ```python 93 | # Always clean up when done to close connections 94 | await mcp_client.cleanup() 95 | ``` 96 | 97 | ## 4. Using Prompt Templates 98 | 99 | Prompt templates that are defined in the MCP servers can be invoked in queries using the `/command` syntax: 100 | 101 | ```python 102 | # This will apply the "gen_report" prompt template to "sales by region" 103 | response, _ = await mcp_client.mcp_chat(query="/gen_report sales by region") 104 | ``` 105 | If no further text is provided after the command, the prompt template will be applied to the output of the previous message in the message history. 106 | 107 | ## 5. Troubleshooting 108 | 109 | If you encounter the "unhandled errors in a TaskGroup" error: 110 | 111 | 1. **Check Server Status**: Make sure your server is running on the specified port 112 | 2. **Verify Configuration**: Ensure port numbers match in both server and config 113 | 3. **Host Binding**: Server must use `host="0.0.0.0"` (not `127.0.0.1`) 114 | 4. **Network Access**: For Docker, ensure host.docker.internal resolves correctly 115 | 5. **Port Exposure**: Make sure the port is accessible from the Docker container 116 | 117 | If using stdio servers, check for syntax errors in your command or arguments. 118 | 119 | ## 6. MCPClient Reference 120 | 121 | Key methods in `MCPClient`: 122 | 123 | - `connect_to_server_from_config(config)`: Connect to all servers in config json file 124 | - `mcp_chat(query)`: Process a query using LLM and available tools 125 | - Calls tools in a loop until no more tools are called or max tokens are reached 126 | - Stores messages in message history so the LLM can use it as context for following queries 127 | - `call_tool(tool_name, tool_args)`: Directly call a specific tool 128 | - `get_prompt(prompt_name, args)`: Retrieve a specific prompt template 129 | - `cleanup()`: Close all server connections -------------------------------------------------------------------------------- /defog/llm/memory/__init__.py: -------------------------------------------------------------------------------- 1 | """Memory management utilities for LLM conversations.""" 2 | 3 | from .history_manager import MemoryManager, ConversationHistory 4 | from .compactifier import compactify_messages 5 | from .token_counter import TokenCounter 6 | 7 | __all__ = [ 8 | "MemoryManager", 9 | "ConversationHistory", 10 | "compactify_messages", 11 | "TokenCounter", 12 | ] 13 | -------------------------------------------------------------------------------- /defog/llm/memory/compactifier.py: -------------------------------------------------------------------------------- 1 | """Message compactification utilities for conversation memory management.""" 2 | 3 | from typing import List, Dict, Any, Optional, Tuple 4 | from ..utils import chat_async 5 | from .token_counter import TokenCounter 6 | 7 | 8 | async def compactify_messages( 9 | system_messages: List[Dict[str, Any]], 10 | messages_to_summarize: List[Dict[str, Any]], 11 | preserved_messages: List[Dict[str, Any]], 12 | provider: str, 13 | model: str, 14 | max_summary_tokens: int = 2000, 15 | **kwargs, 16 | ) -> Tuple[List[Dict[str, Any]], int]: 17 | """ 18 | Compactify a conversation by summarizing older messages while preserving system messages. 19 | 20 | Args: 21 | system_messages: System messages to preserve at the beginning 22 | messages_to_summarize: Messages to be summarized 23 | preserved_messages: Recent messages to keep as-is 24 | provider: LLM provider to use for summarization 25 | model: Model to use for summarization 26 | max_summary_tokens: Maximum tokens for the summary 27 | **kwargs: Additional arguments for the LLM call 28 | 29 | Returns: 30 | Tuple of (new_messages, total_token_count) 31 | """ 32 | if not messages_to_summarize: 33 | # Nothing to summarize, return system + preserved messages 34 | all_messages = system_messages + preserved_messages 35 | token_counter = TokenCounter() 36 | total_tokens = token_counter.count_tokens(all_messages, model, provider) 37 | return all_messages, total_tokens 38 | 39 | # Create a summary prompt 40 | summary_prompt = _create_summary_prompt(messages_to_summarize, max_summary_tokens) 41 | 42 | # Generate summary using the same provider/model 43 | summary_response = await chat_async( 44 | provider=provider, 45 | model=model, 46 | messages=[{"role": "user", "content": summary_prompt}], 47 | temperature=0.3, # Lower temperature for more consistent summaries 48 | **kwargs, 49 | ) 50 | 51 | # Create summary message 52 | summary_message = { 53 | "role": "user", # Summary as user context 54 | "content": f"[Previous conversation summary]\n{summary_response.content}", 55 | } 56 | 57 | # Combine: system messages + summary + preserved messages 58 | new_messages = system_messages + [summary_message] + preserved_messages 59 | 60 | # Calculate new token count 61 | token_counter = TokenCounter() 62 | total_tokens = token_counter.count_tokens(new_messages, model, provider) 63 | 64 | return new_messages, total_tokens 65 | 66 | 67 | def _create_summary_prompt(messages: List[Dict[str, Any]], max_tokens: int) -> str: 68 | """Create a prompt for summarizing conversation history.""" 69 | 70 | # Format messages for summary 71 | formatted_messages = [] 72 | for msg in messages: 73 | role = msg.get("role", "unknown") 74 | content = msg.get("content", "") 75 | 76 | # Handle different content types 77 | if isinstance(content, dict): 78 | # Tool calls or structured content 79 | content = f"[Structured content: {type(content).__name__}]" 80 | elif isinstance(content, list): 81 | # Multiple content items 82 | content = f"[Multiple content items: {len(content)}]" 83 | 84 | formatted_messages.append(f"{role.upper()}: {content}") 85 | 86 | conversation_text = "\n\n".join(formatted_messages) 87 | 88 | prompt = f"""Please provide a concise summary of the following conversation. 89 | Focus on: 90 | 1. Key topics discussed 91 | 2. Important decisions or conclusions reached 92 | 3. Any unresolved questions or ongoing tasks 93 | 4. Critical context that should be preserved 94 | 95 | Keep the summary under {max_tokens // 4} words (approximately {max_tokens} tokens). 96 | 97 | Conversation: 98 | {conversation_text} 99 | 100 | Summary:""" 101 | 102 | return prompt 103 | 104 | 105 | async def smart_compactify( 106 | memory_manager, provider: str, model: str, **kwargs 107 | ) -> Tuple[List[Dict[str, Any]], int]: 108 | """ 109 | Intelligently compactify messages using the memory manager. 110 | 111 | This is a convenience function that works with a MemoryManager instance. 112 | 113 | Args: 114 | memory_manager: MemoryManager instance 115 | provider: LLM provider 116 | model: Model name 117 | **kwargs: Additional LLM arguments 118 | 119 | Returns: 120 | Tuple of (new_messages, total_token_count) 121 | """ 122 | if not memory_manager.should_compactify(): 123 | return ( 124 | memory_manager.get_current_messages(), 125 | memory_manager.history.total_tokens, 126 | ) 127 | 128 | # Get messages to summarize and preserve 129 | system_messages, messages_to_summarize, preserved_messages = ( 130 | memory_manager.get_messages_for_compactification() 131 | ) 132 | 133 | # Perform compactification 134 | new_messages, new_token_count = await compactify_messages( 135 | system_messages=system_messages, 136 | messages_to_summarize=messages_to_summarize, 137 | preserved_messages=preserved_messages, 138 | provider=provider, 139 | model=model, 140 | max_summary_tokens=memory_manager.summary_max_tokens, 141 | **kwargs, 142 | ) 143 | 144 | # Update memory manager 145 | if new_messages and len(new_messages) > len(system_messages): 146 | # Find the summary message (first non-system message in the result) 147 | summary_idx = len(system_messages) 148 | summary_message = new_messages[summary_idx] 149 | memory_manager.update_after_compactification( 150 | system_messages=system_messages, 151 | summary_message=summary_message, 152 | preserved_messages=preserved_messages, 153 | new_token_count=new_token_count, 154 | ) 155 | 156 | return new_messages, new_token_count 157 | -------------------------------------------------------------------------------- /defog/llm/memory/history_manager.py: -------------------------------------------------------------------------------- 1 | """Manages conversation history and memory for LLM interactions.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import List, Dict, Optional, Any 5 | from datetime import datetime 6 | import copy 7 | 8 | 9 | @dataclass 10 | class ConversationHistory: 11 | """Container for conversation messages with metadata.""" 12 | 13 | messages: List[Dict[str, Any]] = field(default_factory=list) 14 | total_tokens: int = 0 15 | created_at: datetime = field(default_factory=datetime.now) 16 | last_compactified_at: Optional[datetime] = None 17 | compactification_count: int = 0 18 | 19 | def add_message(self, message: Dict[str, Any], tokens: int = 0) -> None: 20 | """Add a message to the history.""" 21 | self.messages.append(copy.deepcopy(message)) 22 | self.total_tokens += tokens 23 | 24 | def get_messages(self) -> List[Dict[str, Any]]: 25 | """Get a copy of all messages.""" 26 | return copy.deepcopy(self.messages) 27 | 28 | def clear(self) -> None: 29 | """Clear all messages and reset token count.""" 30 | self.messages.clear() 31 | self.total_tokens = 0 32 | 33 | def replace_messages( 34 | self, new_messages: List[Dict[str, Any]], new_token_count: int 35 | ) -> None: 36 | """Replace all messages with new ones.""" 37 | self.messages = copy.deepcopy(new_messages) 38 | self.total_tokens = new_token_count 39 | self.last_compactified_at = datetime.now() 40 | self.compactification_count += 1 41 | 42 | 43 | class MemoryManager: 44 | """Manages conversation memory with automatic compactification.""" 45 | 46 | def __init__( 47 | self, 48 | token_threshold: int = 50000, # Default to ~100k tokens before compactifying 49 | preserve_last_n_messages: int = 10, 50 | summary_max_tokens: int = 4000, 51 | enabled: bool = True, 52 | ): 53 | """ 54 | Initialize the memory manager. 55 | 56 | Args: 57 | token_threshold: Number of tokens before triggering compactification 58 | preserve_last_n_messages: Number of recent messages to always preserve 59 | summary_max_tokens: Maximum tokens for the summary 60 | enabled: Whether memory management is enabled 61 | """ 62 | self.token_threshold = token_threshold 63 | self.preserve_last_n_messages = preserve_last_n_messages 64 | self.summary_max_tokens = summary_max_tokens 65 | self.enabled = enabled 66 | self.history = ConversationHistory() 67 | 68 | def should_compactify(self) -> bool: 69 | """Check if memory should be compactified based on token count.""" 70 | if not self.enabled: 71 | return False 72 | return self.history.total_tokens >= self.token_threshold 73 | 74 | def add_messages(self, messages: List[Dict[str, Any]], tokens: int) -> None: 75 | """Add multiple messages to history.""" 76 | for message in messages: 77 | self.history.add_message( 78 | message, tokens // len(messages) if messages else 0 79 | ) 80 | 81 | def get_messages_for_compactification( 82 | self, 83 | ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: 84 | """ 85 | Split messages into system messages, messages to summarize, and messages to preserve. 86 | 87 | System messages are always preserved at the beginning. 88 | Only user/assistant messages are eligible for summarization. 89 | 90 | Returns: 91 | Tuple of (system_messages, messages_to_summarize, messages_to_preserve) 92 | """ 93 | all_messages = self.history.get_messages() 94 | 95 | # Separate system messages from others 96 | system_messages = [] 97 | other_messages = [] 98 | 99 | for msg in all_messages: 100 | if msg.get("role") == "system": 101 | system_messages.append(msg) 102 | else: 103 | other_messages.append(msg) 104 | 105 | # Apply preservation logic only to non-system messages 106 | if len(other_messages) <= self.preserve_last_n_messages: 107 | return system_messages, [], other_messages 108 | 109 | split_index = len(other_messages) - self.preserve_last_n_messages 110 | return ( 111 | system_messages, 112 | other_messages[:split_index], 113 | other_messages[split_index:], 114 | ) 115 | 116 | def update_after_compactification( 117 | self, 118 | system_messages: List[Dict[str, Any]], 119 | summary_message: Dict[str, Any], 120 | preserved_messages: List[Dict[str, Any]], 121 | new_token_count: int, 122 | ) -> None: 123 | """Update history after compactification.""" 124 | # Combine: system messages + summary + preserved messages 125 | new_messages = system_messages + [summary_message] + preserved_messages 126 | self.history.replace_messages(new_messages, new_token_count) 127 | 128 | def get_current_messages(self) -> List[Dict[str, Any]]: 129 | """Get current conversation messages.""" 130 | return self.history.get_messages() 131 | 132 | def get_stats(self) -> Dict[str, Any]: 133 | """Get memory statistics.""" 134 | return { 135 | "total_tokens": self.history.total_tokens, 136 | "message_count": len(self.history.messages), 137 | "compactification_count": self.history.compactification_count, 138 | "last_compactified_at": ( 139 | self.history.last_compactified_at.isoformat() 140 | if self.history.last_compactified_at 141 | else None 142 | ), 143 | "enabled": self.enabled, 144 | "token_threshold": self.token_threshold, 145 | } 146 | -------------------------------------------------------------------------------- /defog/llm/memory/token_counter.py: -------------------------------------------------------------------------------- 1 | """Token counting utilities for different LLM providers.""" 2 | 3 | from typing import List, Dict, Any, Optional, Union 4 | import tiktoken 5 | from functools import lru_cache 6 | import json 7 | 8 | 9 | class TokenCounter: 10 | """ 11 | Accurate token counting for different LLM providers. 12 | 13 | Uses: 14 | - tiktoken for OpenAI models (and as fallback for all other providers) 15 | - API endpoints for Anthropic/Gemini when client is provided 16 | """ 17 | 18 | def __init__(self): 19 | self._encoding_cache = {} 20 | 21 | @lru_cache(maxsize=10) 22 | def _get_openai_encoding(self, model: str): 23 | """Get and cache tiktoken encoding for OpenAI models.""" 24 | try: 25 | # Try to get encoding for specific model 26 | return tiktoken.encoding_for_model(model) 27 | except KeyError: 28 | # Default encodings for different model families 29 | if "gpt-4o" in model: 30 | return tiktoken.get_encoding("o200k_base") 31 | else: 32 | # Default for gpt-4, gpt-3.5-turbo, and others 33 | return tiktoken.get_encoding("cl100k_base") 34 | 35 | def count_openai_tokens( 36 | self, messages: Union[str, List[Dict[str, Any]]], model: str = "gpt-4" 37 | ) -> int: 38 | """ 39 | Count tokens for OpenAI models using tiktoken. 40 | 41 | Args: 42 | messages: Either a string or list of message dicts 43 | model: OpenAI model name 44 | 45 | Returns: 46 | Token count 47 | """ 48 | encoding = self._get_openai_encoding(model) 49 | 50 | if isinstance(messages, str): 51 | return len(encoding.encode(messages)) 52 | 53 | # Handle message list format 54 | # Based on OpenAI's token counting guide 55 | tokens_per_message = ( 56 | 3 # Every message follows <|im_start|>{role}\n{content}<|im_end|>\n 57 | ) 58 | tokens_per_name = 1 # If there's a name, the role is omitted 59 | 60 | total_tokens = 0 61 | for message in messages: 62 | total_tokens += tokens_per_message 63 | 64 | for key, value in message.items(): 65 | if key == "role": 66 | total_tokens += len(encoding.encode(value)) 67 | elif key == "content": 68 | if isinstance(value, str): 69 | total_tokens += len(encoding.encode(value)) 70 | else: 71 | # Handle tool calls or other structured content 72 | total_tokens += len(encoding.encode(json.dumps(value))) 73 | elif key == "name": 74 | total_tokens += tokens_per_name 75 | total_tokens += len(encoding.encode(value)) 76 | 77 | total_tokens += 3 # Every reply is primed with <|im_start|>assistant<|im_sep|> 78 | return total_tokens 79 | 80 | async def count_anthropic_tokens( 81 | self, messages: List[Dict[str, Any]], model: str, client: Optional[Any] = None 82 | ) -> int: 83 | """ 84 | Count tokens for Anthropic models using their API. 85 | 86 | Args: 87 | messages: List of message dicts 88 | model: Anthropic model name 89 | client: Optional Anthropic client instance 90 | 91 | Returns: 92 | Token count 93 | """ 94 | if client is None: 95 | # Use OpenAI tokenizer as approximation 96 | return self.count_openai_tokens(messages, "gpt-4") 97 | 98 | try: 99 | # Use Anthropic's token counting endpoint 100 | response = await client.messages.count_tokens( 101 | model=model, messages=messages 102 | ) 103 | return response.input_tokens 104 | except Exception: 105 | # Fallback to OpenAI tokenizer 106 | return self.count_openai_tokens(messages, "gpt-4") 107 | 108 | def count_gemini_tokens( 109 | self, 110 | content: Union[str, List[Dict[str, Any]]], 111 | model: str, 112 | client: Optional[Any] = None, 113 | ) -> int: 114 | """ 115 | Count tokens for Gemini models. 116 | 117 | Args: 118 | content: Text or message list 119 | model: Gemini model name 120 | client: Optional Gemini client instance 121 | 122 | Returns: 123 | Token count 124 | """ 125 | if client is None: 126 | # Use OpenAI tokenizer as approximation 127 | return self.count_openai_tokens(content, "gpt-4") 128 | 129 | try: 130 | # Extract text content 131 | text = self._extract_text(content) 132 | 133 | # Use Gemini's count_tokens method 134 | response = client.count_tokens(text) 135 | return response.total_tokens 136 | except Exception: 137 | # Fallback to OpenAI tokenizer 138 | return self.count_openai_tokens(content, "gpt-4") 139 | 140 | def count_together_tokens( 141 | self, messages: Union[str, List[Dict[str, Any]]], model: str 142 | ) -> int: 143 | """ 144 | Count tokens for Together models using OpenAI tokenizer as approximation. 145 | 146 | Args: 147 | messages: Text or message list 148 | model: Together model name 149 | 150 | Returns: 151 | Estimated token count 152 | """ 153 | # Use OpenAI tokenizer as approximation 154 | return self.count_openai_tokens(messages, "gpt-4") 155 | 156 | def count_tokens( 157 | self, 158 | messages: Union[str, List[Dict[str, Any]]], 159 | model: str, 160 | provider: str, 161 | client: Optional[Any] = None, 162 | ) -> int: 163 | """ 164 | Universal token counting method. 165 | 166 | Args: 167 | messages: Text or message list 168 | model: Model name 169 | provider: Provider name (openai, anthropic, gemini, together) 170 | client: Optional provider client for API-based counting 171 | 172 | Returns: 173 | Token count 174 | """ 175 | provider_lower = provider.lower() 176 | 177 | if provider_lower == "openai": 178 | return self.count_openai_tokens(messages, model) 179 | elif provider_lower == "anthropic": 180 | # Anthropic count_tokens is async, so for sync context use OpenAI approximation 181 | if ( 182 | client 183 | and hasattr(client, "messages") 184 | and hasattr(client.messages, "count_tokens") 185 | ): 186 | # This would need to be called in an async context 187 | return self.count_openai_tokens(messages, "gpt-4") 188 | return self.count_openai_tokens(messages, "gpt-4") 189 | elif provider_lower == "gemini": 190 | return self.count_gemini_tokens(messages, model, client) 191 | elif provider_lower == "together": 192 | return self.count_together_tokens(messages, model) 193 | else: 194 | # Default to OpenAI tokenizer 195 | return self.count_openai_tokens(messages, "gpt-4") 196 | 197 | def _extract_text(self, content: Union[str, List[Dict[str, Any]]]) -> str: 198 | """Extract text from various content formats.""" 199 | if isinstance(content, str): 200 | return content 201 | 202 | if isinstance(content, list): 203 | texts = [] 204 | for item in content: 205 | if isinstance(item, dict): 206 | if "content" in item: 207 | texts.append(str(item["content"])) 208 | elif "text" in item: 209 | texts.append(str(item["text"])) 210 | else: 211 | texts.append(str(item)) 212 | return " ".join(texts) 213 | 214 | return str(content) 215 | 216 | def estimate_remaining_tokens( 217 | self, 218 | messages: Union[str, List[Dict[str, Any]]], 219 | model: str, 220 | provider: str, 221 | max_context_tokens: int = 128000, 222 | response_buffer: int = 4000, 223 | client: Optional[Any] = None, 224 | ) -> int: 225 | """ 226 | Estimate remaining tokens in context window. 227 | 228 | Args: 229 | messages: Current messages 230 | model: Model name 231 | provider: Provider name 232 | max_context_tokens: Maximum context window size 233 | response_buffer: Tokens to reserve for response 234 | client: Optional provider client 235 | 236 | Returns: 237 | Estimated remaining tokens 238 | """ 239 | used_tokens = self.count_tokens(messages, model, provider, client) 240 | return max(0, max_context_tokens - used_tokens - response_buffer) 241 | -------------------------------------------------------------------------------- /defog/llm/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from enum import Enum 3 | from typing import Optional, Union, Dict, Any, Literal 4 | 5 | 6 | class OpenAIFunctionSpecs(BaseModel): 7 | name: str # name of the function to call 8 | description: Optional[str] = None # description of the function 9 | parameters: Optional[Union[str, Dict[str, Any]]] = ( 10 | None # parameters of the function 11 | ) 12 | 13 | 14 | class AnthropicFunctionSpecs(BaseModel): 15 | name: str # name of the function to call 16 | description: Optional[str] = None # description of the function 17 | input_schema: Optional[Union[str, Dict[str, Any]]] = ( 18 | None # parameters of the function 19 | ) 20 | -------------------------------------------------------------------------------- /defog/llm/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLLMProvider, LLMResponse 2 | from .anthropic_provider import AnthropicProvider 3 | from .openai_provider import OpenAIProvider 4 | from .gemini_provider import GeminiProvider 5 | from .together_provider import TogetherProvider 6 | 7 | __all__ = [ 8 | "BaseLLMProvider", 9 | "LLMResponse", 10 | "AnthropicProvider", 11 | "OpenAIProvider", 12 | "GeminiProvider", 13 | "TogetherProvider", 14 | ] 15 | -------------------------------------------------------------------------------- /defog/llm/providers/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Any, Optional, Callable, Tuple 3 | from dataclasses import dataclass 4 | from ..config.settings import LLMConfig 5 | 6 | 7 | @dataclass 8 | class LLMResponse: 9 | content: Any 10 | model: str 11 | time: float 12 | input_tokens: int 13 | output_tokens: int 14 | cached_input_tokens: Optional[int] = None 15 | output_tokens_details: Optional[Dict[str, int]] = None 16 | cost_in_cents: Optional[float] = None 17 | tool_outputs: Optional[List[Dict[str, Any]]] = None 18 | 19 | 20 | class BaseLLMProvider(ABC): 21 | """Abstract base class for all LLM providers.""" 22 | 23 | def __init__( 24 | self, 25 | api_key: Optional[str] = None, 26 | base_url: Optional[str] = None, 27 | config: Optional[LLMConfig] = None, 28 | ): 29 | self.api_key = api_key 30 | self.base_url = base_url 31 | self.config = config or LLMConfig() 32 | 33 | @abstractmethod 34 | def get_provider_name(self) -> str: 35 | """Return the name of the provider.""" 36 | pass 37 | 38 | @abstractmethod 39 | def build_params( 40 | self, 41 | messages: List[Dict[str, str]], 42 | model: str, 43 | max_completion_tokens: Optional[int] = None, 44 | temperature: float = 0.0, 45 | response_format=None, 46 | seed: int = 0, 47 | tools: Optional[List[Callable]] = None, 48 | tool_choice: Optional[str] = None, 49 | store: bool = True, 50 | metadata: Optional[Dict[str, str]] = None, 51 | timeout: int = 100, 52 | prediction: Optional[Dict[str, str]] = None, 53 | reasoning_effort: Optional[str] = None, 54 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 55 | **kwargs 56 | ) -> Tuple[Dict[str, Any], List[Dict[str, str]]]: 57 | """Build parameters for the provider's API call.""" 58 | pass 59 | 60 | @abstractmethod 61 | async def process_response( 62 | self, 63 | client, 64 | response, 65 | request_params: Dict[str, Any], 66 | tools: Optional[List[Callable]], 67 | tool_dict: Dict[str, Callable], 68 | response_format=None, 69 | post_tool_function: Optional[Callable] = None, 70 | **kwargs 71 | ) -> Tuple[ 72 | Any, List[Dict[str, Any]], int, int, Optional[int], Optional[Dict[str, int]] 73 | ]: 74 | """Process the response from the provider.""" 75 | pass 76 | 77 | @abstractmethod 78 | async def execute_chat( 79 | self, 80 | messages: List[Dict[str, str]], 81 | model: str, 82 | max_completion_tokens: Optional[int] = None, 83 | temperature: float = 0.0, 84 | response_format=None, 85 | seed: int = 0, 86 | tools: Optional[List[Callable]] = None, 87 | tool_choice: Optional[str] = None, 88 | store: bool = True, 89 | metadata: Optional[Dict[str, str]] = None, 90 | timeout: int = 100, 91 | prediction: Optional[Dict[str, str]] = None, 92 | reasoning_effort: Optional[str] = None, 93 | post_tool_function: Optional[Callable] = None, 94 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 95 | **kwargs 96 | ) -> LLMResponse: 97 | """Execute a chat completion with the provider.""" 98 | pass 99 | 100 | @abstractmethod 101 | def supports_tools(self, model: str) -> bool: 102 | """Check if the model supports tool calling.""" 103 | pass 104 | 105 | @abstractmethod 106 | def supports_response_format(self, model: str) -> bool: 107 | """Check if the model supports structured response formats.""" 108 | pass 109 | -------------------------------------------------------------------------------- /defog/llm/providers/together_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Dict, List, Any, Optional, Callable, Tuple 4 | 5 | from .base import BaseLLMProvider, LLMResponse 6 | from ..exceptions import ProviderError, MaxTokensError 7 | from ..cost import CostCalculator 8 | 9 | 10 | class TogetherProvider(BaseLLMProvider): 11 | """Together AI provider implementation.""" 12 | 13 | def __init__(self, api_key: Optional[str] = None, config=None): 14 | super().__init__(api_key or os.getenv("TOGETHER_API_KEY"), config=config) 15 | 16 | def get_provider_name(self) -> str: 17 | return "together" 18 | 19 | def supports_tools(self, model: str) -> bool: 20 | return ( 21 | False # Currently Together models don't support tools in our implementation 22 | ) 23 | 24 | def supports_response_format(self, model: str) -> bool: 25 | return False # Currently Together models don't support structured output in our implementation 26 | 27 | def build_params( 28 | self, 29 | messages: List[Dict[str, str]], 30 | model: str, 31 | max_completion_tokens: Optional[int] = None, 32 | temperature: float = 0.0, 33 | response_format=None, 34 | seed: int = 0, 35 | tools: Optional[List[Callable]] = None, 36 | tool_choice: Optional[str] = None, 37 | store: bool = True, 38 | metadata: Optional[Dict[str, str]] = None, 39 | timeout: int = 100, 40 | prediction: Optional[Dict[str, str]] = None, 41 | reasoning_effort: Optional[str] = None, 42 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 43 | **kwargs, 44 | ) -> Tuple[Dict[str, Any], List[Dict[str, str]]]: 45 | """Build parameters for Together's API call.""" 46 | return { 47 | "messages": messages, 48 | "model": model, 49 | "max_tokens": max_completion_tokens, 50 | "temperature": temperature, 51 | "seed": seed, 52 | }, messages 53 | 54 | async def process_response( 55 | self, 56 | client, 57 | response, 58 | request_params: Dict[str, Any], 59 | tools: Optional[List[Callable]], 60 | tool_dict: Dict[str, Callable], 61 | response_format=None, 62 | post_tool_function: Optional[Callable] = None, 63 | **kwargs, 64 | ) -> Tuple[ 65 | Any, List[Dict[str, Any]], int, int, Optional[int], Optional[Dict[str, int]] 66 | ]: 67 | """Process Together API response.""" 68 | if response.choices[0].finish_reason == "length": 69 | raise MaxTokensError("Max tokens reached") 70 | if len(response.choices) == 0: 71 | raise MaxTokensError("Max tokens reached") 72 | 73 | content = response.choices[0].message.content 74 | input_tokens = response.usage.prompt_tokens 75 | output_tokens = response.usage.completion_tokens 76 | 77 | return content, [], input_tokens, output_tokens, None, None 78 | 79 | async def execute_chat( 80 | self, 81 | messages: List[Dict[str, str]], 82 | model: str, 83 | max_completion_tokens: Optional[int] = None, 84 | temperature: float = 0.0, 85 | response_format=None, 86 | seed: int = 0, 87 | tools: Optional[List[Callable]] = None, 88 | tool_choice: Optional[str] = None, 89 | store: bool = True, 90 | metadata: Optional[Dict[str, str]] = None, 91 | timeout: int = 100, 92 | prediction: Optional[Dict[str, str]] = None, 93 | reasoning_effort: Optional[str] = None, 94 | post_tool_function: Optional[Callable] = None, 95 | mcp_servers: Optional[List[Dict[str, Any]]] = None, 96 | **kwargs, 97 | ) -> LLMResponse: 98 | """Execute a chat completion with Together.""" 99 | from together import AsyncTogether 100 | 101 | t = time.time() 102 | client_together = AsyncTogether(timeout=timeout) 103 | params, _ = self.build_params( 104 | messages=messages, 105 | model=model, 106 | max_completion_tokens=max_completion_tokens, 107 | temperature=temperature, 108 | seed=seed, 109 | ) 110 | 111 | try: 112 | response = await client_together.chat.completions.create(**params) 113 | ( 114 | content, 115 | tool_outputs, 116 | input_toks, 117 | output_toks, 118 | cached_toks, 119 | output_details, 120 | ) = await self.process_response( 121 | client=client_together, 122 | response=response, 123 | request_params=params, 124 | tools=tools, 125 | tool_dict={}, 126 | response_format=response_format, 127 | post_tool_function=post_tool_function, 128 | ) 129 | except Exception as e: 130 | raise ProviderError(self.get_provider_name(), f"API call failed: {e}", e) 131 | 132 | # Calculate cost 133 | cost = CostCalculator.calculate_cost( 134 | model, input_toks, output_toks, cached_toks 135 | ) 136 | 137 | return LLMResponse( 138 | model=model, 139 | content=content, 140 | time=round(time.time() - t, 3), 141 | input_tokens=input_toks, 142 | output_tokens=output_toks, 143 | cached_input_tokens=cached_toks, 144 | output_tokens_details=output_details, 145 | cost_in_cents=cost, 146 | tool_outputs=tool_outputs, 147 | ) 148 | -------------------------------------------------------------------------------- /defog/llm/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .handler import ToolHandler 2 | 3 | __all__ = ["ToolHandler"] 4 | -------------------------------------------------------------------------------- /defog/llm/tools/handler.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import asyncio 3 | from typing import Dict, List, Callable, Any, Optional 4 | from ..exceptions import ToolError 5 | from ..utils_function_calling import ( 6 | execute_tool, 7 | execute_tool_async, 8 | execute_tools_parallel, 9 | verify_post_tool_function, 10 | ) 11 | 12 | 13 | class ToolHandler: 14 | """Handles tool calling logic for LLM providers.""" 15 | 16 | def __init__(self, max_consecutive_errors: int = 3): 17 | self.max_consecutive_errors = max_consecutive_errors 18 | 19 | async def execute_tool_call( 20 | self, 21 | tool_name: str, 22 | args: Dict[str, Any], 23 | tool_dict: Dict[str, Callable], 24 | post_tool_function: Optional[Callable] = None, 25 | ) -> Any: 26 | """Execute a single tool call.""" 27 | tool_to_call = tool_dict.get(tool_name) 28 | if tool_to_call is None: 29 | raise ToolError(tool_name, "Tool not found") 30 | 31 | try: 32 | # Execute tool depending on whether it is async 33 | if inspect.iscoroutinefunction(tool_to_call): 34 | result = await execute_tool_async(tool_to_call, args) 35 | else: 36 | result = execute_tool(tool_to_call, args) 37 | except Exception as e: 38 | raise ToolError(tool_name, f"Error executing tool: {e}", e) 39 | 40 | # Execute post-tool function if provided 41 | if post_tool_function: 42 | try: 43 | if inspect.iscoroutinefunction(post_tool_function): 44 | await post_tool_function( 45 | function_name=tool_name, 46 | input_args=args, 47 | tool_result=result, 48 | ) 49 | else: 50 | post_tool_function( 51 | function_name=tool_name, 52 | input_args=args, 53 | tool_result=result, 54 | ) 55 | except Exception as e: 56 | raise ToolError( 57 | tool_name, f"Error executing post_tool_function: {e}", e 58 | ) 59 | 60 | return result 61 | 62 | async def execute_tool_calls_batch( 63 | self, 64 | tool_calls: List[Dict[str, Any]], 65 | tool_dict: Dict[str, Callable], 66 | enable_parallel: bool = False, 67 | post_tool_function: Optional[Callable] = None, 68 | ) -> List[Any]: 69 | """Execute multiple tool calls either in parallel or sequentially.""" 70 | try: 71 | results = await execute_tools_parallel( 72 | tool_calls, tool_dict, enable_parallel 73 | ) 74 | 75 | # Execute post-tool function for each result if provided 76 | if post_tool_function: 77 | for i, (tool_call, result) in enumerate(zip(tool_calls, results)): 78 | func_name = tool_call.get("function", {}).get( 79 | "name" 80 | ) or tool_call.get("name") 81 | func_args = tool_call.get("function", {}).get( 82 | "arguments" 83 | ) or tool_call.get("arguments", {}) 84 | 85 | try: 86 | if inspect.iscoroutinefunction(post_tool_function): 87 | await post_tool_function( 88 | function_name=func_name, 89 | input_args=func_args, 90 | tool_result=result, 91 | ) 92 | else: 93 | post_tool_function( 94 | function_name=func_name, 95 | input_args=func_args, 96 | tool_result=result, 97 | ) 98 | except Exception as e: 99 | # Don't fail the entire batch for post-tool function errors 100 | print( 101 | f"Warning: Error executing post_tool_function for {func_name}: {e}" 102 | ) 103 | 104 | return results 105 | except Exception as e: 106 | raise ToolError("batch", f"Error executing tool batch: {e}", e) 107 | 108 | def build_tool_dict(self, tools: List[Callable]) -> Dict[str, Callable]: 109 | """Build a dictionary mapping tool names to functions.""" 110 | return {tool.__name__: tool for tool in tools} 111 | 112 | def validate_post_tool_function( 113 | self, post_tool_function: Optional[Callable] 114 | ) -> None: 115 | """Validate the post-tool function signature.""" 116 | if post_tool_function: 117 | verify_post_tool_function(post_tool_function) 118 | -------------------------------------------------------------------------------- /defog/llm/utils_memory.py: -------------------------------------------------------------------------------- 1 | """Chat utilities with memory management support.""" 2 | 3 | from typing import Dict, List, Optional, Any, Union, Callable 4 | from dataclasses import dataclass 5 | 6 | from .utils import chat_async 7 | from .llm_providers import LLMProvider 8 | from .providers.base import LLMResponse 9 | from .config import LLMConfig 10 | from .memory import MemoryManager, compactify_messages, TokenCounter 11 | 12 | 13 | @dataclass 14 | class MemoryConfig: 15 | """Configuration for memory management.""" 16 | 17 | enabled: bool = True 18 | token_threshold: int = 50000 # ~50k tokens before compactifying 19 | preserve_last_n_messages: int = 10 20 | summary_max_tokens: int = 4000 21 | max_context_tokens: int = 128000 # 128k context window 22 | 23 | 24 | async def chat_async_with_memory( 25 | provider: Union[LLMProvider, str], 26 | model: str, 27 | messages: List[Dict[str, str]], 28 | memory_manager: Optional[MemoryManager] = None, 29 | memory_config: Optional[MemoryConfig] = None, 30 | auto_compactify: bool = True, 31 | max_completion_tokens: Optional[int] = None, 32 | temperature: float = 0.0, 33 | response_format=None, 34 | seed: int = 0, 35 | store: bool = True, 36 | metadata: Optional[Dict[str, str]] = None, 37 | timeout: int = 100, 38 | backup_model: Optional[str] = None, 39 | backup_provider: Optional[Union[LLMProvider, str]] = None, 40 | prediction: Optional[Dict[str, str]] = None, 41 | reasoning_effort: Optional[str] = None, 42 | tools: Optional[List[Callable]] = None, 43 | tool_choice: Optional[str] = None, 44 | max_retries: Optional[int] = None, 45 | post_tool_function: Optional[Callable] = None, 46 | config: Optional[LLMConfig] = None, 47 | ) -> LLMResponse: 48 | """ 49 | Execute a chat completion with memory management support. 50 | 51 | This function extends chat_async with automatic conversation memory management 52 | and compactification when approaching token limits. 53 | 54 | Args: 55 | provider: LLM provider to use 56 | model: Model name 57 | messages: List of message dictionaries 58 | memory_manager: Optional MemoryManager instance (created if not provided) 59 | memory_config: Memory configuration settings 60 | auto_compactify: Whether to automatically compactify when needed 61 | ... (all other chat_async parameters) 62 | 63 | Returns: 64 | LLMResponse object with the result 65 | """ 66 | # Initialize memory config if not provided 67 | if memory_config is None: 68 | memory_config = MemoryConfig() 69 | 70 | # Initialize memory manager if not provided and memory is enabled 71 | if memory_manager is None and memory_config.enabled: 72 | memory_manager = MemoryManager( 73 | token_threshold=memory_config.token_threshold, 74 | preserve_last_n_messages=memory_config.preserve_last_n_messages, 75 | summary_max_tokens=memory_config.summary_max_tokens, 76 | enabled=memory_config.enabled, 77 | ) 78 | 79 | # If memory is disabled, just pass through to regular chat_async 80 | if not memory_config.enabled or memory_manager is None: 81 | return await chat_async( 82 | provider=provider, 83 | model=model, 84 | messages=messages, 85 | max_completion_tokens=max_completion_tokens, 86 | temperature=temperature, 87 | response_format=response_format, 88 | seed=seed, 89 | store=store, 90 | metadata=metadata, 91 | timeout=timeout, 92 | backup_model=backup_model, 93 | backup_provider=backup_provider, 94 | prediction=prediction, 95 | reasoning_effort=reasoning_effort, 96 | tools=tools, 97 | tool_choice=tool_choice, 98 | max_retries=max_retries, 99 | post_tool_function=post_tool_function, 100 | config=config, 101 | ) 102 | 103 | # Get current messages from memory manager 104 | current_messages = memory_manager.get_current_messages() 105 | 106 | # Add new messages to memory 107 | token_counter = TokenCounter() 108 | new_tokens = token_counter.count_tokens(messages, model, str(provider)) 109 | memory_manager.add_messages(messages, new_tokens) 110 | 111 | # Check if we should compactify 112 | if auto_compactify and memory_manager.should_compactify(): 113 | # Get messages to summarize and preserve 114 | system_messages, messages_to_summarize, preserved_messages = ( 115 | memory_manager.get_messages_for_compactification() 116 | ) 117 | 118 | # Compactify messages 119 | compactified_messages, new_token_count = await compactify_messages( 120 | system_messages=system_messages, 121 | messages_to_summarize=messages_to_summarize, 122 | preserved_messages=preserved_messages, 123 | provider=str(provider), 124 | model=model, 125 | max_summary_tokens=memory_config.summary_max_tokens, 126 | config=config, # Pass config for API credentials 127 | ) 128 | 129 | # Update memory manager with compactified messages 130 | if compactified_messages and len(compactified_messages) > len(system_messages): 131 | # Find the summary message (first non-system message in the result) 132 | summary_idx = len(system_messages) 133 | summary_message = compactified_messages[summary_idx] 134 | memory_manager.update_after_compactification( 135 | system_messages=system_messages, 136 | summary_message=summary_message, 137 | preserved_messages=preserved_messages, 138 | new_token_count=new_token_count, 139 | ) 140 | 141 | # Use compactified messages for the API call 142 | messages_for_api = compactified_messages 143 | else: 144 | # Use all messages from memory 145 | messages_for_api = memory_manager.get_current_messages() 146 | 147 | # Make the API call with the potentially compactified messages 148 | response = await chat_async( 149 | provider=provider, 150 | model=model, 151 | messages=messages_for_api, 152 | max_completion_tokens=max_completion_tokens, 153 | temperature=temperature, 154 | response_format=response_format, 155 | seed=seed, 156 | store=store, 157 | metadata=metadata, 158 | timeout=timeout, 159 | backup_model=backup_model, 160 | backup_provider=backup_provider, 161 | prediction=prediction, 162 | reasoning_effort=reasoning_effort, 163 | tools=tools, 164 | tool_choice=tool_choice, 165 | max_retries=max_retries, 166 | post_tool_function=post_tool_function, 167 | config=config, 168 | ) 169 | 170 | # Add the assistant's response to memory 171 | assistant_message = {"role": "assistant", "content": response.content} 172 | response_tokens = response.output_tokens or 0 173 | memory_manager.add_messages([assistant_message], response_tokens) 174 | 175 | # Add memory stats to response metadata 176 | if hasattr(response, "_memory_stats"): 177 | response._memory_stats = memory_manager.get_stats() 178 | 179 | return response 180 | 181 | 182 | # Convenience function for creating a memory manager 183 | def create_memory_manager( 184 | token_threshold: int = 100000, 185 | preserve_last_n_messages: int = 10, 186 | summary_max_tokens: int = 2000, 187 | enabled: bool = True, 188 | ) -> MemoryManager: 189 | """ 190 | Create a new MemoryManager instance. 191 | 192 | Args: 193 | token_threshold: Token count threshold for triggering compactification 194 | preserve_last_n_messages: Number of recent messages to always preserve 195 | summary_max_tokens: Maximum tokens for the summary 196 | enabled: Whether memory management is enabled 197 | 198 | Returns: 199 | Configured MemoryManager instance 200 | """ 201 | return MemoryManager( 202 | token_threshold=token_threshold, 203 | preserve_last_n_messages=preserve_last_n_messages, 204 | summary_max_tokens=summary_max_tokens, 205 | enabled=enabled, 206 | ) 207 | -------------------------------------------------------------------------------- /defog/llm/web_search.py: -------------------------------------------------------------------------------- 1 | from defog.llm.llm_providers import LLMProvider 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | 5 | 6 | async def web_search_tool( 7 | question: str, 8 | model: str, 9 | provider: LLMProvider, 10 | max_tokens: int = 2048, 11 | ): 12 | """ 13 | Search the web for the answer to the question. 14 | """ 15 | async with ToolProgressTracker( 16 | "Web Search", 17 | f"Searching for: {question[:50]}{'...' if len(question) > 50 else ''}", 18 | ) as tracker: 19 | subtask_logger = SubTaskLogger() 20 | subtask_logger.log_provider_info( 21 | provider.value if hasattr(provider, "value") else str(provider), model 22 | ) 23 | 24 | if provider in [LLMProvider.OPENAI, LLMProvider.OPENAI.value]: 25 | from openai import AsyncOpenAI 26 | 27 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 28 | 29 | tracker.update(20, "Initiating web search") 30 | subtask_logger.log_search_status(question) 31 | 32 | response = await client.responses.create( 33 | model=model, 34 | tools=[{"type": "web_search_preview"}], 35 | tool_choice="required", 36 | input=question, 37 | # in the responses API, this means both the reasoning and the output tokens 38 | max_output_tokens=max_tokens, 39 | ) 40 | tracker.update(80, "Processing search results") 41 | subtask_logger.log_subtask("Extracting citations and content", "processing") 42 | 43 | usage = { 44 | "input_tokens": response.usage.input_tokens, 45 | "output_tokens": response.usage.output_tokens, 46 | } 47 | output_text = response.output_text 48 | websites_cited = [] 49 | for output in response.output: 50 | if hasattr(output, "content") and output.content: 51 | for content in output.content: 52 | if content.annotations: 53 | for annotation in content.annotations: 54 | websites_cited.append( 55 | { 56 | "url": annotation.url, 57 | "title": annotation.title, 58 | } 59 | ) 60 | 61 | subtask_logger.log_result_summary( 62 | "Web Search", 63 | { 64 | "websites_found": len(websites_cited), 65 | "tokens_used": usage["input_tokens"] + usage["output_tokens"], 66 | }, 67 | ) 68 | 69 | return { 70 | "usage": usage, 71 | "search_results": output_text, 72 | "websites_cited": websites_cited, 73 | } 74 | 75 | elif provider in [LLMProvider.ANTHROPIC, LLMProvider.ANTHROPIC.value]: 76 | from anthropic import AsyncAnthropic 77 | from anthropic.types import TextBlock 78 | 79 | client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 80 | 81 | tracker.update(20, "Initiating web search") 82 | subtask_logger.log_search_status(question, max_results=5) 83 | 84 | response = await client.messages.create( 85 | model=model, 86 | max_tokens=max_tokens, 87 | messages=[{"role": "user", "content": question}], 88 | tools=[ 89 | { 90 | "type": "web_search_20250305", 91 | "name": "web_search", 92 | "max_uses": 5, 93 | # can also use allowed_domains to limit the search to specific domains 94 | # can also use blocked_domains to exclude specific domains 95 | } 96 | ], 97 | tool_choice={"type": "any"}, 98 | ) 99 | 100 | tracker.update(80, "Processing search results") 101 | subtask_logger.log_subtask("Extracting citations and content", "processing") 102 | 103 | usage = { 104 | "input_tokens": response.usage.input_tokens, 105 | "output_tokens": response.usage.output_tokens, 106 | } 107 | search_results = response.content 108 | # we want to use only the TextBlock class in the search results 109 | search_results = [ 110 | block for block in search_results if isinstance(block, TextBlock) 111 | ] 112 | 113 | # convert the search_results into simple text with citations 114 | # (where citations = text + hyperlinks 115 | output_text = [ 116 | ( 117 | f'' + block.text + "" 118 | if block.citations 119 | else block.text 120 | ) 121 | for block in search_results 122 | ] 123 | websites_cited = [ 124 | {"url": block.citations[0].url, "title": block.citations[0].title} 125 | for block in search_results 126 | if block.citations 127 | ] 128 | 129 | subtask_logger.log_result_summary( 130 | "Web Search", 131 | { 132 | "text_blocks": len(search_results), 133 | "websites_cited": len(websites_cited), 134 | "tokens_used": usage["input_tokens"] + usage["output_tokens"], 135 | }, 136 | ) 137 | 138 | return { 139 | "usage": usage, 140 | "search_results": output_text, 141 | "websites_cited": websites_cited, 142 | } 143 | elif provider in [LLMProvider.GEMINI, LLMProvider.GEMINI.value]: 144 | from google import genai 145 | from google.genai.types import ( 146 | Tool, 147 | GenerateContentConfig, 148 | GoogleSearch, 149 | ToolConfig, 150 | FunctionCallingConfig, 151 | ) 152 | 153 | client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) 154 | google_search_tool = Tool(google_search=GoogleSearch()) 155 | 156 | tracker.update(20, "Initiating Google search") 157 | subtask_logger.log_search_status(question) 158 | 159 | response = await client.aio.models.generate_content( 160 | model=model, 161 | contents=question, 162 | config=GenerateContentConfig( 163 | tools=[google_search_tool], 164 | response_modalities=["TEXT"], 165 | tool_config=ToolConfig( 166 | function_calling_config=FunctionCallingConfig( 167 | mode="ANY", 168 | ), 169 | ), 170 | ), 171 | ) 172 | tracker.update(80, "Processing search results") 173 | subtask_logger.log_subtask("Extracting grounding metadata", "processing") 174 | 175 | usage = { 176 | "input_tokens": response.usage_metadata.prompt_token_count, 177 | "thinking_tokens": response.usage_metadata.thoughts_token_count or 0, 178 | "output_tokens": response.usage_metadata.candidates_token_count, 179 | } 180 | 181 | websites_cited = [] 182 | if response.candidates: 183 | for candidate in response.candidates: 184 | if ( 185 | candidate.grounding_metadata 186 | and candidate.grounding_metadata.grounding_chunks 187 | ): 188 | for chunk in candidate.grounding_metadata.grounding_chunks: 189 | websites_cited.append( 190 | {"source": chunk.web.title, "url": chunk.web.uri} 191 | ) 192 | 193 | output_text = response.text 194 | 195 | subtask_logger.log_result_summary( 196 | "Web Search", 197 | { 198 | "websites_found": len(websites_cited), 199 | "total_tokens": usage["input_tokens"] 200 | + usage["thinking_tokens"] 201 | + usage["output_tokens"], 202 | }, 203 | ) 204 | 205 | return { 206 | "usage": usage, 207 | "search_results": output_text, 208 | "websites_cited": websites_cited, 209 | } 210 | 211 | else: 212 | raise ValueError(f"Provider {provider} not supported") 213 | -------------------------------------------------------------------------------- /defog/llm/youtube_transcript.py: -------------------------------------------------------------------------------- 1 | # converts a youtube video to a detailed, ideally diarized transcript 2 | from defog.llm.utils_logging import ToolProgressTracker, SubTaskLogger 3 | import os 4 | 5 | 6 | async def get_transcript( 7 | video_url: str, model: str = "gemini-2.5-pro-preview-06-05" 8 | ) -> str: 9 | """ 10 | Get a detailed, diarized transcript of a YouTube video. 11 | 12 | Args: 13 | video_url: The URL of the YouTube video. 14 | 15 | Returns: 16 | A detailed, ideally diarized transcript of the video. 17 | """ 18 | async with ToolProgressTracker( 19 | "YouTube Transcript", 20 | f"Transcribing video from: {video_url[:50]}{'...' if len(video_url) > 50 else ''}", 21 | ) as tracker: 22 | subtask_logger = SubTaskLogger() 23 | subtask_logger.log_provider_info("Gemini", model) 24 | 25 | if os.getenv("GEMINI_API_KEY") is None: 26 | raise ValueError("GEMINI_API_KEY is not set") 27 | 28 | from google import genai 29 | from google.genai.types import ( 30 | Content, 31 | Part, 32 | FileData, 33 | VideoMetadata, 34 | ) 35 | 36 | client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) 37 | 38 | tracker.update(10, "Processing video") 39 | subtask_logger.log_subtask( 40 | "Using low FPS (0.2) for efficient processing", "info" 41 | ) 42 | 43 | response = await client.aio.models.generate_content( 44 | model=model, 45 | contents=Content( 46 | parts=[ 47 | Part( 48 | file_data=FileData(file_uri=video_url), 49 | video_metadata=VideoMetadata(fps=0.2), 50 | ), 51 | Part( 52 | text="Please provide a detailed, accurate transcript of the video. Please include timestamps for each speaker. Do not describe the video, just create a great transcript." 53 | ), 54 | ] 55 | ), 56 | ) 57 | 58 | tracker.update(90, "Finalizing transcript") 59 | transcript_length = len(response.text) if response.text else 0 60 | 61 | subtask_logger.log_result_summary( 62 | "YouTube Transcript", 63 | { 64 | "transcript_length": f"{transcript_length} characters", 65 | "model_used": model, 66 | }, 67 | ) 68 | 69 | return response.text 70 | -------------------------------------------------------------------------------- /defog/query_methods.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from defog.query import execute_query 3 | from datetime import datetime 4 | 5 | 6 | def get_query( 7 | self, 8 | question: str, 9 | hard_filters: str = "", 10 | previous_context: list = [], 11 | glossary: str = "", 12 | debug: bool = False, 13 | dev: bool = False, 14 | temp: bool = False, 15 | profile: bool = False, 16 | ignore_cache: bool = False, 17 | model: str = "", 18 | use_golden_queries: bool = True, 19 | subtable_pruning: bool = False, 20 | glossary_pruning: bool = False, 21 | prune_max_tokens: int = 2000, 22 | prune_bm25_num_columns: int = 10, 23 | prune_glossary_max_tokens: int = 1000, 24 | prune_glossary_num_cos_sim_units: int = 10, 25 | prune_glossary_bm25_units: int = 10, 26 | ): 27 | """ 28 | Sends the query to the defog servers, and return the response. 29 | :param question: The question to be asked. 30 | :return: The response from the defog server. 31 | """ 32 | try: 33 | data = { 34 | "question": question, 35 | "api_key": self.api_key, 36 | "previous_context": previous_context, 37 | "db_type": self.db_type if self.db_type != "databricks" else "postgres", 38 | "glossary": glossary, 39 | "hard_filters": hard_filters, 40 | "dev": dev, 41 | "temp": temp, 42 | "ignore_cache": ignore_cache, 43 | "model": model, 44 | "use_golden_queries": use_golden_queries, 45 | "subtable_pruning": subtable_pruning, 46 | "glossary_pruning": glossary_pruning, 47 | "prune_max_tokens": prune_max_tokens, 48 | "prune_bm25_num_columns": prune_bm25_num_columns, 49 | "prune_glossary_max_tokens": prune_glossary_max_tokens, 50 | "prune_glossary_num_cos_sim_units": prune_glossary_num_cos_sim_units, 51 | "prune_glossary_bm25_units": prune_glossary_bm25_units, 52 | } 53 | 54 | t_start = datetime.now() 55 | r = requests.post( 56 | self.generate_query_url, 57 | json=data, 58 | timeout=300, 59 | verify=False, 60 | ) 61 | resp = r.json() 62 | t_end = datetime.now() 63 | time_taken = (t_end - t_start).total_seconds() 64 | query_generated = resp.get("sql", resp.get("query_generated")) 65 | ran_successfully = resp.get("ran_successfully") 66 | error_message = resp.get("error_message") 67 | query_db = self.db_type 68 | resp = { 69 | "query_generated": query_generated, 70 | "ran_successfully": ran_successfully, 71 | "error_message": error_message, 72 | "query_db": query_db, 73 | "previous_context": resp.get("previous_context"), 74 | "reason_for_query": resp.get("reason_for_query"), 75 | } 76 | if profile: 77 | resp["time_taken"] = time_taken 78 | 79 | return resp 80 | except Exception as e: 81 | if debug: 82 | print(e) 83 | return { 84 | "ran_successfully": False, 85 | "error_message": "Sorry :( Our server is at capacity right now and we are unable to process your query. Please try again in a few minutes?", 86 | } 87 | 88 | 89 | def run_query( 90 | self, 91 | question: str, 92 | hard_filters: str = "", 93 | previous_context: list = [], 94 | glossary: str = "", 95 | query: dict = None, 96 | retries: int = 3, 97 | dev: bool = False, 98 | temp: bool = False, 99 | profile: bool = False, 100 | ignore_cache: bool = False, 101 | model: str = "", 102 | use_golden_queries: bool = True, 103 | subtable_pruning: bool = False, 104 | glossary_pruning: bool = False, 105 | prune_max_tokens: int = 2000, 106 | prune_bm25_num_columns: int = 10, 107 | prune_glossary_max_tokens: int = 1000, 108 | prune_glossary_num_cos_sim_units: int = 10, 109 | prune_glossary_bm25_units: int = 10, 110 | ): 111 | """ 112 | Sends the question to the defog servers, executes the generated SQL, 113 | and returns the response. 114 | :param question: The question to be asked. 115 | :return: The response from the defog server. 116 | """ 117 | if query is None: 118 | print(f"Generating the query for your question: {question}...") 119 | query = self.get_query( 120 | question, 121 | hard_filters, 122 | previous_context, 123 | glossary=glossary, 124 | dev=dev, 125 | temp=temp, 126 | profile=profile, 127 | model=model, 128 | ignore_cache=ignore_cache, 129 | use_golden_queries=use_golden_queries, 130 | subtable_pruning=subtable_pruning, 131 | glossary_pruning=glossary_pruning, 132 | prune_max_tokens=prune_max_tokens, 133 | prune_bm25_num_columns=prune_bm25_num_columns, 134 | prune_glossary_max_tokens=prune_glossary_max_tokens, 135 | prune_glossary_num_cos_sim_units=prune_glossary_num_cos_sim_units, 136 | prune_glossary_bm25_units=prune_glossary_bm25_units, 137 | ) 138 | if query["ran_successfully"]: 139 | try: 140 | print("Query generated, now running it on your database...") 141 | tstart = datetime.now() 142 | colnames, result, executed_query = execute_query( 143 | query=query["query_generated"], 144 | api_key=self.api_key, 145 | db_type=self.db_type, 146 | db_creds=self.db_creds, 147 | question=question, 148 | hard_filters=hard_filters, 149 | retries=retries, 150 | dev=dev, 151 | temp=temp, 152 | base_url=self.base_url, 153 | ) 154 | tend = datetime.now() 155 | time_taken = (tend - tstart).total_seconds() 156 | resp = { 157 | "columns": colnames, 158 | "data": result, 159 | "query_generated": executed_query, 160 | "ran_successfully": True, 161 | "reason_for_query": query.get("reason_for_query"), 162 | "previous_context": query.get("previous_context"), 163 | } 164 | if profile: 165 | resp["execution_time_taken"] = time_taken 166 | resp["generation_time_taken"] = query.get("time_taken") 167 | return resp 168 | except Exception as e: 169 | return { 170 | "ran_successfully": False, 171 | "error_message": str(e), 172 | "query_generated": query["query_generated"], 173 | } 174 | else: 175 | return {"ran_successfully": False, "error_message": query["error_message"]} 176 | -------------------------------------------------------------------------------- /defog/serve.py: -------------------------------------------------------------------------------- 1 | # create a FastAPI app 2 | from fastapi import FastAPI, Request 3 | from fastapi.middleware.cors import CORSMiddleware 4 | 5 | from defog import Defog 6 | import pandas as pd 7 | import os 8 | import json 9 | from io import StringIO 10 | 11 | app = FastAPI() 12 | 13 | origins = ["*"] 14 | app.add_middleware( 15 | CORSMiddleware, 16 | allow_origins=origins, 17 | allow_credentials=True, 18 | allow_methods=["*"], 19 | allow_headers=["*"], 20 | ) 21 | 22 | home_dir = os.path.expanduser("~") 23 | defog_path = os.path.join(home_dir, ".defog") 24 | 25 | 26 | @app.get("/") 27 | async def root(): 28 | return {"message": "Hello, I am Defog"} 29 | 30 | 31 | @app.post("/generate_query") 32 | async def generate(request: Request): 33 | params = await request.json() 34 | question = params.get("question") 35 | previous_context = params.get("previous_context") 36 | defog = Defog() 37 | resp = defog.run_query(question, previous_context=previous_context) 38 | return resp 39 | 40 | 41 | @app.post("/integration/get_tables_db_creds") 42 | async def get_tables_db_creds(request: Request): 43 | try: 44 | defog = Defog() 45 | except: 46 | return {"error": "no defog instance found"} 47 | 48 | try: 49 | with open(os.path.join(defog_path, "tables.json"), "r") as f: 50 | table_names = json.load(f) 51 | except: 52 | table_names = [] 53 | 54 | try: 55 | with open(os.path.join(defog_path, "selected_tables.json"), "r") as f: 56 | selected_table_names = json.load(f) 57 | except: 58 | selected_table_names = [] 59 | 60 | db_type = defog.db_type 61 | db_creds = defog.db_creds 62 | api_key = defog.api_key 63 | 64 | return { 65 | "tables": table_names, 66 | "db_creds": db_creds, 67 | "db_type": db_type, 68 | "selected_tables": selected_table_names, 69 | "api_key": api_key, 70 | } 71 | 72 | 73 | @app.post("/integration/get_metadata") 74 | async def get_metadata(request: Request): 75 | try: 76 | with open(os.path.join(defog_path, "metadata.json"), "r") as f: 77 | metadata = json.load(f) 78 | 79 | return {"metadata": metadata} 80 | except: 81 | return {"error": "no metadata found"} 82 | 83 | 84 | @app.post("/integration/generate_tables") 85 | async def get_tables(request: Request): 86 | params = await request.json() 87 | api_key = params.get("api_key") 88 | db_type = params.get("db_type") 89 | db_creds = params.get("db_creds") 90 | for k in ["api_key", "db_type"]: 91 | if k in db_creds: 92 | del db_creds[k] 93 | 94 | defog = Defog(api_key, db_type, db_creds) 95 | table_names = defog.generate_db_schema(tables=[], return_tables_only=True) 96 | 97 | with open(os.path.join(defog_path, "tables.json"), "w") as f: 98 | json.dump(table_names, f) 99 | 100 | return {"tables": table_names} 101 | 102 | 103 | @app.post("/integration/generate_metadata") 104 | async def generate_metadata(request: Request): 105 | params = await request.json() 106 | tables = params.get("tables") 107 | 108 | with open(os.path.join(defog_path, "selected_tables.json"), "w") as f: 109 | json.dump(tables, f) 110 | 111 | defog = Defog() 112 | table_metadata = defog.generate_db_schema( 113 | tables=tables, scan=True, upload=True, return_format="csv_string" 114 | ) 115 | metadata = ( 116 | pd.read_csv(StringIO(table_metadata)).fillna("").to_dict(orient="records") 117 | ) 118 | 119 | with open(os.path.join(defog_path, "metadata.json"), "w") as f: 120 | json.dump(metadata, f) 121 | 122 | defog.update_db_schema(StringIO(table_metadata)) 123 | return {"metadata": metadata} 124 | 125 | 126 | @app.post("/integration/update_metadata") 127 | async def update_metadata(request: Request): 128 | params = await request.json() 129 | metadata = params.get("metadata") 130 | defog = Defog() 131 | metadata = pd.DataFrame(metadata).to_csv(index=False) 132 | defog.update_db_schema(StringIO(metadata)) 133 | return {"status": "success"} 134 | 135 | 136 | @app.post("/instruct/get_glossary_golden_queries") 137 | async def update_glossary(request: Request): 138 | defog = Defog() 139 | glossary = defog.get_glossary() 140 | golden_queries = defog.get_golden_queries(format="json") 141 | return {"glossary": glossary, "golden_queries": golden_queries} 142 | 143 | 144 | @app.post("/instruct/update_glossary") 145 | async def update_glossary(request: Request): 146 | params = await request.json() 147 | glossary = params.get("glossary") 148 | defog = Defog() 149 | defog.update_glossary(glossary=glossary) 150 | return {"status": "success"} 151 | 152 | 153 | @app.post("/instruct/update_golden_queries") 154 | async def update_golden_queries(request: Request): 155 | params = await request.json() 156 | golden_queries = params.get("golden_queries") 157 | golden_queries = [ 158 | x 159 | for x in golden_queries 160 | if x["sql"] != "" and x["question"] != "" and x["user_validated"] 161 | ] 162 | defog = Defog() 163 | defog.update_golden_queries(golden_queries=golden_queries) 164 | return {"status": "success"} 165 | -------------------------------------------------------------------------------- /defog/static/404.html: -------------------------------------------------------------------------------- 1 |
Ask Defog questions about your data. You have selected the following tables: