├── .gitignore ├── Dockerfile ├── README.md ├── add_tool_to_diffbot_llm_inference.md ├── chunking ├── chunk_processor.py ├── chunking_helper.py └── stop_words.txt ├── config.py ├── llm ├── api_models.py ├── diffbot_tool_call_llm.py ├── llms.py ├── openai_gpt.py ├── plugin.py ├── token_utils.py └── tool_call_llm.py ├── models ├── api.py └── utils.py ├── poetry.lock ├── pyproject.toml ├── ranking ├── README.md ├── rank_bm25.py └── wp_1gram_top1m.txt.gz ├── run_tests.sh ├── server ├── log.py ├── main.py └── rag_router.py ├── services ├── execute_js.py └── kg_rag_service.py ├── start_server.sh ├── static ├── babyshark.webp ├── demo.png ├── extract.webp ├── faa.webp ├── freshqa.png ├── math.webp ├── newjersey.webp ├── simpleqa.png ├── strawberry.webp └── weather.webp ├── supervisord.conf ├── system_prompt.txt ├── system_prompt_with_js.txt └── tests └── test_end2end.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | **/__pycache__/ 3 | /.vscode/ 4 | /.venv/ 5 | /.pytest_cache/ 6 | /.env/ 7 | /tmp/ 8 | .idea/ 9 | *.iml 10 | .cache/ 11 | data/ 12 | .DS_Store -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM vllm/vllm-openai:latest 2 | 3 | # Install required packages 4 | RUN apt-get update && apt-get install -y \ 5 | supervisor \ 6 | && apt-get clean 7 | 8 | # Copy Supervisor configuration 9 | COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf 10 | 11 | # Copy diffbot-llm code 12 | COPY . /code 13 | 14 | WORKDIR /code 15 | 16 | # Install requirements 17 | RUN pip install poetry 18 | RUN pip install poetry-plugin-export 19 | RUN pip install pyasynchat # required by supervisord 20 | RUN poetry env use python3.10 21 | RUN poetry export -f requirements.txt --output requirements.txt --without-hashes 22 | RUN poetry run pip install --no-cache-dir --upgrade -r /code/requirements.txt 23 | 24 | # Expose ports 25 | EXPOSE 3333 8000 26 | 27 | # Start Supervisor 28 | ENTRYPOINT ["/usr/bin/supervisord"] 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffbot GraphRAG LLM 2 | 3 | ## 1. Introduction 4 | 5 | Recently, large language models (LLMs) have been trained with more and more data, leading to an increase in the number of parameters and the compute power needed. But what if, instead of feeding the model more data, we purposefully trained it to rely less on its pretraining data and more on it's ability to find external knowledge? 6 | 7 | To test this idea, we fine-tuned LLama 3.3 70B to be an expert tool user of a real-time Knowledge Graph API, providing the first open-source implementation of a GraphRAG system that outperforms Google Gemini and ChatGPT. 8 | 9 | ## 2. Features 10 | 11 | ## Real-time web URL extraction 12 | 13 | ![extract example](./static/extract.webp) 14 | 15 | As a RAG system, Diffbot LLM can summarize a web document in real-time, appropriately crediting the original source. 16 | 17 | ## Expert Retriever of Factual citations 18 | 19 | ![Mission statement of the FAA](./static/faa.webp) 20 | 21 | Diffbot LLM is explicitly trained to align the cited text with the reference source. 22 | 23 | ## Knowledge Graph Querying 24 | 25 | ![which state contains J?](./static/newjersey.webp) 26 | 27 | Diffbot LLM is an expert tool user of the Diffbot (Knowledge Graph) Query Language. 28 | 29 | ## Image Entailment 30 | 31 | ![How to draw baby shark](./static/babyshark.webp) 32 | 33 | Diffbot LLM an also entail images. 34 | 35 | ## Code Interpreter Tool Use 36 | 37 | ![strawberry problem](./static/strawberry.webp) 38 | 39 | 40 | Instead of relying on the model weights for performing empirical calculations, Diffbot LLM is an expert tool user of a Javascript interpreter that it can use to inform it's response. 41 | 42 | ![is 9.11 or 9.9 larger](./static/math.webp) 43 | 44 | ## Fun stuff 45 | 46 | ![weather in Menlo park](./static/weather.webp) 47 | 48 | Diffbot LLM is an expert maker of ASCII-art weather forecasts, grounded in real sources. 49 | 50 | ## 3. Model Download 51 | 52 | Available on HuggingFace at: 53 | * diffbot-small (8b Llama 3.1 fine tune): https://huggingface.co/diffbot/Llama-3.1-Diffbot-Small-2412 54 | * diffbot-small-xl (70b Llama 3.3 fine tune): https://huggingface.co/diffbot/Llama-3.3-Diffbot-Small-XL-2412 55 | 56 | ## 4. Accuracy Benchmarks 57 | 58 | ### FreshQA Dataset 59 | 60 | ![Accuracy for FreshQA 2024 queries](./static/freshqa.png) 61 | 62 | [FreshQA](https://arxiv.org/abs/2310.03214) is a benchmark that measures real-time accuracy for search RAG systems. Diffbot LLM outperforms gpt-4o (no web access), ChatGPT (with web access), Google Gemini, and Perplexity on real-time factual accuracy. 63 | 64 | In this evaluation, we focus on 130 FreshQA questions whose answer have changed in 2024, which is after the knowledge 65 | cutoff for all evaluated models as of December 2024. 66 | 67 | ### MMLU-Pro 68 | 69 | [MMLU-Pro](https://arxiv.org/abs/2406.01574) is a more difficult version of the [MMLU](https://arxiv.org/abs/2009.03300) benchmark that tests for static knowledge of 57 academic subjects using a 10-choice multiple-choice questions. [MMLU-Pro Leaderboard](https://huggingface.co/spaces/TIGER-Lab/MMLU-Pro). 70 | 71 | Below shows the MMLU-Pro scores of diffbot-small and diffbot-small-xl over the base models it was fine-tuned from. 72 | 73 | | Model | Accuracy (CoT 5-shot) | 74 | | ----- | ----------------- | 75 | | diffbot-small-xl | 72.89 | 76 | | Llama-3.3-70B Instruct | 65.92 | 77 | 78 | | Model | Accuracy (CoT 5-shot) | 79 | | ----- | ----------------- | 80 | | diffbot-small | 48.64 | 81 | | Llama-3.1-8B Instruct | 44.25 | 82 | 83 | Note: This is a measurement of the Diffbot GraphRAG LLM API end-to-end, not a measure of the knowledge contained in the weights. The lift in its performance over the base model comes from its ability to access external tools. 84 | 85 | ### SimpleQA 86 | 87 | ![SimpleQA benchmark evals across several models](./static/simpleqa.png) 88 | 89 | [SimpleQA](https://openai.com/index/introducing-simpleqa/) is OpenAI's factuality benchmark focused on short, fact-seeking queries. It contains over 4000 questions covering science and technology to TV shows and video games. 90 | 91 | Diffbot's 70b model outperforms every other evaluated model to date, including internet connected models like Perplexity Sonar Pro and Gemini-2.0-flash. We attribute this performance to its deep-rooted training to distrust its own knowledge. 92 | 93 | 94 | ## 5. Demo 95 | 96 | Try Diffbot LLM using the demo app at https://diffy.chat 97 | 98 | ## 6. Running Locally 99 | 100 | Tested minimum hardware configurations: 101 | 102 | - Nvidia A100 40G for diffbot-small 103 | - Nvidia 2XH100 80G for diffbot-small-xl @ FP8 104 | 105 | Using Docker image and models in huggingface 106 | 1. Pull docker image: `docker pull docker.io/diffbot/diffbot-llm-inference:latest` 107 | 2. Run docker image. **Note: The model weights will be automatically downloaded from huggingface. 108 | This might take a few minutes.** 109 | 110 | Model: diffbot-small 111 | ```bash 112 | docker run --runtime nvidia --gpus all -p 8001:8001 --ipc=host -e VLLM_OPTIONS="--model diffbot/Llama-3.1-Diffbot-Small-2412 --served-model-name diffbot-small --enable-prefix-caching" docker.io/diffbot/diffbot-llm-inference:latest 113 | ``` 114 | 115 | Model: diffbot-small-xl 116 | ```bash 117 | docker run --runtime nvidia --gpus all -p 8001:8001 --ipc=host -e VLLM_OPTIONS="--model diffbot/Llama-3.3-Diffbot-Small-XL-2412 --served-model-name diffbot-small-xl --enable-prefix-caching --quantization fp8 --tensor-parallel-size 2" docker.io/diffbot/diffbot-llm-inference:latest 118 | ``` 119 | 120 | The Diffbot server leverages vLLM to serve the model, and it is ready to receive requests once vLLM outputs the following message: 121 | ``` 122 | INFO: Application startup complete. 123 | INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) 124 | ``` 125 | 126 | You can now use the endpoint `http://localhost:8001/rag/v1`, which works exactly like the Serverless API below. 127 | 128 | ## 7. Using the Serverless API 129 | 130 | You can also access the Diffbot LLM on the servless API below. The serverless API follows a Zero Data Retention policy. User data is not stored or retained after it has been processed. 131 | 132 | Get a free Diffbot developer token at https://app.diffbot.com/get-started 133 | 134 | ```python 135 | from openai import OpenAI 136 | 137 | client = OpenAI( 138 | base_url = "https://llm.diffbot.com/rag/v1", 139 | api_key = "" 140 | ) 141 | 142 | completion = client.chat.completions.create( 143 | model="diffbot-small-xl", 144 | temperature=0, 145 | messages=[ 146 | { 147 | "role": "user", 148 | "content": "What is the Diffbot Knowledge Graph?" 149 | } 150 | ] 151 | ) 152 | print (completion) 153 | ``` 154 | Contact support@diffbot.com if need more credits or higher limits. 155 | 156 | ## 8. Adding Custom Tools 157 | 158 | To extend the Diffbot LLM Inference Server with new tools, please refer to [this tutorial](add_tool_to_diffbot_llm_inference.md). 159 | -------------------------------------------------------------------------------- /add_tool_to_diffbot_llm_inference.md: -------------------------------------------------------------------------------- 1 | # How to Add a Tool to Diffbot LLM Inference 2 | 3 | We discuss how to add a tool to Diffbot LLM Inference by showing how to 4 | add the tool `execute_js_v1` for javascript code execution. 5 | 6 | ### 0. Set up local development 7 | 8 | To set up the virtual environment: 9 | 10 | ``` 11 | poetry env use python3.10 12 | poetry shell 13 | poetry install 14 | ``` 15 | 16 | To start vLLM: 17 | 18 | Self-host one the Diffbot LLM models with docker (see [Self-Hosting](README.md)) and add "-p 8000:8000" to expose 19 | the vLLM endpoint. Set the vLLM endpoint in config.py. 20 | 21 | To start the server: `./start_server.sh` 22 | 23 | ### 1. Add the new tool to the system prompt (`system_prompt.txt`). 24 | 25 | Below is the original system prompt, which includes the definition of the 26 | available tools in javascript. 27 | ``` 28 | You are a helpful assistant with access to the following functions. Use them if required - 29 | namespace Diffbot { 30 | // Extract the content from the given URLs. Only call this endpoint if the user mentioned a URL. 31 | type extract_v1 = (_: { 32 | // URLs to extract, up to 5 33 | page_url: string[], 34 | }) => any; 35 | // Query the Diffbot Knowledge Graph for an entity or set of entities that match a set of criteria using the Diffbot Query Language syntax. 36 | type dql_v1 = (_: { 37 | // Diffbot Query Language query 38 | dql_query: string, 39 | }) => any; 40 | // Search the web for information that could help answer the user's question. 41 | type web_search_v1 = (_: { 42 | // List of Google advanced search strings (can include phrases, booleans, site:, before:, after:, filetype:, etc) 43 | text: string[], 44 | // Number of results to return (default 5) 45 | num?: number, 46 | // Page number of results to return (default 1) 47 | page?: number, 48 | }) => any; 49 | } // namespace Diffbot 50 | ``` 51 | 52 | To add the tool `execute_js_v1`, we can add the following lines as the last tool: 53 | 54 | ``` 55 | // Execute JavaScript expressions and get accurate results that could help answer the user's question. 56 | type execute_js_v1 = (_: { 57 | // JavaScript expressions to execute separated by newlines 58 | expressions: string, 59 | }) => any; 60 | ``` 61 | 62 | The final result is: 63 | 64 | ``` 65 | You are a helpful assistant with access to the following functions. Use them if required - 66 | namespace Diffbot { 67 | // Extract the content from the given URLs. Only call this endpoint if the user mentioned a URL. 68 | type extract_v1 = (_: { 69 | // URLs to extract, up to 5 70 | page_url: string[], 71 | }) => any; 72 | // Query the Diffbot Knowledge Graph for an entity or set of entities that match a set of criteria using the Diffbot Query Language syntax. 73 | type dql_v1 = (_: { 74 | // Diffbot Query Language query 75 | dql_query: string, 76 | }) => any; 77 | // Search the web for information that could help answer the user's question. 78 | type web_search_v1 = (_: { 79 | // List of Google advanced search strings (can include phrases, booleans, site:, before:, after:, filetype:, etc) 80 | text: string[], 81 | // Number of results to return (default 5) 82 | num?: number, 83 | // Page number of results to return (default 1) 84 | page?: number, 85 | }) => any; 86 | // Execute JavaScript expressions and get accurate results that could help answer the user's question. 87 | type execute_js_v1 = (_: { 88 | // JavaScript expressions to execute separated by newlines 89 | expressions: string, 90 | }) => any; 91 | } // namespace Diffbot 92 | ``` 93 | 94 | ### 2. Implement the new tool 95 | 96 | See `services/execute_js.py` for the implementation of `execute_js_v1`. 97 | 98 | ### 3. Call the tool in llm/plugin.py 99 | 100 | The `invoke` method is responsible for calling tools requested by the LLM. To call the new tool, we can add the 101 | following lines to this method: 102 | 103 | ```python 104 | if function_name == "execute_js_v1": 105 | resp = await get_js_execution_service().execute_js(function_arguments["expressions"]) 106 | return PluginResponse( 107 | plugin_url=function_name, method="INTERNAL", content=resp.json() 108 | ) 109 | ``` 110 | 111 | where `get_js_execution_service().execute_js()` calls the implementation for this new tool. -------------------------------------------------------------------------------- /chunking/chunk_processor.py: -------------------------------------------------------------------------------- 1 | import json, re 2 | import time 3 | from typing import List, Optional, Dict, Any 4 | from collections import deque 5 | from llm.api_models import ChatCompletionRequestMessage 6 | from llm.llms import Role 7 | from chunking.chunking_helper import Chunk, get_similarity_calculator 8 | from models.utils import truncate_long_strings, truncate_long_arrays, truncate_data_dfs 9 | from server.log import get_logstash_logger 10 | 11 | logger = get_logstash_logger("chunk_processor") 12 | 13 | 14 | class DataChunker: 15 | IDENTIFIER_KEYS = ['name', 'title', 'pageUrl', 'date', 'dql_query'] 16 | 17 | @staticmethod 18 | def tokenize(json_obj: Any) -> int: 19 | # estimate size instead of expensive json serializing 20 | if isinstance(json_obj, dict): 21 | return sum((DataChunker.tokenize(k) + DataChunker.tokenize(v)) for k, v in json_obj.items()) 22 | elif isinstance(json_obj, list): 23 | return sum(DataChunker.tokenize(item) for item in json_obj) 24 | elif isinstance(json_obj, str): 25 | return len(json_obj) 26 | else: 27 | return 1 28 | 29 | @staticmethod 30 | def set_nested_dict(d: Dict[str, Any], path: List[str], value: Any) -> None: 31 | if not path: 32 | return 33 | for key in path[:-1]: 34 | d = d.setdefault(key, {}) 35 | d[path[-1]] = value 36 | 37 | @staticmethod 38 | def get_min_chunk_size(max_chunk_size: int): 39 | return max(max_chunk_size - 1000, 50) 40 | 41 | @staticmethod 42 | def filter_urls(markdown_text, max_length=150): 43 | url_pattern = re.compile(r'\[\[?([^\]]+)\]?\]\(([^)]+)\)') 44 | 45 | def replace_long_urls(match): 46 | text = match.group(1) 47 | url = match.group(2) 48 | if len(url) > max_length: 49 | return text 50 | elif '#cite_note' in url: 51 | return '' 52 | else: 53 | return match.group(0) 54 | 55 | return url_pattern.sub(replace_long_urls, markdown_text) 56 | 57 | @staticmethod 58 | def chunk_text(text: str, max_string_chunk_size: int) -> List[str]: 59 | text = DataChunker.filter_urls(text) 60 | min_string_chunk_size = DataChunker.get_min_chunk_size(max_string_chunk_size) 61 | separators = ["\n\n##", "\n\n", "\n"] 62 | text_chunks = [] 63 | start_idx = 0 64 | text_length = len(text) 65 | 66 | while start_idx < text_length: 67 | end_idx = min(start_idx + max_string_chunk_size + min_string_chunk_size, text_length) 68 | best_split_idx = None 69 | start_find_index = start_idx + min_string_chunk_size 70 | if start_find_index < end_idx: 71 | 72 | for sep in separators: 73 | sep_idx = text.rfind(sep, start_find_index, end_idx) 74 | if sep_idx != -1: 75 | best_split_idx = sep_idx 76 | break 77 | 78 | if best_split_idx is None: 79 | best_split_idx = end_idx 80 | 81 | text_chunks.append(text[start_idx:best_split_idx].strip()) 82 | start_idx = best_split_idx + ( 83 | len(sep) if best_split_idx < text_length and sep in text[best_split_idx:] else 0) 84 | 85 | return text_chunks 86 | 87 | @staticmethod 88 | def wrap_chunks(chunks: List[Any]) -> List[Chunk]: 89 | start = time.time() 90 | chunk_objects = [] 91 | for chunk in chunks: 92 | if isinstance(chunk, dict) or isinstance(chunk, list): 93 | chunk_str = json.dumps(chunk, sort_keys=True) 94 | else: 95 | chunk_str = chunk 96 | chunk_objects.append(Chunk(text=chunk_str, size=DataChunker.tokenize(chunk))) 97 | # print("Time to wrap chunks: ", (time.time() - start) * 1000) 98 | return chunk_objects 99 | 100 | @staticmethod 101 | def set_identifiers(chunk: dict, identifiers: dict): 102 | if not identifiers: 103 | return 104 | # each chunk has the identifier values e.g. "name": ".." in case the data is divided across multiple chunks 105 | for id_key, id_value in identifiers.items(): 106 | # id_value[1] contains the current path used to get the nested path of the value 107 | DataChunker.set_nested_dict(chunk, id_value[1] + [id_key], id_value[0]) 108 | 109 | @staticmethod 110 | def convert_list_to_dict(data: Any) -> Any: 111 | # simplify list processing by converting to dict 112 | if isinstance(data, dict): 113 | return {k: DataChunker.convert_list_to_dict(v) for k, v in data.items()} 114 | elif isinstance(data, list): 115 | return { 116 | str(i): DataChunker.convert_list_to_dict(item) 117 | for i, item in enumerate(data) 118 | } 119 | else: 120 | return data 121 | 122 | @staticmethod 123 | def convert_dict_to_list(data: Any) -> Any: 124 | if isinstance(data, dict): 125 | try: 126 | sorted_keys = sorted(int(k) for k in data.keys()) 127 | return [ 128 | DataChunker.convert_dict_to_list(data[str(i)]) 129 | if str(i) in data else None for i in sorted_keys 130 | ] 131 | except ValueError: 132 | return {k: DataChunker.convert_dict_to_list(v) for k, v in data.items()} 133 | else: 134 | return data 135 | 136 | @staticmethod 137 | def chunk_object(data: Dict[str, Any], 138 | max_chunk_size: int, min_chunk_size: int, 139 | path: Optional[List[str]] = None, 140 | chunks: Optional[List[Any]] = None, 141 | identifiers: Optional[Dict] = None): 142 | path = path or [] 143 | chunks = chunks or [{}] 144 | 145 | if isinstance(data, dict): 146 | if not identifiers: 147 | identifiers = {k: (v, path) for k, v in data.items() if k in DataChunker.IDENTIFIER_KEYS} 148 | 149 | for key, value in data.items(): 150 | current_path = path + [key] 151 | chunk_size = DataChunker.tokenize(chunks[-1]) 152 | size = DataChunker.tokenize({key: value}) 153 | remaining = max_chunk_size - chunk_size 154 | 155 | if size < remaining: 156 | DataChunker.set_nested_dict(chunks[-1], current_path, value) 157 | DataChunker.set_identifiers(chunks[-1], identifiers) 158 | else: 159 | if chunk_size >= min_chunk_size: 160 | new_chunk = {} 161 | DataChunker.set_identifiers(new_chunk, identifiers) 162 | chunks.append(new_chunk) 163 | DataChunker.chunk_object(value, max_chunk_size, min_chunk_size, current_path, chunks, identifiers) 164 | 165 | elif isinstance(data, str): 166 | data = DataChunker.filter_urls(data) 167 | 168 | # the chunk size reduced by approximated path lengths 169 | max_string_chunk_size = max_chunk_size - len(path) * 6 170 | start = time.time() 171 | string_chunks = DataChunker.chunk_text(data, max_string_chunk_size=max_string_chunk_size) 172 | # print("chunk_text: ", (time.time()-start)*1000) 173 | if chunks[-1]: 174 | chunks.append({}) 175 | for chunk in string_chunks: 176 | DataChunker.set_identifiers(chunks[-1], identifiers) 177 | if path: 178 | DataChunker.set_nested_dict(chunks[-1], path, chunk) 179 | else: 180 | chunks[-1] = chunk 181 | chunks.append({}) 182 | if not chunks[-1]: 183 | chunks.pop() 184 | else: 185 | DataChunker.set_nested_dict(chunks[-1], path, data) 186 | return chunks 187 | 188 | @staticmethod 189 | def chunk_data(message_input: Any, max_chunk_size: int = 5000) -> List[Any]: 190 | # Limits to keep truncation time from being unreasonable. We should increase these limits as we improve 191 | # the truncation performance. 192 | max_string_length = 500_000 193 | max_array_length = 100 194 | max_json_length = 1_000_000 195 | max_chunks = 500 196 | try: 197 | start = time.time() 198 | if isinstance(message_input, str) and (message_input.startswith("{") or message_input.startswith("[")): 199 | # TODO: loading the whole json takes ~70ms for JSON string with 40M+ chars. 200 | # Try json streaming to stop reading after a particular limit? 201 | message_input = json.loads(message_input) 202 | # print("time to load json: ", (time.time() - start)*1000) 203 | start = time.time() 204 | message_input = truncate_long_strings(message_input, max_string_length=max_string_length) 205 | message_input = truncate_long_arrays(message_input, max_array_length=max_array_length) 206 | message_input, _ = truncate_data_dfs(message_input, max_length=max_json_length) 207 | # print("dumb truncation: ", (time.time() - start) * 1000) 208 | start = time.time() 209 | message_input = DataChunker.convert_list_to_dict(message_input) 210 | # print("convert list to dict: ", (time.time() - start) * 1000) 211 | start = time.time() 212 | processed_chunks = DataChunker.chunk_object(message_input, max_chunk_size=max_chunk_size, 213 | min_chunk_size=DataChunker.get_min_chunk_size(max_chunk_size)) 214 | processed_chunks = processed_chunks[:max_chunks] 215 | # print("chunk_object: ", (time.time()-start)*1000) 216 | except (TypeError, NameError, ValueError, OverflowError) as e: 217 | start = time.time() 218 | logger.error(f"Exception during chunk_data: {e}", exc_info=True) 219 | if not message_input: 220 | message_input = "" 221 | if not isinstance(message_input, str): 222 | message_input = str(message_input) 223 | message_input = message_input[:max_string_length] 224 | processed_chunks = DataChunker.chunk_text(message_input, max_string_chunk_size=max_chunk_size) 225 | # print("chunk exception: ", (time.time() - start) * 1000) 226 | return DataChunker.wrap_chunks(processed_chunks) 227 | 228 | 229 | class ChunkProcessor: 230 | def __init__(self): 231 | self.data_chunker = DataChunker() 232 | 233 | @staticmethod 234 | def get_last_user_message(messages): 235 | last_user_message_index = None 236 | last_user_message = None 237 | for i in range(len(messages) - 1, -1, -1): 238 | # make sure that the user message is not a tool response 239 | if messages[i].role == Role.user and ( 240 | i == 0 or ('' not in messages[i - 1].content or messages[i - 1].role == 'system')) and \ 241 | not messages[i].content.startswith('{"status": '): 242 | last_user_message_index = i 243 | last_user_message = messages[i] 244 | break 245 | return last_user_message_index, last_user_message 246 | 247 | def merge_chunks(self, chunks): 248 | def merge_json(json1, json2): 249 | merged = {} 250 | seen_key = set() 251 | # keep order of keys 252 | key_list = [] 253 | key_list.extend([k for k, v in json1.items()]) 254 | key_list.extend([k for k, v in json2.items()]) 255 | for key in key_list: 256 | if key in seen_key: 257 | continue 258 | seen_key.add(key) 259 | if key in json1 and key in json2: 260 | if isinstance(json1[key], dict) and isinstance(json2[key], dict): 261 | if json1[key] == json2[key]: 262 | merged[key] = json1[key] 263 | else: 264 | merged[key] = merge_json(json1[key], json2[key]) 265 | elif isinstance(json1[key], str) and isinstance(json2[key], str): 266 | if json1[key] == json2[key]: 267 | merged[key] = json1[key] 268 | else: 269 | merged[key] = f"{json1[key]} {json2[key]}" 270 | else: 271 | merged[key] = json2[key] 272 | elif key in json1: 273 | merged[key] = json1[key] 274 | else: 275 | merged[key] = json2[key] 276 | return merged 277 | 278 | def try_parse_json(s): 279 | try: 280 | if not s.endswith('}'): 281 | return s, False 282 | return json.loads(s), True 283 | except json.JSONDecodeError: 284 | return s, False 285 | 286 | merged_json = {} 287 | merged_str = "" 288 | 289 | for s in chunks: 290 | parsed, is_json = try_parse_json(s) 291 | if is_json: 292 | if merged_json: 293 | merged_json = merge_json(merged_json, parsed) 294 | else: 295 | merged_json = parsed 296 | else: 297 | if merged_str: 298 | merged_str += "\n" + s 299 | else: 300 | merged_str = s 301 | 302 | if merged_json: 303 | if merged_str: 304 | return json.dumps(merged_json) + "\n" + merged_str 305 | else: 306 | return json.dumps(DataChunker.convert_dict_to_list(merged_json)) 307 | else: 308 | return merged_str 309 | 310 | # this method should take the message array as input 311 | # iterate messages in the order - last user message 312 | # Chunk each message content 313 | # Compute similarity of all chunks with the last user role content and pick the most similar chunks 314 | def process_messages(self, messages: List[ChatCompletionRequestMessage], max_tokens: int, 315 | max_tokens_last_tool: int, model_input_token_limit: int = 12000, log_ctx: dict = {}) -> List[ 316 | ChatCompletionRequestMessage]: 317 | max_size = max_tokens * 4 318 | max_size_last_tool = max_tokens_last_tool * 4 319 | model_input_token_limit_size = model_input_token_limit * 4 320 | 321 | if max_size_last_tool > max_size: 322 | raise Exception("max_size_last_tool cannot be greater than max_size") 323 | max_chunk_size = 2500 324 | log_ctx["max_size"] = max_size 325 | log_ctx["max_chunk_size"] = max_chunk_size 326 | try: 327 | 328 | # returns immediately in case under limit: identity function 329 | total_content_size = sum(len(message.content) for message in messages) 330 | if total_content_size < max_size: 331 | return messages 332 | 333 | _, last_user_message = self.get_last_user_message(messages) 334 | 335 | # TODO: this can potentially return more tokens than max_tokens. Write an alternative algorithm 336 | # (e.g., latest messages)? 337 | if not last_user_message: 338 | log_ctx["error"] = "could not find last user message" 339 | return messages 340 | 341 | # no processing possible if the user message is too long; return messages, 342 | # UI should return error 343 | if len(last_user_message.content) >= max_size: 344 | # TODO: truncate last user message by keeping beginning and end, while ignoring the middle? 345 | log_ctx["error"] = "user message is too long" 346 | return messages 347 | 348 | start_chunking = time.time() 349 | chunks_role_list = [] 350 | last_user_message_index = None 351 | system_message = None 352 | for message in messages: 353 | if message.role == Role.system: 354 | system_message = message 355 | elif message == last_user_message: 356 | # include the last user message in the 357 | chunk = Chunk(text=message.content, size=len(message.content), include_in_request=True, 358 | similarity=0) 359 | chunks_role_list.append(([chunk], message.role, False)) 360 | else: 361 | is_intext_functioncall = False 362 | if message.role == Role.assistant and message.content.startswith(''): 363 | is_intext_functioncall = False 364 | elif message.role == Role.assistant and not message.content.startswith( 365 | '') and '' in message.content: 366 | # flag to identify if there is only functioncall in the assistant message or functioncall exists in text with response 367 | is_intext_functioncall = True 368 | chunks_role_list.append( 369 | (self.chunk_data(message.content, max_chunk_size), message.role, is_intext_functioncall)) 370 | 371 | all_chunks: List[Chunk] = [] 372 | for chunk_list, role, is_intext_functioncall in chunks_role_list: 373 | for chunk in chunk_list: 374 | chunk.role = role 375 | chunk.is_intext_functioncall = is_intext_functioncall 376 | all_chunks.append(chunk) 377 | if last_user_message.content == chunk.text: 378 | last_user_message_index = len(all_chunks) - 1 379 | 380 | if last_user_message_index is None: 381 | last_user_message_index = len(all_chunks) 382 | log_ctx["chunking_time"] = round((time.time() - start_chunking) * 1000) 383 | log_ctx["num_chunks"] = len(all_chunks) 384 | # print("chunking_time: ", log_ctx["chunking_time"]) 385 | # print("num_chunks: ", log_ctx["num_chunks"]) 386 | 387 | start_similarity = time.time() 388 | similarity_calculator = get_similarity_calculator() 389 | similarity_calculator.calculate_similarity(last_user_message.content, all_chunks, explain=False) 390 | log_ctx["similarity_time"] = round((time.time() - start_similarity) * 1000) 391 | # print("similarity_time: ", log_ctx["similarity_time"]) 392 | 393 | last_assistant_message_index = -1 394 | last_tool_messsage_index = -1 395 | last_function_call_index = -1 396 | prev_role = None 397 | 398 | # Hardcode the similarity to 1 the beginning of any assistant message after that (not functioncalls) to prioritize these 399 | for i in range(last_user_message_index, len(all_chunks)): 400 | # Keep the assistant message which does not start with functioncall or has intext functioncall 401 | if all_chunks[i].role == Role.assistant and ( 402 | not all_chunks[i].text.startswith('') or all_chunks[i].is_intext_functioncall): 403 | all_chunks[i].similarity = 1 404 | if prev_role != all_chunks[i].role: # first chunk of this assistant language turn 405 | last_assistant_message_index = i 406 | if all_chunks[i].role == Role.assistant and prev_role != all_chunks[i].role: 407 | last_function_call_index = i 408 | if all_chunks[i].role != Role.assistant and prev_role != all_chunks[i].role: 409 | last_tool_messsage_index = i 410 | prev_role = all_chunks[i].role 411 | # add tool call + response pairs in case of interleaved calls into history messages, only keep the last pair 412 | # move old interleaving messages to history 413 | old_interleaving_chunks = [] 414 | current_chunks = all_chunks[last_user_message_index:] 415 | if last_assistant_message_index != -1: 416 | old_interleaving_chunks = all_chunks[last_user_message_index + 1:last_assistant_message_index] 417 | current_chunks = [all_chunks[last_user_message_index]] + all_chunks[last_assistant_message_index:] 418 | 419 | final_messages: deque[ChatCompletionRequestMessage] = deque() 420 | total_length_so_far = 0 421 | if system_message: 422 | total_length_so_far += len(system_message.content) 423 | 424 | expand_context_size, max_size_last_tool = self.expand_last_tool_context(all_chunks, 425 | last_function_call_index, 426 | last_tool_messsage_index, 427 | max_size_last_tool, 428 | model_input_token_limit_size) 429 | 430 | # choose current chunks, that is, those containing the last user message and last tool response 431 | # use at least half of the remaining room with current chunks 432 | min_length = total_length_so_far + round((max_size_last_tool - total_length_so_far) / 2) 433 | 434 | total_length_so_far = self.convert_best_chunks_to_messages(final_messages, current_chunks, 435 | total_length_so_far, min_length, 436 | max_size_last_tool, 437 | similarity_calculator.threshold) 438 | if expand_context_size > 0: 439 | # adjust the max size to make up for the expansion of the tool messages size 440 | max_size += min(total_length_so_far, expand_context_size) 441 | max_size = min(max_size, model_input_token_limit_size) 442 | 443 | # choose history messages 444 | history_chunks = all_chunks[:last_user_message_index] + old_interleaving_chunks 445 | history_chunks = [c for c in history_chunks if c.role != Role.system] 446 | # inverse history chunks so that we prioritize those close to the last user message 447 | history_chunks = history_chunks[::-1] 448 | history_messages: list[ChatCompletionRequestMessage] = list() 449 | self.convert_best_chunks_to_messages(history_messages, history_chunks, 450 | total_length_so_far, min_length=max_size, 451 | max_length=max_size, 452 | threshold=similarity_calculator.threshold) 453 | 454 | for index in range(0, len(history_messages), 2): 455 | remaining = max_size - total_length_so_far 456 | if index + 1 < len(history_messages): 457 | content_length = len(history_messages[index].content) + len(history_messages[index + 1].content) 458 | else: 459 | content_length = len(history_messages[index].content) 460 | if content_length >= remaining: 461 | break 462 | # adding history messages in pairs to avoid errors 463 | final_messages.appendleft(ChatCompletionRequestMessage(content=history_messages[index].content, 464 | role=history_messages[index].role)) 465 | if index + 1 < len(history_messages): 466 | final_messages.appendleft(ChatCompletionRequestMessage(content=history_messages[index + 1].content, 467 | role=history_messages[index + 1].role)) 468 | total_length_so_far += content_length 469 | if system_message: 470 | final_messages.appendleft( 471 | ChatCompletionRequestMessage(content=system_message.content, role=Role.system)) 472 | return list(final_messages) 473 | 474 | except Exception as e: 475 | logger.error(f"Exception while truncating messages: {e}", exc_info=True) 476 | return self.backup_process_messages(messages, max_size=max_size, max_chunk_size=max_chunk_size, 477 | max_chunks=20) 478 | 479 | def expand_last_tool_context(self, all_chunks, last_function_call_index, last_tool_messsage_index, 480 | max_size_last_tool, model_input_token_limit_size) -> int: 481 | expand_needed = False 482 | if last_tool_messsage_index != -1 and last_function_call_index != -1: 483 | # merge the last assistant message into a single message 484 | last_tool_message_content = [] 485 | for i in range(last_function_call_index, last_tool_messsage_index): 486 | last_tool_message_content.append(all_chunks[i].text) 487 | last_tool_message_content = self.merge_chunks(last_tool_message_content) 488 | if last_tool_message_content.count('') == 1: 489 | functioncall = last_tool_message_content.index('') 490 | functioncall_content = last_tool_message_content[functioncall + len(''):].strip() 491 | if functioncall_content.startswith('{') and functioncall_content.endswith('}'): 492 | try: 493 | # check if the functioncall is for extract_v1 494 | functioncall_content = json.loads(functioncall_content) 495 | if functioncall_content['name'] == 'extract_v1': 496 | expand_needed = True 497 | except: 498 | pass 499 | expand_content_size = 0 500 | if expand_needed: 501 | # increase the max size of the last tool response to model limit/2 502 | new_max_size_last_tool = model_input_token_limit_size // 2 503 | expand_content_size = new_max_size_last_tool - max_size_last_tool 504 | expand_content_size = 0 if expand_content_size < 0 else expand_content_size 505 | if new_max_size_last_tool > max_size_last_tool: 506 | max_size_last_tool = new_max_size_last_tool 507 | 508 | return expand_content_size, max_size_last_tool 509 | 510 | def convert_best_chunks_to_messages(self, messages, chunks, total_length_so_far, min_length, max_length, threshold): 511 | self.choose_chunks(chunks, max_length, min_length, threshold, total_length_so_far) 512 | # collect messages 513 | first_chunk_of_message = None 514 | combined_text = [] 515 | for i in range(0, len(chunks)): 516 | content_length = len(chunks[i].text) 517 | if content_length > (max_length - total_length_so_far): 518 | break 519 | if chunks[i].include_in_request: 520 | combined_text.append(chunks[i].text) 521 | total_length_so_far += content_length 522 | # store the first chunk of a message for future use 523 | if i == 0 or chunks[i].role != chunks[i - 1].role: 524 | first_chunk_of_message = i 525 | 526 | is_last_chunk = (i + 1 >= len(chunks) # last chunk overall 527 | or chunks[i].role != chunks[i + 1].role # last chunk for this role 528 | or len(chunks[i + 1].text) > (max_length - total_length_so_far) # next chunk doesn't fit 529 | ) 530 | if is_last_chunk: 531 | # this is the last chunk of a message, merge chosen chunks 532 | if not combined_text: 533 | # if no chunk was selected via similarity for current message, chose the first chunk. 534 | combined_text = [chunks[first_chunk_of_message].text] 535 | total_length_so_far += len(chunks[first_chunk_of_message].text) 536 | messages.append(ChatCompletionRequestMessage(content=self.merge_chunks(combined_text), 537 | role=chunks[i].role)) 538 | combined_text = [] 539 | 540 | return total_length_so_far 541 | 542 | def choose_chunks(self, chunks, max_length, min_length, threshold, total_length_so_far): 543 | estimated_remaining = max_length - total_length_so_far 544 | # sort chunks by similarity 545 | sorted_indices = sorted(range(len(chunks)), key=lambda k: chunks[k].similarity, reverse=True) 546 | min_length = min_length 547 | for i, index in enumerate(sorted_indices): 548 | content_length = len(chunks[index].text) 549 | if content_length > estimated_remaining: 550 | break 551 | if chunks[index].similarity > threshold or (max_length - estimated_remaining) < min_length: 552 | chunks[index].include_in_request = True 553 | estimated_remaining = estimated_remaining - content_length 554 | 555 | def chunk_data(self, data, max_chunk_size): 556 | return self.data_chunker.chunk_data(data, max_chunk_size=max_chunk_size) 557 | 558 | def backup_process_messages(self, messages: List[ChatCompletionRequestMessage], max_size: int, max_chunk_size: int, 559 | max_chunks: int) -> List[ChatCompletionRequestMessage]: 560 | 561 | processed_messages: deque[ChatCompletionRequestMessage] = deque() 562 | 563 | last_user_message_index, _ = self.get_last_user_message(messages) 564 | 565 | # TODO: this can potentially return more tokens than max_tokens 566 | if last_user_message_index is None: 567 | return messages 568 | 569 | current_size = 0 570 | system_message = None 571 | if len(messages) > 0 and messages[0].role == Role.system: 572 | current_size = len(messages[0].content) 573 | system_message = messages[0] 574 | 575 | # first should be the last user message, all following messages, then preceding history messages till the limit is reached 576 | for i in range(last_user_message_index, len(messages)): 577 | remaining = max_size - current_size 578 | content_length = len(messages[i].content) 579 | message_length = min(max_chunk_size * 5, content_length) 580 | if message_length >= remaining or len(processed_messages) >= max_chunks: 581 | break 582 | messages[i].content = messages[i].content[:message_length] 583 | processed_messages.append(messages[i]) 584 | current_size += len(messages[i].content[:message_length]) 585 | 586 | for i in range(last_user_message_index - 1, -1, -1): 587 | remaining = max_size - current_size 588 | content_length = len(messages[i].content) 589 | message_length = min(max_chunk_size, content_length) 590 | if messages[i].role == Role.system: 591 | continue 592 | if message_length >= remaining or len(processed_messages) >= max_chunks: 593 | # every assistant message should have user message pair 594 | if processed_messages[0].role == Role.assistant: 595 | processed_messages.popleft() 596 | break 597 | messages[i].content = messages[i].content[:message_length] 598 | processed_messages.appendleft(messages[i]) 599 | current_size += len(messages[i].content[:message_length]) 600 | 601 | if system_message: 602 | processed_messages.appendleft(system_message) 603 | 604 | return list(processed_messages) 605 | 606 | 607 | chunking_processor = None 608 | 609 | 610 | def get_chunking_processor(): 611 | global chunking_processor 612 | if not chunking_processor: 613 | chunking_processor = ChunkProcessor() 614 | return chunking_processor 615 | 616 | -------------------------------------------------------------------------------- /chunking/chunking_helper.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import gzip 3 | import re 4 | 5 | from unidecode import unidecode 6 | 7 | from abc import ABC 8 | from dataclasses import dataclass 9 | from typing import List, Optional 10 | from ranking.rank_bm25 import BM25Okapi 11 | from llm.llms import Role 12 | 13 | 14 | @dataclass 15 | class Chunk: 16 | size: int 17 | text: str 18 | role: Optional[Role] = None 19 | similarity: Optional[float] = None 20 | include_in_request: Optional[bool] = None 21 | # flag to identify if there is only functioncall in the assistant message or 22 | # functioncall exists interleaving with response 23 | is_intext_functioncall: Optional[bool] = None 24 | 25 | 26 | class SimilarityCalculator(ABC): 27 | def __init__(self): 28 | self.threshold = 0.0 29 | 30 | def calculate_similarity(self, query: str, target_chunks: List[Chunk], explain=False): 31 | pass 32 | 33 | 34 | def _load_word_frequency(filepath): 35 | ret = {} # word -> freq 36 | with gzip.open(filepath, "rt") as file: 37 | csvreader = csv.reader(file, delimiter="\t", escapechar=None, doublequote=None, quotechar=None) 38 | for row in csvreader: 39 | if len(row) != 2: 40 | continue 41 | try: 42 | freq = int(row[0]) 43 | word = row[1] 44 | # tokenizer, normalize and combine counts 45 | for token in tokenize(word): 46 | ret[token] = ret.get(token, 0) + freq 47 | except Exception as e: 48 | print(e) 49 | pass 50 | return ret 51 | 52 | 53 | stop_words = set() 54 | with open("chunking/stop_words.txt") as file: 55 | stop_words.update([line.rstrip() for line in file]) 56 | 57 | 58 | def is_stop_word(word): 59 | if word in stop_words: 60 | return True 61 | if len(word) <= 1: 62 | return True 63 | return False 64 | 65 | 66 | non_alphanumeric = re.compile(r'[^a-z0-9]') 67 | 68 | 69 | def tokenize(doc): 70 | doc = unidecode(doc) 71 | doc = doc.lower() 72 | words = doc.split() 73 | normalized_words = [] 74 | for word in words: 75 | # TODO: rewrite this without using regex for better performance 76 | normalized_word = non_alphanumeric.sub(" ", word) 77 | if word != normalized_word: 78 | normalized_words.extend(normalized_word.split()) 79 | words.extend(normalized_words) 80 | words = [word for word in words if not is_stop_word(word)] 81 | return words 82 | 83 | 84 | def remove_urls(text): 85 | text = text.lower() 86 | ret = [] 87 | for chunk in text.split(): 88 | if chunk.startswith("http://") or chunk.startswith("https://"): 89 | continue 90 | ret.append(chunk) 91 | return " ".join(ret) 92 | 93 | 94 | def normalize_query_for_truncation(query): 95 | query = query.lower() 96 | query = remove_urls(query) 97 | return query 98 | 99 | 100 | class BM25Calculator(SimilarityCalculator): 101 | _external_word_freq = _load_word_frequency("ranking/wp_1gram_top1m.txt.gz") 102 | _bm25 = BM25Okapi(external_word_freq=_external_word_freq) 103 | 104 | def __init__(self): 105 | super().__init__() 106 | self.threshold = 0.19 107 | 108 | def calculate_similarity(self, query: str, target_chunks: List[Chunk], explain=False): 109 | corpus = [chunk.text for chunk in target_chunks] 110 | tokenized_corpus = [tokenize(doc) for doc in corpus] 111 | query = normalize_query_for_truncation(query) 112 | query = tokenize(query) 113 | explain_dict = {} if explain else None 114 | scores = self._bm25.get_scores(query, tokenized_corpus, max_term_frequency=10, 115 | recalculate_idf=False, explain=explain_dict) 116 | scores = self.normalize(scores) 117 | for chunk, score in zip(target_chunks, scores): 118 | chunk.similarity = score 119 | 120 | if explain: 121 | for idx, chunk in enumerate(target_chunks): 122 | chunk.explain_similarity = explain_dict.get(idx) 123 | 124 | @staticmethod 125 | def normalize(scores): 126 | min_score = min(scores) 127 | max_score = max(scores) 128 | if max_score == min_score: 129 | return [0] * len(scores) 130 | return [(score - min_score) / (max_score - min_score) for score in scores] 131 | 132 | 133 | def get_similarity_calculator(similarity_type: str = "bm25") -> SimilarityCalculator: 134 | return BM25Calculator() 135 | 136 | 137 | if __name__ == '__main__': 138 | pass 139 | 140 | 141 | -------------------------------------------------------------------------------- /chunking/stop_words.txt: -------------------------------------------------------------------------------- 1 | i 2 | me 3 | my 4 | myself 5 | we 6 | our 7 | ours 8 | ourselves 9 | you 10 | your 11 | yours 12 | yourself 13 | yourselves 14 | he 15 | him 16 | his 17 | himself 18 | she 19 | her 20 | hers 21 | herself 22 | it 23 | its 24 | itself 25 | they 26 | them 27 | their 28 | theirs 29 | themselves 30 | what 31 | which 32 | who 33 | whom 34 | this 35 | that 36 | these 37 | those 38 | am 39 | is 40 | are 41 | was 42 | were 43 | be 44 | been 45 | being 46 | have 47 | has 48 | had 49 | having 50 | do 51 | does 52 | did 53 | doing 54 | a 55 | an 56 | the 57 | and 58 | but 59 | if 60 | or 61 | because 62 | as 63 | until 64 | while 65 | of 66 | at 67 | by 68 | for 69 | with 70 | about 71 | against 72 | between 73 | into 74 | through 75 | during 76 | before 77 | after 78 | above 79 | below 80 | to 81 | from 82 | up 83 | down 84 | in 85 | out 86 | on 87 | off 88 | over 89 | under 90 | again 91 | further 92 | then 93 | once 94 | here 95 | there 96 | when 97 | where 98 | why 99 | how 100 | all 101 | any 102 | both 103 | each 104 | few 105 | more 106 | most 107 | other 108 | some 109 | such 110 | no 111 | nor 112 | not 113 | only 114 | own 115 | same 116 | so 117 | than 118 | too 119 | very 120 | s 121 | t 122 | can 123 | will 124 | just 125 | don 126 | should -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | SERVER = "http://YOUR_SERVER_HERE:3333" 4 | VLLM_SERVER = "http://localhost:8000" 5 | SYSTEM_PROMPT_FILE="system_prompt.txt" 6 | 7 | class Config: 8 | 9 | def __init__(self, server, vllm_server, system_prompt_file) -> None: 10 | self.server = server 11 | self.vllm_server = vllm_server 12 | self.system_prompt_file = system_prompt_file 13 | with open(self.system_prompt_file, "r") as f: 14 | self.system_prompt=f.read() 15 | 16 | def get_server_url(self): 17 | return self.server 18 | 19 | def get_vllm_server_url(self): 20 | return self.vllm_server 21 | 22 | def get_system_prompt(self): 23 | return self.system_prompt 24 | 25 | config = None 26 | def get_config(): 27 | global config 28 | if config is not None: 29 | return config 30 | 31 | server = os.environ.get("SERVER", None) 32 | if not server: 33 | server=SERVER 34 | vllm_server = os.environ.get("VLLM_SERVER", None) 35 | if not vllm_server: 36 | vllm_server = VLLM_SERVER 37 | system_prompt_file = os.environ.get("SYSTEM_PROMPT_FILE", None) 38 | if not system_prompt_file: 39 | system_prompt_file = SYSTEM_PROMPT_FILE 40 | config = Config(server=server, vllm_server=vllm_server, system_prompt_file=system_prompt_file) 41 | return config -------------------------------------------------------------------------------- /llm/api_models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Manually created based on the OpenAI OpenAPI specification: 3 | https://github.com/openai/openai-openapi/blob/master/openapi.yaml 4 | 5 | openapi generator cannot generate python code correctly for some types. 6 | ''' 7 | from typing import Any, Dict, List, Optional, Literal, Union 8 | from pydantic import BaseModel, Field 9 | 10 | class ChatCompletionNamedToolChoiceFunction(BaseModel): 11 | name: str = Field(description="The name of the function to call.") 12 | 13 | class ChatCompletionNamedToolChoice(BaseModel): 14 | type: Literal['function'] = Field(description="The type of the tool.", default="function") 15 | function: ChatCompletionNamedToolChoiceFunction = Field(description="The specific function to call.") 16 | 17 | class ChatCompletionToolCallFunction(BaseModel): 18 | name: Optional[str] = Field(description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.") 19 | arguments: Optional[str] = Field(description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." ) 20 | 21 | class ChatCompletionToolFunctionObject(BaseModel): 22 | description: Optional[str] = Field(default=None, description="A description of what the function does, used by the model to choose when and how to call the function.") 23 | name: str = Field(description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.") 24 | parameters: Dict[str, Any] = Field(description="The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format.\n\nTo describe a function that accepts no parameters, provide the value `{\"type\": \"object\", \"properties\": {}}`." ) 25 | 26 | class ChatCompletionTool(BaseModel): 27 | type: Literal['function'] = Field(description="The type of the tool. Currently, only `function` is supported.", default="function") 28 | function: ChatCompletionToolFunctionObject = Field(description="function.") 29 | 30 | class ChatCompletionToolCall(BaseModel): 31 | id: str = Field(description="The ID of the tool call.") 32 | type: Literal['function'] = Field(description="The type of the tool. Currently, only `function` is supported.", default="function") 33 | function: ChatCompletionToolCallFunction = Field(description="The function that the model called.") 34 | 35 | ChatCompletionMessageToolCall = ChatCompletionToolCall 36 | class ChatCompletionMessageToolCallChunk(BaseModel): 37 | index: int = Field(description="index") 38 | id: Optional[str] = Field(description="The ID of the tool call.") 39 | type: Optional[Literal['function']] = Field(description="The type of the tool. Currently, only `function` is supported.", default="function") 40 | function: ChatCompletionToolCallFunction = Field(description="The function that the model called.") 41 | 42 | class ChatCompletionMessage(BaseModel): 43 | content: Optional[Union[str, List[Dict]]] = Field(description="The contents of the message. `content` is required for all messages, and may be null for assistant messages with function calls.") 44 | role: Optional[Literal["system", "user", "tool", "assistant", "function"]] = Field(description="The role of the messages author. One of `system`, `user`, `assistant`, or `function`.") 45 | tool_call_id: Optional[str] = Field(description="Tool call that this message is responding to. Only apply for role=tool") 46 | tool_calls: Optional[List[ChatCompletionMessageToolCall]] = Field(description="The tool calls generated by the model, such as function calls. Only apply for role=assistant") 47 | 48 | ChatCompletionRequestMessage = ChatCompletionMessage 49 | ChatCompletionResponseMessage = ChatCompletionMessage 50 | 51 | ChatCompletionStreamResponseDeltaFunctionCall = ChatCompletionToolCall 52 | 53 | class ChatCompletionStreamResponseDelta(BaseModel): 54 | content: Optional[str] = Field(description="The contents of the message. `content` is required for all messages, and may be null for assistant messages with function calls.", default=None) 55 | role: Optional[Literal["system", "user", "tool", "assistant", "function"]] = Field(description="The role of the messages author. One of `system`, `user`, `assistant`, or `function`.", default=None) 56 | tool_call_id: Optional[str] = Field(description="Tool call that this message is responding to. Only apply for role=tool", default=None) 57 | tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = Field(description="The tool calls generated by the model, such as function calls. Only apply for role=assistant", default=None) 58 | 59 | class CompletionUsage(BaseModel): 60 | completion_tokens: int = Field(description="Number of tokens in the generated completion.", default=0) 61 | prompt_tokens: int = Field(description="Number of tokens in the prompt.", default=0) 62 | total_tokens: int = Field(description="Total number of tokens used in the request (prompt + completion).", default=0) 63 | 64 | 65 | class ChatCompletionResponseChoice(BaseModel): 66 | finish_reason: Literal["eos", "stop", "length", "tool_calls", "content_filter"]= Field(description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,`length` if the maximum number of tokens specified in the request was reached,`content_filter` if content was omitted due to a flag from our content filters, or `tool_calls` if the model called a tool.") 67 | index: int = Field(description="The index of the choice in the list of choices.") 68 | message: ChatCompletionResponseMessage = Field(alias="message", description="message") 69 | 70 | class ChatCompletionResponse(BaseModel): 71 | id: str = Field(description="A unique identifier for the chat completion.") 72 | choices: List[ChatCompletionResponseChoice] = Field(description="A list of chat completion choices. Can be more than one if `n` is greater than 1.") 73 | created: int = Field(description="The Unix timestamp (in seconds) of when the chat completion was created.") 74 | model: str = Field(description="The model used for the chat completion.") 75 | object: str = Field(description="The object type, which is always `chat.completion`", default="chat.completion") 76 | system_fingerprint: Optional[str] = Field(description="This fingerprint represents the backend configuration that the model runs with.") 77 | usage: Optional[CompletionUsage] = Field(default=None, description="usage") 78 | 79 | ChatCompletionFunctionResponse = ChatCompletionResponse 80 | 81 | class ChatCompletionStreamResponseChoice(BaseModel): 82 | delta: Optional[ChatCompletionStreamResponseDelta] = Field(alias="delta") 83 | finish_reason: Optional[Literal["eos", "stop", "length", "tool_calls", "content_filter", ""]] = Field(default=None, description="finish reason") 84 | index: int = Field(description="The index of the choice in the list of choices.") 85 | 86 | class ChatCompletionStreamResponse(BaseModel): 87 | id: str = Field(description="A unique identifier for the chat completion. Each chunk has the same ID.") 88 | choices: List[ChatCompletionStreamResponseChoice] = Field(description="A list of chat completion choices. Can be more than one if `n` is greater than 1.") 89 | created: int = Field(description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.") 90 | model: str = Field(description="The model to generate the completion.") 91 | system_fingerprint: Optional[str] = Field(description="This fingerprint represents the backend configuration that the model runs with.") 92 | object: Literal["chat.completion.chunk"]= Field(description="The object type, which is always `chat.completion.chunk`.", default="chat.completion.chunk") 93 | 94 | class ChatCompletionResponseFormat(BaseModel): 95 | type: Literal['text', 'json_object']= Field(description="The format that the model must output. Must be one of `text` or `json_object`.", default='text') 96 | 97 | class CreateChatCompletionRequest(BaseModel): 98 | messages: List[ChatCompletionRequestMessage] = Field(description="A list of messages comprising the conversation so far.") 99 | model: str = Field(description="ID of the model to use. e.g., gpt-3.5-turbo") 100 | frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) 101 | logit_bias: Optional[Dict[str, int]] = Field(description="Modify the likelihood of specified tokens appearing in the completion.", default=None) 102 | max_tokens: Optional[int] = Field(description="The maximum number of [tokens](/tokenizer) to generate in the chat completion.") 103 | n: Optional[int] = Field(description="How many chat completion choices to generate for each input message.", default=1, ge=1, le=128) 104 | presence_penalty: Optional[float] = Field(description="", default=0, ge=-1, le=2) 105 | response_format: Optional[ChatCompletionResponseFormat] = Field(description="An object specifying the format that the model must output.") 106 | seed: Optional[int] = Field(description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.", default=None) 107 | stop: Optional[Union[str, List[str]]] = Field(description="Up to 4 sequences where the API will stop generating further tokens.", default=None) 108 | stream: Optional[bool] = Field(alias="stream", default=False) 109 | temperature: Optional[float] = Field(alias="temperature", default=0, ge=0, le=2) 110 | top_p: Optional[float] = Field(alias="top_p", default=1, ge=0, le=1) 111 | tools: Optional[List[ChatCompletionTool]] = Field(description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for", default=None) 112 | tool_choice: Optional[Union[Literal["none", "auto"], ChatCompletionNamedToolChoice]] = Field(default=None) 113 | user: Optional[str] = Field(alias="user", default=None) 114 | 115 | class Error(BaseModel): 116 | code: str = Field(alias="code", description="The code of this Error.") 117 | message: str = Field(alias="message", description="The message of this Error.") 118 | param: Optional[str] = Field(alias="param", description="The param of this Error.") 119 | type: Optional[str] = Field(alias="type", description="The type of this Error.") 120 | 121 | class ErrorResponse(BaseModel): 122 | error: Error = Field(description="The error of this ErrorResponse") 123 | 124 | class LLMException(Exception): 125 | error: Error = None 126 | 127 | def __init__(self, *args: object, error: Error) -> None: 128 | super().__init__(*args) 129 | self.error = error 130 | 131 | class ChunkRequest(BaseModel): 132 | text: str 133 | chunk_size: int 134 | chunk_overlap: int 135 | 136 | class ChunkResponse(BaseModel): 137 | chunk_text: str 138 | offset: int 139 | 140 | class ToolCall(BaseModel): 141 | function_name: str 142 | function_arguments: str 143 | tool_call_id: str 144 | 145 | class ToolCalls(BaseModel): 146 | tool_calls: List[ToolCall] 147 | content: str 148 | 149 | class RequestGlobals(BaseModel): 150 | usage: dict[str, CompletionUsage] = {} 151 | timings: dict[str, float] = {} 152 | diffbot_responses: list[dict[str, Any]] = [] 153 | internal_request: dict[str, Any] = {} 154 | disable_tool_fallback: bool = False 155 | diffbot_token: str = "" -------------------------------------------------------------------------------- /llm/diffbot_tool_call_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union, AsyncIterable, Optional, AsyncIterator 2 | import json 3 | import re 4 | 5 | from config import get_config 6 | from llm.api_models import ChatCompletionRequestMessage, CompletionUsage, CreateChatCompletionRequest, \ 7 | ChatCompletionResponse, ChatCompletionStreamResponse, ToolCall, ToolCalls, LLMException, Error, RequestGlobals 8 | from llm.openai_gpt import OpenAIModel, get_openai_llm 9 | from llm.llms import ModelID, Role 10 | from llm.tool_call_llm import ToolCallLLM, FUNCTION_CALL_TOKEN, END_OF_TEXT_TOKEN, parse_tool_call 11 | from server.log import get_logstash_logger 12 | 13 | DIFFBOT_TOOL_USE_PROMPT = get_config().get_system_prompt() 14 | 15 | DIFFBOT_ALTERNATIVE_PROMPT = """You are a helpful assistant without access to any functions. Use the information below to answer the users query. 16 | """ 17 | 18 | END_OF_SYSTEM_TOKEN = '=======' 19 | logger = get_logstash_logger("diffbot_tool_call_llm") 20 | WHITESPACE = re.compile("\s+") 21 | 22 | 23 | def contains_url(query): 24 | regex = r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+" 25 | url = re.findall(regex, query) 26 | return len(url) > 0 27 | 28 | 29 | class DiffbotToolCallLLM(ToolCallLLM): 30 | supported_models = {ModelID.DIFFBOT_SMALL, ModelID.DIFFBOT_SMALL_XL} 31 | 32 | def __init__(self): 33 | super().__init__() 34 | 35 | self.llm = get_openai_llm() 36 | 37 | @classmethod 38 | def is_supported(cls, model: ModelID) -> bool: 39 | return model and model in cls.supported_models 40 | 41 | def get_system_prompt(self, query: Optional[str] = None) -> str: 42 | return f'{DIFFBOT_TOOL_USE_PROMPT}\n{END_OF_SYSTEM_TOKEN}' 43 | 44 | 45 | def remove_system_prompt(self, message: str) -> str: 46 | idx = message.find(END_OF_SYSTEM_TOKEN) 47 | if idx == -1: 48 | return message 49 | return message[idx+len(END_OF_SYSTEM_TOKEN):] 50 | 51 | def chat_completion(self, request: CreateChatCompletionRequest, request_globals: RequestGlobals) \ 52 | -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: 53 | return self.llm.chat_completion(request, request_globals) 54 | 55 | async def select_tool( 56 | self, 57 | request: CreateChatCompletionRequest, 58 | request_globals: RequestGlobals 59 | ) -> Any: 60 | 61 | response = await self.llm.chat_completion(request, request_globals) 62 | if not response: 63 | return None 64 | if isinstance(response, AsyncIterable): 65 | return self.parse_response(response, request, request_globals) 66 | 67 | elif response.choices[-1]: 68 | tool_call_message = response.choices[-1].message.content 69 | if FUNCTION_CALL_TOKEN in tool_call_message.strip(): 70 | tool_calls = parse_tool_call(tool_call_message) 71 | # check non-empty list of tool_calls 72 | if tool_calls and tool_calls.tool_calls: 73 | return tool_calls 74 | response.choices[-1].message.content = sanitize_response(response.choices[-1].message.content) 75 | 76 | return response 77 | 78 | def parse_content_delta(self, content_delta: str, buffer: List[str]) -> str: 79 | # next chunk should stop at whitespace 80 | ret = "" 81 | next_chunk = content_delta 82 | remaining_chunk = None 83 | 84 | match = WHITESPACE.search(content_delta) 85 | if match: 86 | idx = match.end() 87 | next_chunk = content_delta[:idx] 88 | remaining_chunk = content_delta[idx:] 89 | 90 | buffer.append(next_chunk) 91 | buffer_str = "".join(buffer) 92 | min_length = min(len(buffer_str), len(FUNCTION_CALL_TOKEN)) 93 | if buffer_str[:min_length] != FUNCTION_CALL_TOKEN[:min_length]: 94 | buffer.clear() 95 | ret = buffer_str 96 | if remaining_chunk is not None: 97 | buffer.append(remaining_chunk) 98 | return ret 99 | 100 | async def parse_response(self, response, request: CreateChatCompletionRequest, request_globals: RequestGlobals): 101 | all_chunks = [] # all message.chunks being streamed by llm 102 | buffer = [] # latest chunks streamed by llm that cannot be yielded yet 103 | async for chunk in response: 104 | 105 | if chunk.choices and len(chunk.choices) == 1 and chunk.choices[0].delta.content: 106 | all_chunks.append(chunk.choices[0].delta.content) 107 | 108 | # if finished_reason is present, LLM stopped generating. 109 | # if buffer is not empty, it might contain function calls to be parsed. 110 | if chunk.choices and len(chunk.choices) == 1 and chunk.choices[0].finish_reason: 111 | buffer_str = "".join(buffer) + chunk.choices[0].delta.content 112 | 113 | if buffer_str.startswith(FUNCTION_CALL_TOKEN): 114 | all_chunks_str = "".join(all_chunks) 115 | yield parse_tool_call(all_chunks_str) 116 | else: 117 | chunk.choices[0].delta.content = buffer_str 118 | # send last tokens 119 | reason = chunk.choices[0].finish_reason 120 | chunk.choices[0].finish_reason = None 121 | yield chunk 122 | # send last chunk with empty content 123 | chunk.choices[0].finish_reason = reason 124 | chunk.choices[0].delta.content = '' 125 | yield chunk 126 | continue 127 | 128 | if chunk.choices and len(chunk.choices) == 1 and chunk.choices[0].delta.content: 129 | new_content = self.parse_content_delta(chunk.choices[0].delta.content, buffer) 130 | chunk.choices[0].delta.content = new_content 131 | if len(new_content)>0: 132 | yield chunk 133 | 134 | async def combine_stream(non_stream_results, stream_result): 135 | for item in non_stream_results: 136 | yield item 137 | async for item in stream_result: 138 | yield item 139 | 140 | async def sanitize_stream(stream_result): 141 | async for item in stream_result: 142 | delta_content = item.choices[-1].delta.content 143 | item.choices[-1].delta.content = sanitize_response(delta_content) 144 | yield item 145 | 146 | def sanitize_response(text): 147 | if not text: 148 | return text 149 | text = text.replace(END_OF_TEXT_TOKEN, "") 150 | text = text.replace("", "") 151 | return text 152 | 153 | diffbot_tool_call_llm = None 154 | def get_diffbot_tool_call_llm(): 155 | global diffbot_tool_call_llm 156 | if not diffbot_tool_call_llm: 157 | diffbot_tool_call_llm = DiffbotToolCallLLM() 158 | return diffbot_tool_call_llm -------------------------------------------------------------------------------- /llm/llms.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from types import GeneratorType 3 | from typing import Any, List, Optional, Dict, Union 4 | from enum import Enum 5 | from dataclasses import dataclass 6 | 7 | from llm.api_models import ChatCompletionRequestMessage, CompletionUsage, CreateChatCompletionRequest, \ 8 | ChatCompletionResponse, ChatCompletionStreamResponse, RequestGlobals 9 | 10 | # # supported LLM models and functions 11 | 12 | OPENAI_GPT_3_5 = 'gpt-3.5-turbo' 13 | OPENAI_GPT_4 = 'gpt-4-turbo' 14 | MISTRAL_7B_INSTRUCT = 'mistral-7b-instruct-32k' 15 | OUTPUT_TOKENS_RESERVE = 4000 16 | INPUT_TOKENS_LIMIT = 16384 # no model should have a tokenLimit lower than INPUT_TOKENS_LIMIT + OUTPUT_TOKENS_RESERVE. 17 | LAST_TOOL_TOKENS_LIMIT = 6000 # limit on tokens to use for last tool response. should be lower than INPUT_TOKENS_LIMIT. 18 | 19 | class ModelID(str, Enum): 20 | DIFFBOT_SMALL = 'diffbot-small' 21 | DIFFBOT_SMALL_XL = 'diffbot-small-xl' 22 | UNKNOWN = 'unknown' 23 | 24 | @classmethod 25 | def get_model_id(cls, model_name: str): 26 | try: 27 | return ModelID(model_name) 28 | except ValueError: 29 | return ModelID.UNKNOWN 30 | 31 | @dataclass 32 | class ModelInfo: 33 | id: ModelID 34 | model: str 35 | tokenLimit: int 36 | 37 | LLMS = { 38 | ModelID.DIFFBOT_SMALL: ModelInfo( 39 | id=ModelID.DIFFBOT_SMALL, 40 | model=ModelID.DIFFBOT_SMALL, 41 | tokenLimit=131072, 42 | ), 43 | ModelID.DIFFBOT_SMALL_XL: ModelInfo( 44 | id=ModelID.DIFFBOT_SMALL_XL, 45 | model=ModelID.DIFFBOT_SMALL_XL, 46 | tokenLimit=131072, 47 | ), 48 | } 49 | 50 | 51 | # # Chat related data structures. 52 | class Role(str, Enum): 53 | system = 'system' 54 | user = 'user' 55 | assistant = 'assistant' 56 | tool = 'tool' 57 | 58 | class LLM(ABC): 59 | @classmethod 60 | def is_supported(cls, model: ModelID) -> bool: 61 | pass 62 | 63 | def generate_prompt(self, system_prompt: str, messages: List[ChatCompletionRequestMessage], maxLength: int) -> str: 64 | pass 65 | 66 | def chat_completion(self, request: CreateChatCompletionRequest, request_globals: RequestGlobals) \ 67 | -> Union[ChatCompletionResponse, ChatCompletionStreamResponse]: 68 | pass 69 | -------------------------------------------------------------------------------- /llm/openai_gpt.py: -------------------------------------------------------------------------------- 1 | import time 2 | from openai import AsyncOpenAI 3 | from typing import Union, AsyncIterator 4 | from llm.token_utils import count_prompt_tokens 5 | 6 | from models.utils import cleanNullValues 7 | from config import Config, get_config 8 | from llm.llms import LLM, ModelID, LLMS, INPUT_TOKENS_LIMIT, OUTPUT_TOKENS_RESERVE, LAST_TOOL_TOKENS_LIMIT 9 | from llm.api_models import ChatCompletionRequestMessage, CompletionUsage, CreateChatCompletionRequest, \ 10 | ChatCompletionResponse, ChatCompletionStreamResponse, LLMException, RequestGlobals, Error 11 | from chunking.chunk_processor import get_chunking_processor 12 | from server.log import get_logstash_logger 13 | 14 | logger = get_logstash_logger("openai_gpt") 15 | 16 | class OpenAIModel(LLM): 17 | supported_model = {ModelID.DIFFBOT_SMALL, ModelID.DIFFBOT_SMALL_XL} 18 | diffbot_client = AsyncOpenAI( 19 | api_key="EMPTY", 20 | base_url=get_config().get_vllm_server_url()+"/v1", 21 | ) 22 | 23 | @classmethod 24 | def is_supported(cls, model: ModelID) -> bool: 25 | return model and model in cls.supported_model 26 | 27 | async def chat_completion(self, request: CreateChatCompletionRequest, request_globals: RequestGlobals) \ 28 | -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: 29 | model_token_limit = LLMS[request.model].tokenLimit 30 | if model_token_limit < INPUT_TOKENS_LIMIT + OUTPUT_TOKENS_RESERVE: 31 | input_tokens_limit = model_token_limit - OUTPUT_TOKENS_RESERVE 32 | else: 33 | input_tokens_limit = INPUT_TOKENS_LIMIT 34 | log_ctx = {} 35 | start_truncation = time.time() 36 | json_length_before = sum(len(message.content) for message in request.messages) 37 | request.messages = get_chunking_processor().process_messages(request.messages, 38 | max_tokens=input_tokens_limit, 39 | max_tokens_last_tool=LAST_TOOL_TOKENS_LIMIT, 40 | log_ctx=log_ctx) 41 | json_length_after = sum(len(message.content) for message in request.messages) 42 | truncation_time = round((time.time() - start_truncation) * 1000) #ms 43 | log_ctx.update({"truncation_time": truncation_time, 44 | "json_length_before": json_length_before, 45 | "json_length_after": json_length_after, 46 | "is_longer_after_truncation": json_length_after > json_length_before, 47 | "is_longer_than_limit": json_length_after > (log_ctx.get("max_size", INPUT_TOKENS_LIMIT * 4)), 48 | "model": request.model 49 | }) 50 | logger.info("message truncation completed", extra=log_ctx) 51 | 52 | request_new = cleanNullValues(request.dict()) 53 | 54 | if request_new['model'] not in request_globals.usage: 55 | request_globals.usage[request_new['model']] = CompletionUsage() 56 | curr_usage = request_globals.usage[request_new['model']] 57 | 58 | if request_new['model'] == ModelID.DIFFBOT_SMALL or request_new['model'] == ModelID.DIFFBOT_SMALL_XL: 59 | client = self.diffbot_client 60 | 61 | # VLLM errs with these unnecessary fields, remove them 62 | request_new.pop('tools', None) 63 | request_new.pop('tool_choice', None) 64 | for message in request_new['messages']: 65 | # Convert all role:tool to role:user for VLLM to enforce alternative user/assisant turns 66 | if message['role'] == "tool": 67 | message['role'] = "user" 68 | message.pop('tool_call_id', None) 69 | 70 | # stop tokens 71 | request_new['stop'] = ['<|endoftext|>', '<|im_end|>', 72 | # to prevent the LLM from having infinite conversations with itself 73 | '### USER:', '### ASSISTANT:' , '### ' 74 | ] 75 | else: 76 | raise LLMException(error=Error(code=422, message="Invalid model: {}".format(request_new['model']))) 77 | 78 | request_globals.internal_request = request_new # save last internal request 79 | response = await client.chat.completions.create(**request_new, timeout=60) 80 | 81 | if request.stream: 82 | # get prompt_token 83 | prompt_tokens_count = count_prompt_tokens(request.messages) 84 | curr_usage.prompt_tokens += prompt_tokens_count 85 | curr_usage.total_tokens += prompt_tokens_count 86 | return parse_stream(response, curr_usage) 87 | else: 88 | chat_response = ChatCompletionResponse(**response.dict()) 89 | curr_usage.prompt_tokens += chat_response.usage.prompt_tokens 90 | curr_usage.completion_tokens += chat_response.usage.completion_tokens 91 | curr_usage.total_tokens += chat_response.usage.total_tokens 92 | 93 | return chat_response 94 | 95 | async def parse_stream(response, curr_usage: CompletionUsage): 96 | async for event in response: 97 | if not event.object: # skip empty events 98 | continue 99 | if event.choices and len(event.choices) == 1 and event.choices[0].delta and event.choices[0].delta.content is None: 100 | event.choices[0].delta.content = "" 101 | curr_usage.completion_tokens += 1 102 | curr_usage.total_tokens += 1 103 | yield ChatCompletionStreamResponse(**event.dict()) 104 | 105 | openai_model = None 106 | def get_openai_llm(): 107 | global openai_model 108 | if not openai_model: 109 | openai_model = OpenAIModel() 110 | return openai_model -------------------------------------------------------------------------------- /llm/plugin.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from urllib.parse import urlencode 3 | import json 4 | import httpx 5 | from llm.api_models import ChatCompletionTool, ChatCompletionToolFunctionObject 6 | from models.api import ResponseModel 7 | from services.execute_js import get_js_execution_service 8 | 9 | TIMEOUT_SECONDS = 60 10 | 11 | class PluginResponse(ResponseModel): 12 | plugin_url: str 13 | method: str 14 | content: object = None 15 | 16 | @dataclass 17 | class PluginApiOperationParameter: 18 | name: str = None 19 | _in: str = None 20 | description: str = None 21 | required: bool = None 22 | schema: object = None 23 | 24 | @classmethod 25 | def from_dict(cls, param, api_spec): 26 | if "schema" in param: 27 | schema = param.get("schema") 28 | if "$ref" in schema: 29 | schema = cls.resolve_ref(schema.get("$ref", api_spec)) 30 | if "title" in schema: 31 | del schema["title"] 32 | if "items" in schema and "$ref" in schema.get("items"): 33 | schema["items"] = cls.resolve_ref( 34 | schema.get("items").get("$ref"), api_spec 35 | ) 36 | if "title" in schema["items"]: 37 | del schema["items"]["title"] 38 | if "description" in schema["items"]: 39 | del schema["items"]["description"] 40 | 41 | return cls( 42 | name=param.get("name"), 43 | _in=param.get("in"), 44 | description=param.get("description"), 45 | required=param.get("required"), 46 | schema=schema, 47 | ) 48 | 49 | return None 50 | 51 | @classmethod 52 | def resolve_ref(cls, ref, api_spec): 53 | if not ref.startswith("#"): 54 | raise Exception(f"Unsupported reference: {ref}") 55 | 56 | ref_path = ref[1:].split("/") 57 | current = api_spec 58 | for ref_path_part in ref_path: 59 | if ref_path_part == "": 60 | continue 61 | if current is None: 62 | raise Exception(f"Cannot resolve reference: {ref}") 63 | current = current.get(ref_path_part) 64 | 65 | return current 66 | 67 | 68 | @dataclass 69 | class PluginApiOperationRequestBody: 70 | description: str = None 71 | required: bool = None 72 | content: dict[str, object] = None 73 | 74 | 75 | @dataclass 76 | class PluginApiOperationResponse: 77 | description: str 78 | content: dict[str, object] 79 | 80 | 81 | @dataclass 82 | class PluginApiOperation: 83 | operation_id: str = None 84 | server_url: str = None 85 | api_path: str = None 86 | method: str = None # get/post 87 | description: str = None 88 | parameters: list[PluginApiOperationParameter] = None 89 | request_body: PluginApiOperationRequestBody = None 90 | responses: dict[str, PluginApiOperationResponse] = None 91 | 92 | 93 | def _get_plugin_spec(plugin_url: str): 94 | return _get_content(plugin_url) 95 | 96 | def _get_content(url: str): 97 | # get the content from the input URL 98 | try: 99 | with httpx.Client(timeout=2) as client: 100 | response = client.get(url=url) 101 | return response.json() 102 | except Exception as e: 103 | print(f"Error getting content from {url}: {e}") 104 | return None 105 | 106 | class Plugin: 107 | def __init__(self, plugin_api_spec: dict[str, any]= None): 108 | self.plugin_apis = self._get_plugin_apis(plugin_api_spec) 109 | 110 | def _get_plugin_apis(self, plugin_api_spec: dict[str, any]): 111 | # parse the plugin API spec 112 | if "servers" in plugin_api_spec and len(plugin_api_spec.get("servers")) > 0: 113 | server_url = plugin_api_spec.get("servers")[0].get("url") 114 | else: 115 | server_url = plugin_api_spec["servers"][0]["url"] 116 | 117 | plugin_apis = [] 118 | for api_path in plugin_api_spec.get("paths"): 119 | api_detail = plugin_api_spec.get("paths")[api_path] 120 | for method in api_detail: 121 | operation_object = api_detail[method] 122 | operation_id = operation_object.get("operationId") 123 | description = operation_object.get("description") 124 | parameters = [] 125 | if "parameters" in operation_object: 126 | parameters = [ 127 | PluginApiOperationParameter.from_dict( 128 | parameter, plugin_api_spec 129 | ) 130 | for parameter in operation_object.get("parameters") 131 | ] 132 | 133 | plugin_apis.append( 134 | PluginApiOperation( 135 | operation_id=operation_id, 136 | server_url=server_url, 137 | api_path=api_path, 138 | method=method, 139 | description=description, 140 | parameters=parameters, 141 | ) 142 | ) 143 | 144 | return plugin_apis 145 | 146 | def _get_tool_from_plugin_api_operation( 147 | self, pluginApiOperation: PluginApiOperation 148 | ) -> ChatCompletionTool: 149 | properties = {} 150 | required = [] 151 | if pluginApiOperation.parameters: 152 | for param in pluginApiOperation.parameters: 153 | properties[param.name] = param.schema 154 | if param.required: 155 | required.append(param.name) 156 | 157 | if pluginApiOperation.request_body: 158 | for key in pluginApiOperation.request_body: 159 | for property_key in pluginApiOperation.request_body[ 160 | key 161 | ].schema.properties: 162 | schema_properties = pluginApiOperation.request_body[ 163 | key 164 | ].schema.properties 165 | properties[property_key] = { 166 | "type": schema_properties[property_key].type, 167 | "description": schema_properties[property_key].description, 168 | } 169 | if schema_properties[property_key].required: 170 | required.append(property_key) 171 | 172 | return ChatCompletionTool( 173 | type="function", 174 | function=ChatCompletionToolFunctionObject( 175 | description=pluginApiOperation.description, 176 | name=pluginApiOperation.operation_id, 177 | parameters={ 178 | "type": "object", 179 | "properties": properties, 180 | "required": required, 181 | }, 182 | ) 183 | ) 184 | 185 | def get_tools(self): 186 | return [self._get_tool_from_plugin_api_operation(operation) for operation in self.plugin_apis] 187 | 188 | async def invoke(self, function_name: str, function_arguments: object, token: str) -> PluginResponse: 189 | if not function_name or not function_arguments: 190 | return None 191 | 192 | # call local tools 193 | if function_name == "execute_js_v1": 194 | resp = await get_js_execution_service().execute_js(function_arguments["expressions"]) 195 | return PluginResponse( 196 | plugin_url=function_name, method="INTERNAL", content=resp.json() 197 | ) 198 | 199 | # or call external Diffbot tools 200 | plugin_apis = [plugin_api for plugin_api in self.plugin_apis if plugin_api.operation_id == function_name] 201 | if not plugin_apis or len(plugin_apis) == 0: 202 | return None 203 | 204 | params = {} 205 | operation = plugin_apis[0] 206 | if operation.parameters: 207 | for parameter in operation.parameters: 208 | if parameter._in == "path" and parameter.name in function_arguments: 209 | new_path = function_arguments[parameter.name] 210 | operation.api_path = operation.api_path.replace( 211 | parameter.name, new_path 212 | ) 213 | if parameter._in == "query" and parameter.name in function_arguments: 214 | new_query = function_arguments[parameter.name] 215 | params[parameter.name] = new_query 216 | 217 | body = {} 218 | if operation.request_body: 219 | for key in operation.request_body.content: 220 | for property_key in operation.request_body.content[key]["schema"][ 221 | "properties" 222 | ]: 223 | if property_key in function_arguments: 224 | body[property_key] = function_arguments[property_key] 225 | body = json.dumps(body) 226 | else: 227 | body = None 228 | 229 | url = f"{operation.server_url}{operation.api_path}" 230 | if params: 231 | url = f"{url}?{urlencode(params, doseq=True)}" 232 | 233 | print(f"calling plugin api: {url}") 234 | 235 | return await self._call_plugin_api(api_url=url, method=operation.method, body=body, token=token) 236 | 237 | @classmethod 238 | async def _call_plugin_api(cls, api_url: str, method: str, body: str, token: str): 239 | async with httpx.AsyncClient(timeout=TIMEOUT_SECONDS) as client: 240 | try: 241 | response = await client.request( 242 | method=method, 243 | url=api_url, 244 | data=body, 245 | headers={"Content-Type": "application/json", "Authorization": f"Bearer {token}"}, 246 | ) 247 | response.raise_for_status() 248 | return PluginResponse( 249 | plugin_url=api_url, method=method, content=response.json() 250 | ) 251 | except httpx.HTTPStatusError as e: 252 | status = ( 253 | "code" in e.response.json() 254 | and e.response.json()["code"] 255 | or e.response.status_code 256 | ) 257 | message = ( 258 | "message" in e.response.json() 259 | and e.response.json()["message"] 260 | or e.response.reason_phrase 261 | ) 262 | return PluginResponse( 263 | plugin_url=api_url, method=method, status=status, message=message 264 | ) 265 | except Exception as e: 266 | return PluginResponse( 267 | plugin_url=api_url, method=method, status=500, message=str(e) 268 | ) 269 | 270 | plugin = Plugin(_get_plugin_spec("https://llm.diffbot.com/api/.well-known/openapi.yaml")) 271 | def get_plugin() -> Plugin: 272 | return plugin -------------------------------------------------------------------------------- /llm/token_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import tiktoken 3 | 4 | from llm.api_models import ChatCompletionRequestMessage, CompletionUsage 5 | 6 | 7 | def count_stream(stream_response, usage: CompletionUsage): 8 | for ret in stream_response: 9 | usage.completion_tokens += 1 10 | yield ret 11 | 12 | def count_prompt_tokens(messages: List[ChatCompletionRequestMessage]): 13 | prompts = "".join([msg.content for msg in messages]) 14 | return num_of_prompt_tokens(prompts) 15 | 16 | def num_of_prompt_tokens(prompts: str, encoding_name: str = 'cl100k_base'): 17 | encoding = tiktoken.get_encoding(encoding_name) 18 | return len(encoding.encode(prompts)) -------------------------------------------------------------------------------- /llm/tool_call_llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from abc import abstractmethod 4 | from typing import Any, List, Optional, Dict, Union, AsyncIterable 5 | import asyncio 6 | 7 | from llm.api_models import (ChatCompletionRequestMessage, CompletionUsage, CreateChatCompletionRequest, 8 | ChatCompletionResponse, ChatCompletionStreamResponse, Error, LLMException, ToolCall, 9 | ToolCalls, RequestGlobals) 10 | from llm.llms import LLM, Role, ModelID 11 | from models.api import ResponseModel, DiffbotAPIResponse, DQLResponse 12 | from server.log import get_logstash_logger 13 | from llm.plugin import get_plugin, Plugin 14 | 15 | logger = get_logstash_logger("tool_call_llm") 16 | 17 | MAX_NUM_CALLS = 3 18 | FUNCTION_CALL_TOKEN = '' 19 | END_OF_TEXT_TOKEN = '<|endoftext|>' 20 | 21 | def parse_tool_call(content) -> ToolCalls: 22 | try: 23 | content = content.lstrip() 24 | function_calls_str = content.split(FUNCTION_CALL_TOKEN) 25 | ret = [] 26 | for call_str in function_calls_str: 27 | call_str = call_str.strip() 28 | # we don't expect the function call JSON to have linebreaks. The LLM sometimes adds additional paragraphs of 29 | # text after the function call and we want to ignore these additional paragraphs. 30 | if "\n" in call_str: 31 | call_str = call_str[:call_str.index("\n")] 32 | if not call_str: 33 | continue 34 | start_idx = 0 35 | end_idx = call_str.index(END_OF_TEXT_TOKEN) if END_OF_TEXT_TOKEN in call_str else len(call_str) 36 | json_text = call_str[start_idx:end_idx].strip() 37 | if not json_text.startswith("{") or not json_text.endswith("}"): 38 | continue 39 | try: 40 | function_call = json.loads(json_text) 41 | function_name = function_call['name'] 42 | function_arguments = json.dumps(function_call['arguments']) 43 | ret.append( 44 | ToolCall(function_name=function_name, function_arguments=function_arguments, tool_call_id="")) 45 | except Exception as e: 46 | logger.error(f"Failed to parse tool call: {json_text}. Exception: {e}", exc_info=True) 47 | if not ret: 48 | raise LLMException(error=Error(code=500, message="Invalid tool call. {}".format(content))) 49 | return ToolCalls(tool_calls=ret, content=content) 50 | except Exception: 51 | raise LLMException(error=Error(code=500, message="Invalid tool call. {}".format(content))) 52 | 53 | async def combine_streams(stream_result1: AsyncIterable, stream_result2: AsyncIterable): 54 | async for item in stream_result1: 55 | yield item 56 | async for item in stream_result2: 57 | yield item 58 | 59 | async def async_iterable(data): 60 | for item in data: 61 | yield item 62 | 63 | class ToolCallLLM(LLM): 64 | 65 | @abstractmethod 66 | def get_system_prompt(self, query: Optional[str] = None) -> str: 67 | pass 68 | 69 | @abstractmethod 70 | def remove_system_prompt(self, message: str) -> str: 71 | pass 72 | 73 | @abstractmethod 74 | def select_tool( 75 | self, 76 | request: CreateChatCompletionRequest, 77 | request_globals: RequestGlobals, 78 | ) -> Any: 79 | pass 80 | 81 | async def process_tool_calls( 82 | self, 83 | request: CreateChatCompletionRequest, 84 | plugin: Plugin, 85 | request_globals: RequestGlobals 86 | ) -> Any: 87 | # setup the request for tool calls 88 | tool_call_request = request.copy(deep=True) 89 | 90 | # add/update system prompt for tool calls 91 | start_system_prompt = time.time() 92 | system_prompt = self.get_system_prompt(query=tool_call_request.messages[-1].content) 93 | has_system = False 94 | if tool_call_request.messages and len(tool_call_request.messages) > 0: 95 | for message in tool_call_request.messages: 96 | if message.role == 'system': 97 | message.content = f'{system_prompt}\n\n{message.content}' 98 | has_system = True 99 | break 100 | if not has_system: 101 | tool_call_request.messages.insert(0, ChatCompletionRequestMessage(role='system', content=system_prompt)) 102 | request_globals.timings["system_prompt"] = (time.time() - start_system_prompt) * 1000 103 | 104 | request_globals.diffbot_responses = [] 105 | if request.stream: 106 | return self._process_stream_tool_calls(tool_call_request, request_globals, plugin) 107 | return await self._process_non_stream_tool_calls(tool_call_request, request_globals, plugin) 108 | 109 | async def _process_non_stream_tool_calls( 110 | self, 111 | request: CreateChatCompletionRequest, 112 | request_globals: RequestGlobals, 113 | plugin: Plugin, 114 | num_tool_calls=0, 115 | force_last_tool_call=False 116 | ): 117 | response = await self.select_tool_from_llm_or_user(request, request_globals) 118 | if isinstance(response, ToolCalls) and len(response.tool_calls)>0 and isinstance(response.tool_calls[0], ToolCall): 119 | num_tool_calls += 1 120 | 121 | if force_last_tool_call: 122 | raise LLMException(error=Error(code=500, message=f"Exceeded max function call attempts")) 123 | 124 | if num_tool_calls > MAX_NUM_CALLS: 125 | # try to generate a response without tool calls as a last attempt 126 | await self.change_request_to_skip_function_calling(request) 127 | force_last_tool_call = True 128 | else: 129 | allow_fallback = not request_globals.disable_tool_fallback 130 | await self.invoke_function_call(request, request_globals, plugin, response, 131 | allow_fallback=allow_fallback) 132 | return await self._process_non_stream_tool_calls(request, request_globals, plugin, 133 | num_tool_calls=num_tool_calls, 134 | force_last_tool_call=force_last_tool_call) 135 | elif isinstance(response, ChatCompletionResponse): 136 | request.messages.append(response.choices[0].message) 137 | return response 138 | 139 | raise LLMException(error=Error(code=500, message="Invalid function call")) 140 | 141 | async def _process_stream_tool_calls( 142 | self, 143 | request: CreateChatCompletionRequest, 144 | request_globals: RequestGlobals, 145 | plugin: Plugin, 146 | num_tool_calls=0, 147 | force_last_tool_call=False 148 | ): 149 | response = await self.select_tool_from_llm_or_user(request, request_globals) 150 | if isinstance(response, ToolCalls): 151 | # emulate async generation of tool call when tool call is provided by user 152 | response = async_iterable([response]) 153 | 154 | async for item in response: 155 | if not isinstance(item, ToolCalls): 156 | yield item 157 | continue 158 | # call the tool, make a completion request, and yield the LLM response here 159 | num_tool_calls += 1 160 | if force_last_tool_call: 161 | raise LLMException(error=Error(code=500, message=f"Exceeded max function call attempts")) 162 | 163 | if num_tool_calls > MAX_NUM_CALLS: 164 | # try to generate a response without tool calls as a last attempt 165 | await self.change_request_to_skip_function_calling(request) 166 | force_last_tool_call = True 167 | else: 168 | # now process the tool call 169 | allow_fallback = not request_globals.disable_tool_fallback 170 | await self.invoke_function_call(request, request_globals, plugin, item, allow_fallback=allow_fallback) 171 | ret = self._process_stream_tool_calls(request, request_globals, plugin, 172 | num_tool_calls=num_tool_calls, 173 | force_last_tool_call=force_last_tool_call) 174 | async for ret_item in ret: 175 | yield ret_item 176 | 177 | 178 | async def select_tool_from_llm_or_user(self, request, request_globals): 179 | # if user requests particular tool request, call function directly without LLM interaction. 180 | if request.messages[-1].content.startswith(FUNCTION_CALL_TOKEN): 181 | last_message = request.messages[-1].content 182 | # tool call request and response are added again after invocation 183 | request.messages.pop() 184 | response = parse_tool_call(last_message) 185 | else: 186 | response = await self.select_tool(request, request_globals) 187 | return response 188 | 189 | async def invoke_function_call(self, request: CreateChatCompletionRequest, 190 | request_globals: RequestGlobals, 191 | plugin, 192 | tool_calls: ToolCalls, 193 | allow_fallback: bool = False): 194 | invoked_tool_calls = [] 195 | call_tasks = [] 196 | tool_call_content = '' 197 | for tool_call in tool_calls.tool_calls: 198 | invoke_function_name = tool_call.function_name 199 | if not tool_call.function_name or not tool_call.function_arguments: 200 | logger.error("Invalid function call: " + str(tool_calls)) 201 | continue 202 | tool_call_content += "" + json.dumps({"name": tool_call.function_name, "arguments": json.loads(tool_call.function_arguments)}) + '\n' 203 | call_tasks.append(plugin.invoke( 204 | function_name = invoke_function_name, 205 | function_arguments = json.loads(tool_call.function_arguments), 206 | token = request_globals.diffbot_token) 207 | ) 208 | invoked_tool_calls.append(tool_call) 209 | plugin_responses = await asyncio.gather(*call_tasks) 210 | if tool_calls.content: 211 | request.messages.append(ChatCompletionRequestMessage( 212 | role=Role.assistant, 213 | content=tool_calls.content 214 | )) 215 | elif tool_call_content: 216 | request.messages.append(ChatCompletionRequestMessage( 217 | role=Role.assistant, 218 | content=tool_call_content.strip() 219 | )) 220 | 221 | request_with_call = request.copy(deep=True) 222 | 223 | request.messages.append(ChatCompletionRequestMessage( 224 | role=Role.user, 225 | tool_call_id="", 226 | content="" 227 | )) 228 | 229 | tool_responses = [] 230 | for index, plugin_response in enumerate(plugin_responses): 231 | tool_call = invoked_tool_calls[index] 232 | if plugin_response is None: 233 | continue 234 | plugin_content = plugin_response.dict().get('content', {}) 235 | tool_responses.append(plugin_content) 236 | request.messages[-1].tool_call_id = tool_call.tool_call_id # use last tool_call_id 237 | request.messages[-1].tool_call_id = request.messages[-1].tool_call_id.strip() 238 | 239 | request_globals.diffbot_responses.append({ 240 | "diffbot_request": { 241 | "name": tool_call.function_name, 242 | "arguments": json.loads(tool_call.function_arguments), 243 | }, 244 | "diffbot_response": plugin_content 245 | }) 246 | 247 | # tool message should be a valid json so that it can be truncated as a json later 248 | if tool_responses: 249 | request.messages[-1].content = json.dumps(tool_responses) 250 | 251 | # now we'll check if results are good enough. if they are not, fallback to web_search. 252 | 253 | if not allow_fallback: 254 | return 255 | 256 | tool_call_names = [tool_call.function_name for tool_call in invoked_tool_calls] 257 | if "web_search_v1" in tool_call_names or "dql_v1" not in tool_call_names: 258 | # fallback is for dql tool calls only for now 259 | return 260 | 261 | results_are_good = await self.verify_results(invoked_tool_calls, plugin_responses) 262 | if results_are_good: 263 | return 264 | 265 | start_fallback = time.time() 266 | fallback_request = request_with_call 267 | original_calls = request_with_call.messages[-1].content 268 | fallback_request.messages.append( 269 | ChatCompletionRequestMessage( 270 | content="""This function call does not return satisfactory results. Try again, now using web_search_v1. Your answer MUST start with: {"name": "web_search_v1", ...""", 271 | role="user") 272 | ) 273 | fallback_request.stream=False 274 | # Fallback request should have the same prefix as original request to leverage vLLM's prefix caching. 275 | response = await self.select_tool(fallback_request, request_globals) 276 | logger.info(f"fallback function call request completed", 277 | extra={ 278 | "fallback_request_time": round((time.time() - start_fallback) * 1000), 279 | "llm_response": str(response), 280 | "original_calls": original_calls, 281 | "is_tool_call_response": isinstance(response, ToolCalls) 282 | }) 283 | 284 | if isinstance(response, ToolCalls) and response.tool_calls and response.tool_calls[0].function_name == "web_search_v1": 285 | # remove last dql response 286 | request_globals.diffbot_responses.pop() 287 | # remove last tool call response 288 | request.messages.pop() 289 | # remove last tool call message 290 | tool_call_message = request.messages.pop() 291 | # add extra tokens before back 292 | idx = tool_call_message.content.find(FUNCTION_CALL_TOKEN) 293 | if idx > 0 : 294 | extra_tokens = tool_call_message.content[:idx] 295 | response.content = extra_tokens + " " + response.content 296 | return await self.invoke_function_call(request, request_globals, plugin, response, allow_fallback=False) 297 | 298 | return 299 | 300 | async def verify_results(self, tool_calls: list[ToolCall], plugin_responses: list[DiffbotAPIResponse | DQLResponse]): 301 | for call, resp in zip(tool_calls, plugin_responses): 302 | if call.function_name == "dql_v1" and ( 303 | resp.status != 200 or 304 | resp.dict().get("content", {}).get('status') != 200 or 305 | len(resp.dict().get("content", {}).get('data', [])) == 0): 306 | return False 307 | return True 308 | 309 | 310 | async def change_request_to_skip_function_calling(self, tool_call_request): 311 | tool_call_request.tool_choice = None 312 | tool_call_request.tools = None 313 | for message in tool_call_request.messages: 314 | if message.role == Role.system: 315 | message.content = self.remove_system_prompt(message.content) 316 | -------------------------------------------------------------------------------- /models/api.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Any, Union 3 | from enum import Enum 4 | 5 | class ResponseModel(BaseModel): 6 | status: int = 200 7 | message: str = None 8 | 9 | class DiffbotAPIResponse(ResponseModel): 10 | url: str = "" 11 | type: str = "" 12 | title: str = "" 13 | data: Any = None 14 | dql_time: float = 0.0 15 | diffbotapi_time: float = 0.0 16 | webindex_time: float = 0.0 17 | 18 | class QueryResponse(ResponseModel): 19 | dql_query: str = None 20 | hits: int = None 21 | data: Any = None 22 | articles: Any = None 23 | instruction: str = None 24 | page_url: str = None 25 | 26 | class ExtractionResponse(ResponseModel): 27 | data: Any = None 28 | 29 | class SearchResponse(ResponseModel): 30 | query: list[str] = None 31 | search_results: Any = None 32 | instructions: str = None 33 | 34 | class DQLRequest(BaseModel): 35 | size: int = 10 36 | type: str = "query" 37 | query: str = "" 38 | 39 | class DQLResponse(ResponseModel): 40 | query: str = "" 41 | type: str = "" 42 | hits: int = 0 43 | data: Any = None 44 | page_url: str = None 45 | 46 | class WebSearchRequest(BaseModel): 47 | text: str 48 | 49 | class WebSearchResult(BaseModel): 50 | title: str = None 51 | url: str = None 52 | snippet: str = None 53 | 54 | class DiffbotException(Exception): 55 | status_code: int = 200 56 | detail: str 57 | 58 | class DiffbotAPIException(DiffbotException): 59 | page_url: str = None 60 | 61 | class DQLException(DiffbotException): 62 | dql: str = None 63 | 64 | 65 | class JSExecutionRequest(BaseModel): 66 | expression: str = None 67 | 68 | class JSExecutionResponse(ResponseModel): 69 | expression: Any = None 70 | data: Any = None 71 | 72 | class JSExecutionException(DiffbotException): 73 | expression: str = None 74 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | def cleanNullValues(json_obj): 2 | if isinstance(json_obj, dict) and json_obj: 3 | clean = {} 4 | for k, v in json_obj.items(): 5 | nested = cleanNullValues(v) 6 | if nested is not None: 7 | clean[k] = nested 8 | return clean 9 | elif isinstance(json_obj, list) and json_obj: 10 | return [cleanNullValues(v_) for v_ in json_obj if v_ is not None] 11 | elif json_obj is not None: 12 | return json_obj 13 | 14 | return None 15 | 16 | def truncate_long_strings(data, max_string_length=10_000): 17 | """Recursively truncate all strings to max_string_length""" 18 | if isinstance(data, dict): 19 | for k, v in data.items(): 20 | data[k] = truncate_long_strings(v, max_string_length) 21 | elif isinstance(data, list): 22 | for i, v in enumerate(data): 23 | data[i] = truncate_long_strings(v, max_string_length) 24 | elif isinstance(data, str) and len(data)>max_string_length: 25 | return data[:max_string_length] 26 | return data 27 | 28 | 29 | def truncate_long_arrays(data, max_array_length=30): 30 | """Recursively truncate all arrays to max_array_length""" 31 | if isinstance(data, dict): 32 | for k, v in data.items(): 33 | data[k] = truncate_long_arrays(v) 34 | elif isinstance(data, list): 35 | if len(data) > max_array_length: 36 | data = data[:max_array_length] 37 | for i, v in enumerate(data): 38 | data[i] = truncate_long_arrays(v) 39 | return data 40 | 41 | def truncate_data_dfs(data, max_length=100_000): 42 | """Truncate data to fit max_length""" 43 | total_length = 0 44 | if isinstance(data, dict): 45 | for k, v in data.items(): 46 | if total_length > max_length: 47 | data[k] = None 48 | else: 49 | data[k], length = truncate_data_dfs(v, max_length=max_length - total_length) 50 | total_length += length 51 | return data, total_length 52 | 53 | if isinstance(data, list): 54 | for i, v in enumerate(data): 55 | if total_length > max_length: 56 | data[i] = None 57 | else: 58 | data[i], length = truncate_data_dfs(v, max_length=max_length - total_length) 59 | total_length += length 60 | return data, total_length 61 | 62 | if isinstance(data, str): 63 | data_str = data 64 | else: 65 | data_str = str(data) 66 | return data, len(data_str) 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "diffbot-llm-inference" 3 | version = "0.0.1" 4 | description = "Diffbot LLM Inference Server" 5 | authors = ["Diffbot "] 6 | readme = "README.md" 7 | packages = [{include = "server"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.10,<3.11" 11 | fastapi = "^0.92.0" 12 | uvicorn = "^0.20.0" 13 | openai = "^1.11.1" 14 | python-dotenv = "^0.21.1" 15 | pydantic = "^1.10.5" 16 | tiktoken = "^0.5.2" 17 | numpy = "^1.24.2" 18 | pyyaml = "^6.0" 19 | python-logstash = "^0.4.8" 20 | httpx = "^0.24.1" 21 | aiohttp = "^3.8.6" 22 | setuptools = "^70.3.0" 23 | unidecode = "^1.3.8" 24 | fastapi-utils ="^0.7.0" 25 | diskcache = "^5.6.3" 26 | nodejs="^0.1.1" 27 | 28 | [tool.poetry.scripts] 29 | start = "server.main:start" 30 | 31 | [tool.poetry.group.dev.dependencies] 32 | pytest = "^7.2.1" 33 | pytest-cov = "^4.0.0" 34 | pytest-asyncio = "^0.20.3" 35 | poetry = "^1.8.2" 36 | poetry-plugin-export = "^1.7.1" 37 | 38 | [build-system] 39 | requires = ["poetry-core"] 40 | build-backend = "poetry.core.masonry.api" 41 | 42 | [tool.pytest.ini_options] 43 | pythonpath = [ 44 | "tests", 45 | ] 46 | python_files = [ 47 | "test_*.py" 48 | ] 49 | asyncio_mode="auto" 50 | -------------------------------------------------------------------------------- /ranking/README.md: -------------------------------------------------------------------------------- 1 | ## wp_1gram_top1m.txt.gz 2 | 3 | Top 1M unigrams by frequency from Wikipedia. Original file (8M unigrams, 31MB) from: https://nlp.cs.nyu.edu/wikipedia-data/ 4 | -------------------------------------------------------------------------------- /ranking/rank_bm25.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import numpy as np 5 | 6 | """ 7 | Initial code from: https://github.com/dorianbrown/rank_bm25 8 | We've added support for external frequency counts. 9 | """ 10 | 11 | class BM25: 12 | def __init__(self, external_word_freq: dict[str,int]=None): 13 | self.idf = {} 14 | self.highest_idf = 0 15 | self._calc_idf(word_freq=external_word_freq) 16 | 17 | def _calc_idf(self, word_freq: dict[str,int]): 18 | raise NotImplementedError() 19 | 20 | def get_scores(self, query, corpus, max_term_frequency=0, recalculate_idf=False, explain=False): 21 | raise NotImplementedError() 22 | 23 | class BM25Okapi(BM25): 24 | def __init__(self, 25 | k1=1.5, # controls how document frequency influences score. k1=0 means zero influence. 26 | b=0.1, # controls how document length influences score. b=0 means zero influence. 27 | epsilon=0.0, # assigns an idf value for words that have negative idf 28 | external_word_freq=None): 29 | self.k1 = k1 30 | self.b = b 31 | self.epsilon = epsilon 32 | super().__init__(external_word_freq=external_word_freq) 33 | 34 | def _calc_idf(self, word_freq: dict[str,int]): 35 | """ 36 | Calculates idf based on frequency dict 37 | """ 38 | # collect idf sum to calculate an average idf for epsilon value 39 | idf_sum = 0 40 | # collect words with negative idf to set them a special epsilon value. 41 | # idf can be negative if word is contained in more than half of documents 42 | negative_idfs = [] 43 | fake_corpus_size = max(word_freq.values()) * 2 44 | for word, freq in word_freq.items(): 45 | idf = math.log(fake_corpus_size - freq + 0.5) - math.log(freq + 0.5) 46 | self.idf[word] = idf 47 | idf_sum += idf 48 | if idf < 0: 49 | negative_idfs.append(word) 50 | self.average_idf = idf_sum / len(self.idf) 51 | 52 | eps = self.epsilon * self.average_idf 53 | for word in negative_idfs: 54 | self.idf[word] = eps 55 | self.highest_idf = max(self.idf.values()) 56 | 57 | def get_scores(self, query, corpus, max_term_frequency=0, recalculate_idf=False, explain: dict = None): 58 | """ 59 | The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores, 60 | this algorithm also adds a floor to the idf value of epsilon. 61 | See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info 62 | """ 63 | corpus_size = 0 64 | doc_freqs = [] 65 | doc_len = [] 66 | 67 | nd = {} # word -> number of documents with word 68 | num_doc = 0 69 | for document in corpus: 70 | doc_len.append(len(document)) 71 | num_doc += len(document) 72 | 73 | frequencies = {} 74 | for word in document: 75 | if word not in frequencies: 76 | frequencies[word] = 0 77 | if max_term_frequency > 0: 78 | frequencies[word] = min(max_term_frequency, frequencies[word] + 1) 79 | else: 80 | frequencies[word] += 1 81 | doc_freqs.append(frequencies) 82 | 83 | for word, freq in frequencies.items(): 84 | nd[word] = nd.get(word, 0) + 1 85 | corpus_size += 1 86 | 87 | avg_doc_len = num_doc / corpus_size 88 | 89 | score = np.zeros(corpus_size) 90 | doc_len = np.array(doc_len) 91 | for q in query: 92 | q_freq = np.array([(doc.get(q) or 0) for doc in doc_freqs]) 93 | score += self.get_idf(q) * (q_freq * (self.k1 + 1) / 94 | (q_freq + self.k1 * (1 - self.b + self.b * doc_len / avg_doc_len))) 95 | if explain is not None: 96 | for idx, doc in enumerate(doc_freqs): 97 | if doc.get(q): 98 | if not explain.get(idx): 99 | explain[idx] = {} 100 | explain[idx][q] = self.get_idf(q) * (q_freq[idx] * (self.k1 + 1) / 101 | (q_freq[idx] + self.k1 * (1 - self.b + self.b * doc_len[idx] / avg_doc_len))) 102 | 103 | return score 104 | 105 | def get_idf(self, q): 106 | return self.idf.get(q, self.highest_idf) -------------------------------------------------------------------------------- /ranking/wp_1gram_top1m.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/ranking/wp_1gram_top1m.txt.gz -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pytest -k 'not test_evaluation.py' 4 | -------------------------------------------------------------------------------- /server/log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | import logging 4 | from logstash.formatter import LogstashFormatterVersion1 5 | 6 | message_type="diffbot-llm" 7 | 8 | class LogStashCustomFormatter(LogstashFormatterVersion1): 9 | def format(self, record): 10 | record.path = None # clear the path information 11 | return super().format(record) 12 | 13 | class ConsoleCustomFormatter(LogStashCustomFormatter): 14 | def format(self, record): 15 | formatted_record = super().format(record).decode('utf-8') 16 | if record.exc_info: 17 | exception_string = ''.join(traceback.format_exception(*record.exc_info)) 18 | print(exception_string, file=sys.stderr) 19 | 20 | return formatted_record 21 | 22 | console_handler = logging.StreamHandler(sys.stdout) 23 | console_handler.setFormatter(ConsoleCustomFormatter(message_type=message_type)) 24 | 25 | handler = console_handler 26 | loggers = {} 27 | def get_logstash_logger(logger_name:str = None): 28 | if logger_name is None or logger_name == '': 29 | logger_name = 'diffbot-llm' 30 | 31 | if logger_name not in loggers: 32 | logger = logging.getLogger(logger_name) 33 | logger.setLevel(logging.INFO) 34 | logger.addHandler(handler) 35 | loggers[logger_name] = logger 36 | 37 | return loggers[logger_name] -------------------------------------------------------------------------------- /server/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import uvicorn 4 | from fastapi import FastAPI, Request, Response 5 | from fastapi.exception_handlers import RequestValidationError 6 | from fastapi.exceptions import ValidationError 7 | from starlette.responses import JSONResponse 8 | 9 | from server.rag_router import rag_router 10 | 11 | from config import get_config 12 | 13 | config = get_config() 14 | server_url = config.get_server_url() 15 | 16 | app = FastAPI( 17 | openapi_url=None, 18 | docs_url=None, 19 | redoc_url=None, 20 | servers=[{"url": server_url}] 21 | ) 22 | 23 | # handle CORS preflight requests 24 | @app.options('/{rest_of_path:path}') 25 | async def preflight_handler(request: Request, rest_of_path: str) -> Response: 26 | response = Response() 27 | response.headers['Access-Control-Allow-Origin'] = '*' 28 | response.headers['Access-Control-Allow-Methods'] = '*' 29 | response.headers['Access-Control-Allow-Headers'] = '*' 30 | return response 31 | 32 | # set CORS headers 33 | @app.middleware("http") 34 | async def add_CORS_header(request: Request, call_next): 35 | response = await call_next(request) 36 | response.headers['Access-Control-Allow-Origin'] = '*' 37 | response.headers['Access-Control-Allow-Methods'] = '*' 38 | response.headers['Access-Control-Allow-Headers'] = '*' 39 | return response 40 | 41 | app.include_router(rag_router) 42 | 43 | from server.log import get_logstash_logger 44 | logger = get_logstash_logger("diffbot_llm_api") 45 | 46 | @app.exception_handler(RequestValidationError) 47 | async def validation_exception_handler(request: Request, exc: RequestValidationError): 48 | logger.error(f"Validation error: {exc.errors()}") 49 | return JSONResponse( 50 | status_code=422, 51 | content={"detail": exc.errors()}, 52 | ) 53 | 54 | @app.exception_handler(ValidationError) 55 | async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): 56 | logger.error(f"Pydantic validation error: {exc.errors()}") 57 | return JSONResponse( 58 | status_code=422, 59 | content={"detail": exc.errors()}, 60 | ) 61 | 62 | def start(): 63 | uvicorn.run("server.main:app", host="0.0.0.0", port=3333, reload=True) -------------------------------------------------------------------------------- /server/rag_router.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | from fastapi import APIRouter, HTTPException, Body, Depends 5 | from fastapi.responses import StreamingResponse 6 | from fastapi.security import OAuth2PasswordBearer 7 | 8 | from pydantic import BaseModel, ValidationError 9 | from models.utils import cleanNullValues 10 | from server.log import get_logstash_logger 11 | from services.kg_rag_service import KgRagService, get_kg_rag_service 12 | from llm.api_models import ChatCompletionStreamResponse, CreateChatCompletionRequest, \ 13 | ErrorResponse, LLMException, RequestGlobals 14 | from llm.llms import ModelID 15 | 16 | rag_router = APIRouter( 17 | prefix="/rag", 18 | tags=['rag'] 19 | ) 20 | 21 | logger = get_logstash_logger("diffbot_llm_api") 22 | 23 | oauth2_shceme = OAuth2PasswordBearer(tokenUrl="token") 24 | 25 | async def authentication(token: str = Depends(oauth2_shceme)): 26 | return token 27 | 28 | # OpenAI api compatible: https://platform.openai.com/docs/api-reference/chat/create 29 | @rag_router.post("/v1/chat/completions") 30 | async def chat_completions( 31 | diffbot_token: str = Depends(authentication), 32 | request: CreateChatCompletionRequest = Body(...), 33 | kg_rag_service: KgRagService = Depends(get_kg_rag_service) 34 | ): 35 | 36 | try: 37 | request_globals = RequestGlobals() 38 | request_globals.diffbot_token = diffbot_token 39 | request_globals.timings["start_time"] = time.time() 40 | 41 | # filter out empty messages 42 | request.messages = [msg for msg in request.messages if msg.content] 43 | 44 | # support new text/image schema (kindof, we'll ignore images for now) 45 | for msg in request.messages: 46 | if type(msg.content) is list and 'text' in msg.content[0]: 47 | msg.content = msg.content[0]['text'] 48 | 49 | llm_request = cleanNullValues(request.copy(deep=True).dict()) 50 | llm_request = CreateChatCompletionRequest(**llm_request) 51 | 52 | result = await kg_rag_service.chat_completions( 53 | request=llm_request, 54 | request_globals=request_globals, 55 | ) 56 | 57 | if request.stream: 58 | return StreamingResponse( 59 | stream_generator_with_usage_logging(result, diffbot_token, request_globals), 60 | media_type="text/event-stream" 61 | ) 62 | else: 63 | return result 64 | except ValidationError as e: 65 | raise HTTPException(status_code=422, detail=e.errors()) 66 | except LLMException as e: 67 | logger.info(f"LLM exception: {e}", extra={"token": diffbot_token}, exc_info=True) 68 | return ErrorResponse(error=e.error) 69 | except Exception as e: 70 | logger.error(f"Error in chat completions. Exception: {e}", extra={"token": diffbot_token}, exc_info=True) 71 | return ErrorResponse(error=e) 72 | 73 | async def stream_generator_with_usage_logging(result, diffbot_token, request_globals: RequestGlobals): 74 | async for item in result: 75 | if isinstance(item, BaseModel): 76 | yield f'data: {json.dumps(item.dict())}\n\n' 77 | 78 | if ("time_to_stream" not in request_globals.timings and isinstance(item, ChatCompletionStreamResponse) 79 | and item.choices and len(item.choices) > 0 and item.choices[0].delta 80 | and item.choices[0].delta.content and not item.choices[0].delta.content.isspace()): 81 | request_globals.timings["time_to_stream"] = (time.time() - request_globals.timings["start_time"]) * 1000 82 | 83 | @rag_router.get("/v1/models") 84 | def supported_models(): 85 | data = [ 86 | { 87 | "id": ModelID.DIFFBOT_SMALL, 88 | "object": "model", 89 | "owned_by": "diffbot" 90 | }, 91 | { 92 | "id": ModelID.DIFFBOT_SMALL_XL, 93 | "object": "model", 94 | "owned_by": "diffbot" 95 | } 96 | ] 97 | body = { 98 | "object": "list", 99 | "data": data 100 | } 101 | return body 102 | 103 | -------------------------------------------------------------------------------- /services/execute_js.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import asyncio 3 | import json 4 | import os 5 | import tempfile 6 | from typing import Dict, Any 7 | 8 | from models.api import JSExecutionResponse 9 | 10 | 11 | class JSExecutionService(ABC): 12 | 13 | def __init__(self) -> None: 14 | super().__init__() 15 | 16 | async def execute_js_code(self, js_code: str, timeout: int) -> Dict[str, Any]: 17 | # Create temporary file for the JS code 18 | with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f: 19 | # Disable potentially harmful functions by replacing with empty objects 20 | safe_environment = """ 21 | var require = {}; 22 | var process = {}; 23 | var global = {}; 24 | """ 25 | 26 | # Wrap the code to capture console.log output and the final result 27 | wrapped_code = f""" 28 | try {{ 29 | var logs = []; 30 | _consolelog = console.log; 31 | console.log = function(...args) {{logs.push(args.map(String).join(" "));}}; 32 | var result = eval({repr(js_code)}); 33 | _consolelog(JSON.stringify({{logs: logs, result: result}})); 34 | }} catch (error) {{ 35 | _consolelog(JSON.stringify({{error: error.toString()}})); 36 | }} 37 | """ 38 | 39 | # Write the complete code to the temp file 40 | f.write(f""" 41 | {safe_environment} 42 | {wrapped_code} 43 | """) 44 | temp_path = f.name 45 | 46 | try: 47 | # Create subprocess to run Node.js 48 | process = await asyncio.create_subprocess_exec( 49 | 'node', temp_path, 50 | stdout=asyncio.subprocess.PIPE, 51 | stderr=asyncio.subprocess.PIPE 52 | ) 53 | 54 | try: 55 | # Wait for the process with timeout 56 | stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) 57 | except asyncio.TimeoutError: 58 | process.kill() 59 | raise Exception("JavaScript execution timed out!") 60 | 61 | # Clean up temp file 62 | os.unlink(temp_path) 63 | 64 | if stderr: 65 | stderr_text = stderr.decode().strip() 66 | if stderr_text: 67 | raise Exception(f"JavaScript error: {stderr_text}") 68 | 69 | # Parse the output 70 | output = stdout.decode().strip() 71 | if not output: 72 | raise Exception("No output from JavaScript code") 73 | 74 | try: 75 | result = json.loads(output) 76 | if "error" in result: 77 | raise Exception(result["error"]) 78 | return {"logs": result.get("logs", []), "result": result.get("result")} 79 | except json.JSONDecodeError: 80 | raise Exception(f"Invalid JSON output: {output}") 81 | 82 | except Exception as e: 83 | # Clean up temp file in case of error 84 | if os.path.exists(temp_path): 85 | os.unlink(temp_path) 86 | raise e 87 | 88 | def clean_code_block(self, code: str) -> str: 89 | if code.startswith('```javascript') and code.endswith('```'): 90 | code = code[len('```javascript'): -3].strip() 91 | elif code.startswith('```') and code.endswith('```'): 92 | code = code[3:-3].strip() 93 | if code.startswith(''): 94 | code = code[len('')].strip() 95 | return code 96 | 97 | async def execute_js(self, expressions: str = None, timeout: int = 2) -> JSExecutionResponse: 98 | try: 99 | if not expressions: 100 | raise Exception("No JavaScript code provided") 101 | 102 | result = await self.execute_js_code(self.clean_code_block(expressions), timeout) 103 | 104 | if not result: 105 | raise Exception("No result returned from JavaScript code") 106 | 107 | return JSExecutionResponse(expression=expressions, status=200, message="Success", data=result) 108 | 109 | except Exception as e: 110 | return JSExecutionResponse(expression=expressions, status=500, message=str(e)) 111 | 112 | 113 | js_execution_service = None 114 | 115 | 116 | def get_js_execution_service() -> JSExecutionService: 117 | global js_execution_service 118 | if js_execution_service is None: 119 | js_execution_service = JSExecutionService() 120 | return js_execution_service 121 | -------------------------------------------------------------------------------- /services/kg_rag_service.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, AsyncIterator 2 | 3 | from llm.diffbot_tool_call_llm import DiffbotToolCallLLM, get_diffbot_tool_call_llm 4 | from llm.api_models import CompletionUsage, CreateChatCompletionRequest, ChatCompletionResponse, \ 5 | ChatCompletionStreamResponse, LLMException, Error, ChatCompletionRequestMessage, RequestGlobals 6 | from llm.llms import ModelID, LLMS, LLM, Role 7 | from llm.tool_call_llm import ToolCallLLM 8 | from llm.plugin import Plugin, get_plugin 9 | 10 | class KgRagService: 11 | """ 12 | Service to rag against Diffbot Knowledge Graph. 13 | """ 14 | def __init__(self) -> None: 15 | self.diffbot_models = {ModelID.DIFFBOT_SMALL, ModelID.DIFFBOT_SMALL_XL} 16 | 17 | @classmethod 18 | def get_llm(cls, model_id: ModelID) -> LLM: 19 | if DiffbotToolCallLLM.is_supported(model_id): 20 | return get_diffbot_tool_call_llm() 21 | else: 22 | raise LLMException(error=Error(code=429, message=f"unsupported model: {model_id}", param=None, type=None)) 23 | 24 | @classmethod 25 | def get_tool_call_llm(cls, model_id: ModelID) -> ToolCallLLM: 26 | if DiffbotToolCallLLM.is_supported(model_id): 27 | return get_diffbot_tool_call_llm() 28 | else: 29 | raise LLMException(error=Error(code=429, message=f"unsupported tool call model: {model_id}", param=None, type=None)) 30 | 31 | async def chat_completions(self, 32 | request: CreateChatCompletionRequest, 33 | plugin: Optional[Plugin] = get_plugin(), 34 | request_globals: RequestGlobals = RequestGlobals(), 35 | ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: 36 | model_id = ModelID.get_model_id(request.model) 37 | request.model = LLMS[model_id].model 38 | llm = self.get_llm(model_id) 39 | 40 | if (model_id not in self.diffbot_models) or (request.tool_choice == "none"): 41 | # only diffbot models support tool calling 42 | print(f"calling {request.model} without tool calling.") 43 | if llm: 44 | request.tool_choice = None 45 | return await llm.chat_completion(request=request, request_globals=request_globals) 46 | else: 47 | raise LLMException(error=Error(code=429, message=f"llm not initialized: {request.model}", param=None, type=None)) 48 | else: 49 | # Using a "diffbot-" model 50 | 51 | print(f"calling {request.model} with tool calling") 52 | 53 | if request.tool_choice is None: 54 | request.tool_choice = "auto" 55 | 56 | request.tools = plugin.get_tools() #TODO: call to llm.diffbot.com and cache? 57 | tool_call_llm = self.get_tool_call_llm(model_id) 58 | llm_result = await tool_call_llm.process_tool_calls(request, plugin, request_globals) 59 | return llm_result 60 | 61 | async def combine_stream(non_stream_result, stream_result): 62 | if non_stream_result: 63 | yield non_stream_result 64 | async for item in stream_result: 65 | yield item 66 | 67 | async def skip_diffbot_responses(llm_result): 68 | async for item in llm_result: 69 | if isinstance(item, ChatCompletionStreamResponse): 70 | yield item 71 | 72 | kg_rag_service = None 73 | def get_kg_rag_service() -> KgRagService: 74 | global kg_rag_service 75 | if kg_rag_service is None: 76 | kg_rag_service = KgRagService() 77 | 78 | return kg_rag_service -------------------------------------------------------------------------------- /start_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | uvicorn server.main:app --host 0.0.0.0 --port 8001 --workers 4 4 | -------------------------------------------------------------------------------- /static/babyshark.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/babyshark.webp -------------------------------------------------------------------------------- /static/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/demo.png -------------------------------------------------------------------------------- /static/extract.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/extract.webp -------------------------------------------------------------------------------- /static/faa.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/faa.webp -------------------------------------------------------------------------------- /static/freshqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/freshqa.png -------------------------------------------------------------------------------- /static/math.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/math.webp -------------------------------------------------------------------------------- /static/newjersey.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/newjersey.webp -------------------------------------------------------------------------------- /static/simpleqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/simpleqa.png -------------------------------------------------------------------------------- /static/strawberry.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/strawberry.webp -------------------------------------------------------------------------------- /static/weather.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diffbot/diffbot-llm-inference/5a72b8d86335b5c0192395e5b0afe51f497016b9/static/weather.webp -------------------------------------------------------------------------------- /supervisord.conf: -------------------------------------------------------------------------------- 1 | [supervisord] 2 | nodaemon=true 3 | 4 | [program:diffbot-llm-inference] 5 | command=/bin/bash -c "cd /code && poetry run sh ./start_server.sh 2>&1 | tee -a /var/log/diffbot.log | cat"; 6 | autostart=true 7 | autorestart=true 8 | stdout_logfile=/dev/stdout 9 | stderr_logfile=/dev/stderr 10 | stdout_logfile_maxbytes=0 11 | stderr_logfile_maxbytes=0 12 | 13 | [program:vllm] 14 | command=/bin/bash -c "python3 -m vllm.entrypoints.openai.api_server ${VLLM_OPTIONS} 2>&1 | tee -a /var/log/vllm.log | cat" 15 | autostart=true 16 | autorestart=true 17 | stdout_logfile=/dev/stdout 18 | stderr_logfile=/dev/stderr 19 | stdout_logfile_maxbytes=0 20 | stderr_logfile_maxbytes=0 -------------------------------------------------------------------------------- /system_prompt.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant with access to the following functions. Use them if required - 2 | namespace Diffbot { 3 | // Extract the content from the given URLs. Only call this endpoint if the user mentioned a URL. 4 | type extract_v1 = (_: { 5 | // URLs to extract, up to 5 6 | page_url: string[], 7 | }) => any; 8 | // Query the Diffbot Knowledge Graph for an entity or set of entities that match a set of criteria using the Diffbot Query Language syntax. 9 | type dql_v1 = (_: { 10 | // Diffbot Query Language query 11 | dql_query: string, 12 | }) => any; 13 | // Search the web for information that could help answer the user's question. 14 | type web_search_v1 = (_: { 15 | // List of Google advanced search strings (can include phrases, booleans, site:, before:, after:, filetype:, etc) 16 | text: string[], 17 | // Number of results to return (default 5) 18 | num?: number, 19 | // Page number of results to return (default 1) 20 | page?: number, 21 | }) => any; 22 | } // namespace Diffbot 23 | -------------------------------------------------------------------------------- /system_prompt_with_js.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant with access to the following functions. Use them if required - 2 | namespace Diffbot { 3 | // Extract the content from the given URLs. Only call this endpoint if the user mentioned a URL. 4 | type extract_v1 = (_: { 5 | // URLs to extract, up to 5 6 | page_url: string[], 7 | }) => any; 8 | // Query the Diffbot Knowledge Graph for an entity or set of entities that match a set of criteria using the Diffbot Query Language syntax. 9 | type dql_v1 = (_: { 10 | // Diffbot Query Language query 11 | dql_query: string, 12 | }) => any; 13 | // Search the web for information that could help answer the user's question. 14 | type web_search_v1 = (_: { 15 | // List of Google advanced search strings (can include phrases, booleans, site:, before:, after:, filetype:, etc) 16 | text: string[], 17 | // Number of results to return (default 5) 18 | num?: number, 19 | // Page number of results to return (default 1) 20 | page?: number, 21 | }) => any; 22 | // Execute JavaScript expressions and get accurate results that could help answer the user's question. 23 | type execute_js_v1 = (_: { 24 | // JavaScript expressions to execute separated by newlines 25 | expressions: string, 26 | }) => any; 27 | } // namespace Diffbot 28 | -------------------------------------------------------------------------------- /tests/test_end2end.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import time 3 | import threading 4 | import uvicorn 5 | import pytest 6 | from openai import OpenAI 7 | import os 8 | import sys 9 | 10 | DIFFBOT_TOKEN = os.environ.get("DIFFBOT_TOKEN", "") 11 | 12 | if not DIFFBOT_TOKEN: 13 | print("Set DIFFBOT_TOKEN environment variable before running tests.") 14 | sys.exit(1) 15 | 16 | class Server(uvicorn.Server): 17 | def install_signal_handlers(self): 18 | pass 19 | 20 | @contextlib.contextmanager 21 | def run_in_thread(self): 22 | thread = threading.Thread(target=self.run) 23 | thread.start() 24 | try: 25 | while not self.started: 26 | time.sleep(1e-3) 27 | yield 28 | finally: 29 | self.should_exit = True 30 | thread.join() 31 | 32 | @pytest.fixture(scope="session") 33 | def server(): 34 | config = uvicorn.Config("server.main:app", host="127.0.0.1", port=3334, log_level="info") 35 | server = Server(config=config) 36 | with server.run_in_thread(): 37 | yield 38 | 39 | @pytest.fixture(scope="session") 40 | def endpoint(): 41 | return "http://localhost:3334" 42 | 43 | def test_completion_with_tool_call_small(server, endpoint): 44 | client = OpenAI(api_key=DIFFBOT_TOKEN, base_url=endpoint + "/rag/v1") 45 | completion = client.chat.completions.create( 46 | model="diffbot-small", 47 | temperature=0, 48 | messages=[ 49 | {"role": "system", "content": "You are a helpful assistant."}, 50 | { 51 | "role": "user", 52 | "content": "Who is Nike's new CEO?" 53 | } 54 | ] 55 | ) 56 | print (completion) 57 | assert 'hill' in completion.choices[0].message.content.lower(), completion.choices[0].message.content 58 | 59 | def test_completion_with_tool_call_small_stream(server, endpoint): 60 | client = OpenAI(api_key=DIFFBOT_TOKEN, base_url=endpoint + "/rag/v1") 61 | completion = client.chat.completions.create( 62 | model="diffbot-small", 63 | temperature=0, 64 | stream=True, 65 | messages=[ 66 | {"role": "system", "content": "You are a helpful assistant."}, 67 | { 68 | "role": "user", 69 | "content": "Who is Nike's new CEO?" 70 | } 71 | ] 72 | ) 73 | response = "" 74 | for chunk in completion: 75 | if chunk.choices and len(chunk.choices) == 1 and chunk.choices[0].delta.content: 76 | response += chunk.choices[0].delta.content 77 | print (response) 78 | assert 'hill' in response.lower(), response 79 | 80 | def test_javascript(server, endpoint): 81 | client = OpenAI(api_key=DIFFBOT_TOKEN, base_url=endpoint + "/rag/v1") 82 | completion = client.chat.completions.create( 83 | model="diffbot-small", 84 | temperature=0, 85 | messages=[ 86 | {"role": "system", "content": "You are a helpful assistant"}, 87 | { 88 | "role": "user", 89 | "content": 'What is 3245134 * 3476?' 90 | }, { 91 | "role": "tool", 92 | "content": ' {"name": "execute_js_v1", "arguments": {"expressions": "result = 3245134 * 3476; console.log(result)"}}' 93 | } 94 | ] 95 | ) 96 | print(completion) 97 | answer = completion.choices[0].message.content.lower() 98 | answer = answer.replace(",","") 99 | assert '11280085784' in answer, completion.choices[0].message.content 100 | --------------------------------------------------------------------------------