├── src └── mlx_textgen │ ├── __init__.py │ ├── log_utils.py │ ├── rotating_kv_cache_patch.py │ ├── utils.py │ ├── cli.py │ ├── generation_utils.py │ ├── server.py │ ├── tokenizer_utils.py │ ├── chat_presets.py │ ├── chat_utils.py │ ├── sampling_utils.py │ ├── engine.py │ └── cache_utils.py ├── .gitignore ├── LICENSE.md ├── pyproject.toml └── README.md /src/mlx_textgen/__init__.py: -------------------------------------------------------------------------------- 1 | from .rotating_kv_cache_patch import * 2 | 3 | __version__ = '0.2.1' -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | bin/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # Installer logs 27 | pip-log.txt 28 | pip-delete-this-directory.txt 29 | 30 | # Unit test / coverage reports 31 | .tox/ 32 | .coverage 33 | .cache 34 | nosetests.xml 35 | coverage.xml 36 | 37 | # Translations 38 | *.mo 39 | 40 | # Mr Developer 41 | .mr.developer.cfg 42 | .project 43 | .pydevproject 44 | 45 | # Rope 46 | .ropeproject 47 | 48 | # Django stuff: 49 | *.log 50 | *.pot 51 | 52 | # Sphinx documentation 53 | docs/_build/ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Pok Hin Tam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mlx-textgen" 3 | description = "An OpenAI-compatible API LLM engine with smart prompt caching, batch processing, structured output with guided decoding, and function calling for all models using MLX." 4 | dynamic = ["version"] 5 | 6 | authors = [ 7 | {name = "Nathan Tam", email = "nathan1295@gmail.com"}, 8 | ] 9 | requires-python = ">=3.9" 10 | readme = "README.md" 11 | license = {text = "MIT"} 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: MIT License", 15 | "Operating System :: MacOS :: MacOS X", 16 | ] 17 | dependencies = [ 18 | "mlx-lm>=0.24.0", 19 | "mlx-vlm>=0.1.26", 20 | "outlines>=0.2.3", 21 | "fastapi", 22 | "uvicorn" 23 | ] 24 | 25 | [project.urls] 26 | Homepage = "https://github.com/nath1295/MLX-Textgen" 27 | 28 | [project.scripts] 29 | mlx_textgen = "mlx_textgen.cli:main" 30 | 31 | [build-system] 32 | requires = ["setuptools>=61.0", "wheel"] 33 | build-backend = "setuptools.build_meta" 34 | 35 | [tool.setuptools] 36 | py-modules = [] 37 | 38 | [tool.setuptools.dynamic] 39 | version = {attr = "mlx_textgen.__version__"} 40 | 41 | [tool.setuptools.packages.find] 42 | where = ["src"] -------------------------------------------------------------------------------- /src/mlx_textgen/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import logging.handlers 4 | 5 | def get_logger(name: str) -> logging.Logger: 6 | from .utils import get_package_cache_dir 7 | import os 8 | 9 | logging.basicConfig(format='[(%(levelname)s) %(asctime)s]: %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 10 | logger = logging.getLogger(name) 11 | logger.setLevel(logging.INFO) 12 | 13 | if not logger.handlers: 14 | file_dir = os.path.join(get_package_cache_dir(), 'logs', 'api.log') 15 | if not os.path.exists(os.path.dirname(file_dir)): 16 | os.makedirs(os.path.dirname(file_dir)) 17 | handler = logging.handlers.TimedRotatingFileHandler( 18 | filename=os.path.join(get_package_cache_dir(), 'logs', 'api.log'), 19 | when='H', 20 | interval=1, 21 | backupCount=96 22 | ) 23 | # Create a formatter to define the log entry format 24 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s - %(name)s - %(message)s') 25 | handler.setFormatter(formatter) 26 | 27 | console_handler = logging.StreamHandler(sys.stdout) # Use sys.stdout for standard output 28 | console_handler.setFormatter(formatter) 29 | 30 | # Add the custom handler to the logger 31 | logger.addHandler(handler) 32 | logger.addHandler(console_handler) 33 | return logger -------------------------------------------------------------------------------- /src/mlx_textgen/rotating_kv_cache_patch.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | from mlx_lm.models.cache import RotatingKVCache, KVCache # Import necessary classes 4 | from mlx.utils import tree_map 5 | 6 | 7 | def new_update_and_fetch(self, k, v): 8 | KVCache.update_and_fetch(self, k, v) 9 | self._idx = self.keys.shape[2] 10 | return self.state 11 | 12 | 13 | def new_state_getter(self): 14 | if self.offset <= self.max_size: 15 | return self.keys[..., :self.offset, :], self.values[..., :self.offset, :] 16 | elif self.keep: 17 | keys = mx.concat(self.keys[..., :self.keep, :], self.keys[..., (self.offset -(self.max_size - self.keep)):self.offset, :], axis=2) 18 | values = mx.concat(self.values[..., :self.keep, :], self.values[..., (self.offset -(self.max_size - self.keep)):self.offset, :], axis=2) 19 | return keys, values 20 | else: 21 | return self.keys[..., (self.offset - self.max_size):self.offset, :], self.values[..., (self.offset - self.max_size):self.offset, :] 22 | 23 | new_state_setter = KVCache.state.fset 24 | 25 | new_is_trimmable = KVCache.is_trimmable 26 | 27 | new_trim = KVCache.trim 28 | 29 | 30 | 31 | 32 | # --- Apply the Patches --- 33 | RotatingKVCache.update_and_fetch = new_update_and_fetch 34 | RotatingKVCache.state = property(new_state_getter, new_state_setter) 35 | RotatingKVCache.trim = new_trim 36 | RotatingKVCache.is_trimmable = new_is_trimmable 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/mlx_textgen/utils.py: -------------------------------------------------------------------------------- 1 | PACKAGE_NAME = 'mlx_textgen' 2 | 3 | def env_name() -> str: 4 | """Get the current python environment name. 5 | 6 | Returns: 7 | str: Current python environment name. 8 | """ 9 | import os 10 | import sys 11 | base = os.path.basename(sys.prefix) 12 | if base.lower() == 'anaconda3': 13 | return 'base' 14 | elif 'python3' in base.lower(): 15 | return 'base' 16 | else: 17 | return base 18 | 19 | def get_config_file_dir() -> str: 20 | """Get the directory of the package configuration file. 21 | 22 | Returns: 23 | str: Directory of the package configuration file. 24 | """ 25 | import os 26 | return os.path.join(os.path.expanduser('~'), '.config', PACKAGE_NAME, env_name(), 'config.json') 27 | 28 | def get_package_cache_dir() -> str: 29 | """Get the directory where mlx converted models and prompt cache files are stored. 30 | 31 | Returns: 32 | str: Directory where mlx converted models and prompt cache files are stored. 33 | """ 34 | import os 35 | import json 36 | default_config = dict( 37 | cache_dir = os.path.join(os.path.expanduser('~'), '.cache', PACKAGE_NAME) 38 | ) 39 | config_dir = get_config_file_dir() 40 | if not os.path.exists(config_dir): 41 | os.makedirs(os.path.dirname(config_dir), exist_ok=True) 42 | with open(config_dir, 'w') as f: 43 | json.dump(default_config, f, indent=4) 44 | config = default_config 45 | else: 46 | with open(config_dir, 'r') as f: 47 | config = json.load(f) 48 | os.makedirs(config['cache_dir'], exist_ok=True) 49 | return config['cache_dir'] 50 | 51 | def get_prompt_cache_dir() -> str: 52 | """Get the directory of prompt cache files. 53 | 54 | Returns: 55 | str: Directory of prompt cache files. 56 | """ 57 | import os 58 | prompt_cache_dir = os.path.join(get_package_cache_dir(), 'prompt_cache') 59 | os.makedirs(prompt_cache_dir, exist_ok=True) 60 | return prompt_cache_dir 61 | 62 | def set_cache_dir(cache_dir: str) -> None: 63 | """Set the directory where mlx converted models and prompt cache files are stored. 64 | 65 | Args: 66 | cache_dir (str): The new directory for mlx converted models and prompt cache files. 67 | """ 68 | if cache_dir.strip(): 69 | import os 70 | import json 71 | get_package_cache_dir() 72 | config = dict(cache_dir=os.path.abspath(cache_dir)) 73 | with open(get_config_file_dir(), 'w') as f: 74 | json.dump(config, f, indent=4) 75 | 76 | else: 77 | import warnings 78 | warnings.warn("`cache_dir` cannot be None or an empty string. `cache_dir` not set.") 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /src/mlx_textgen/cli.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | import argparse 3 | from typing import Optional, List, Union 4 | from .utils import PACKAGE_NAME, set_cache_dir, get_package_cache_dir 5 | from .server import serve_api 6 | 7 | parser = argparse.ArgumentParser(prog=PACKAGE_NAME, description=f'Welcome to {PACKAGE_NAME} CLI') 8 | subparsers = parser.add_subparsers(dest='command') 9 | 10 | # Subcommand for setting default config 11 | def set_config_cli() -> None: 12 | current = get_package_cache_dir() 13 | new = input(f'Cache directory for {PACKAGE_NAME} [{current}]: ') 14 | if new.strip(): 15 | set_cache_dir(new.strip()) 16 | 17 | 18 | parser_set_config = subparsers.add_parser('cachedir', help='Set cache default directory.') 19 | parser_set_config.set_defaults(func=set_config_cli) 20 | 21 | # Subcommand for serving OpenAI API endpoint 22 | def config_and_serve( 23 | config_file: Optional[str] = None, 24 | model_path: Optional[str] = None, 25 | tokenizer_path: Optional[str] = None, 26 | revision: Optional[str] = None, 27 | model_name: Optional[str] = None, 28 | host: str = '127.0.0.1', 29 | port: int = 5001, 30 | api_key: Optional[str] = None, 31 | min_tokens: int = 20, 32 | max_reprocess_tokens: int = 250, 33 | replace_threshold: float = 0.95, 34 | max_capacity: int = 50, 35 | use_reasoning_content: bool = False 36 | ): 37 | from .log_utils import get_logger 38 | from .engine import ModelConfig 39 | import yaml 40 | 41 | if model_path: 42 | mconfig = dict( 43 | model_id_or_path=model_path, 44 | tokenizer_repo_or_path=tokenizer_path, 45 | model_kwargs=dict(revision=revision), 46 | tokenizer_kwargs=dict(revision=revision), 47 | model_name=model_name 48 | ) 49 | config = dict( 50 | model_configs = [mconfig], 51 | host=host, 52 | port=port, 53 | api_keys=[api_key] if api_key else None, 54 | min_tokens=min_tokens, 55 | max_reprocess_tokens=max_reprocess_tokens, 56 | replace_threshold=replace_threshold, 57 | max_capacity=max_capacity, 58 | use_reasoning_content=use_reasoning_content 59 | ) 60 | elif config_file: 61 | with open(config_file, 'r') as f: 62 | config = yaml.load(f, yaml.SafeLoader) 63 | 64 | else: 65 | raise ValueError('Must provide one of "config_file" and "model_path".') 66 | 67 | 68 | config['model_configs'] = [ModelConfig.model_validate(mc) for mc in config['model_configs']] 69 | config['logger'] = get_logger("MLX Textgen") 70 | serve_api(**config) 71 | 72 | parser_serve = subparsers.add_parser('serve', help='Start the MLX Textgen OpenAI-cmopatible API server.') 73 | parser_serve.add_argument('-m', '--model-path', type=str, 74 | default=None, help='Path to the model or the HuggingFace repository name if only one model should be served.') 75 | parser_serve.add_argument('--tokenizer-path', type=str, 76 | default=None, help='Path to the tokenizer or the HuggingFace repository name if only one model should be served. If None is given, it will be the model_path. Defaults to None.') 77 | parser_serve.add_argument('--revision', type=str, 78 | default=None, help='Revision of the repository if an HF repository is given. Defaults to None.') 79 | parser_serve.add_argument('--model-name', type=str, 80 | default=None, help='Model name appears in the API endpoint. If None is given, it will be created automatically with the model path. Defaults to None.') 81 | parser_serve.add_argument('-cf', '--config-file', type=str, 82 | default=None, 83 | help='Path of the config file that store the configs of all models wanted to be served. If this is passed, all other arguments will be ignored.') 84 | parser_serve.add_argument('--api-key', type=str, default=None, help='API key to access the endpoints. Defaults to None.') 85 | parser_serve.add_argument('-p', '--port', type=int, 86 | default=5001, help='Port to server the API endpoints.') 87 | parser_serve.add_argument('--host', type=str, 88 | default='127.0.0.1', help='Host to bind the server to. Defaults to "127.0.0.1".') 89 | parser_serve.add_argument('--min-tokens', type=int, default=20, help='Minimum number of tokens in the cache to be considered for saving.') 90 | parser_serve.add_argument('--max-reprocess-tokens', type=int, default=250, 91 | help='Maximum number of tokens to be dicarded if a cache is regarded as worth saving, but another similar cache exists.') 92 | parser_serve.add_argument('--replace-threshold', type=float, default=0.95, 93 | help='Percentage threshold to consider two cache similar in terms of token prefix. Affected by "max_reprocess_tokens" for longer prompts.') 94 | parser_serve.add_argument('--max-capacity', type=int, default=50, help='Maximum number of cache per model to save. Older ones will be discarded.') 95 | parser_serve.add_argument('--use-reasoning-content', type=bool, default=False, help='Whether to put thoughts of reasoning models in reasoning_content instead of content in /v1/chat/completions endpoint.') 96 | parser_serve.set_defaults(func=config_and_serve) 97 | 98 | # Subcommand for creating config file 99 | def create_config(num_models: int): 100 | import yaml 101 | from .engine import ModelConfig 102 | 103 | mconfig = [ModelConfig(model_id_or_path=f'/path/to/model_{i}').model_dump()for i in range(num_models)] 104 | 105 | config = dict( 106 | model_configs=mconfig, 107 | host='127.0.0.1', 108 | port=5001, 109 | api_keys=None, 110 | min_tokens=20, 111 | max_reprocess_tokens=250, 112 | replace_threshold=0.95, 113 | max_capacity=50, 114 | use_reasoning_content=False 115 | ) 116 | with open('model_config.yaml', 'w') as f: 117 | yaml.dump(config, f, sort_keys=False) 118 | 119 | parser_cf = subparsers.add_parser('createconfig', help='Creating config file.') 120 | parser_cf.add_argument('-n', '--num-models', type=int, default=1, help='Number of model examples in the config file.') 121 | parser_cf.set_defaults(func=create_config) 122 | 123 | args = parser.parse_args() 124 | if args.command: 125 | args_kwargs = vars(args) 126 | args.func(**{k: v for k, v in args_kwargs.items() if k not in ['command', 'func']}) 127 | else: 128 | parser.print_help() 129 | 130 | if __name__ == '__main__': 131 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX-Textgen 2 | [![PyPI](https://img.shields.io/pypi/v/mlx-textgen)](https://pypi.org/project/mlx-textgen/) 3 | [![PyPI - License](https://img.shields.io/pypi/l/mlx-textgen)](https://pypi.org/project/mlx-textgen/) 4 | [![GitHub Repo stars](https://img.shields.io/github/stars/nath1295/mlx-textgen)](https://pypi.org/project/mlx-textgen/) 5 | 6 | ## An OpenAI-compatible API LLM engine with smart prompt caching, batch processing, structured output with guided decoding, and function calling for models using MLX 7 | 8 | MLX-Textgen is a light-weight LLM serving engine that utilize MLX and a smart KV cache management system to make your LLM generation more seamless on your Apple silicon machine. It features: 9 | - **Multiple KV-cache slots to reduce the needs of prompt processing** 10 | - **Structured text generation with json schemas, regex, or context free grammar** 11 | - **Batch inferencing with multiple prompts** 12 | - **Multiple models serving with Fastapi** 13 | - **Common OpenAI API endpoints: `/v1/models`, `/v1/completions`, `/v1/chat/completions`** 14 | 15 | It is built with: 16 | 1. [mlx-lm](https://github.com/ml-explore/mlx-lm) 17 | 2. [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) 18 | 3. [Outlines](https://github.com/dottxt-ai/outlines) 19 | 4. [FastAPI](https://github.com/fastapi/fastapi) 20 | 21 | ## Updates 22 | - **2025-06-21:** Some vision models supported with `mlx-vlm` integration. Tested with Gemma 3 family models and Mistral Small 3.1. 23 | - **2025-06-21:** Reasoning parser supported for reasoning models with `deepseek_r1` parser. 24 | - **2025-06-21:** Breaking changes due to new vision model support and code restructuring. Run `mlx_texgen createconfig` to create a new config file. 25 | - **2025-06-21:** Quantising models will need to be done manually with `mlx-lm` or `mlx-vlm`. 26 | 27 | ## Installing MLX-Textgen 28 | MLX-textgen can be easily installed with `pip`: 29 | ``` 30 | pip install mlx-textgen 31 | ``` 32 | 33 | ## Usage 34 | ### 1. Serving a single model 35 | You can quickly set up a OpenAI API server with a single command. 36 | 37 | ```bash 38 | mlx_textgen serve --model-path mlx-community/gemma-3-4b-it-8bit --port 5001 39 | ``` 40 | 41 | ### 2. Serving a multiple models server 42 | Create a config file template and add as many model as you like. 43 | ```bash 44 | mlx_textgen createconfig --num-models 2 45 | ``` 46 | 47 | It will generate a file called `model_config.yaml`. Edit this file for the models you want to serve. 48 | ```yaml 49 | model_configs: 50 | - model_id_or_path: /path/to/model_0 51 | tokenizer_repo_or_path: null 52 | model_kwargs: null 53 | tokenizer_kwargs: null 54 | model_name: null 55 | enable_cache: true 56 | preprocess_batch_size: 512 57 | extra_stop_words: null 58 | reasoning_parser: null 59 | default_template: null 60 | - model_id_or_path: /path/to/model_1 61 | tokenizer_repo_or_path: null 62 | model_kwargs: null 63 | tokenizer_kwargs: null 64 | model_name: null 65 | enable_cache: true 66 | preprocess_batch_size: 512 67 | extra_stop_words: null 68 | reasoning_parser: null 69 | default_template: null 70 | host: 127.0.0.1 71 | port: 5001 72 | api_keys: null 73 | min_tokens: 20 74 | max_reprocess_tokens: 250 75 | replace_threshold: 0.95 76 | max_capacity: 50 77 | use_reasoning_content: false 78 | ``` 79 | 80 | Then start the engine: 81 | ```bash 82 | mlx_textgen serve --config-file ./model_config.yaml 83 | ``` 84 | 85 | ### 3. More engine arguments 86 | You can check the details of other engine arguments by running: 87 | ```bash 88 | mlx_textgen serve --help 89 | ``` 90 | 91 | You can specify the number of cache slots for each model, minimum number of tokens to create a cache file, and API keys etc. 92 | 93 | ## Features 94 | ### 1. Multiple KV cache slots support 95 | All the KV cache are stored on disk. Therefore, unlike other LLM serving engine, a newly created KV cache will not overwrite the existing KV cache. This works better for agentic workflows where different types of prompts are being used frequently without losing previous cache for a long prompt. 96 | 97 | ### 2. Guided decoding with Regex, Json schema, and Grammar 98 | You can pass your guided decoding argument `guided_json`, `guided_choice`, `guided_regex`, or `guided_grammar` as extra arguments and create structured text generation in a similar fashion to [vllm](https://github.com/vllm-project/vllm). 99 | 100 | ### 3. Batch inference support 101 | Batch inference is supported for multiple prompts or multiple generations for a single prompt. Just pass a list of prompts to the `prompt` argument to the `/v1/completions` endpoint or `n=2` (or more than 2) to the `/v1/chat/completions` or `v1/completions` endpoints for batch inferencing. 102 | 103 | ### 4. Function calling support 104 | Function calling with the `/v1/chat/completions` is supported. Simply use the `tools` and `tool_choice` arguments to supply lists of tools. There are three modes of using function calling: 105 | 1. `tool_choice="auto"`: The model will decide if tool calling is needed based on the conversation. If a tool is needed, it will pick the appropriate tool and generate the arguments. Otherwise, it will only response with normal text. 106 | 2. `tool_choice="required"`: One of the given tools must be selected by the model. The model will pick the appropriate tool and generate the arguments. 107 | 3. `tool_choice={"type": "function", "function": {"name": ""}}`: The model will generate the arguments of the selected tools. 108 | 109 | If function calling is triggered, the call arguments will be contained in the `tool_calls` attribute in the `choices` element in the response. The `finish_reason` will be `tool_calls`. 110 | ```python 111 | from openai import OpenAI 112 | 113 | tools = [{ 114 | "type": "function", 115 | "function": { 116 | "name": "get_current_weather", 117 | "description": "Get the current weather in a given location", 118 | "parameters": { 119 | "type": "object", 120 | "properties": { 121 | "location": { 122 | "type": "string" 123 | }, 124 | "unit": { 125 | "type": "string", 126 | "default": "celsius" 127 | } 128 | }, 129 | "required": ["location"] 130 | } 131 | } 132 | }] 133 | 134 | client = OpenAI(api_key='Your API Key', base_url='http://localhost:5001/v1/') 135 | 136 | output = client.chat.completions.create( 137 | model='model_name', 138 | messages=[ 139 | dict(role='user', content='What is the current weather in London?') 140 | ], 141 | max_tokens=256, 142 | tools=tools, 143 | tool_choice='auto', 144 | stream=False 145 | ).choices[0].model_dump() 146 | 147 | # output: 148 | # {'finish_reason': 'tool_calls', 149 | # 'index': 0, 150 | # 'logprobs': None, 151 | # 'message': {'content': None, 152 | # 'role': 'assistant', 153 | # 'function_call': None, 154 | # 'tool_calls': [{'id': 'call_052c8a6b', 155 | # 'function': {'arguments': '{"location": "London", "unit": "celsius" }', 156 | # 'name': 'get_current_weather'}, 157 | # 'type': 'function', 158 | # 'index': 0}]}} 159 | ``` 160 | 161 | If `tool_choice="none"` is passed, the list of tools provided will be ignored and the model will only generate normal text. 162 | 163 | ### 5. Multiple LLMs serving 164 | Only one model is loaded on ram at a time, but the engine leverage MLX fast module loading time to spin up another model when it is requested. This allows serving multiple models with one endpoint. 165 | 166 | ## License 167 | This project is licensed under the terms of the MIT license. 168 | -------------------------------------------------------------------------------- /src/mlx_textgen/generation_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Literal, Any, Dict, Union 2 | from dataclasses import dataclass 3 | from pydantic import BaseModel, Field 4 | from mlx.core import array 5 | 6 | def string_partial_pause(text: str, stop: List[str], stop_len: List[int]) -> Optional[str]: 7 | """Checks if the end of a string partially matches any of the stop strings. 8 | 9 | This function iterates through the `stop` list and checks if the end of the 10 | input `text` matches the beginning of any of the stop strings up to a 11 | certain length specified by `stop_len`. It returns the matched portion of 12 | the stop string if a match is found, otherwise it returns None. 13 | 14 | Args: 15 | text (str): The input string to check. 16 | stop (List[str]): A list of stop strings to compare against. 17 | stop_len (List[int]): A list of maximum lengths to consider for each 18 | stop string. Should be the same length as `stop`. 19 | 20 | Returns: 21 | Optional[str: The matching portion of a stop string if found, 22 | otherwise None. 23 | """ 24 | for s, l in zip(stop, stop_len): 25 | clen = min(len(text), l) 26 | for i in range(clen): 27 | seg = clen - i 28 | if text[-seg:] == s[:seg]: 29 | return s[:seg] 30 | 31 | def get_stop(text: str, stop: List[str]) -> Optional[str]: 32 | """Checks if any of the stop strings are present in the text. 33 | 34 | This function iterates through the `stop` list and checks if any of the 35 | stop strings are present as substrings within the input `text`. 36 | It returns the first stop string found in the text, or None if none are found. 37 | 38 | Args: 39 | text (str): The input string to check. 40 | stop (List[str]): A list of stop strings to search for. 41 | 42 | Returns: 43 | Optional[str]: The first stop string found in the text, or None. 44 | """ 45 | for s in stop: 46 | if s in text: 47 | return s 48 | 49 | @dataclass 50 | class NewStringToken: 51 | new_token: str 52 | stop_str: Optional[str] 53 | 54 | @dataclass 55 | class GenerationOutput: 56 | index: int 57 | token: str 58 | token_id: Optional[int] 59 | stop_str: Optional[str] 60 | logprobs: Optional[Dict[str, Any]] 61 | input_tokens: int 62 | output_tokens: int 63 | finish_reason: Optional[Literal['stop', 'length', 'tool_calls']] 64 | 65 | class StringStop: 66 | """String stop checker to prevent partial stop sequences. 67 | """ 68 | def __init__(self, num_prompt: int, stop: List[str]) -> None: 69 | self.num_prompt = num_prompt 70 | self.buffer = [''] * num_prompt 71 | self.to_yield = [''] * num_prompt 72 | stop_pair = [(s, len(s)) for s in set(stop)] 73 | stop_pair.sort(key=lambda x: x[1], reverse=True) 74 | self.stop = [s[0] for s in stop_pair] 75 | self.stop_len = [s[1] for s in stop_pair] 76 | self.is_stop = [False] * num_prompt 77 | 78 | def get_finalised_token_strings(self, tokens: List[str]) -> List[NewStringToken]: 79 | """Processes a list of newly generated tokens, checking for stop sequences and preparing tokens to yield. 80 | 81 | This method updates the internal buffer with the new tokens, checks for complete stop sequences, 82 | and identifies partial stop sequences. It then determines the tokens to yield and updates the 83 | buffer accordingly. The `is_stop` flag prevents further processing once a stop sequence is found 84 | for a specific sequence. 85 | 86 | Args: 87 | tokens (List[str]): A list of newly generated tokens, one for each prompt. The length of this 88 | list must match the number of prompts specified during initialization. 89 | 90 | Raises: 91 | ValueError: If the number of provided tokens does not match the expected number of prompts. 92 | 93 | Returns: 94 | List[NewStringToken]: A list of `NewStringToken` objects, one for each prompt. Each object 95 | contains the token to yield and the stop string that was found (if any). 96 | """ 97 | if len(tokens) != self.num_prompt: 98 | raise ValueError('Number of provided tokens not ') 99 | 100 | self.buffer = [o + n if not s else o for o, n, s in zip(self.buffer, tokens, self.is_stop)] 101 | stop = [get_stop(t, self.stop) if not si else None for t, si in zip(self.buffer, self.is_stop)] 102 | self.is_stop = [(si or bool(s)) for si, s in zip(self.is_stop, stop)] 103 | to_yield = [b.split(s)[0] if s else b for b, s in zip(self.buffer, stop)] 104 | temp_stop = [string_partial_pause(b, self.stop, self.stop_len) if not si else None for b, si in zip(to_yield, self.is_stop)] 105 | to_yield = [b.removesuffix(ts) if ts else b for b, ts in zip(to_yield, temp_stop)] 106 | self.buffer = [b.removeprefix(ty) if not si else '' for b, ty, si in zip(self.buffer, to_yield, self.is_stop)] 107 | return [NewStringToken(new_token=ty, stop_str=s) for ty, s in zip(to_yield, stop)] 108 | 109 | def get_remains(self) -> List[NewStringToken]: 110 | return [NewStringToken(new_token=b, stop_str=None) for b in self.buffer] 111 | 112 | def to_completion_logprobs(logprobs: List[Dict[str, Any]]): 113 | tokens = [lp['token'] for lp in logprobs] 114 | token_logprobs = [lp['logprob'] for lp in logprobs] 115 | top_logprobs = [{l['token']: l['logprob'] for l in lp['top_logprobs']} 116 | for lp in logprobs 117 | ] 118 | return dict(tokens=tokens, token_logprobs=token_logprobs, top_logprobs=top_logprobs) 119 | 120 | class TopLogprob(BaseModel): 121 | token: str 122 | logprob: float 123 | 124 | class Logprob(BaseModel): 125 | token: str 126 | logprob: float 127 | top_logprobs: List[TopLogprob] 128 | 129 | class CompletionLogprobs(BaseModel): 130 | tokens: List[str] 131 | token_logprobs: List[float] 132 | top_logprobs: List[Dict[str, float]] 133 | 134 | class LogprobsObject(BaseModel): 135 | content: List[Logprob] 136 | 137 | class CompletionUsageDetails(BaseModel): 138 | reasoning_tokens: int 139 | 140 | class Usage(BaseModel): 141 | prompt_tokens: int 142 | completion_tokens: int 143 | total_tokens: int 144 | completion_tokens_details: Optional[CompletionUsageDetails] = None 145 | 146 | class TextCompletionChoice(BaseModel): 147 | index: int 148 | finish_reason: Optional[Literal['stop', 'length']] = None 149 | text: str 150 | logprobs: Optional[CompletionLogprobs] = None 151 | 152 | class TextCompletionOutput(BaseModel): 153 | id: str = Field(pattern='^cmpl-[a-z0-9]{32}$') 154 | object: Literal['text_completion'] = 'text_completion' 155 | created: int 156 | model: str 157 | choices: List[TextCompletionChoice] 158 | usage: Optional[Usage] = None 159 | 160 | class FunctionInput(BaseModel): 161 | name: str 162 | arguments: Optional[str] 163 | 164 | class ToolCall(BaseModel): 165 | index: int 166 | id: str = Field(pattern='^call_[a-z0-9]{8}$') 167 | function: FunctionInput 168 | type: Literal['function'] = 'function' 169 | 170 | class ChatCompletionDelta(BaseModel): 171 | role: Optional[Literal['assistant']] = None 172 | content: Optional[str] = None 173 | reasoning_content: Optional[str] = None 174 | tool_calls: List[ToolCall] = Field(default_factory=list) 175 | 176 | class ChatCompletionStreamChoice(BaseModel): 177 | index: int 178 | finish_reason: Optional[Literal['stop', 'length', 'tool_calls']] = None 179 | logprobs: Optional[LogprobsObject] = None 180 | delta: ChatCompletionDelta 181 | 182 | class ChatMessage(BaseModel): 183 | role: Literal['assistant'] = 'assistant' 184 | content: Optional[str] = None 185 | reasoning_content: Optional[str] = None 186 | tool_calls: List[ToolCall] = Field(default_factory=list) 187 | 188 | class ChatCompletionChoice(BaseModel): 189 | index: int 190 | message: ChatMessage 191 | finish_reason: Optional[Literal['stop', 'length', 'tool_calls']] = None 192 | logprobs: Optional[LogprobsObject] = None 193 | 194 | class ChatCompletionOutput(BaseModel): 195 | id: str = Field(pattern='^chatcmpl-[a-z0-9]{32}$') 196 | object: Literal['chat.completion', 'chat.completion.chunk'] 197 | created: int 198 | model: str 199 | choices: List[Union[ChatCompletionChoice, ChatCompletionStreamChoice]] 200 | usage: Optional[Usage] = None 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /src/mlx_textgen/server.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Optional, Union, Dict, TYPE_CHECKING 2 | if TYPE_CHECKING: 3 | from .engine import ModelConfig 4 | from logging import Logger 5 | 6 | 7 | def serve_api( 8 | model_configs: List["ModelConfig"], 9 | logger: "Logger", 10 | host: str = '127.0.0.1', 11 | port: int = 5001, 12 | api_keys: Optional[Union[str, List[str]]] = None, 13 | min_tokens: int = 20, 14 | max_reprocess_tokens: int = 250, 15 | replace_threshold: float = 0.95, 16 | max_capacity: int = 50, 17 | use_reasoning_content: bool = False 18 | ): 19 | from .engine import InferenceEngine 20 | from .utils import PACKAGE_NAME 21 | from . import __version__ 22 | from fastapi import FastAPI, HTTPException, status, Request 23 | from fastapi.encoders import jsonable_encoder 24 | from fastapi.responses import JSONResponse, StreamingResponse 25 | from copy import deepcopy 26 | import json 27 | import asyncio 28 | import uvicorn 29 | 30 | engine = InferenceEngine( 31 | model_configs=model_configs, 32 | min_tokens=min_tokens, 33 | max_reprocess_tokens=max_reprocess_tokens, 34 | replace_threshold=replace_threshold, 35 | max_capacity=max_capacity, 36 | use_reasoning_content=use_reasoning_content, 37 | logger=logger 38 | ) 39 | api_keys = api_keys if api_keys else [] 40 | if isinstance(api_keys, str): 41 | api_keys = [api_keys] 42 | 43 | app = FastAPI() 44 | semaphore = asyncio.Semaphore(1) 45 | 46 | def _validate_api_key(request: Request): 47 | api_key = request.headers.get('authorization', 'Bearer ').removeprefix('Bearer ') 48 | if api_keys and (api_key not in api_keys): 49 | raise HTTPException( 50 | status_code=status.HTTP_401_UNAUTHORIZED, 51 | detail="Invalid API key." 52 | ) 53 | 54 | def log_request(content: Dict[str, Any]): 55 | alt = deepcopy(content) 56 | if 'prompt' in alt: 57 | if len(alt['prompt']) > 200: 58 | alt['prompt'] = alt['prompt'][:100] + '...' + alt['prompt'][-100:] 59 | 60 | if 'messages' in alt: 61 | for msgs in alt['messages']: 62 | if not isinstance(msgs, list): 63 | if isinstance(msgs.get('content'), list): 64 | for c in msgs['content']: 65 | img = c.get('image_url') 66 | if img: 67 | if isinstance(img, dict): 68 | img['url'] = img['url'] if not img['url'].startswith('data:') else 'base64image_string' 69 | text = c.get('text') 70 | if text and (len(text) > 200): 71 | c['text'] = text[:99] + '...' + text[-99:] 72 | elif isinstance(msgs.get('content'), str): 73 | text = msgs['content'] 74 | if len(text) > 200: 75 | msgs['content'] = text[:99] + '...' + text[-99:] 76 | else: 77 | for msg in msgs: 78 | if isinstance(msg.get('content'), list): 79 | for c in msg['content']: 80 | img = c.get('image_url') 81 | if img: 82 | if isinstance(img, dict): 83 | img['url'] = img['url'] if not img['url'].startswith('data:') else 'base64image_string' 84 | text = c.get('text') 85 | if text and (len(text) > 200): 86 | c['text'] = text[:99] + '...' + text[-99:] 87 | elif isinstance(msg.get('content'), str): 88 | text = msg['content'] 89 | if len(text) > 200: 90 | msg['content'] = text[:99] + '...' + text[-99:] 91 | 92 | 93 | logger.info(json.dumps(alt, indent=2)) 94 | 95 | @app.get('/v1/models') 96 | async def get_models(request: Request) -> JSONResponse: 97 | _validate_api_key(request) 98 | return JSONResponse(content=jsonable_encoder(dict(object='list', data=engine.model_info))) 99 | 100 | @app.get('/v1/models/{model_id}') 101 | async def get_model(request: Request, model_id: str) -> JSONResponse: 102 | _validate_api_key(request) 103 | model_dict = {info['id']: info for info in engine.model_info} 104 | if model_id not in model_dict: 105 | raise HTTPException( 106 | status_code=status.HTTP_404_NOT_FOUND, 107 | detail=f'Model "{model_id}" does not exist.' 108 | ) 109 | return JSONResponse(content=jsonable_encoder(model_dict[model_id])) 110 | 111 | @app.post('/v1/completions', response_model=None) 112 | async def completions(request: Request) -> Union[StreamingResponse, JSONResponse]: 113 | content = await request.json() 114 | log_request(content) 115 | _validate_api_key(request) 116 | model = content.get('model') 117 | if model not in engine.model_dict.keys(): 118 | return JSONResponse(jsonable_encoder(dict(error=f'Model "{model}" does not exist.')), status_code=404) 119 | 120 | if isinstance(content.get('stop', None), str): 121 | content['stop'] = [content['stop']] 122 | 123 | stream = content.get('stream', False) 124 | 125 | async with semaphore: 126 | if stream: 127 | async def gen(): 128 | try: 129 | generator = await asyncio.to_thread(engine.generate, **content) 130 | for chunk in generator: 131 | yield f'data: {chunk.model_dump_json()}\n\n' 132 | yield 'data: [DONE]' 133 | except Exception as e: 134 | logger.error(str(e)[:500]) 135 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)[:500]) 136 | return StreamingResponse(gen(), media_type="text/event-stream") 137 | 138 | else: 139 | try: 140 | output = await asyncio.to_thread(engine.generate, **content) 141 | return JSONResponse(jsonable_encoder(output.model_dump())) 142 | except Exception as e: 143 | logger.error(str(e)[:500]) 144 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)[:500]) 145 | 146 | @app.post('/v1/chat/completions', response_model=None) 147 | async def completions(request: Request) -> Union[StreamingResponse, JSONResponse]: 148 | content = await request.json() 149 | log_request(content) 150 | _validate_api_key(request) 151 | model = content.get('model') 152 | if model not in engine.model_dict.keys(): 153 | return JSONResponse(jsonable_encoder(dict(error=f'Model "{model}" does not exist.')), status_code=404) 154 | 155 | if isinstance(content.get('stop', None), str): 156 | content['stop'] = [content['stop']] 157 | 158 | if content.get('max_tokens', None) and (not content.get('max_completion_tokens', None)): 159 | content['max_completion_tokens'] = content['max_tokens'] 160 | 161 | stream = content.get('stream', False) 162 | 163 | async with semaphore: 164 | if stream: 165 | async def gen(): 166 | try: 167 | generator = await asyncio.to_thread(engine.chat_generate, **content) 168 | for chunk in generator: 169 | yield f'data: {chunk.model_dump_json()}\n\n' 170 | yield 'data: [DONE]' 171 | except Exception as e: 172 | logger.error(str(e)[:500]) 173 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)[:500]) 174 | return StreamingResponse(gen(), media_type="text/event-stream") 175 | 176 | else: 177 | try: 178 | output = await asyncio.to_thread(engine.chat_generate, **content) 179 | return JSONResponse(jsonable_encoder(output.model_dump())) 180 | except Exception as e: 181 | logger.error(e) 182 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) 183 | 184 | print(f'{PACKAGE_NAME} OpenAI-compatible LLM API server version: {__version__}') 185 | 186 | uvicorn.run(app, port=port, host=host) 187 | 188 | 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /src/mlx_textgen/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | # adapted from mlx-lm 2 | import json 3 | from functools import partial 4 | from abc import ABC, abstractmethod 5 | from typing import List, Optional, Dict, Any, TYPE_CHECKING 6 | if TYPE_CHECKING: 7 | from logging import Logger 8 | from transformers import PreTrainedTokenizer 9 | 10 | REPLACEMENT_CHAR = "\ufffd" 11 | SPECIAL_SPACE = "\u2581" 12 | 13 | def _remove_space(x): 14 | if x and x[0] == " ": 15 | return x[1:] 16 | return x 17 | 18 | 19 | class BaseDetokenizer: 20 | 21 | @abstractmethod 22 | def reset(self, num_seqs: Optional[int] = None) -> None: 23 | pass 24 | 25 | @abstractmethod 26 | def add_tokens(self, token_ids: List[List[int]]) -> None: 27 | pass 28 | 29 | @abstractmethod 30 | def finalize(self) -> None: 31 | pass 32 | 33 | @property 34 | @abstractmethod 35 | def last_segments(self) -> List[str]: 36 | pass 37 | 38 | class NaiveDetokenizer(BaseDetokenizer): 39 | """Detokenizer for streaming. 40 | """ 41 | def __init__(self, tokenizer: "PreTrainedTokenizer") -> None: 42 | self._tokenizer = tokenizer 43 | self.reset() 44 | 45 | def reset(self, num_seqs: Optional[int] = None) -> None: 46 | self.num_seqs = num_seqs 47 | self.texts = [] if self.num_seqs is None else [''] * self.num_seqs 48 | self.tokens = [] if self.num_seqs is None else [[]] * self.num_seqs 49 | self.current_tokens = [] if self.num_seqs is None else [[]] * self.num_seqs 50 | self.final_segments = None 51 | 52 | def add_tokens(self, token_ids: List[List[int]]) -> None: 53 | if self.num_seqs is None: 54 | self.reset(num_seqs=len(token_ids)) 55 | elif len(token_ids) != self.num_seqs: 56 | raise Exception('Number of tokens does not match with number of existing sequences.') 57 | self.current_tokens = [tids + nids for tids, nids in zip(self.current_tokens, token_ids)] 58 | 59 | def finalize(self) -> None: 60 | new_texts = self._tokenizer.batch_decode(self.current_tokens) 61 | new_texts = [nt.rstrip(REPLACEMENT_CHAR) for nt in new_texts] 62 | self.tokens = [t + ct for t, ct in zip(self.tokens, self.current_tokens)] 63 | self.texts = [t + nt for t, nt in zip(self.texts, new_texts)] 64 | self.current_tokens = [[]] * self.num_seqs 65 | self.final_segments = new_texts 66 | 67 | @property 68 | def last_segments(self) -> List[str]: 69 | if self.final_segments is not None: 70 | return self.final_segments 71 | new_texts = self._tokenizer.batch_decode(self.current_tokens) 72 | with_repl = [newt.endswith(REPLACEMENT_CHAR) for newt in new_texts] 73 | bundle = [(ot if wr else ot + nt, nt if wr else [], os if wr else os + ns, '' if wr else ns) for ot, nt, os, ns, wr in zip(self.tokens, self.current_tokens, self.texts, new_texts, with_repl)] 74 | tokens, current_tokens, texts, new_segments = list(zip(*bundle)) 75 | self.tokens = list(tokens) 76 | self.current_tokens = list(current_tokens) 77 | self.texts = list(texts) 78 | new_segments = list(new_segments) 79 | return new_segments 80 | 81 | class SPMDetokenizer(BaseDetokenizer): 82 | """A streaming detokenizer for SPM models. 83 | 84 | It adds tokens to the text if the next token starts with the special SPM 85 | underscore which results in linear complexity. 86 | """ 87 | 88 | def __init__(self, tokenizer: "PreTrainedTokenizer", trim_space=True): 89 | self.trim_space = trim_space 90 | self.eos_token = tokenizer.eos_token 91 | 92 | # Extract the tokens in a list from id to text 93 | self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) 94 | for value, tokenid in tokenizer.vocab.items(): 95 | self.tokenmap[tokenid] = value 96 | 97 | self.hexcode_tokens = [i for i, t in enumerate(self.tokenmap) if t.startswith('<0x')] 98 | 99 | self.reset() 100 | 101 | def reset(self, num_seqs: Optional[int] = None) -> None: 102 | self.num_seqs = num_seqs 103 | self.range = [] if self.num_seqs is None else list(range(self.num_seqs)) 104 | self.texts = [] if self.num_seqs is None else [''] * self.num_seqs 105 | self.tokens = [] if self.num_seqs is None else [[]] * self.num_seqs 106 | self.hexcodes = [] if self.num_seqs is None else [[]] * self.num_seqs 107 | self.segments = [] if self.num_seqs is None else [''] * self.num_seqs 108 | 109 | def _get_text_token(self, token_id, raw_token, token_condition, index) -> str: 110 | self.tokens[index].append(token_id[0]) 111 | is_space, is_hex = token_condition 112 | output = '' 113 | if is_hex: 114 | self.hexcodes[index].append(int(raw_token[3:5], 16)) 115 | elif is_space: 116 | if self.texts[index] or (not self.trim_space): 117 | output = ('' if len(self.hexcodes[index]) == 0 else bytes(self.hexcodes[index]).decode()) + raw_token.replace(SPECIAL_SPACE, ' ') 118 | else: 119 | output = ('' if len(self.hexcodes[index]) == 0 else bytes(self.hexcodes[index]).decode()) + _remove_space(raw_token.replace(SPECIAL_SPACE, ' ')) 120 | self.hexcodes[index] = [] 121 | else: 122 | output = ('' if len(self.hexcodes[index]) == 0 else bytes(self.hexcodes[index]).decode()) + raw_token 123 | self.hexcodes[index] = [] 124 | self.texts[index] += output 125 | return output 126 | 127 | def add_tokens(self, token_ids: List[List[int]]) -> None: 128 | if self.num_seqs is None: 129 | self.reset(num_seqs=len(token_ids)) 130 | elif len(token_ids) != self.num_seqs: 131 | raise Exception('Number of tokens does not match with number of existing sequences.') 132 | raw_tokens = [self.tokenmap[token[0]] for token in token_ids] 133 | token_conditions = [((rt[0] == SPECIAL_SPACE), (tid[0] in self.hexcode_tokens)) for rt, tid in zip(raw_tokens, token_ids)] 134 | self.segments = [self._get_text_token(tid, rt, tc, i) for tid, rt, tc, i in zip(token_ids, raw_tokens, token_conditions, self.range)] 135 | 136 | def finalize(self) -> None: 137 | hex_str = [('' if len(self.hexcodes[index]) == 0 else bytes(self.hexcodes[index]).decode()) for index in self.range] 138 | self.segments = [s + hs for s, hs in zip(self.segments, hex_str)] 139 | 140 | @property 141 | def last_segments(self): 142 | """Return the last segment of readable text since last time this property was accessed.""" 143 | segments = self.segments 144 | self.segments = [''] * self.num_seqs 145 | return segments 146 | 147 | 148 | class TokenizerWrapper: 149 | """A wrapper that combines an HF tokenizer and a detokenizer. 150 | 151 | Accessing any attribute other than the ``detokenizer`` is forwarded to the 152 | huggingface tokenizer. 153 | """ 154 | 155 | def __init__(self, tokenizer: "PreTrainedTokenizer", detokenizer_class: "BaseDetokenizer" = NaiveDetokenizer) -> None: 156 | self._tokenizer = tokenizer 157 | self._detokenizer_class = detokenizer_class 158 | self._detokenizer = detokenizer_class(tokenizer) 159 | 160 | def __getattr__(self, attr): 161 | if attr == "detokenizer": 162 | return self._detokenizer 163 | elif attr.startswith("_"): 164 | return self.__getattribute__(attr) 165 | else: 166 | return getattr(self._tokenizer, attr) 167 | 168 | def __setattr__(self, attr, value): 169 | if attr == "detokenizer": 170 | raise AttributeError("Cannot set the detokenizer.") 171 | elif attr.startswith("_"): 172 | super().__setattr__(attr, value) 173 | else: 174 | setattr(self._tokenizer, attr, value) 175 | 176 | 177 | 178 | def _match(a, b): 179 | if type(a) != type(b): 180 | return False 181 | if isinstance(a, dict): 182 | return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a) 183 | if isinstance(a, list): 184 | return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b)) 185 | 186 | return a == b 187 | 188 | 189 | def _is_spm_decoder(decoder): 190 | _target_description = { 191 | "type": "Sequence", 192 | "decoders": [ 193 | {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, 194 | {"type": "ByteFallback"}, 195 | {"type": "Fuse"}, 196 | {"type": "Strip", "content": " ", "start": 1, "stop": 0}, 197 | ], 198 | } 199 | return _match(_target_description, decoder) 200 | 201 | 202 | def _is_spm_decoder_no_space(decoder): 203 | _target_description = { 204 | "type": "Sequence", 205 | "decoders": [ 206 | {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, 207 | {"type": "ByteFallback"}, 208 | {"type": "Fuse"}, 209 | ], 210 | } 211 | return _match(_target_description, decoder) 212 | 213 | 214 | def _is_bpe_decoder(decoder): 215 | return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" 216 | 217 | 218 | def load_tokenizer(model_path: str, tokenizer_config_extra: Optional[Dict[str, Any]] = None, logger: Optional["Logger"] = None) -> TokenizerWrapper: 219 | """Load a huggingface tokenizer and try to infer the type of streaming 220 | detokenizer to use. 221 | 222 | Note, to use a fast streaming tokenizer, pass a local file path rather than 223 | a Hugging Face repo ID. 224 | """ 225 | import os 226 | from transformers import AutoTokenizer 227 | from huggingface_hub import hf_hub_download 228 | 229 | tokenizer_config_extra = dict() if tokenizer_config_extra is None else tokenizer_config_extra 230 | detokenizer_class = NaiveDetokenizer 231 | 232 | tokenizer_file = os.path.join(model_path, "tokenizer.json") 233 | if not os.path.exists(tokenizer_file): 234 | tokenizer_file = hf_hub_download(repo_id=str(model_path), filename='tokenizer.json') 235 | with open(tokenizer_file, "r") as fid: 236 | tokenizer_content = json.load(fid) 237 | if "decoder" in tokenizer_content: 238 | if _is_spm_decoder(tokenizer_content["decoder"]): 239 | if logger: 240 | logger.info('Using SPM decoder.') 241 | detokenizer_class = SPMDetokenizer 242 | elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): 243 | if logger: 244 | logger.info('Using SPM decoder with trim_space=False.') 245 | detokenizer_class = partial(SPMDetokenizer, trim_space=False) 246 | elif logger: 247 | logger.info('Using Naive decoder.') 248 | 249 | return TokenizerWrapper( 250 | AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), 251 | detokenizer_class, 252 | ) -------------------------------------------------------------------------------- /src/mlx_textgen/chat_presets.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | PRESETS_EOT = { 4 | "chatml": "", 5 | "llama3": "<|eot_id|>", 6 | "gemma": "", 7 | "deepseek": '<\uff5cend\u2581of\u2581sentence\uff5c>', 8 | "openchat": None, 9 | "phi": None 10 | } 11 | 12 | PRESETS = { 13 | "chatml": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{{\\\"name\\\": , \\\"arguments\\\": }}\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", 14 | 15 | "llama3": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", 16 | "mistral": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful assistant.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", 17 | 18 | "gemma": "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", 19 | 20 | "deepseek": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %} {%- if message['role'] == 'system' %} {% set ns.system_prompt = message['content'] %} {%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %} {%- if message['role'] == 'user' %} {%- set ns.is_tool = false -%}{{'<\uff5cUser\uff5c>' + message['content']}} {%- endif %} {%- if message['role'] == 'assistant' and message['content'] is none %} {%- set ns.is_tool = false -%} {%- for tool in message['tool_calls']%} {%- if not ns.is_first %}{{'<\uff5cAssistant\uff5c><\uff5ctool\u2581calls\u2581begin\uff5c><\uff5ctool\u2581call\u2581begin\uff5c>' + tool['type'] + '<\uff5ctool\u2581sep\uff5c>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<\uff5ctool\u2581call\u2581end\uff5c>'}} {%- set ns.is_first = true -%} {%- else %}{{'\\n' + '<\uff5ctool\u2581call\u2581begin\uff5c>' + tool['type'] + '<\uff5ctool\u2581sep\uff5c>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<\uff5ctool\u2581call\u2581end\uff5c>'}}{{'<\uff5ctool\u2581calls\u2581end\uff5c><\uff5cend\u2581of\u2581sentence\uff5c>'}} {%- endif %} {%- endfor %} {%- endif %} {%- if message['role'] == 'assistant' and message['content'] is not none %} {%- if ns.is_tool %}{{'<\uff5ctool\u2581outputs\u2581end\uff5c>' + message['content'] + '<\uff5cend\u2581of\u2581sentence\uff5c>'}} {%- set ns.is_tool = false -%} {%- else %}{{'<\uff5cAssistant\uff5c>' + message['content'] + '<\uff5cend\u2581of\u2581sentence\uff5c>'}} {%- endif %} {%- endif %} {%- if message['role'] == 'tool' %} {%- set ns.is_tool = true -%} {%- if ns.is_output_first %}{{'<\uff5ctool\u2581outputs\u2581begin\uff5c><\uff5ctool\u2581output\u2581begin\uff5c>' + message['content'] + '<\uff5ctool\u2581output\u2581end\uff5c>'}} {%- set ns.is_output_first = false %} {%- else %}{{'\\n<\uff5ctool\u2581output\u2581begin\uff5c>' + message['content'] + '<\uff5ctool\u2581output\u2581end\uff5c>'}} {%- endif %} {%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<\uff5ctool\u2581outputs\u2581end\uff5c>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<\uff5cAssistant\uff5c>'}}{% endif %}", 21 | 22 | "openchat": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", 23 | 24 | "phi": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}" 25 | } 26 | 27 | TOOL_LIST = [{'properties': {'example_arg': {'title': 'Example Arg', 'type': 'string'}}, 'required': ['example_arg'], 'title': 'ExampleTool', 'type': 'function'}] 28 | TOOL_CALL_LIST = [dict(id='call_1324asdf', type='function', function=dict(name='ExampleTool', arguments='{"example_arg": "adsf"}'))] 29 | MSG_WITH_TOOL = [ 30 | dict(role='user', content='Hello'), 31 | dict(role='assistant', tool_calls=TOOL_CALL_LIST) 32 | ] 33 | MSG_SINGLE_ASSISTANT = [dict(role='assistant', content='')] 34 | DEFAULT_TOOL_SYSTEM = ''' 35 | 36 | # Tools 37 | 38 | You may call one or more functions to assist with the user query. 39 | 40 | You are provided with function signatures within XML tags: 41 | 42 | $$tool_list$$ 43 | 44 | 45 | For each function call, return a json object with function name and arguments within XML tags: 46 | 47 | {"name": , "arguments": } 48 | 49 | 50 | The tool output will be encapsulated within the XML tags provided by the user.''' -------------------------------------------------------------------------------- /src/mlx_textgen/chat_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, List, Dict, Any, Optional, Union, Tuple, TYPE_CHECKING 2 | from pydantic import BaseModel, model_validator 3 | if TYPE_CHECKING: 4 | from transformers.tokenization_utils import PreTrainedTokenizer 5 | 6 | class TextContent(BaseModel): 7 | type: Literal['text'] 8 | text: str 9 | 10 | class ImageURL(BaseModel): 11 | url: str 12 | 13 | class ImageContent(BaseModel): 14 | type: Literal['image_url', 'input_image'] 15 | image_url: Union[ImageURL, str] 16 | 17 | class Parameters(BaseModel): 18 | type: Literal['object'] 19 | properties: Dict[str, Any] 20 | 21 | class FunctionSchema(BaseModel): 22 | name: str 23 | description: Optional[str] = None 24 | parameters: Parameters 25 | 26 | class OpenAIToolSchema(BaseModel): 27 | type: Literal['function'] 28 | function: FunctionSchema 29 | 30 | class ToolChoiceFunction(BaseModel): 31 | name: str 32 | 33 | class ToolChoiceSchema(BaseModel): 34 | type: Literal['function'] 35 | function: ToolChoiceFunction 36 | 37 | class ResponseFormat(BaseModel): 38 | type: Literal['json_schema'] 39 | json_schema: Dict[str, Any] 40 | 41 | class Function(BaseModel): 42 | name: str 43 | arguments: str 44 | 45 | class ToolCall(BaseModel): 46 | id: str 47 | type: Literal['function'] 48 | function: Function 49 | 50 | class ChatMessage(BaseModel): 51 | role: Literal['system', 'developer', 'user', 'assistant', 'tool'] 52 | content: Optional[Union[str, List[Union[TextContent, ImageContent]]]] = None 53 | tool_calls: Optional[List[ToolCall]] = None 54 | 55 | @property 56 | def tool_call_list(self): 57 | if self.tool_calls: 58 | import json 59 | tool_call_list = [] 60 | for t in self.tool_calls: 61 | t_dict = t.model_dump() 62 | arg_str = t_dict['function']['arguments'] 63 | if arg_str.strip(): 64 | t_dict['function']['arguments'] = json.loads(arg_str.strip()) 65 | else: 66 | t_dict['function']['arguments'] = {} 67 | tool_call_list.append(t_dict) 68 | return tool_call_list 69 | 70 | def model_post_init(self, __context): 71 | self.role = 'system' if self.role == 'developer' else self.role 72 | if isinstance(self.content, list): 73 | if all(c.type == 'text' for c in self.content): 74 | self.content = '\n'.join([c.text for c in self.content]) 75 | 76 | if self.tool_calls: 77 | import json 78 | tool_call_list = [] 79 | for t in self.tool_calls: 80 | t_dict = t.model_dump() 81 | arg_str = t_dict['function']['arguments'] 82 | if arg_str.strip(): 83 | t_dict['function']['arguments'] = json.loads(arg_str.strip()) 84 | else: 85 | t_dict['function']['arguments'] = {} 86 | tool_call_list.append(t_dict) 87 | self.tool_calls = tool_call_list 88 | else: 89 | self.tool_calls = None 90 | 91 | @model_validator(mode='after') 92 | def content_validation(self): 93 | if self.role == 'assistant': 94 | if (self.content is not None) and (not isinstance(self.content, str)) and (any(x.type != 'text' for x in self.content)): 95 | raise ValueError('"assistant" content must be string.') 96 | 97 | elif (self.content is None) and (not self.tool_calls): 98 | raise ValueError('No "tool_calls" or "content" for the assistant message.') 99 | 100 | elif(self.content is not None) and (self.tool_calls is not None): 101 | raise ValueError('Cannot have both tool_calls and content.') 102 | 103 | elif self.role != 'user': 104 | if self.content is None: 105 | raise ValueError(f'Content cannot be None for role "{self.role}".') 106 | elif (isinstance(self.content, list)) and (any(c.type != 'text' for c in self.content)): 107 | raise ValueError(f'Content can only be string for role "{self.role}".') 108 | 109 | return self 110 | 111 | # Model to format mapping 112 | model_to_format = { 113 | # Models using message_list_with_image format 114 | "idefics2": "message_list_with_image", 115 | "idefics3": "message_list_with_image_first", 116 | "aya_vision": "message_list_with_image", 117 | "mistral3": "message_list_with_image_first", 118 | "qwen2_vl": "message_list_with_image", 119 | "qwen2_5_vl": "message_list_with_image_first", 120 | "kimi_vl": "message_list_with_image", 121 | "llama4": "message_list_with_image", 122 | "smolvlm": "message_list_with_image_first", 123 | "llava": "message_list_with_image", 124 | "llava_next": "message_list_with_image", 125 | "mllama": "message_list_with_image", 126 | # Models using message_list_with_content_image format 127 | "internvl_chat": "message_list_with_image_type", 128 | "pixtral": "message_list_with_image_type", 129 | # Models using 130 | "gemma3": "message_with_start_image_token", 131 | # Models using \n 132 | "llava-qwen2": "message_with_image_token_new_line", 133 | "bunny-llama": "message_with_image_token_new_line", 134 | "deepseek_vl_v2": "message_with_image_token_new_line", 135 | # Models using message_with_image_token format 136 | "multi_modality": "message_with_image_token", 137 | # Models using <|image_i|> 138 | "phi3_v": "message_with_numbered_image_tokens", 139 | # Models using prompt_with_image_token format 140 | "paligemma": "prompt_with_image_token", 141 | # Models using prompt_only format 142 | "florence2": "prompt_only", 143 | "molmo": "prompt_only", 144 | } 145 | 146 | def convert_tool_to_json_schema(tool: Dict[str, Any]) -> Dict[str, Any]: 147 | params_json = tool.get('function', dict()).get('parameters') 148 | if params_json is not None: 149 | if 'title' not in params_json.keys(): 150 | params_json['title'] = tool['function']['name'].title() + 'Args' 151 | return params_json 152 | else: 153 | return tool 154 | 155 | def build_function_call_schema(tools: list[Dict[str, Any]]) -> Dict[str, Any]: 156 | defs = dict() 157 | ref_list = [] 158 | for tool in tools: 159 | tool_name = tool['function']['name'] 160 | tool_title = tool_name.title() 161 | defs[tool_title] = dict( 162 | properties=dict( 163 | name=dict(const=tool_name, title='Name', type='string'), 164 | arguments={'$ref': f'#/$defs/{tool_title}Args'} 165 | ), 166 | required=['name', 'arguments'], 167 | title=tool_title, 168 | type='object' 169 | ) 170 | defs[tool_title + 'Args'] = convert_tool_to_json_schema(tool) 171 | ref_list.append({'$ref': f'#/$defs/{tool_title}'}) 172 | 173 | schema = {'$defs': defs, 'anyOf': ref_list, 'title': 'FunctionCall'} if len(ref_list) > 1 else {'$defs': defs, '$ref': list(ref_list[0].values())[0], 'title': 'FunctionCall'} 174 | return schema 175 | 176 | def get_message_json(model_type: str, message: ChatMessage, image_count: int = 0) -> Dict[str, Optional[Union[str, List[Dict[str, str]]]]]: 177 | format_type = model_to_format.get(model_type.lower()) 178 | if not format_type: 179 | raise ValueError(f'Cannot find format type for model type "model_type".') 180 | 181 | if message.role != 'user': 182 | return message.model_dump() 183 | 184 | if format_type in ('message_list_with_image', "message_list_with_image_first"): 185 | if isinstance(message.content, str): 186 | content = [{"type": "text", "text": message.content}] 187 | return dict(role=message.role, content=content) 188 | elif isinstance(message.content, list): 189 | content = [{"type": "text", "text": c.text} if c.type == 'text' else {"type": "image"} for c in message.content] 190 | return dict(role=message.role, content=content) 191 | else: 192 | return message.model_dump() 193 | 194 | elif format_type == 'message_list_with_image_type': 195 | if isinstance(message.content, str): 196 | content = [{"type": "text", "content": message.content}] 197 | return dict(role=message.role, content=content) 198 | elif isinstance(message.content, list): 199 | content = [{"type": "text", "content": c.text} if c.type == 'text' else {"type": "image"} for c in message.content] 200 | return dict(role=message.role, content=content) 201 | else: 202 | return message.model_dump() 203 | 204 | elif format_type in ("message_with_start_image_token", "message_with_image_token_new_line", "message_with_image_token"): 205 | if format_type == 'message_with_start_image_token': 206 | image_token = '' 207 | elif format_type == 'message_with_image_token_new_line': 208 | image_token = '\n' 209 | else: 210 | image_token = '' 211 | 212 | if isinstance(message.content, str): 213 | return dict(role=message.role, content=message.content) 214 | elif isinstance(message.content, list): 215 | content = '' 216 | for c in message.content: 217 | if c.type == 'text': 218 | content += c.text 219 | else: 220 | content += image_token 221 | return dict(role=message.role, content=content) 222 | else: 223 | return message.model_dump() 224 | 225 | elif format_type == 'message_with_numbered_image_tokens': 226 | current_image_count = image_count + 1 227 | if isinstance(message.content, str): 228 | return dict(role=message.role, content=message.content) 229 | elif isinstance(message.content, list): 230 | content = '' 231 | for c in message.content: 232 | if c.type == 'text': 233 | content += c.text 234 | else: 235 | content += f'<|image_{current_image_count}|> ' 236 | current_image_count += 1 237 | return dict(role=message.role, content=content) 238 | else: 239 | return message.model_dump() 240 | 241 | else: 242 | raise ValueError('Model chat template type not supported.') 243 | 244 | def convert_vision_message_list(messages: List[ChatMessage], model_type: str) -> List[Dict[str, Optional[Union[str, List[Dict[str, str]]]]]]: 245 | image_count = 0 246 | img_count_list = [] 247 | for m in messages: 248 | img_count_list.append(image_count) 249 | if isinstance(m.content, list): 250 | image_count += sum([1 if c.type == 'imput_image' else 0 for c in m.content]) 251 | 252 | msg_dicts = [get_message_json(model_type, m, c) for m, c in zip(messages, img_count_list)] 253 | return msg_dicts 254 | 255 | 256 | DEFAULT_TEMPLATES = Literal['chatml', 'llama3', 'gemma', 'deepseek', 'openchat', 'phi'] 257 | 258 | class ChatTemplate: 259 | 260 | def __init__(self, 261 | tokenizer: "PreTrainedTokenizer", 262 | model_type: str, 263 | default_template: Optional[DEFAULT_TEMPLATES] = None, 264 | is_vision: bool = False, 265 | reasoning_parser: Optional[Literal['deepseek_r1']] = None 266 | ): 267 | from .chat_presets import PRESETS, PRESETS_EOT 268 | self._tokenizer = tokenizer 269 | self._model_type = model_type 270 | self._is_vision = is_vision 271 | self._reasoning_parser = reasoning_parser 272 | self._ban_none_content = False 273 | if default_template: 274 | if default_template in PRESETS.keys(): 275 | self.tokenizer.chat_template = PRESETS[default_template] 276 | self.eot = PRESETS_EOT[default_template] 277 | else: 278 | raise ValueError(f'Default chat template "{default_template}" does not exist.') 279 | elif self.tokenizer.chat_template is None: 280 | self.tokenizer.chat_template = PRESETS['chatml'] 281 | self.eot = PRESETS_EOT['chatml'] 282 | else: 283 | self.eot = self.tokenizer.eos_token 284 | 285 | @property 286 | def tokenizer(self) -> "PreTrainedTokenizer": 287 | return self._tokenizer 288 | 289 | @property 290 | def is_vision(self) -> bool: 291 | return self._is_vision 292 | 293 | @property 294 | def reasoning_start(self) -> Optional[str]: 295 | if self._reasoning_parser == 'deepseek_r1': 296 | return '\n' 297 | 298 | @property 299 | def reasoning_end(self) -> Optional[str]: 300 | if self._reasoning_parser == 'deepseek_r1': 301 | return '\n\n\n' 302 | 303 | @property 304 | def support_system(self) -> bool: 305 | if not hasattr(self, '_support_system'): 306 | try: 307 | system = 'Test system message, see if exist.' 308 | messages = [dict(role='system', content=system), dict(role='user', content='Hi there')] 309 | messages = [ChatMessage.model_validate(m).model_dump() for m in messages] 310 | prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 311 | self._support_system = system in prompt 312 | except: 313 | self._support_system = False 314 | return self._support_system 315 | 316 | @property 317 | def allow_multiple_assistant(self) -> bool: 318 | if not hasattr(self, '_allow_multiple_assistant'): 319 | try: 320 | from .chat_presets import MSG_SINGLE_ASSISTANT 321 | messages = MSG_SINGLE_ASSISTANT * 2 322 | messages = [ChatMessage.model_validate(m).model_dump() for m in messages] 323 | self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, continue_final_message=True) 324 | self._allow_multiple_assistant = True 325 | except: 326 | self._allow_multiple_assistant = False 327 | return self._allow_multiple_assistant 328 | 329 | @property 330 | def support_tool_call(self) -> bool: 331 | if not hasattr(self, '_support_tool_call'): 332 | try: 333 | from .chat_presets import MSG_WITH_TOOL, TOOL_LIST 334 | messages = MSG_WITH_TOOL[:1] 335 | messages = [ChatMessage.model_validate(m).model_dump() for m in messages] 336 | p_with_tool = self.tokenizer.apply_chat_template(messages, tools=TOOL_LIST, add_generation_prompt=True, tokenize=False) 337 | p_wo_tool = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) 338 | self._support_tool_call = p_with_tool != p_wo_tool 339 | except: 340 | self._support_tool_call = False 341 | return self._support_tool_call 342 | 343 | @property 344 | def tool_start(self) -> str: 345 | if not hasattr(self, '_tool_start'): 346 | if self.support_tool_call: 347 | from .chat_presets import MSG_WITH_TOOL, TOOL_LIST 348 | messages = [ChatMessage.model_validate(m).model_dump() for m in MSG_WITH_TOOL] 349 | try: 350 | p_with_tool = self.tokenizer.apply_chat_template(messages, tools=TOOL_LIST, add_generation_prompt=True, tokenize=False) 351 | except: 352 | messages = [self._content_to_str_with_tool_call(m) for m in messages] 353 | p_with_tool = self.tokenizer.apply_chat_template(messages, tools=TOOL_LIST, add_generation_prompt=True, tokenize=False) 354 | self._ban_none_content = True 355 | p_wo_tool = self.tokenizer.apply_chat_template(messages[:1], tools=TOOL_LIST, add_generation_prompt=True, tokenize=False) 356 | diff_str = p_with_tool.removeprefix(p_wo_tool) 357 | tool_first_index = diff_str.find('{') 358 | self._tool_start = diff_str[:tool_first_index] 359 | if self.reasoning_start: 360 | rchunk = self.reasoning_start + self.reasoning_end 361 | if rchunk in self._tool_start: 362 | self._tool_start = self._tool_start.split(rchunk)[-1] 363 | 364 | else: 365 | self._tool_start = '\n' 366 | return self._tool_start 367 | 368 | def _content_to_str_with_tool_call(self, message: Dict[str, Any]) -> Dict[str, Any]: 369 | if (message['role'] == 'assistant') and ('tool_calls' in message): 370 | if message.get('content', None) is None: 371 | message['content'] = '' 372 | return message 373 | 374 | def _validate_msg_seq(self, messages: List[Union[Dict[str, Any], ChatMessage]]) -> List[ChatMessage]: 375 | if len(messages) == 0: 376 | raise ValueError('Cannot have an empty message list.') 377 | msgs = [ChatMessage.model_validate(m) if not isinstance(m, ChatMessage) else m for m in messages] 378 | 379 | system_count = sum([1 if m.role == 'system' else 0 for m in msgs]) 380 | if system_count > 1: 381 | raise ValueError('Cannot have more than one system messages.') 382 | elif (system_count == 1) and (msgs[0].role != 'system'): 383 | raise ValueError('If system messages is provided, it must be the first message.') 384 | elif (len(msgs) < 2 and (system_count == 1)): 385 | raise ValueError('Cannot have a messages list with only a system message.') 386 | 387 | if system_count and (not self.support_system): 388 | system = '\n' + msgs[0].content.strip() + '\n' 389 | if msgs[1].role != 'user': 390 | msgs = [ChatMessage(role='user', content=system)] + msgs[:1] 391 | elif isinstance(msgs[1].content, list): 392 | msgs[1].content = [TextContent(type='text', text=system + '\n\n')] + msgs[1].content 393 | else: 394 | msgs[1].content = system + '\n\n' + msgs[1].content 395 | 396 | msgs_wo_sys = msgs[1:] if system_count else msgs 397 | if (not self.allow_multiple_assistant) and (msgs_wo_sys[0].role != 'user'): 398 | msgs_wo_sys = [ChatMessage(role='user', content='')] + msgs_wo_sys 399 | 400 | last_role = None 401 | for i, m in enumerate(msgs_wo_sys): 402 | if not self.support_tool_call and m.role == 'tool': 403 | m.role = 'user' 404 | if isinstance(m.content, list): 405 | m.content = [TextContent(type='text', text='\n')] + m.content + [TextContent(type='text', text='\n')] 406 | elif isinstance(m.content, str): 407 | m.content = '\n' + m.content + '\n' 408 | 409 | if (not self.is_vision) and (isinstance(m.content, list)): 410 | raise ValueError('Text only template only support string contents.') 411 | 412 | if (not self.allow_multiple_assistant) and (m.role == last_role): 413 | raise ValueError('Current chat template only support user/assistant/user/assistant message sequences.') 414 | last_role = m.role 415 | 416 | msgs = [msgs[0]] + msgs_wo_sys if system_count else msgs_wo_sys 417 | return msgs 418 | 419 | def apply_chat_template(self, 420 | messages: List[Union[Dict[str, Any], ChatMessage]], 421 | tools: Optional[List[Dict[str, Any]]] = None, 422 | tool_choice: Union[Literal['none', 'auto', 'required'], Dict[str, Union[str, Dict[str, str]]]] = 'auto', 423 | reasoning: bool = False, 424 | add_generation_prompt: bool = True 425 | ) -> Tuple[str, Optional[List[str]]]: 426 | import json 427 | from .chat_presets import DEFAULT_TOOL_SYSTEM 428 | msgs = self._validate_msg_seq(messages) 429 | images = [] 430 | for m in msgs: 431 | if isinstance(m.content, list): 432 | for c in m.content: 433 | if isinstance(c, ImageContent): 434 | if isinstance(c.image_url, ImageURL): 435 | image_str = c.image_url.url 436 | else: 437 | image_str = c.image_url 438 | images.append(image_str) 439 | images = images if images else None 440 | 441 | 442 | msgs = convert_vision_message_list(msgs, self._model_type) if self.is_vision else [m.model_dump() for m in msgs] 443 | tools = tools if tool_choice != 'none' else None 444 | if tools: 445 | [OpenAIToolSchema.model_validate(t) for t in tools] 446 | continue_final_message = False 447 | if msgs[-1]['role'] =='assistant': 448 | continue_final_message = True 449 | add_generation_prompt = False 450 | if tools and (not self.support_tool_call): 451 | tool_json = '\n'.join([json.dumps(tool) for tool in tools]) 452 | if msgs[0]['role'] == 'system': 453 | content = msgs[0].get('content', '') + DEFAULT_TOOL_SYSTEM.replace('$$tool_list$$', tool_json) 454 | msgs[0]['content'] = content.strip() 455 | else: 456 | msgs = [dict(role='system', content=DEFAULT_TOOL_SYSTEM.replace('$$tool_list$$', tool_json).strip())] + msgs 457 | 458 | last_role = msgs[-1]['role'] 459 | 460 | if self._ban_none_content: 461 | msgs = [self._content_to_str_with_tool_call(m) for m in msgs] 462 | 463 | prompt = self.tokenizer.apply_chat_template(conversation=msgs, 464 | tools=tools if self.support_tool_call else None, tokenize=False, 465 | add_generation_prompt=add_generation_prompt, 466 | continue_final_message=continue_final_message) 467 | 468 | if last_role == 'user': 469 | if reasoning and self.reasoning_start: 470 | if not prompt.rstrip().endswith(self.reasoning_start.rstrip()): 471 | prompt += self.reasoning_start 472 | elif self.reasoning_start: 473 | if not prompt.rstrip().endswith(self.reasoning_start.rstrip()): 474 | prompt += self.reasoning_start 475 | prompt += self.reasoning_end 476 | 477 | if (tool_choice == 'required') or isinstance(tool_choice, dict): 478 | prompt += self.tool_start 479 | 480 | return prompt, images 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | -------------------------------------------------------------------------------- /src/mlx_textgen/sampling_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Optional, Tuple, TYPE_CHECKING 2 | from pydantic import BaseModel, Field, model_validator 3 | import mlx.core as mx 4 | 5 | if TYPE_CHECKING: 6 | from mlx.core import array 7 | try: 8 | from outlines.processors.base_logits_processor import OutlinesLogitsProcessor 9 | from outlines.processors.structured import JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor 10 | from outlines.models.transformers import TransformerTokenizer 11 | except: 12 | pass 13 | 14 | 15 | def apply_repetition_penalty(logits: "array", token_ids: "array", penalty: float = 1.0, context_size: Optional[int] = None, row_range: Optional["array"] = None) -> "array": 16 | """Applies a repetition penalty to the logits based on previously seen tokens. 17 | 18 | This penalty discourages the model from repeating tokens that have 19 | already appeared in the sequence. 20 | 21 | Args: 22 | logits (array): The logits to adjust. 23 | token_ids (array): The previously generated token ids. 24 | penalty (float, optional): The penalty to apply. Should be >= 1.0. 25 | Values > 1 penalize repetition, values < 1 encourage it. Default: 1.0. 26 | context_size (Optional[int], optional): The context size to consider for the repetition penalty. 27 | If None, the entire history is used. Default: None. 28 | row_range (Optional[array], optional): An optional array of indices to use for selecting logits. Defaults to None. 29 | 30 | Returns: 31 | array: The adjusted logits. 32 | """ 33 | if penalty == 1: 34 | return logits 35 | 36 | tokens = token_ids if context_size is None else token_ids[:, -context_size:] 37 | if tokens.shape[1] > 0: 38 | rows = mx.arange(tokens.shape[0])[:, mx.newaxis] if row_range is None else row_range 39 | selected_logits = logits[rows, tokens] 40 | selected_logits = mx.where( 41 | selected_logits < 0, selected_logits * penalty, selected_logits / penalty 42 | ) 43 | logits[rows, tokens] = selected_logits 44 | return logits 45 | 46 | def apply_presence_penalty(logits: "array", token_ids: "array", penalty: float = 0.0, context_size: Optional[int] = None, row_range: Optional["array"] = None) -> "array": 47 | """Applies a presence penalty to the logits based on previously seen tokens. 48 | 49 | This penalty adds a fixed bias to the logits of tokens that have 50 | already appeared in the sequence. This encourages the model to 51 | generate novel tokens. 52 | 53 | Args: 54 | logits (array): The logits to adjust. 55 | token_ids (array): The previously generated token ids. 56 | penalty (float, optional): The penalty to apply. Should be >= 0. 57 | Values > 0 penalize presence, values < 0 encourage it. Default: 0.0. 58 | context_size (Optional[int], optional): The context size to consider for the presence penalty. 59 | If None, the entire history is used. Default: None. 60 | row_range (Optional[array], optional): An optional array of indices to use for selecting logits. Defaults to None. 61 | 62 | Returns: 63 | array: The adjusted logits. 64 | """ 65 | if penalty == 0: 66 | return logits 67 | 68 | tokens = token_ids if context_size is None else token_ids[:, -context_size:] 69 | if tokens.shape[1] > 0: 70 | rows = mx.arange(tokens.shape[0])[:, mx.newaxis] if row_range is None else row_range 71 | logits[rows, tokens] -= penalty 72 | return logits 73 | 74 | 75 | def apply_frequency_penalty(logits: "array", token_ids: "array", penalty: float = 0.0, context_size: Optional[int] = None, vocab_ids: Optional["array"] = None) -> "array": 76 | """Applies a frequency penalty to the logits based on the frequency of previously seen tokens. 77 | 78 | Args: 79 | logits (array): The logits to adjust. 80 | token_ids (array): The previously generated token ids. 81 | penalty (float, optional): The penalty to apply. Should be greater than or equal to 0. 82 | Values > 0 discourage frequent tokens, while values < 0 encourage them. Default: 0.0. 83 | context_size (Optional[int, optional): The context size to consider for the frequency penalty. 84 | If None, the entire history is used. Default: None. 85 | vocab_ids (Optional[array], optional): An optional array of vocab ids to use for calculating the frequency. Defaults to None. 86 | 87 | Returns: 88 | array: The adjusted logits. 89 | """ 90 | if penalty == 0: 91 | return logits 92 | 93 | tokens = token_ids if context_size is None else token_ids[:, -context_size:] 94 | 95 | if tokens.size == 0: 96 | return logits 97 | 98 | tids = mx.arange(logits.shape[1], dtype=tokens.dtype) if vocab_ids is None else vocab_ids 99 | frequency_factor = (tokens[..., mx.newaxis] == tids).sum(axis=1) ** 0.5 100 | 101 | logits -= frequency_factor * penalty 102 | 103 | return logits 104 | 105 | def create_logit_bias_args(logit_bias: Dict[int, float]) -> Dict[str, "array"]: 106 | """Creates arguments for applying a logit bias. 107 | 108 | Args: 109 | logit_bias (Dict[int, float]): A dictionary mapping token ids to bias values. 110 | 111 | Returns: 112 | Dict[str, "array"]: A dictionary containing the logit keys and biases as mlx arrays. 113 | """ 114 | args = dict( 115 | logit_key=mx.array(list(logit_bias.keys())), 116 | logit_bias=mx.array(list(logit_bias.values())) 117 | ) 118 | return args 119 | 120 | 121 | def apply_logit_bias(logits: "array", logit_keys: "array", logit_bias: "array") -> "array": 122 | """Applies a logit bias to the logits. 123 | 124 | Args: 125 | logits (array): The logits to adjust. 126 | logit_keys (array): The token ids to apply the bias to. 127 | logit_bias (array): The bias values to apply to the logits. 128 | 129 | Returns: 130 | array: The adjusted logits. 131 | """ 132 | if logit_bias is None: 133 | return logits 134 | 135 | logits[:, logit_keys] += logit_bias 136 | return logits 137 | 138 | def apply_temperature(logits: "array", temperature: float = 0.0, row_range: Optional["array"] = None) -> "array": 139 | """Applies temperature scaling to the logits. 140 | 141 | If temperature is 0, performs greedy sampling by setting the probability 142 | of the most likely token to 1 and all others to 0. Otherwise, divides 143 | the logits by the temperature. 144 | 145 | Args: 146 | logits (array): The logits to adjust. 147 | temperature (float, optional): The temperature to apply. If 0, performs greedy sampling. 148 | Defaults to 0.0. 149 | row_range (Optional["array"], optional): An optional array of indices to use for selecting logits. 150 | Defaults to None. 151 | Returns: 152 | array: The adjusted logits. 153 | """ 154 | if temperature != 0: 155 | return logits / temperature 156 | 157 | else: 158 | from mlx.core import arange, inf 159 | indices = logits.argmax(axis=1).reshape(-1, 1) 160 | if row_range is None: 161 | rows = arange(logits.shape[0]).reshape(-1, 1) 162 | else: 163 | rows = row_range 164 | logits[:, :] = -inf 165 | logits[rows, indices] = 1 166 | return logits 167 | 168 | def apply_top_k(logits: "array", top_k: Optional[int] = None, row_range: Optional["array"] = None) -> "array": 169 | """Applies top-k filtering to the logits. 170 | 171 | This keeps only the top k tokens with the highest probabilities and 172 | sets the probabilities of the remaining tokens to negative infinity, 173 | effectively removing them from consideration. 174 | 175 | Args: 176 | logits (array): The logits to adjust. 177 | top_k (Optional[int, optional): The number of top tokens to keep. 178 | If None, no filtering is applied. Defaults to None. 179 | row_range (Optional["array"], optional): An optional array of indices to use for selecting logits. 180 | Defaults to None. 181 | 182 | Returns: 183 | array: The adjusted logits. 184 | """ 185 | if (top_k is None) or (logits.shape[1] < top_k): 186 | return logits 187 | 188 | rows = mx.arange(logits.shape[0]).reshape(-1, 1) if row_range is None else row_range 189 | token_sorted = mx.argsort(-logits) 190 | logits[rows, token_sorted[:, top_k:]] = -mx.inf 191 | return logits 192 | 193 | def apply_top_p(logits: "array", top_p: float = 1.0, is_prob: bool = False, row_range: Optional["array"] = None) -> "array": 194 | """Applies top-p filtering to the logits. 195 | 196 | This keeps only the tokens with a cumulative probability above a 197 | certain threshold (top_p) and sets the probabilities of the remaining 198 | tokens to zero, effectively removing them from consideration. 199 | 200 | Args: 201 | logits (array): The logits to adjust. 202 | top_p (float, optional): The cumulative probability threshold. 203 | Should be between 0 and 1. If 1, no filtering is applied. 204 | Defaults to 1.0. 205 | is_prob (bool, optional): Whether the input is probabilities or logits. 206 | If False, softmax is applied to the logits before filtering. 207 | Defaults to False. 208 | row_range (Optional["array"], optional): An optional array of indices to use for selecting logits. 209 | Defaults to None. 210 | 211 | Returns: 212 | array: The adjusted logits or probabilities. 213 | """ 214 | if top_p == 1: 215 | return logits if is_prob else mx.softmax(logits, axis=-1) 216 | 217 | rows = mx.arange(logits.shape[0]).reshape(-1, 1) if row_range is None else row_range 218 | probs = mx.softmax(logits, axis=-1) if not is_prob else logits 219 | token_sorted = mx.argsort(probs) 220 | sorted_probs = probs[rows, token_sorted] 221 | cumulative_probs = mx.cumsum(sorted_probs, axis=-1) 222 | top_probs = mx.where( 223 | cumulative_probs > 1 - top_p, 224 | sorted_probs, 225 | 0, 226 | ) 227 | probs[rows, token_sorted] = top_probs 228 | return probs 229 | 230 | def apply_min_p(logits: "array", min_p: float = 0.0, is_prob: bool = False, min_tokens_to_keep: int = 1, row_range: Optional["array"] = None) -> "array": 231 | """Applies min-p filtering to the logits. 232 | 233 | This keeps only the tokens with a probability above a 234 | certain threshold (min_p * max(probs)) and sets the probabilities 235 | of the remaining tokens to zero, effectively removing them from 236 | consideration. A minimum number of tokens are always kept to 237 | avoid completely removing all candidates. 238 | 239 | Args: 240 | logits (array): The logits to adjust. 241 | min_p (float, optional): The minimum probability threshold, 242 | as a fraction of the most likely probability. Should be 243 | between 0 and 1. If 0, no filtering is applied. 244 | Defaults to 0.0. 245 | is_prob (bool, optional): Whether the input is probabilities or 246 | logits. If False, softmax is applied to the logits before 247 | filtering. Defaults to False. 248 | min_tokens_to_keep (int, optional): The minimum number of tokens 249 | to keep, regardless of their probability. Defaults to 1. 250 | row_range (Optional["array"], optional): An optional array of 251 | indices to use for selecting logits. Defaults to None. 252 | 253 | Returns: 254 | array: The adjusted logits or probabilities. 255 | """ 256 | if min_p == 0: 257 | return logits if is_prob else mx.softmax(logits, axis=-1) 258 | 259 | rows = mx.arange(logits.shape[0]).reshape(-1, 1) if row_range is None else row_range 260 | probs = mx.softmax(logits, axis=-1) if not is_prob else logits 261 | token_sorted = mx.argsort(-probs) 262 | sorted_probs = probs[rows, token_sorted] 263 | top_probs = probs.max(axis=-1).reshape(-1, 1) 264 | scaled_min_p = min_p * top_probs 265 | tokens_to_remove = sorted_probs < scaled_min_p 266 | tokens_to_remove[..., :min_tokens_to_keep] = False 267 | selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) 268 | probs[rows, token_sorted] = selected_probs 269 | return probs 270 | 271 | def create_json_logit_processor(json_schema: Dict[str, Any], tokenizer: "TransformerTokenizer", whitespace_pattern: Optional[str] = None) -> "JSONLogitsProcessor": 272 | """Creates a JSON logits processor for constrained generation. 273 | 274 | The JSON logits processor ensures that the generated text conforms to the 275 | provided JSON schema. 276 | 277 | Args: 278 | json_schema (Dict[str, Any]): The JSON schema to constrain the generation. 279 | tokenizer (TransformerTokenizer): The tokenizer used by the model. 280 | whitespace_pattern (Optional[str], optional): A regex pattern defining what constitutes a whitespace. 281 | Defaults to None which uses the tokenizer's default whitespace pattern. 282 | 283 | Returns: 284 | JSONLogitsProcessor: A JSON logits processor instance. 285 | """ 286 | from outlines.processors.structured import JSONLogitsProcessor 287 | return JSONLogitsProcessor(schema=json_schema, tokenizer=tokenizer, whitespace_pattern=whitespace_pattern) 288 | 289 | def create_regex_logit_processor(regex_pattern: str, tokenizer: "TransformerTokenizer") -> "RegexLogitsProcessor": 290 | """Creates a regex logits processor for constrained generation. 291 | 292 | The regex logits processor ensures that the generated text conforms to the 293 | provided regex pattern. 294 | 295 | Args: 296 | regex_pattern (str): The regex pattern to constrain the generation. 297 | tokenizer (TransformerTokenizer): The tokenizer used by the model. 298 | 299 | Returns: 300 | RegexLogitsProcessor: A regex logits processor instance. 301 | """ 302 | from outlines.processors.structured import RegexLogitsProcessor 303 | return RegexLogitsProcessor(regex_string=regex_pattern, tokenizer=tokenizer) 304 | 305 | def create_choice_logit_processor(choices: List[str], tokenizer: "TransformerTokenizer") -> "RegexLogitsProcessor": 306 | """Creates a choice logits processor for constrained generation. 307 | 308 | The choice logits processor ensures that the generated text is one of the 309 | provided choices. 310 | 311 | Args: 312 | choices (List[str): A list of strings representing the possible choices. 313 | tokenizer (TransformerTokenizer): The tokenizer used by the model. 314 | 315 | Returns: 316 | RegexLogitsProcessor: A regex logits processor instance. 317 | """ 318 | regex_pattern = r"(" + r"|".join(list(set(choices))) + r")" 319 | return create_regex_logit_processor(regex_pattern=regex_pattern, tokenizer=tokenizer) 320 | 321 | def create_cfg_logit_processor(cfg_str: str, tokenizer: "TransformerTokenizer") -> "CFGLogitsProcessor": 322 | """Creates a CFG logits processor for constrained generation. 323 | 324 | The CFG logits processor ensures that the generated text conforms to the 325 | provided Context-Free Grammar (CFG). 326 | 327 | Args: 328 | cfg_str (str): The CFG string to constrain the generation. 329 | tokenizer (TransformerTokenizer): The tokenizer used by the model. 330 | 331 | Returns: 332 | CFGLogitsProcessor: A CFG logits processor instance. 333 | """ 334 | from outlines.processors.structured import CFGLogitsProcessor 335 | return CFGLogitsProcessor(cfg_str=cfg_str, tokenizer=tokenizer) 336 | 337 | class SamplingParams(BaseModel): 338 | temperature: float = Field(ge = 0.0, default= 0.0) 339 | top_k: Optional[int] = Field(gt = 0, default = None) 340 | top_p: float = Field(gt = 0.0, le = 1.0, default = 1.0) 341 | min_p: float = Field(ge = 0.0, le = 1.0, default = 0.0) 342 | stop: List[str] = Field(default_factory=list) 343 | max_completion_tokens: int = Field(gt = 0, default=4096) 344 | max_reasoning_tokens: int = Field(ge = 0, default=0) 345 | min_tokens_to_keep: int = Field(gt = 0, default = 1) 346 | frequency_penalty: float = Field(ge = -2.0, le = 2.0, default = 0.0) 347 | presence_penalty: float = Field(ge = -2.0, le = 2.0, default = 0.0) 348 | repetition_penalty: float = Field(gt = 0.0, default = 1.0) 349 | penalty_context_size: Optional[int] = Field(gt = 0, default = 1000) 350 | logit_bias: Optional[Dict[int, float]] = Field(default = None) 351 | seed: Optional[int] = Field(default = None) 352 | guided_json: Optional[Dict[str, Any]] = None 353 | guided_choice: Optional[List[str]] = None 354 | guided_regex: Optional[str] = None 355 | guided_grammar: Optional[str] = None 356 | whitespace_pattern: Optional[str] = None 357 | logprobs: bool = False 358 | top_logprobs: int = Field(gt=0, default=4) 359 | 360 | @model_validator(mode='after') 361 | def validate_single_penalty_type(self) -> "SamplingParams": 362 | """ 363 | Ensures that only one of frequency_penalty, presence_penalty, 364 | or repetition_penalty is active (not at its default/neutral value). 365 | """ 366 | active_penalties = 0 367 | 368 | if self.frequency_penalty != 0.0: 369 | active_penalties += 1 370 | if self.presence_penalty != 0.0: 371 | active_penalties += 1 372 | if self.repetition_penalty != 1.0: # Default is 1.0 (no penalty) 373 | active_penalties += 1 374 | 375 | if active_penalties > 1: 376 | raise ValueError( 377 | "Only one of 'frequency_penalty', 'presence_penalty', or 'repetition_penalty' " 378 | "can be active (i.e., set to a non-default value) at a time." 379 | ) 380 | return self 381 | 382 | @model_validator(mode='after') 383 | def validate_single_guided_decoding_processor(self) -> "SamplingParams": 384 | active_processor = 0 385 | 386 | for p in [self.guided_json, self.guided_choice, self.guided_regex, self.guided_grammar]: 387 | if p is not None: 388 | active_processor += 1 389 | 390 | if active_processor > 1: 391 | raise ValueError( 392 | 'Only one of "guided_json", "guided_choice", "guided_regex", "guided_grammar" can be used at a time.' 393 | ) 394 | 395 | if active_processor > 0: 396 | from importlib.util import find_spec 397 | if find_spec("outlines") is None: 398 | raise ValueError('Guided decoding is not supported as Outlines is not installed. Please install Outlines with `pip install outlines` to enable guided decoding.') 399 | 400 | return self 401 | 402 | class Sampler: 403 | """Implements the sampling procedure for the model. 404 | """ 405 | def __init__(self, params: SamplingParams, tokenizer: Optional["TransformerTokenizer"] = None) -> None: 406 | """Initializes the Sampler with the given parameters and tokenizer. 407 | 408 | Args: 409 | params (SamplingParams): The sampling parameters to use. 410 | tokenizer (Optional["TransformerTokenizer"], optional): The tokenizer used by the model. Defaults to None. 411 | """ 412 | self.params = params 413 | self.tokenizer = tokenizer 414 | 415 | if self.params.logit_bias: 416 | self.logit_bias_args = create_logit_bias_args(self.params.logit_bias) 417 | 418 | if self.params.seed is not None: 419 | mx.random.seed(self.params.seed) 420 | 421 | if self.params.guided_json: 422 | self.structured_processor = create_json_logit_processor(self.params.guided_json, tokenizer=tokenizer, whitespace_pattern=self.params.whitespace_pattern) 423 | elif self.params.guided_choice: 424 | self.structured_processor = create_choice_logit_processor(self.params.guided_choice, tokenizer=tokenizer) 425 | elif self.params.guided_regex: 426 | self.structured_processor = create_regex_logit_processor(self.params.guided_regex, tokenizer=tokenizer) 427 | elif self.params.guided_grammar: 428 | self.structured_processor = create_cfg_logit_processor(self.params.guided_grammar, tokenizer=tokenizer) 429 | else: 430 | self.structured_processor = None 431 | 432 | def sample(self, logits: "array", token_ids: "array", start_index: int) -> Tuple["array", "array"]: 433 | """Samples the next token based on the logits. 434 | 435 | Applies various sampling techniques such as guided decoding, logit bias, 436 | frequency/presence/repetition penalties, temperature scaling, top-k, 437 | top-p, and min-p filtering. 438 | 439 | Args: 440 | logits (array): The logits to sample from. 441 | token_ids (array): The previously generated token ids. 442 | start_index (int): The starting index of the current generation step. 443 | 444 | Returns: 445 | Tuple[array, array]: A tuple containing the new tokens and the log probabilities. 446 | """ 447 | if not hasattr(self, 'vocab_ids'): 448 | self.vocab_ids = mx.arange(logits.shape[1], dtype=token_ids.dtype) 449 | 450 | if not hasattr(self, 'row_range'): 451 | self.row_range = mx.arange(logits.shape[0]).reshape(-1, 1) 452 | 453 | if self.row_range.shape[0] != logits.shape[0]: 454 | self.row_range = mx.arange(logits.shape[0]).reshape(-1, 1) 455 | 456 | if self.structured_processor: 457 | logits = self.structured_processor(input_ids=token_ids[:, start_index:], logits=logits) 458 | 459 | if self.params.logit_bias: 460 | logits = apply_logit_bias(logits, **self.logit_bias_args) 461 | 462 | if self.params.frequency_penalty != 0: 463 | logits = apply_frequency_penalty( 464 | logits=logits, 465 | token_ids=token_ids, 466 | penalty=self.params.frequency_penalty, 467 | context_size=self.params.penalty_context_size, 468 | vocab_ids=self.vocab_ids 469 | ) 470 | 471 | elif self.params.presence_penalty != 0: 472 | logits = apply_presence_penalty( 473 | logits=logits, 474 | token_ids=token_ids, 475 | penalty=self.params.presence_penalty, 476 | context_size=self.params.penalty_context_size, 477 | row_range=self.row_range 478 | ) 479 | 480 | elif self.params.repetition_penalty != 1: 481 | logits = apply_repetition_penalty( 482 | logits=logits, 483 | token_ids=token_ids, 484 | penalty=self.params.repetition_penalty, 485 | context_size=self.params.penalty_context_size, 486 | row_range=self.row_range 487 | ) 488 | 489 | logits = apply_temperature(logits=logits, temperature=self.params.temperature, row_range=self.row_range) 490 | logits = apply_top_k(logits=logits, top_k=self.params.top_k, row_range=self.row_range) 491 | 492 | logprobs = logits - mx.logsumexp(logits, axis=-1).reshape(-1, 1) 493 | 494 | probs = apply_min_p(logits=logits, min_p=self.params.min_p, is_prob=False, min_tokens_to_keep=self.params.min_tokens_to_keep, row_range=self.row_range) 495 | probs = apply_top_p(logits=probs, is_prob=True, top_p=self.params.top_p, row_range=self.row_range) 496 | 497 | new_tokens = mx.random.categorical(mx.log(probs), axis=-1).reshape(-1, 1) 498 | 499 | return new_tokens, logprobs 500 | 501 | 502 | 503 | -------------------------------------------------------------------------------- /src/mlx_textgen/engine.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Any, Literal, Union, Iterator, TYPE_CHECKING 2 | from .chat_utils import DEFAULT_TEMPLATES 3 | from pydantic import BaseModel 4 | if TYPE_CHECKING: 5 | from logging import Logger 6 | from .model_utils import LLMModel 7 | from .generation_utils import TextCompletionOutput, ChatCompletionOutput 8 | 9 | def print_images_list(images: List): 10 | new = [] 11 | for i in images: 12 | if isinstance(i, list): 13 | new.append(print_images_list(i)) 14 | elif isinstance(i, str): 15 | new.append(i[:50]) 16 | else: 17 | new.append(i) 18 | return new 19 | 20 | class ModelConfig(BaseModel): 21 | model_id_or_path: str 22 | tokenizer_repo_or_path: Optional[str] = None 23 | model_kwargs: Optional[Dict[str, Any]] = None 24 | tokenizer_kwargs: Optional[Dict[str, Any]] = None 25 | model_name: Optional[str] = None 26 | enable_cache: bool = True 27 | preprocess_batch_size: int = 512 28 | extra_stop_words: Optional[Union[str, List[str]]] = None 29 | reasoning_parser: Optional[Literal['deepseek_r1']] = None 30 | default_template: Optional[DEFAULT_TEMPLATES] = None 31 | 32 | def get_model_name(model_id_or_path: str, model_name: Optional[str] = None) -> str: 33 | if model_name: 34 | return model_name 35 | name = model_id_or_path.split('/')[-1].split('\\')[-1] 36 | name = name.lower().replace('_', '-') 37 | if 'bit' in name.split('-')[-1]: 38 | name = '-'.join(name.split('-')[:-1]) 39 | return name 40 | 41 | class InferenceEngine: 42 | 43 | def __init__(self, 44 | model_configs: List[ModelConfig], 45 | min_tokens: int = 20, 46 | max_reprocess_tokens: int = 250, 47 | replace_threshold: float = 0.95, 48 | max_capacity: int = 50, 49 | use_reasoning_content: bool = False, 50 | logger: Optional["Logger"] = None 51 | ): 52 | self._logger = logger 53 | self._use_reasoning_content = use_reasoning_content 54 | from .model_utils import make_model_exist 55 | import time 56 | start = time.perf_counter() 57 | self._model_dict = dict() 58 | for mc in model_configs: 59 | model_name = get_model_name(mc.model_id_or_path, mc.model_name) 60 | if model_name not in self._model_dict: 61 | self._model_dict[model_name] = mc 62 | kwargs = mc.model_kwargs if mc.model_kwargs else dict() 63 | make_model_exist(model_id_or_path=mc.model_id_or_path, **kwargs) 64 | else: 65 | msg = f'More than one model is named as "{model_name}". Please set the "model_name" argument differently for these models.' 66 | self.log(msg, 'error') 67 | ValueError(msg) 68 | self._model = None 69 | self._cache_manage_config = dict( 70 | min_tokens=min_tokens, 71 | max_reprocess_tokens=max_reprocess_tokens, 72 | replace_threshold=replace_threshold, 73 | max_capacity=max_capacity 74 | ) 75 | end = time.perf_counter() 76 | self.log(f'All models prepared locally. Time taken: {end - start:.3f}s.') 77 | 78 | def log(self, msg: str, level: Literal["error", "warning", "info", "debug"] = "info") -> None: 79 | """Logs a message to the logger at the specified level. 80 | 81 | Args: 82 | msg (str): The message to log. 83 | level (Literal["error", "warning", "info", "debug"], optional): The logging level. Defaults to "info". 84 | """ 85 | levels = dict( 86 | error=40, 87 | warning=30, 88 | info=20, 89 | debug=10 90 | ) 91 | if self._logger: 92 | self._logger.log(level=levels.get(level), msg=msg) 93 | 94 | @property 95 | def model(self) -> Optional["LLMModel"]: 96 | return self._model 97 | 98 | @property 99 | def model_dict(self) -> Dict[str, ModelConfig]: 100 | return self._model_dict 101 | 102 | @property 103 | def model_info(self) -> List[Dict[str, Optional[Union[str, int, List, Dict[str, Any]]]]]: 104 | if not hasattr(self, '_model_info'): 105 | self._model_info = [self._get_model_info(k, v) for k, v in self.model_dict.items()] 106 | return self._model_info 107 | 108 | def _get_model_info(self, key: str, config: ModelConfig) -> Dict[str, Optional[Union[str, int, List, Dict[str, Any]]]]: 109 | from datetime import datetime as dt 110 | config = dict( 111 | id=key, 112 | object='model', 113 | created=int(dt.now().timestamp()), 114 | owned_by=None, 115 | permission=[], 116 | root='root', 117 | info=dict( 118 | tokenizer_id=config.tokenizer_repo_or_path if config.tokenizer_repo_or_path else config.model_id_or_path, 119 | tokenizer_kwargs=config.tokenizer_kwargs 120 | ) 121 | ) 122 | return config 123 | 124 | def load_model(self, model_name: str) -> None: 125 | import time 126 | 127 | start = time.perf_counter() 128 | if model_name not in self._model_dict: 129 | error = f'Model "{model_name}" does not exist.' 130 | self.log(error, 'error') 131 | raise ValueError(error) 132 | 133 | if self.model and (self.model.model_name == model_name): 134 | return 135 | 136 | elif self.model: 137 | self.model.unload() 138 | del self._model 139 | self._model = None 140 | 141 | from .model_utils import LLMModel 142 | mc = self.model_dict[model_name] 143 | self._model = LLMModel( 144 | model_id_or_path=mc.model_id_or_path, 145 | tokenizer_repo_or_path=mc.tokenizer_repo_or_path, 146 | model_kwargs=mc.model_kwargs, 147 | tokenizer_kwargs=mc.tokenizer_kwargs, 148 | model_name=mc.model_name, 149 | logger=self._logger, 150 | enable_cache=mc.enable_cache, 151 | cache_manage_config=self._cache_manage_config, 152 | preprocess_batch_size=mc.preprocess_batch_size, 153 | extra_stop_words=mc.extra_stop_words, 154 | reasoning_parser=mc.reasoning_parser, 155 | default_template=mc.default_template 156 | ) 157 | end = time.perf_counter() 158 | self.log(f'Model "{model_name}" loaded. Time taken: {end - start:.3f}s.') 159 | 160 | def generate(self, 161 | model: str, 162 | prompt: Union[str, List[str]], 163 | images: Optional[List[Optional[List[str]]]] = None, 164 | logit_bias: Optional[Dict[str, int]] = None, 165 | logprobs: Optional[int] = None, 166 | stream: bool = False, 167 | n: int = 1, 168 | max_tokens: int = 4096, 169 | stop: Optional[List[str]] = None, 170 | seed: Optional[int] = None, 171 | presence_penalty: float = 0.0, 172 | frequency_penalty: float = 0.0, 173 | repetition_penalty: float = 1.0, 174 | penalty_context_size: Optional[int] = 1000, 175 | temperature: float = 0.0, 176 | top_k: Optional[int] = None, 177 | top_p: float = 1.0, 178 | min_p: float = 0.0, 179 | min_tokens_to_keep: int = 1, 180 | guided_json: Optional[Dict[str, Any]] = None, 181 | guided_choice: Optional[List[str]] = None, 182 | guided_regex: Optional[str] = None, 183 | guided_grammar: Optional[str] = None, 184 | whitespace_pattern: Optional[str] = None, 185 | **kwargs 186 | ) -> Union["TextCompletionOutput", Iterator["TextCompletionOutput"]]: 187 | from .sampling_utils import SamplingParams 188 | from datetime import datetime as dt 189 | import uuid 190 | from .generation_utils import TextCompletionOutput, to_completion_logprobs 191 | self.load_model(model) 192 | params = SamplingParams( 193 | temperature=temperature, 194 | top_k=top_k, 195 | top_p=top_p, 196 | min_p=min_p, 197 | stop=stop if stop else [], 198 | max_completion_tokens=max_tokens, 199 | max_reasoning_tokens=0, 200 | min_tokens_to_keep=min_tokens_to_keep, 201 | frequency_penalty=frequency_penalty, 202 | presence_penalty=presence_penalty, 203 | repetition_penalty=repetition_penalty, 204 | penalty_context_size=penalty_context_size, 205 | logit_bias={int(k): v for k, v in logit_bias} if logit_bias else None, 206 | seed=seed, 207 | guided_json=guided_json, 208 | guided_choice=guided_choice, 209 | guided_regex=guided_regex, 210 | guided_grammar=guided_grammar, 211 | whitespace_pattern=whitespace_pattern, 212 | logprobs=True if logprobs else False, 213 | top_logprobs=logprobs if logprobs else 4 214 | ) 215 | cpl_id = 'cmpl-' + uuid.uuid4().hex 216 | created = int(dt.now().timestamp()) 217 | if stream: 218 | def gen_output(): 219 | output = self.model.stream(prompts=prompt, sampling_params=params, images=images, n=n, is_thinking=False) 220 | status = [None] if isinstance(prompt, str) else [None] * len(prompt) 221 | prompt_tokens = None 222 | completion_tokens = [0] if isinstance(prompt, str) else [0] * len(prompt) 223 | for gos in output: 224 | choices = [ 225 | dict( 226 | index=go.index, 227 | text=go.token, 228 | finish_reason=go.finish_reason, 229 | logprobs=to_completion_logprobs([go.logprobs]) if params.logprobs and (not s) else None 230 | ) 231 | for s, go in zip(status, gos) 232 | ] 233 | if prompt_tokens is None: 234 | prompt_tokens = sum([go.input_tokens for go in gos]) 235 | completion_tokens = [go.output_tokens if not s else c for c, s, go in zip(completion_tokens, status, gos)] 236 | status = [go.finish_reason if not s else s for s, go in zip(status, gos)] 237 | cmpl = dict( 238 | id=cpl_id, 239 | created=created, 240 | model=model, 241 | choices=choices 242 | ) 243 | yield TextCompletionOutput.model_validate(cmpl) 244 | 245 | # Usage information 246 | completion_tokens = sum(completion_tokens) 247 | cmpl = dict( 248 | id=cpl_id, 249 | created=created, 250 | model=model, 251 | choices=[dict(index=go.index, text='') for go in gos], 252 | usage=dict(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) 253 | ) 254 | yield TextCompletionOutput.model_validate(cmpl) 255 | return gen_output() 256 | 257 | else: 258 | output = self.model.generate(prompts=prompt, sampling_params=params, images=images, n=n, is_thinking=False) 259 | prompt_tokens = sum(output['input_tokens']) 260 | completion_tokens = sum(output['output_tokens']) 261 | usage = dict( 262 | prompt_tokens=prompt_tokens, 263 | completion_tokens=completion_tokens, 264 | total_tokens=prompt_tokens + completion_tokens 265 | ) 266 | logprobs = output['logprobs'] if output['logprobs'] else range(len(output['indices'])) 267 | cmpl = dict( 268 | id=cpl_id, 269 | created=created, 270 | model=model, 271 | choices=[dict( 272 | index=i, 273 | text=t, 274 | logprobs=to_completion_logprobs(l) if params.logprobs else None, 275 | finish_reason=fr 276 | ) for i, t, l, fr in zip( 277 | output['indices'], output['texts'], logprobs, output['finish_reasons'] 278 | )], 279 | usage=usage 280 | ) 281 | return TextCompletionOutput.model_validate(cmpl) 282 | 283 | def chat_generate(self, 284 | model: str, 285 | messages: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], 286 | logit_bias: Optional[Dict[str, int]] = None, 287 | logprobs: bool = False, 288 | top_logprobs: int = 4, 289 | stream: bool = False, 290 | n: int = 1, 291 | max_completion_tokens: int = 4096, 292 | reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None, 293 | max_reasoning_tokens: Optional[int] = None, 294 | stop: Optional[List[str]] = None, 295 | response_format: Optional[Dict[str, Any]] = None, 296 | tools: Optional[List[Dict[str, Any]]] = None, 297 | tool_choice: Union[Literal['auto', 'required', 'none'], Dict[str, Any]] = 'auto', 298 | seed: Optional[int] = None, 299 | presence_penalty: float = 0.0, 300 | frequency_penalty: float = 0.0, 301 | repetition_penalty: float = 1.0, 302 | penalty_context_size: Optional[int] = 1000, 303 | temperature: float = 0.0, 304 | top_k: Optional[int] = None, 305 | top_p: float = 1.0, 306 | min_p: float = 0.0, 307 | min_tokens_to_keep: int = 1, 308 | guided_json: Optional[Dict[str, Any]] = None, 309 | guided_choice: Optional[List[str]] = None, 310 | guided_regex: Optional[str] = None, 311 | guided_grammar: Optional[str] = None, 312 | whitespace_pattern: Optional[str] = None, 313 | use_reasoning_content: Optional[bool] = None, 314 | **kwargs 315 | ) -> Union["ChatCompletionOutput", Iterator["ChatCompletionOutput"]]: 316 | from .chat_utils import OpenAIToolSchema, build_function_call_schema, convert_tool_to_json_schema, ToolChoiceSchema, ResponseFormat 317 | from datetime import datetime as dt 318 | import uuid 319 | import json 320 | from .sampling_utils import SamplingParams 321 | from .generation_utils import ChatCompletionOutput 322 | if (not tools) and (tool_choice == 'auto'): 323 | _tool_choice = 'none' 324 | _tools = None 325 | else: 326 | _tool_choice = tool_choice 327 | _tools = tools if (_tool_choice in ('auto', 'required')) or (isinstance(_tool_choice, dict)) else None 328 | if isinstance(_tool_choice, dict): 329 | ToolChoiceSchema.model_validate(_tool_choice) 330 | 331 | if _tools: 332 | [OpenAIToolSchema.model_validate(t) for t in _tools] 333 | elif (_tool_choice == 'required') or isinstance(_tool_choice, dict): 334 | error = 'Required function calling, but no tools are provided.' 335 | self.log(error, 'error') 336 | raise ValueError(error) 337 | 338 | if guided_json and response_format: 339 | error = 'Either use "response_format" or "guided_json", but not both.' 340 | self.log(error, 'error') 341 | raise ValueError(error) 342 | 343 | _guided_json = ResponseFormat.model_validate(response_format).json_schema if response_format else guided_json 344 | 345 | if ((_tool_choice == 'required') or isinstance(_tool_choice, dict)) and _guided_json: 346 | error = 'Cannot use json schema alongside with tool calling.' 347 | self.log(error, 'error') 348 | raise ValueError(error) 349 | 350 | if _tool_choice == 'required': 351 | _guided_json = build_function_call_schema(_tools) 352 | elif isinstance(_tool_choice, dict): 353 | tool_name = _tool_choice['function']['name'] 354 | tool_schemas = [t for t in _tools if t['function']['name'] == tool_name] 355 | if len(tool_schemas) == 0: 356 | error = f'Provided tool choice "{tool_name}" not in given list of tools.' 357 | self.log(error, 'error') 358 | raise ValueError(error) 359 | _guided_json = build_function_call_schema(tool_schemas) 360 | 361 | self.load_model(model) 362 | 363 | multi_msgs = isinstance(messages[0], list) 364 | end_roles = [msgs[-1]['role'] for msgs in messages] if multi_msgs else [messages[-1]['role']] 365 | role_check = end_roles[0] 366 | if any(r != role_check for r in end_roles): 367 | error = 'Different message sequences have different roles for the last message.' 368 | self.log(error, 'error') 369 | raise ValueError(error) 370 | 371 | if role_check == 'assistant': 372 | _max_reasoning_tokens = 0 373 | 374 | elif self.model.chat_template.reasoning_start: 375 | rmap = dict(low=512, medium=2048, high=4096) 376 | if max_reasoning_tokens is not None: 377 | _max_reasoning_tokens = max_reasoning_tokens 378 | elif reasoning_effort: 379 | _max_reasoning_tokens = rmap(reasoning_effort) 380 | else: 381 | _max_reasoning_tokens = rmap['medium'] 382 | 383 | elif max_reasoning_tokens or reasoning_effort: 384 | _max_reasoning_tokens = 0 385 | self.log(f'Model "{self.model.model_name}" is not configured to support reasoning. Setting max_reasoning_tokens to 0.') 386 | 387 | else: 388 | _max_reasoning_tokens = 0 389 | 390 | use_reasoning_content = self._use_reasoning_content if use_reasoning_content is None else use_reasoning_content 391 | 392 | if _max_reasoning_tokens: 393 | tparams = SamplingParams( 394 | temperature=temperature, 395 | top_k=top_k, 396 | top_p=top_p, 397 | min_p=min_p, 398 | stop=[self.model.chat_template.reasoning_end.rstrip()], 399 | max_completion_tokens=max_completion_tokens, 400 | max_reasoning_tokens=_max_reasoning_tokens, 401 | min_tokens_to_keep=min_tokens_to_keep, 402 | frequency_penalty=frequency_penalty, 403 | presence_penalty=presence_penalty, 404 | repetition_penalty=repetition_penalty, 405 | penalty_context_size=penalty_context_size, 406 | logit_bias={int(k): v for k, v in logit_bias} if logit_bias else None, 407 | seed=seed, 408 | guided_json=None, 409 | guided_choice=None, 410 | guided_regex=None, 411 | guided_grammar=None, 412 | whitespace_pattern=whitespace_pattern, 413 | logprobs=logprobs, 414 | top_logprobs=top_logprobs 415 | ) 416 | else: 417 | tparams = None 418 | 419 | stop_str = stop if stop else [] 420 | if _tool_choice == 'auto' and (self.model.chat_template.tool_start not in stop_str): 421 | stop_str.append(self.model.chat_template.tool_start) 422 | params = SamplingParams( 423 | temperature=temperature, 424 | top_k=top_k, 425 | top_p=top_p, 426 | min_p=min_p, 427 | stop=stop_str, 428 | max_completion_tokens=max_completion_tokens, 429 | max_reasoning_tokens=_max_reasoning_tokens, 430 | min_tokens_to_keep=min_tokens_to_keep, 431 | frequency_penalty=frequency_penalty, 432 | presence_penalty=presence_penalty, 433 | repetition_penalty=repetition_penalty, 434 | penalty_context_size=penalty_context_size, 435 | logit_bias={int(k): v for k, v in logit_bias} if logit_bias else None, 436 | seed=seed, 437 | guided_json=_guided_json, 438 | guided_choice=guided_choice, 439 | guided_regex=guided_regex, 440 | guided_grammar=guided_grammar, 441 | whitespace_pattern=whitespace_pattern, 442 | logprobs=logprobs, 443 | top_logprobs=top_logprobs 444 | ) 445 | 446 | cpl_id = 'chatcmpl-' + uuid.uuid4().hex 447 | created = int(dt.now().timestamp()) 448 | 449 | if len(messages) == 0: 450 | error = 'No message sequences provided.' 451 | self.log(error, 'error') 452 | raise ValueError(error) 453 | 454 | if self.model.is_vision and _max_reasoning_tokens and (n > 1): 455 | error = 'Vsion model cannot have `n` > 2 when reasoning is enabled.' 456 | self.log(error, 'error') 457 | raise ValueError(error) 458 | 459 | 460 | msgs = [messages] if isinstance(messages[0], dict) else messages 461 | for m in msgs: 462 | if len(m) == 0: 463 | error = 'Cannot have empty messages sequences.' 464 | self.log(error, 'error') 465 | raise ValueError(error) 466 | 467 | if tparams: 468 | pi_pairs = [ 469 | self.model.chat_template.apply_chat_template(m, tools=_tools, reasoning=True) 470 | for m in msgs 471 | ] 472 | else: 473 | pi_pairs = [ 474 | self.model.chat_template.apply_chat_template(m, tools=_tools, tool_choice=_tool_choice, reasoning=False) 475 | for m in msgs 476 | ] 477 | 478 | prompts = [pp[0] for pp in pi_pairs] 479 | images = [pp[1] for pp in pi_pairs] 480 | 481 | if stream: 482 | def gen_tokens(): 483 | nonlocal prompts, images, tparams, params, n, cpl_id, created 484 | object='chat.completion.chunk' 485 | indices = list(range(len(prompts) * n)) 486 | prompt_tokens = None 487 | reasoning_tokens = [0] * len(indices) 488 | completion_tokens = [0] * len(indices) 489 | cmpls = [dict( 490 | id=cpl_id, 491 | object=object, 492 | created=created, 493 | model=model, 494 | choices = [dict(index=i, delta=dict(role='assistant', content=''))] 495 | ) for i in indices] 496 | for c in cmpls: 497 | yield ChatCompletionOutput.model_validate(c) 498 | if tparams: 499 | gen_outputs = [''] * len(indices) 500 | status = [None] * len(indices) 501 | toutput = self.model.stream(prompts=prompts, sampling_params=tparams, images=images, n=n, is_thinking=True) 502 | prefs = [self.model.chat_template.reasoning_start if not use_reasoning_content else ''] * len(indices) 503 | tend = self.model.chat_template.reasoning_end if self.model.chat_template.reasoning_end else '' 504 | for gos in toutput: 505 | gen_outputs = [t + go.token if not s else t for s, t, go in zip(status, gen_outputs, gos)] 506 | cmpls = [dict( 507 | id=cpl_id, 508 | object=object, 509 | created=created, 510 | model=model, 511 | choices = [dict( 512 | index=go.index, 513 | delta=dict( 514 | content='', 515 | reasoning_content=p + go.token if not s else None 516 | ) if use_reasoning_content else dict( 517 | content=(p + go.token + tend if go.finish_reason else p + go.token) if not s else None), 518 | finish_reason=None, 519 | logprobs=dict(content=[go.logprobs]) if tparams.logprobs and (not s) else None 520 | )] 521 | ) for s, go, p in zip(status, gos, prefs)] 522 | prefs = ['' if go.token is not None else p for p, go in zip(prefs, gos)] 523 | if prompt_tokens is None: 524 | prompt_tokens = [go.input_tokens for go in gos] 525 | reasoning_tokens = [go.output_tokens if s else rt for s, go, rt in zip(status, gos, reasoning_tokens)] 526 | status = [go.finish_reason if not s else s for s, go in zip(status, gos)] 527 | 528 | for cmpl in cmpls: 529 | if cmpl['choices'][0]['delta']['reasoning_content' if use_reasoning_content else 'content'] is not None: 530 | yield ChatCompletionOutput.model_validate(cmpl) 531 | 532 | new_prompts = [] 533 | new_images = [] 534 | for prompt in prompts: 535 | new_prompts.extend([prompt] * n) 536 | for img in images: 537 | new_images.extend([img] * n) 538 | new_prompts = [p + gt + self.model.chat_template.reasoning_end for p, gt in zip(new_prompts, gen_outputs)] 539 | prompts = new_prompts 540 | images = new_images 541 | 542 | if (_tool_choice == 'required') or (isinstance(_tool_choice, dict)): 543 | prompts = [p + self.model.chat_template.tool_start for p in prompts] 544 | is_tool_call = True 545 | else: 546 | is_tool_call = False 547 | 548 | output = self.model.stream(prompts=prompts, sampling_params=params, images=images, n=1, is_thinking=False) 549 | 550 | else: 551 | is_tool_call = (_tool_choice == 'required') or (isinstance(_tool_choice, dict)) 552 | output = self.model.stream(prompts=prompts, sampling_params=params, images=images, n=n, is_thinking=False) 553 | 554 | # Definite tool calls 555 | if is_tool_call: 556 | tool_call_strs = [''] * len(indices) 557 | call_ids = ['call_' + uuid.uuid4().hex[:8] for i in range(len(indices))] 558 | status = [None] * len(indices) 559 | logprob_list = [[]] * len(indices) 560 | for gos in output: 561 | tool_call_strs = [tcs + go.token for tcs, go in zip(tool_call_strs, gos)] 562 | if prompt_tokens is None: 563 | prompt_tokens = [go.input_tokens for go in gos] 564 | completion_tokens = [go.output_tokens if not s else c for s, go, c in zip(status, gos, completion_tokens)] 565 | if params.logprobs: 566 | logprob_list = [lp + [go.logprobs] if not s else lp for s, go, lp in zip(status, gos, logprob_list)] 567 | status = [s if s else go.finish_reason for s, go in zip(status, gos)] 568 | 569 | tool_call_list = [json.loads(tcs) for tcs in tool_call_strs] 570 | cmpls = [dict( 571 | id=cpl_id, 572 | object=object, 573 | created=created, 574 | model=model, 575 | choices=[ 576 | dict( 577 | index=go.index, 578 | logprobs=dict(content=lp) if params.logprobs else None, 579 | delta=dict( 580 | tool_calls=[dict(index=0, id=cid, function=dict(name=tc['name'], arguments=json.dumps(tc['arguments'])))] 581 | ) 582 | ) 583 | ] 584 | ) for go, cid, tc, lp in zip(gos, call_ids, tool_call_list, logprob_list)] 585 | for cmpl in cmpls: 586 | yield ChatCompletionOutput.model_validate(cmpl) 587 | 588 | cmpls = [dict( 589 | id=cpl_id, 590 | object=object, 591 | created=created, 592 | model=model, 593 | choices=[ 594 | dict(index=go.index, delta=dict(), finish_reason='tool_calls') 595 | ], 596 | usage=dict(prompt_tokens=pt, completion_tokens=ct + rt, total_tokens=pt + ct + rt, completion_tokens_details=dict(reasoning_tokens=rt)) 597 | ) for go, pt, rt, ct in zip(gos, prompt_tokens, reasoning_tokens, completion_tokens)] 598 | for cmpl in cmpls: 599 | yield ChatCompletionOutput.model_validate(cmpl) 600 | 601 | # Any other form of generation except 'auto' 602 | elif _tool_choice == 'none': 603 | status = [None] * len(indices) 604 | for gos in output: 605 | cmpls = [dict( 606 | id=cpl_id, 607 | object=object, 608 | created=created, 609 | model=model, 610 | choices=[ 611 | dict( 612 | index=go.index, 613 | logprobs=dict(content=[go.logprobs]) if params.logprobs else None, 614 | delta=dict( 615 | content=go.token if go.token and (not s) else None 616 | ) 617 | ) 618 | ] 619 | ) for go, s in zip(gos, status)] 620 | if prompt_tokens is None: 621 | prompt_tokens = [go.input_tokens for go in gos] 622 | completion_tokens = [go.output_tokens if not s else c for s, go, c in zip(status, gos, completion_tokens)] 623 | status = [s if s else go.finish_reason for s, go in zip(status, gos)] 624 | for cmpl in cmpls: 625 | if cmpl['choices'][0]['delta']['content'] is not None: 626 | yield ChatCompletionOutput.model_validate(cmpl) 627 | 628 | # Yield ended sequences as well 629 | cmpls = [dict( 630 | id=cpl_id, 631 | object=object, 632 | created=created, 633 | model=model, 634 | choices=[ 635 | dict( 636 | index=go.index, 637 | delta=dict(), 638 | finish_reason=go.finish_reason 639 | ) 640 | ], 641 | usage=dict(prompt_tokens=pt, completion_tokens=ct + rt, total_tokens=pt + ct + rt, completion_tokens_details=dict(reasoning_tokens=rt)) 642 | ) for go, ct, rt, pt in zip(gos, completion_tokens, reasoning_tokens, prompt_tokens) if go.finish_reason] 643 | for cmpl in cmpls: 644 | yield ChatCompletionOutput.model_validate(cmpl) 645 | 646 | # Deal with tool_choice="auto" case 647 | else: 648 | status = [None] * len(indices) 649 | if len(prompts) != len(indices): 650 | new_prompts = [] 651 | new_images = [] 652 | for p, img in zip(prompts, images): 653 | new_prompts.extend([p] * n) 654 | new_images.extend([img] * n) 655 | prompts = new_prompts 656 | images = new_images 657 | 658 | for gos in output: 659 | cmpls = [dict( 660 | id=cpl_id, 661 | object=object, 662 | created=created, 663 | model=model, 664 | choices=[ 665 | dict( 666 | index=go.index, 667 | logprobs=dict(content=[go.logprobs]) if params.logprobs else None, 668 | delta=dict( 669 | content=go.token if go.token and (not s) else None 670 | ) 671 | ) 672 | ] 673 | ) for go, s in zip(gos, status)] 674 | if prompt_tokens is None: 675 | prompt_tokens = [go.input_tokens for go in gos] 676 | completion_tokens = [go.output_tokens if not s else c for s, go, c in zip(status, gos, completion_tokens)] 677 | prompts = [p + go.token if not s else p for s, go, p in zip(status, gos, prompts)] 678 | status = [s if s else (go.finish_reason if go.stop_str != self.model.chat_template.tool_start else 'tool_call_start') for s, go in zip(status, gos)] 679 | for cmpl in cmpls: 680 | if cmpl['choices'][0]['delta']['content'] is not None: 681 | yield ChatCompletionOutput.model_validate(cmpl) 682 | 683 | # Yield ended sequences as well 684 | cmpls = [dict( 685 | id=cpl_id, 686 | object=object, 687 | created=created, 688 | model=model, 689 | choices=[ 690 | dict( 691 | index=go.index, 692 | delta=dict(), 693 | finish_reason=go.finish_reason 694 | ) 695 | ], 696 | usage=dict(prompt_tokens=pt, completion_tokens=ct + rt, total_tokens=pt + ct + rt, completion_tokens_details=dict(reasoning_tokens=rt)) 697 | ) for go, ct, rt, pt in zip(gos, completion_tokens, reasoning_tokens, prompt_tokens) if (go.finish_reason and (go.stop_str != self.model.chat_template.tool_start))] 698 | for cmpl in cmpls: 699 | yield ChatCompletionOutput.model_validate(cmpl) 700 | 701 | # Dealing with tool calls for those with tool_call_start 702 | if len([s for s in status if status == 'tool_call_start']) > 0: 703 | # Set guided decoding to tool call schema 704 | params.guided_choice = None 705 | params.guided_regex = None 706 | params.guided_grammar = None 707 | params.guided_json = build_function_call_schema(_tools) 708 | 709 | prompts = [p + self.model.chat_template.tool_start for p, s in zip(prompts, status) if s == 'tool_call_start'] 710 | images = [img for img, s in zip(images, status) if s == 'tool_call_start'] 711 | indices = [i for i, s in zip(indices, status) if s == 'tool_call_start'] 712 | prompt_tokens = [pt for pt, s in zip(prompt_tokens, status) if s == 'tool_call_start'] 713 | completion_tokens = [ct for ct, s in zip(completion_tokens, status) if s == 'tool_call_start'] 714 | reasoning_tokens = [rt for rt, s in zip(reasoning_tokens, status) if s == 'tool_call_start'] 715 | tool_call_strs = [''] * len(indices) 716 | call_ids = ['call_' + uuid.uuid4().hex[:8] for i in range(len(indices))] 717 | logprob_list = [[]] * len(indices) 718 | status = [None] * len(indices) 719 | output = self.model.stream(prompts, params, images, n=1, is_thinking=False) 720 | 721 | for gos in output: 722 | tool_call_strs = [tcs + go.token for tcs, go in zip(tool_call_strs, gos)] 723 | completion_tokens = [go.output_tokens if not s else c for s, go, c in zip(status, gos, completion_tokens)] 724 | if params.logprobs: 725 | logprob_list = [lp + [go.logprobs] if not s else lp for s, go, lp in zip(status, gos, logprob_list)] 726 | status = [s if s else go.finish_reason for s, go in zip(status, gos)] 727 | 728 | tool_call_list = [json.loads(tcs) for tcs in tool_call_strs] 729 | cmpls = [dict( 730 | id=cpl_id, 731 | object=object, 732 | created=created, 733 | model=model, 734 | choices=[ 735 | dict( 736 | index=i, 737 | logprobs=dict(content=lp) if params.logprobs else None, 738 | delta=dict( 739 | tool_calls=[dict(index=0, id=cid, function=dict(name=tc['name'], arguments=json.dumps(tc['arguments'])))] 740 | ) 741 | ) 742 | ] 743 | ) for cid, tc, lp, i in zip(call_ids, tool_call_list, logprob_list, indices)] 744 | for cmpl in cmpls: 745 | yield ChatCompletionOutput.model_validate(cmpl) 746 | 747 | cmpls = [dict( 748 | id=cpl_id, 749 | object=object, 750 | created=created, 751 | model=model, 752 | choices=[ 753 | dict(index=i, delta=dict(), finish_reason='tool_calls') 754 | ], 755 | usage=dict(prompt_tokens=pt, completion_tokens=ct, total_tokens=pt + ct + rt) 756 | ) for i, pt, rt, ct in zip(indices, prompt_tokens, reasoning_tokens, completion_tokens)] 757 | for cmpl in cmpls: 758 | yield ChatCompletionOutput.model_validate(cmpl) 759 | 760 | return gen_tokens() 761 | 762 | else: 763 | object = 'chat.completion' 764 | indices = list(range(len(prompts) * n)) 765 | cmpls = dict( 766 | id=cpl_id, 767 | object=object, 768 | created=created, 769 | model=model, 770 | choices = [ 771 | dict( 772 | index=i, 773 | message=dict( 774 | role='assistant', 775 | content='', 776 | reasoning_content='', 777 | tool_calls=[] 778 | ), 779 | finish_reason=None, 780 | logprobs=dict(content=[]) if params.logprobs else None 781 | ) 782 | for i in indices], 783 | usage = dict( 784 | prompt_tokens=0, 785 | completion_tokens=0, 786 | total_tokens=0, 787 | completion_tokens_details=dict( 788 | reasoning_tokens=0 789 | ) 790 | ) 791 | ) 792 | if tparams: 793 | output = self.model.generate(prompts=prompts, images=images, sampling_params=tparams, n=n, is_thinking=True) 794 | texts = output['texts'] 795 | input_tokens = output['input_tokens'] 796 | output_tokens = output['output_tokens'] 797 | logprobs = output['logprobs'] 798 | for i, t in zip(indices, texts): 799 | if use_reasoning_content: 800 | cmpls['choices'][i]['message']['reasoning_content'] = t 801 | else: 802 | cmpls['choices'][i]['message']['content'] = self.model.chat_template.reasoning_start + t + self.model.chat_template.reasoning_end 803 | if params.logprobs: 804 | for i, lp in zip(indices, logprobs): 805 | cmpls['choices'][i]['logprobs']['content'].extend(lp) 806 | 807 | if cmpls['usage']['prompt_tokens'] == 0: 808 | cmpls['usage']['prompt_tokens'] = sum(input_tokens) 809 | 810 | cmpls['usage']['completion_tokens'] += sum(output_tokens) 811 | cmpls['usage']['completion_tokens_details']['reasoning_tokens'] = sum(output_tokens) 812 | 813 | if n > 1: 814 | new_prompts = [] 815 | new_images = [] 816 | for p, img in zip(prompts, images): 817 | new_prompts.extend([p] * n) 818 | new_images.extend([img] * n) 819 | 820 | prompts = new_prompts 821 | images = new_images 822 | 823 | prompts = [p + t + self.model.chat_template.reasoning_end for p, t in zip(prompts, texts)] 824 | 825 | if (_tool_choice == 'required') or (isinstance(_tool_choice, dict)): 826 | prompts = [p + self.model.chat_template.tool_start for p in prompts] 827 | 828 | output = self.model.generate(prompts=prompts, images=images, sampling_params=params, n=1, is_thinking=False) 829 | 830 | else: 831 | output = self.model.generate(prompts=prompts, images=images, sampling_params=params, n=n, is_thinking=False) 832 | 833 | if (_tool_choice == 'required') or (isinstance(_tool_choice, dict)): 834 | texts = output['texts'] 835 | input_tokens = output['input_tokens'] 836 | output_tokens = output['output_tokens'] 837 | logprobs = output['logprobs'] 838 | 839 | tool_call_dicts = [json.loads(t) for t in texts] 840 | cids = ['call_' + uuid.uuid4().hex[:8] for i in range(len(indices))] 841 | tool_calls = [dict( 842 | index=0, 843 | id=cid, 844 | type='function', 845 | function=dict(name=tc['name'], arguments=json.dumps(tc['arguments'])) 846 | ) for tc, cid in zip(tool_call_dicts, cids)] 847 | 848 | if cmpls['usage']['prompt_tokens'] == 0: 849 | cmpls['usage']['prompt_tokens'] = sum(input_tokens) 850 | 851 | if params.logprobs: 852 | for i, lp in zip(indices, logprobs): 853 | cmpls['choices'][i]['logprobs']['content'].extend(lp) 854 | 855 | cmpls['usage']['completion_tokens'] += sum(output_tokens) 856 | 857 | cmpls['usage']['total_tokens'] = cmpls['usage']['completion_tokens'] + cmpls['usage']['prompt_tokens'] 858 | 859 | for i, tc in zip(indices, tool_calls): 860 | cmpls['choices'][i]['message']['tool_calls'].append(tc) 861 | cmpls['choices'][i]['finish_reason'] = 'tool_calls' 862 | 863 | else: 864 | texts = output['texts'] 865 | input_tokens = output['input_tokens'] 866 | output_tokens = output['output_tokens'] 867 | logprobs = output['logprobs'] 868 | finish_reasons = output['finish_reasons'] 869 | stop_strs = output['stop_strs'] 870 | is_auto = _tool_choice == 'auto' 871 | 872 | finish_reasons = ['tool_calls' if ((ss == self.model.chat_template.tool_start) and is_auto) else fr for fr, ss in zip(finish_reasons, stop_strs)] 873 | 874 | if cmpls['usage']['prompt_tokens'] == 0: 875 | cmpls['usage']['prompt_tokens'] = sum(input_tokens) 876 | 877 | if params.logprobs: 878 | for i, lp in zip(indices, logprobs): 879 | cmpls['choices'][i]['logprobs']['content'].extend(lp) 880 | 881 | cmpls['usage']['completion_tokens'] += sum(output_tokens) 882 | 883 | cmpls['usage']['total_tokens'] = cmpls['usage']['completion_tokens'] + cmpls['usage']['prompt_tokens'] 884 | 885 | for i, t, fr in zip(indices, texts, finish_reasons): 886 | cmpls['choices'][i]['message']['content'] += t 887 | cmpls['choices'][i]['finish_reason'] = fr 888 | 889 | if len([fr for fr in finish_reasons if fr == 'tool_calls']) > 0: 890 | if (n > 1) and (len(prompts) != len(indices)): 891 | new_prompts = [] 892 | new_images = [] 893 | for p, img in zip(prompts, images): 894 | new_prompts.extend([p] * n) 895 | new_images.extend([img] * n) 896 | 897 | prompts = new_prompts 898 | images = new_images 899 | 900 | indices = [i for i, fr in zip(indices, finish_reasons) if fr == 'tool_calls'] 901 | prompts = [p + t + self.model.chat_template.tool_start for p, fr, t in zip(prompts, finish_reasons, texts) if fr == 'tool_calls'] 902 | images = [img for img, fr in zip(images, finish_reasons) if fr == 'tool_calls'] 903 | 904 | params.guided_choice = None 905 | params.guided_regex = None 906 | params.guided_grammar = None 907 | params.guided_json = build_function_call_schema(_tools) 908 | 909 | output = self.model.generate(prompts, params, images, n=1, is_thinking=False) 910 | 911 | texts = output['texts'] 912 | input_tokens = output['input_tokens'] 913 | output_tokens = output['output_tokens'] 914 | logprobs = output['logprobs'] 915 | 916 | if params.logprobs: 917 | for i, lp in zip(indices, logprobs): 918 | cmpls['choices'][i]['logprobs']['content'].extend(lp) 919 | 920 | cmpls['usage']['completion_tokens'] += sum(output_tokens) 921 | 922 | cmpls['usage']['total_tokens'] += sum(output_tokens) 923 | 924 | for i, t, fr in zip(indices, texts, finish_reasons): 925 | tc = json.loads(t) 926 | cmpls['choices'][i]['message']['tool_calls'].append(dict( 927 | index=0, 928 | id='call_' + uuid.uuid4().hex[:8], 929 | type='function', 930 | function=dict(name=tc['name'], arguments=json.dumps(tc['arguments'])) 931 | )) 932 | 933 | indices = list(range(len(cmpls['choices']))) 934 | for i in indices: 935 | if not cmpls['choices'][i]['message']['reasoning_content']: 936 | cmpls['choices'][i]['message']['reasoning_content'] = None 937 | if not cmpls['choices'][i]['message']['content']: 938 | cmpls['choices'][i]['message']['content'] = None 939 | 940 | return ChatCompletionOutput.model_validate(cmpls) 941 | 942 | 943 | 944 | 945 | 946 | 947 | 948 | 949 | 950 | 951 | 952 | 953 | 954 | 955 | 956 | 957 | 958 | 959 | 960 | 961 | 962 | 963 | 964 | 965 | 966 | 967 | 968 | 969 | 970 | 971 | 972 | 973 | 974 | 975 | 976 | 977 | 978 | 979 | 980 | 981 | -------------------------------------------------------------------------------- /src/mlx_textgen/cache_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict, Union, Tuple, Literal, Callable, TYPE_CHECKING 2 | from pydantic import BaseModel, Field 3 | if TYPE_CHECKING: 4 | from logging import Logger 5 | from mlx_lm.models.cache import KVCache, RotatingKVCache 6 | from mlx.core import array 7 | 8 | 9 | def create_cache_dict( 10 | cache: List[Union["KVCache", "RotatingKVCache"]] 11 | ) -> Optional[Dict[str, "array"]]: 12 | """ 13 | Converts a list of KVCache or RotatingKVCache objects into a dictionary of arrays suitable for saving. 14 | 15 | The dictionary keys are generated automatically based on the internal structure of the KVCache/RotatingKVCache objects. 16 | Only the cached portion of the keys and values are extracted when the cache offset is not zero. 17 | 18 | Args: 19 | cache (List[Union["KVCache", "RotatingKVCache"]]): A list of KVCache or RotatingKVCache objects. 20 | 21 | Returns: 22 | Optional[Dict[str, "array"]]: A dictionary containing the keys and values arrays from the cache, or None if the cache offset is 0. 23 | """ 24 | from mlx.utils import tree_flatten 25 | offset = cache[0].offset 26 | if offset != 0: 27 | cache_dict = dict(tree_flatten([(c.keys[..., :c.offset, :], c.values[..., :c.offset, :]) for c in cache])) 28 | return cache_dict 29 | 30 | def save_cache( 31 | cache: List[Union["KVCache", "RotatingKVCache"]], 32 | file: str, 33 | model_name: str, 34 | metadata: Optional[Dict[str, str]] = None, 35 | logger: Optional["Logger"] = None 36 | ) -> None: 37 | """ 38 | Saves the provided KVCache or RotatingKVCache to a file in safetensors format. 39 | 40 | Args: 41 | cache (List[Union["KVCache", "RotatingKVCache"]]): A list of KVCache or RotatingKVCache objects to be saved. 42 | file (str): The file path where the cache will be saved (e.g., "path/to/cache.safetensors"). 43 | model_name (str): The name of the model associated with this cache. This will be stored in the metadata. 44 | metadata (Optional[Dict[str, str]], optional): Optional metadata to be saved along with the cache. Defaults to None. 45 | logger (Optional["Logger"], optional): An optional logger object for logging progress and errors. Defaults to None. 46 | """ 47 | from time import perf_counter 48 | from mlx.core import save_safetensors, clear_cache 49 | 50 | if logger: 51 | start = perf_counter() 52 | 53 | cache_dict = create_cache_dict(cache=cache) 54 | 55 | if cache_dict: 56 | metadata = metadata if metadata else {} 57 | metadata['model_name'] = model_name 58 | save_safetensors(file=file, arrays=cache_dict, metadata=metadata) 59 | 60 | del cache_dict 61 | clear_cache() 62 | 63 | if logger: 64 | end = perf_counter() 65 | logger.info(f'Save cache for model "{model_name}" to "{file}"; Time taken: {end - start:.3f}s.') 66 | 67 | def load_cache( 68 | file: str, 69 | cache: List[Union["KVCache", "RotatingKVCache"]], 70 | logger: Optional["Logger"] = None 71 | ) -> Tuple[List[Union["KVCache", "RotatingKVCache"]], Dict[str, str]]: 72 | """ 73 | Loads a KVCache or RotatingKVCache from a safetensors file. 74 | 75 | The function reads the safetensors file, extracts the cached keys and values, and updates the provided KVCache objects. 76 | It also returns any metadata stored within the file. 77 | 78 | Args: 79 | file (str): The path to the safetensors file containing the cache. 80 | cache (List[Union["KVCache", "RotatingKVCache"): A list of KVCache or RotatingKVCache objects to be updated with the loaded cache. 81 | logger (Optional["Logger"], optional): An optional logger object for logging progress and errors. Defaults to None. 82 | 83 | Returns: 84 | Tuple[List[Union["KVCache", "RotatingKVCache"]], Dict[str, str]]: A tuple containing the updated KVCache objects and a dictionary of metadata loaded from the file. 85 | """ 86 | from time import perf_counter 87 | from mlx.core import load, eval, clear_cache 88 | from mlx.utils import tree_unflatten 89 | 90 | if logger: 91 | start = perf_counter() 92 | 93 | cache_dict, metadata = load(file, return_metadata=True) 94 | cache_list = tree_unflatten(list(cache_dict.items())) 95 | 96 | if len(cache_list) != len(cache): 97 | error = f'Cache file length and cache length mismatch.' 98 | 99 | if logger: 100 | logger.error(error) 101 | 102 | raise ValueError(error) 103 | 104 | elif not all(c.offset == 0 for c in cache): 105 | error = f'Provided list of caches are not empty.' 106 | 107 | if logger: 108 | logger.error(error) 109 | 110 | raise ValueError(error) 111 | 112 | for i, (key, value) in enumerate(cache_list): 113 | cache[i].update_and_fetch(key, value) 114 | eval(cache[i].state) 115 | 116 | eval([c.state for c in cache]) 117 | del cache_dict 118 | clear_cache() 119 | 120 | if logger: 121 | end = perf_counter() 122 | logger.info(f'Loaded cache from "{file}"; Time taken: {end - start:.3f}s.') 123 | 124 | return cache, metadata 125 | 126 | 127 | def split_list_by_image_token(data_list: List[int], image_token_id: Union[int, List[int]]) -> Tuple[List[List[int]], List[List[int]]]: 128 | result = [] 129 | images = [] 130 | current_sublist = [] 131 | img_sublist = [] 132 | token_set = [image_token_id] if isinstance(image_token_id, int) else image_token_id 133 | for i in data_list: 134 | if i not in token_set: 135 | current_sublist.append(i) 136 | if img_sublist: 137 | images.append(img_sublist) 138 | img_sublist = [] 139 | else: 140 | img_sublist.append(i) 141 | if current_sublist: 142 | result.append(current_sublist) 143 | current_sublist = [] 144 | result.append(current_sublist) 145 | if img_sublist: 146 | images.append(img_sublist) 147 | return result, images 148 | 149 | class CacheInfo(BaseModel): 150 | cache_id: str = Field(pattern='cache_\d+') 151 | token_ids: List[int] 152 | images: Optional[List[str]] = None 153 | image_diffs: List[int] = Field(default_factory=list) 154 | last_modified: float 155 | 156 | @property 157 | def length(self) -> int: 158 | """Returns the length of the token_ids list. 159 | 160 | This is a cached property, so the length is only computed once. 161 | 162 | Returns: 163 | int: The number of tokens in the `token_ids` list. 164 | """ 165 | if not hasattr(self, '_length'): 166 | self._length = len(self.token_ids) 167 | return self._length 168 | 169 | class SeqIndex(BaseModel): 170 | token_ids: List[int] 171 | images: Optional[List[str]] = None 172 | index: int 173 | 174 | @property 175 | def length(self) -> int: 176 | """Returns the length of the token_ids list. 177 | 178 | This is a cached property, so the length is only computed once. 179 | 180 | Returns: 181 | int: The number of tokens in the `token_ids` list. 182 | """ 183 | if not hasattr(self, '_length'): 184 | self._length = len(self.token_ids) 185 | return self._length 186 | 187 | class CacheManager: 188 | """The CacheManager is responsible for storing and retrieving cached prompt completions to accelerate generation. 189 | It supports both text-only models and vision models (models that incorporate images into the prompt). 190 | 191 | The cache is stored in a directory named after the model under the mlx prompt cache directory 192 | (see `utils.get_prompt_cache_dir`). 193 | """ 194 | def __init__(self, 195 | model_name: str, 196 | is_vision: bool = False, 197 | image_token_id: Optional[int] = None, 198 | extra_image_tokens: Optional[List[int]] = None, 199 | min_tokens: int = 20, 200 | max_reprocess_tokens: int = 250, 201 | replace_threshold: float = 0.95, 202 | max_capacity: int = 50, 203 | logger: Optional["Logger"] = None 204 | ) -> None: 205 | """Initializes the CacheManager with the specified parameters. 206 | 207 | The CacheManager is responsible for storing and retrieving cached prompt completions to accelerate generation. 208 | It supports both text-only models and vision models (models that incorporate images into the prompt). 209 | 210 | The cache is stored in a directory named after the model under the mlx prompt cache directory 211 | (see `utils.get_prompt_cache_dir`). 212 | 213 | Args: 214 | model_name (str): A unique identifier for the model using this cache manager. 215 | Used to create a dedicated cache directory within the mlx prompt cache directory. 216 | is_vision (bool, optional): Whether the model is a vision model (accepts images as input). 217 | Defaults to False. If True, `image_token_id` must be set. 218 | image_token_id (Optional[int], optional): The token ID that represents an image in the model's 219 | vocabulary. Required if `is_vision` is True. Defaults to None. 220 | min_tokens (int, optional): The minimum number of tokens a prompt must have to be considered 221 | for caching. Prompts shorter than this length will not be cached. Defaults to 20. 222 | max_reprocess_tokens (int, optional): When comparing sequences to determine if one sequence should be replaced by 223 | another cached sequence, the sequences need to share at least `replace_threshold` * `min(len(seq1), len(seq2))` tokens. This threshold is raised 224 | if the minimum length is low to `max(replace_threshold, (min_len - max_reprocess_tokens) / min_len)`. Defaults to 250. 225 | replace_threshold (float, optional): A similarity threshold (between 0 and 1) used to determine 226 | whether to replace an existing cache entry with a new one. A higher value means the new prompt must be 227 | more similar to the existing one to trigger replacement. Defaults to 0.95. 228 | max_capacity (int, optional): The maximum number of prompts to store in the cache. When the cache is full, 229 | the least recently used (LRU) entries are evicted to make space for new ones. Defaults to 50. 230 | logger (Optional["Logger"], optional): An optional logger object for logging cache operations. 231 | Defaults to None. 232 | 233 | Raises: 234 | ValueError: If `image_token_id` is None when `is_vision` is True. 235 | ValueError: If `min_tokens` is less than 0. 236 | ValueError: If `max_capacity` is less than 1. 237 | ValueError: If `max_reprocess_tokens` is less than `min_tokens`. 238 | ValueError: If `replace_threshold` is not between 0 and 1. 239 | """ 240 | from .utils import get_prompt_cache_dir 241 | import os 242 | 243 | self._logger = logger 244 | self.cache_dir = os.path.join(get_prompt_cache_dir(), model_name) 245 | os.makedirs(self.cache_dir, exist_ok=True) 246 | self.cache_info_dir = os.path.join(self.cache_dir, 'cached_prompts.json') 247 | 248 | if is_vision and (image_token_id is None): 249 | error = '"image_token_id" cannot be None for vision model.' 250 | self.log(error, level='error') 251 | raise ValueError(error) 252 | 253 | if min_tokens < 0: 254 | error = '"min_tokens" needs to be an integer larger than or equal to 0.' 255 | self.log(error, level='error') 256 | raise ValueError(error) 257 | 258 | if max_capacity < 1: 259 | error = '"max_capacity" needs to be larger than or equal to 1.' 260 | self.log(error, level='error') 261 | raise ValueError(error) 262 | 263 | if max_reprocess_tokens < min_tokens: 264 | error = '"max_reprocess_tokens" needs to be larger than or equal to "min_tokens".' 265 | self.log(error, level='error') 266 | raise ValueError(error) 267 | 268 | if (replace_threshold <= 0) or (replace_threshold > 1): 269 | error = '"replace_threshold" needs to be between 0 and 1.' 270 | self.log(error, level='error') 271 | raise ValueError(error) 272 | 273 | self.model_name = model_name 274 | self.is_vision = is_vision 275 | self.image_token_id = image_token_id 276 | self._image_tokens = [] if not extra_image_tokens else extra_image_tokens 277 | if (self.image_token_id not in self._image_tokens) and self.is_vision: 278 | self._image_tokens.append(self.image_token_id) 279 | self.min_tokens = min_tokens 280 | self.max_reprocess_tokens = max_reprocess_tokens 281 | self.replace_threshold = replace_threshold 282 | self.max_capacity = max_capacity 283 | 284 | def log(self, msg: str, level: Literal["error", "warning", "info", "debug"] = "info") -> None: 285 | """Logs a message to the logger at the specified level. 286 | 287 | Args: 288 | msg (str): The message to log. 289 | level (Literal["error", "warning", "info", "debug"], optional): The logging level. Defaults to "info". 290 | """ 291 | levels = dict( 292 | error=40, 293 | warning=30, 294 | info=20, 295 | debug=10 296 | ) 297 | if self._logger: 298 | self._logger.log(level=levels.get(level), msg=msg) 299 | 300 | @property 301 | def cache_info(self) -> List[CacheInfo]: 302 | """ 303 | Retrieves a list of CacheInfo objects representing the cached prompts. 304 | 305 | The list is sorted by the last modified timestamp in descending order, 306 | so the most recently used caches appear first. 307 | 308 | Returns: 309 | List[CacheInfo]: A list of CacheInfo objects, sorted by last_modified (most recent first). 310 | """ 311 | import os 312 | import json 313 | if not hasattr(self, '_cache_info'): 314 | if os.path.exists(self.cache_info_dir): 315 | with open(self.cache_info_dir, 'r') as f: 316 | cache_info = json.load(f) 317 | else: 318 | cache_info = {} 319 | 320 | self._cache_info = [CacheInfo(cache_id=k, **v) for k, v in cache_info.items()] 321 | 322 | self._cache_info.sort(key=lambda x: x.last_modified, reverse=True) 323 | return self._cache_info 324 | 325 | def get_new_cache_id(self, num: int = 1) -> List[str]: 326 | """ 327 | Generates a list of new, unique cache IDs. 328 | 329 | This method finds the next available cache IDs, ensuring they don't conflict with existing IDs. 330 | If there are gaps in the existing IDs, it will fill those gaps first. Otherwise, it will increment from the maximum existing ID. 331 | 332 | Args: 333 | num (int): The number of new cache IDs to generate. Must be at least 1. 334 | 335 | Returns: 336 | List[str]: A list of new, unique cache IDs. 337 | """ 338 | if num < 1: 339 | error = f'Trying to get {num} cache IDs. Must at least get 1 new cache ID.' 340 | self.log(error, level='error') 341 | raise ValueError(error) 342 | 343 | existing = [int(cf.cache_id.removeprefix('cache_')) for cf in self.cache_info] 344 | if len(existing) == 0: 345 | return [f'cache_{i}' for i in range(num)] 346 | 347 | max_id = max(existing) 348 | count = 0 349 | new = [] 350 | for i in range(max_id + 1): 351 | if i not in existing: 352 | new.append(f'cache_{i}') 353 | count += 1 354 | if count == num: 355 | break 356 | 357 | remain = num - count 358 | if remain > 0: 359 | for i in range(remain): 360 | new.append(f'cache_{max_id + 1 + i}') 361 | return new 362 | 363 | def save_cache_info(self) -> None: 364 | """ 365 | Saves the current cache information to a JSON file. 366 | 367 | This method serializes the `cache_info` list into a dictionary and saves it as a JSON file. 368 | The JSON file is used to persist the cache metadata across sessions. 369 | """ 370 | import json 371 | cach_info_dict = { 372 | cf.cache_id: (dict(token_ids=cf.token_ids, images=cf.images, image_diffs=cf.image_diffs, last_modified=cf.last_modified) if self.is_vision else dict(token_ids=cf.token_ids, last_modified=cf.last_modified)) for cf in self.cache_info 373 | } 374 | with open(self.cache_info_dir, 'w') as f: 375 | json.dump(cach_info_dict, f, indent=4) 376 | 377 | def drop_cache_by_id(self, cache_ids: Union[str, List[str]]) -> None: 378 | """ 379 | Removes the specified cache(s) from the cache manager. 380 | 381 | This method removes cache information from the `cache_info` list, saves the updated cache information, 382 | and deletes the corresponding cache files from disk. 383 | 384 | Args: 385 | cache_ids (Union[str, List[str]]): A single cache ID or a list of cache IDs to remove. 386 | """ 387 | import time 388 | start = time.perf_counter() 389 | 390 | import os 391 | 392 | existing = [cf.cache_id for cf in self.cache_info] 393 | to_drop = [cache_ids] if isinstance(cache_ids, str) else cache_ids 394 | to_drop = [td for td in to_drop if td in existing] 395 | 396 | self._cache_info = [cf for cf in self._cache_info if cf.cache_id not in to_drop] 397 | self.save_cache_info() 398 | 399 | for cid in to_drop: 400 | os.remove(os.path.join(self.cache_dir, f'{cid}.safetensors')) 401 | 402 | to_drop_list = [f'"{td}"' for td in to_drop] 403 | to_drop_str = ', '.join(to_drop_list) 404 | 405 | end = time.perf_counter() 406 | self.log(f'Cache dropped for model "{self.model_name}": {to_drop_str}. Time taken: {end - start:.3f}s.') 407 | 408 | def split_cache(self, 409 | cache: List[Union["KVCache", "RotatingKVCache"]], 410 | create_cache_fn: Callable[[], List[Union["KVCache", "RotatingKVCache"]]], 411 | offsets: "array" 412 | ) -> List[List[Union["KVCache", "RotatingKVCache"]]]: 413 | """ 414 | Splits the provided KV cache into multiple smaller caches based on sequence offsets. 415 | 416 | This function divides an existing KV cache into a list of new KV caches, one for each sequence 417 | in the batch. It uses provided offsets to determine where each sequence's cached data begins 418 | within the original cache. A `create_cache_fn` is used to generate new, empty caches with 419 | the same structure as the input `cache`. 420 | 421 | If the original `cache` is empty (indicated by a zero offset), the function returns an empty list. 422 | 423 | Args: 424 | cache (List[Union["KVCache", "RotatingKVCache"]]): The KV cache to split. This cache contains cached 425 | data for a batch of sequences. 426 | create_cache_fn (Callable[[], List[Union["KVCache", "RotatingKVCache"]]]): A function that creates a new, 427 | empty KV cache with the same structure (number of layers, dimensions) as the input `cache`. This is 428 | used to initialize the individual caches for each sequence. 429 | offsets ("array"): An array of integer offsets. Each offset indicates the starting index of the cached 430 | data for a specific sequence within the original `cache`. The length of this array should match the 431 | batch size (number of sequences) represented by the cache. 432 | 433 | Returns: 434 | List[List[Union["KVCache", "RotatingKVCache"]]]: A list of new KV caches. Each element in the outer list 435 | represents a single sequence from the batch, and contains a list of KVCache/RotatingKVCache objects (one 436 | for each layer in the model). Returns an empty list if the input `cache` is empty. 437 | 438 | Raises: 439 | ValueError: If the batch size of the input `cache` does not match the length of the `offsets` array. 440 | """ 441 | from mlx.core import eval, clear_cache 442 | 443 | if cache[0].offset == 0: 444 | return [] 445 | 446 | bsize = cache[0].keys.shape[0] 447 | if bsize != offsets.shape[0]: 448 | error = 'Number of token sequences and number of offsets mismatched.' 449 | self.log(error, level='error') 450 | raise ValueError(error) 451 | 452 | new_cache = [create_cache_fn() for i in range(bsize)] 453 | for i, nc in enumerate(new_cache): 454 | for j, l in enumerate(nc): 455 | c = cache[j] 456 | l.update_and_fetch(c.keys[i:(i+1), :, offsets[i].tolist():c.offset, :], c.values[i:(i+1), :, offsets[i].tolist():c.offset, :]) 457 | eval(c.state) 458 | del cache 459 | clear_cache() 460 | return new_cache 461 | 462 | def search_cache_non_vision(self, token_ids: List[int], cache_info: Optional[List[CacheInfo]] = None) -> Optional[Tuple[str, int, int]]: 463 | """Searches the cache for a matching prefix of token IDs (non-vision model). 464 | 465 | This function searches the existing cached prompts for the longest matching prefix 466 | with the input `token_ids`. The function prioritizes longer shared prefixes. 467 | If multiple caches have the same shared prefix length, the cache with the shortest 468 | total length is selected. 469 | 470 | Args: 471 | token_ids (List[int]): The list of token IDs to search for in the cache. 472 | cache_info (Optional[List[CacheInfo]]): Optional list of cache infos. If None, use the internal `cache_info`. 473 | 474 | Returns: 475 | Optional[Tuple[str, int]]: A tuple containing the cache ID of the best match and the 476 | length of the shared token ID prefix, or None if no suitable match is found. 477 | """ 478 | cache_info = self.cache_info if cache_info is None else cache_info 479 | if len(cache_info) == 0: 480 | return 481 | 482 | from itertools import takewhile 483 | 484 | selected = None 485 | current_shared = 0 486 | slen = max([cf.length for cf in cache_info]) 487 | 488 | for cf in cache_info: 489 | shared = sum([1 for _ in takewhile(lambda x: x[0] == x[1], zip(token_ids, cf.token_ids))]) 490 | if shared > current_shared: 491 | selected = cf.cache_id 492 | current_shared = shared 493 | slen = cf.length 494 | 495 | elif (shared == current_shared) and (cf.length < slen) and (current_shared != 0): 496 | selected = cf.cache_id 497 | current_shared = shared 498 | slen = cf.length 499 | 500 | if selected and (current_shared == len(token_ids)): # Need to have at least one token for the model to process before generation. 501 | current_shared -= 1 502 | 503 | if selected and (current_shared > 0): 504 | return selected, current_shared, 0 505 | 506 | else: 507 | return 508 | 509 | def search_cache_vision(self, token_ids: List[int], images: Optional[List[str]] = None) -> Optional[Tuple[str, int, int]]: 510 | """Searches the cache for a matching prefix of token IDs and images (vision model). 511 | 512 | This function attempts to find the best matching cache entry for a given input consisting of token IDs and, 513 | for vision models, a list of image identifiers. It prioritizes longer shared prefixes of token IDs and 514 | matching image identifiers when image tokens are present in the shared prefix. If multiple cache entries have 515 | the same shared prefix length, the cache entry with the shortest total length is selected. 516 | 517 | When image tokens are present in the shared prefix of tokens, the function compares the provided image 518 | identifiers with the image identifiers stored in the cache entry. A cache entry is considered a better match 519 | if it has a longer shared prefix of image identifiers. 520 | 521 | If no images are provided or if no cached prompts include images, the function falls back to searching based 522 | solely on the token IDs. 523 | 524 | Args: 525 | token_ids (List[int]): The list of token IDs to search for in the cache. 526 | images (Optional[List[str]], optional): The list of image identifiers (e.g., file paths or unique IDs) 527 | corresponding to any image tokens present in the `token_ids`. Defaults to None. 528 | 529 | Returns: 530 | Optional[Tuple[str, int]]: A tuple containing the cache ID of the best match and the length of the 531 | shared token ID prefix. Returns None if no suitable match is found. The length returned represents 532 | the number of matching tokens, including any image tokens that have corresponding image identifier 533 | matches. 534 | """ 535 | if (not images) or (len(self.cache_info) == 0): 536 | return self.search_cache_non_vision(token_ids=token_ids) 537 | 538 | cache_info = self.cache_info 539 | from itertools import takewhile 540 | 541 | selected = None 542 | current_shared = 0 543 | current_diffs = [] 544 | slen = max([cf.length for cf in cache_info]) 545 | st_chunks, si_chunks = split_list_by_image_token(token_ids, self._image_tokens) 546 | simage_lens = [len(i) for i in si_chunks] 547 | 548 | for cf in cache_info: 549 | shared = len(list(takewhile(lambda x: x[0] == x[1], zip(token_ids, cf.token_ids)))) 550 | ct_chunks, ci_chunks = split_list_by_image_token(cf.token_ids, self._image_tokens) 551 | cimage_lens = [len(i) for i in ci_chunks] 552 | diffs = cf.image_diffs 553 | image_lens_in_shared = [len(i) for i in split_list_by_image_token(token_ids[:shared], self._image_tokens)[1]] 554 | num_images_in_shared = len(image_lens_in_shared) 555 | diffs_in_shared = diffs[:num_images_in_shared] 556 | if (num_images_in_shared > 0) and ((simage_lens[num_images_in_shared - 1] != image_lens_in_shared[-1]) or (cimage_lens[num_images_in_shared - 1] != image_lens_in_shared[-1])): 557 | shared -= image_lens_in_shared[-1] 558 | image_lens_in_shared = image_lens_in_shared[:-1] 559 | num_images_in_shared -= 1 560 | diffs_in_shared = diffs_in_shared[:-1] 561 | 562 | if shared > current_shared: 563 | if num_images_in_shared == 0: 564 | selected = cf.cache_id 565 | current_shared = shared 566 | current_diffs = [] 567 | slen = cf.length 568 | 569 | else: # Compare the images 570 | oimages = images[:num_images_in_shared] 571 | cimages = cf.images[:num_images_in_shared] 572 | shared_images = len(list(takewhile(lambda x: x[0] == x[1], zip(oimages, cimages)))) 573 | if len(oimages) == shared_images: 574 | selected = cf.cache_id 575 | current_shared = shared 576 | current_diffs = diffs_in_shared 577 | slen = cf.length 578 | else: 579 | text_seqs, img_seqs = split_list_by_image_token(token_ids[:shared], self._image_tokens) 580 | shared_seq = text_seqs[0] 581 | for i in range(shared_images): 582 | shared_seq += img_seqs[i] * + text_seqs[i + 1] 583 | shared = len(shared_seq) 584 | if shared > current_shared: 585 | selected = cf.cache_id 586 | current_shared = shared 587 | current_diffs = diffs_in_shared[:shared_images] 588 | slen = cf.length 589 | 590 | elif (shared == current_shared) and (cf.length < slen) and (current_shared != 0): 591 | if num_images_in_shared == 0: 592 | selected = cf.cache_id 593 | current_shared = shared 594 | current_diffs = [] 595 | slen = cf.length 596 | 597 | else: # Compare the images 598 | oimages = images[:num_images_in_shared] 599 | cimages = cf.images[:num_images_in_shared] 600 | shared_images = len(list(takewhile(lambda x: x[0] == x[1], zip(oimages, cimages)))) 601 | if len(oimages) == shared_images: 602 | selected = cf.cache_id 603 | current_shared = shared 604 | current_diffs = diffs_in_shared 605 | slen = cf.length 606 | 607 | # Need to have at least one token for the model to process before generation. 608 | if selected and (current_shared == len(token_ids)): 609 | if token_ids[current_shared - 1] not in self._image_tokens: 610 | current_shared -= 1 611 | else: # Need to remove the entire image 612 | current_shared -= split_list_by_image_token(token_ids[:current_shared], self._image_tokens)[1][-1] 613 | current_diffs = current_diffs[:-1] 614 | 615 | if selected and (current_shared > 0): 616 | out_diffs = sum(current_diffs) if current_diffs else 0 617 | return selected, current_shared, out_diffs 618 | 619 | else: 620 | return 621 | 622 | def search_cache(self, token_ids: List[int], images: Optional[List[str]] = None) -> Optional[Tuple[str, int, int]]: 623 | """Searches the cache for a matching prompt. 624 | 625 | This method intelligently searches the cache, using either `search_cache_vision` for 626 | vision-enabled models (models that accept images) or `search_cache_non_vision` for text-only 627 | models. The appropriate search function is called based on the `is_vision` attribute 628 | of the `CacheManager`. 629 | 630 | Args: 631 | token_ids (List[int]): The sequence of token IDs representing the prompt. 632 | images (Optional[List[str]], optional): A list of image identifiers (e.g., file paths) 633 | associated with the prompt. Required only for vision models. Defaults to None. 634 | 635 | Returns: 636 | Optional[Tuple[str, int]]: A tuple containing the cache ID of the best matching 637 | prompt and the length of the shared token prefix (the number of matching tokens). 638 | Returns None if no suitable match is found in the cache. 639 | """ 640 | if self.is_vision: 641 | return self.search_cache_vision(token_ids=token_ids, images=images) 642 | else: 643 | return self.search_cache_non_vision(token_ids=token_ids) 644 | 645 | def get_cache(self, 646 | create_cache_fn: Callable[[], List[Union["KVCache", "RotatingKVCache"]]], 647 | token_ids: "array", 648 | offsets: "array", 649 | images: Optional[List[Optional[List[str]]]] = None 650 | ) -> Tuple[List[Union["KVCache", "RotatingKVCache"]], int]: 651 | """Retrieves and pre-fills the KV cache based on existing cached prompts, maximizing reuse for accelerated generation. 652 | 653 | This function searches the cache for the longest matching prefixes of the given token ID sequences 654 | (and associated images, if applicable). If matches are found, the corresponding KV cache states are 655 | loaded and used to pre-fill the provided `cache`. The function aims to leverage existing cached 656 | computations to avoid redundant processing. 657 | 658 | Args: 659 | create_cache_fn (Callable[[], List[Union["KVCache", "RotatingKVCache"]]]): A function that creates a new, 660 | empty KV cache with the same structure (layers, dimensions) as required by the model. This 661 | function is used to initialize the cache if no suitable cached prompts are found, or to construct 662 | intermediate caches during the pre-filling process. 663 | token_ids ("array"): A 2D array of token IDs representing the input prompts. Each row corresponds to a 664 | separate prompt sequence. 665 | offsets ("array"): A 1D array of integer offsets. Each offset specifies the starting position of a prompt 666 | sequence within the `token_ids` array. This enables the function to handle batched inputs efficiently. 667 | images (Optional[List[Optional[List[str]]]], optional): A list of lists containing image identifiers for each 668 | prompt sequence. This argument is only relevant for vision models and should be set to None for 669 | text-only models. If a prompt contains images, the corresponding inner list should contain the 670 | identifiers (e.g., file paths) of those images. Defaults to None. 671 | 672 | Returns: 673 | List[Union["KVCache", "RotatingKVCache"]]: A list of KVCache or RotatingKVCache objects representing the 674 | pre-filled KV cache. If suitable cached prompts are found, the cache will be partially filled with 675 | the loaded states. If no matches are found, the function returns a newly created, empty cache. 676 | Raises: 677 | ValueError: If the number of token ID sequences in `token_ids` does not match the number of offsets 678 | provided in the `offsets` array. This indicates an inconsistency in the input data. 679 | """ 680 | import time 681 | 682 | start = time.perf_counter() 683 | 684 | if token_ids.shape[0] != offsets.shape[0]: 685 | error = 'Number of token sequences and number of offsets mistmatch.' 686 | self.log(error, level='error') 687 | raise ValueError(error) 688 | 689 | if token_ids.shape[1] < self.min_tokens: 690 | cache = create_cache_fn() 691 | end = time.perf_counter() 692 | self.log(f'Existing cache not required as the prompts have fewer than {self.min_tokens} tokens. Time taken: {end - start:.3f}s.') 693 | return cache, 0 694 | 695 | offset_list = offsets.tolist() 696 | token_seqs = [t[o:] for t, o in zip(token_ids.tolist(), offset_list)] 697 | search_results = [self.search_cache(tids, images[i]) for i, tids in enumerate(token_seqs)] 698 | coverage = [0 if sr is None else sr[1] + o - sr[2] for sr, o in zip(search_results, offset_list)] 699 | 700 | cache = create_cache_fn() 701 | min_coverage = min(coverage) 702 | 703 | token_offset_index = [i for i, c in enumerate(coverage) if c == min_coverage][0] 704 | token_offset = search_results[token_offset_index][1] + offset_list[token_offset_index] if search_results[token_offset_index] else 0 705 | 706 | if min_coverage == 0: 707 | end = time.perf_counter() 708 | self.log(f'No suitable cache found. Time taken: {end - start:.3f}s.') 709 | return cache, token_offset 710 | 711 | import os 712 | from datetime import datetime 713 | from mlx.core import load, eval, clear_cache, zeros 714 | from mlx.utils import tree_unflatten 715 | 716 | search_results = [None if ((sr is None) or (o >= min_coverage)) else sr for sr, o in zip(search_results, offset_list)] 717 | cache_to_load = list(set([sr[0] for sr in search_results if sr])) 718 | cache_files = [os.path.join(self.cache_dir, f'{cid}.safetensors') for cid in cache_to_load] 719 | cache_dict = {} 720 | for cid, cf in zip(cache_to_load, cache_files): 721 | cd, metadata = load(cf, return_metadata=True) 722 | cache_dict[cid] = tree_unflatten(list(cd.items())) 723 | 724 | for i, c in enumerate(cache): 725 | shape = cache_dict[cid][i][0].shape 726 | kv_heads = shape[1] 727 | emb_size = shape[3] 728 | key_dtype = cache_dict[cid][i][0].dtype 729 | value_dtype = cache_dict[cid][i][1].dtype 730 | keys = zeros(shape=(token_ids.shape[0], kv_heads, min_coverage, emb_size), dtype=key_dtype) 731 | values = zeros(shape=(token_ids.shape[0], kv_heads, min_coverage, emb_size), dtype=value_dtype) 732 | 733 | for j, sr in enumerate(search_results): 734 | if sr is None: 735 | continue 736 | else: 737 | keys[j, :, offset_list[j]:, :] = cache_dict[sr[0]][i][0][0, :, :(min_coverage - offset_list[j]), :] 738 | values[j, :, offset_list[j]:, :] = cache_dict[sr[0]][i][1][0, :, :(min_coverage - offset_list[j]), :] 739 | 740 | c.update_and_fetch(keys, values) 741 | eval(c.state) 742 | 743 | del cd, cache_dict 744 | clear_cache() 745 | 746 | ts = datetime.now().timestamp() 747 | for cid in cache_to_load: 748 | [sr for sr in self.cache_info if sr.cache_id == cid][0].last_modified = ts 749 | self.save_cache_info() 750 | 751 | cache_id_str = ', '.join([f'"{cid}"' for cid in cache_to_load]) 752 | end = time.perf_counter() 753 | self.log(f'Reusing cache {cache_id_str}. {token_offset} tokens for each prompt prefilled. Time taken: {end - start:.3f}s.') 754 | return cache, token_offset 755 | 756 | def find_seq_to_keep_drop_update(self, token_ids: List[List[int]], images: Optional[List[Optional[List[str]]]] = None) -> Tuple[List[int], List[str], List[str]]: 757 | """Identifies sequences to keep, drop, or update in the cache. 758 | 759 | This function compares newly generated token ID sequences with existing cached sequences to determine 760 | whether each new sequence should be kept (added to the cache), dropped (ignored), or used to update 761 | an existing cache entry. The comparison is based on the length of shared token prefixes and, for vision 762 | models, the matching of image identifiers. 763 | Args: 764 | token_ids (List[List[int]]): A list of token ID sequences representing the newly generated prompts. 765 | images (Optional[List[Optional[List[str]]]], optional): A list of lists of image identifiers. 766 | Each inner list corresponds to the images associated with a token sequence (or None if there are no 767 | images for that sequence). This is only required for vision models. Defaults to None. 768 | 769 | Returns: 770 | Tuple[List[int], List[str], List[str]]: A tuple containing three lists: 771 | - keep (List[int]): A list of indices (corresponding to the input `token_ids` list) of the sequences 772 | that should be added as new entries to the cache. 773 | - drop (List[str]): A list of cache IDs of existing cache entries that should be removed from the cache. 774 | - update (List[str]): A list of cache IDs of existing cache entries that should be updated with the 775 | newly generated sequences. 776 | """ 777 | import time 778 | start = time.perf_counter() 779 | 780 | from itertools import takewhile 781 | 782 | images = [None] * len(token_ids) if images == None else images 783 | if len(images) != len(token_ids): 784 | error = 'Number of image list and number of token sequences mismatch.' 785 | self.log(error, level='error') 786 | raise ValueError(error) 787 | seqs = [SeqIndex(token_ids=tids, images=images[i], index=i) for i, tids in enumerate(token_ids)] 788 | seqs = [s for s in seqs if s.length >= self.min_tokens] 789 | 790 | if len(seqs) == 0: 791 | end = time.perf_counter() 792 | self.log(f'Found no new cache to save. Time taken: {end - start:.3f}s.') 793 | return [] 794 | 795 | # self comparing 796 | to_process: List[SeqIndex] = [s for s in seqs] 797 | cont_seqs: List[SeqIndex] = [] 798 | done_seqs = [] 799 | while len(to_process) != 0: 800 | seq = to_process[0] 801 | comp = [s for s in to_process if s.index != seq.index] 802 | 803 | strong_seqs = [] 804 | num_images = 0 if (not seq.images) or (not self.is_vision) else len(seq.images) 805 | st_chunks, si_chunks = split_list_by_image_token(seq.token_ids, image_token_id=self._image_tokens) 806 | simage_lens = [len(i) for i in si_chunks] 807 | 808 | for cseq in comp: 809 | num_share = len(list(takewhile(lambda x: x[0] == x[1], zip(seq.token_ids, cseq.token_ids)))) 810 | slen = seq.length 811 | clen = cseq.length 812 | if num_images: 813 | ct_chunks, ci_chunks = split_list_by_image_token(cseq.token_ids, image_token_id=self._image_tokens) 814 | cimage_lens = [len(i) for i in ci_chunks] 815 | image_lens_in_shared = [len(i) for i in split_list_by_image_token(seq.token_ids[:num_share], self._image_tokens)[1]] 816 | num_images_in_shared = len(image_lens_in_shared) 817 | if (num_images_in_shared > 0) and ((simage_lens[num_images_in_shared - 1] != image_lens_in_shared[-1]) or (cimage_lens[num_images_in_shared - 1] != image_lens_in_shared[-1])): 818 | num_share -= image_lens_in_shared[-1] 819 | image_lens_in_shared = image_lens_in_shared[:-1] 820 | num_images_in_shared -= 1 821 | 822 | if cseq.images: 823 | num_share_images = len(list(takewhile(lambda x: x[0] == x[1], zip(seq.images[:num_images_in_shared], cseq.images[:num_images_in_shared])))) 824 | else: 825 | num_share_images = 0 826 | 827 | if num_share_images != num_images_in_shared: 828 | text_seqs, img_seqs = split_list_by_image_token(seq.token_ids[:num_share], self._image_tokens) 829 | shared_seq = text_seqs[0] 830 | for i in range(num_share_images): 831 | shared_seq += img_seqs[i] + text_seqs[i + 1] 832 | num_share = len(shared_seq) 833 | 834 | min_len = min(slen, clen) 835 | threshold = max(self.replace_threshold, (min_len - self.max_reprocess_tokens) / min_len) 836 | if (clen > slen) and ((num_share / slen) > threshold): 837 | strong_seqs.append(cseq) 838 | break 839 | elif (slen >= clen) and ((num_share / clen) > threshold): 840 | done_seqs.append(cseq) 841 | 842 | if not len(strong_seqs): 843 | cont_seqs.append(seq) 844 | 845 | done_seqs.append(seq) 846 | to_process = [s for s in to_process if s.index not in [cs.index for cs in done_seqs]] 847 | 848 | # Comparing to existing caches 849 | existing_cache = self.cache_info 850 | keep = [] 851 | drop = [] 852 | update = [] 853 | if len(existing_cache) == 0: 854 | keep = [seq.index for seq in cont_seqs] 855 | end = time.perf_counter() 856 | self.log(f'Found {len(keep)} new caches to save. Time taken: {end - start:.3f}s.') 857 | return keep, drop, update 858 | 859 | for seq in cont_seqs: 860 | to_update = None 861 | to_update_len = 0 862 | num_images = 0 if (not seq.images) or (not self.is_vision) else len(seq.images) 863 | st_chunks, si_chunks = split_list_by_image_token(seq.token_ids, image_token_id=self._image_tokens) 864 | simage_lens = [len(i) for i in si_chunks] 865 | 866 | for cseq in existing_cache: 867 | num_share = sum([1 for _ in takewhile(lambda x: x[0] == x[1], zip(seq.token_ids, cseq.token_ids))]) 868 | slen = seq.length 869 | clen = cseq.length 870 | if num_images: 871 | ct_chunks, ci_chunks = split_list_by_image_token(cseq.token_ids, image_token_id=self._image_tokens) 872 | cimage_lens = [len(i) for i in ci_chunks] 873 | image_lens_in_shared = [len(i) for i in split_list_by_image_token(seq.token_ids[:num_share], self._image_tokens)[1]] 874 | num_images_in_shared = len(image_lens_in_shared) 875 | if (num_images_in_shared > 0) and ((simage_lens[num_images_in_shared - 1] != image_lens_in_shared[-1]) or (cimage_lens[num_images_in_shared - 1] != image_lens_in_shared[-1])): 876 | num_share -= image_lens_in_shared[-1] 877 | image_lens_in_shared = image_lens_in_shared[:-1] 878 | num_images_in_shared -= 1 879 | 880 | if cseq.images: 881 | num_share_images = sum([1 for _ in takewhile(lambda x: x[0] == x[1], zip(seq.images[:num_images_in_shared], cseq.images[:num_images_in_shared]))]) 882 | else: 883 | num_share_images = 0 884 | if num_share_images != num_images_in_shared: 885 | text_seqs, img_seqs = split_list_by_image_token(seq.token_ids[:num_share], self._image_tokens) 886 | shared_seq = text_seqs[0] 887 | for i in range(num_share_images): 888 | shared_seq += img_seqs[i] + text_seqs[i + 1] 889 | num_share = len(shared_seq) 890 | min_len = min(slen, clen) 891 | threshold = max(self.replace_threshold, (min_len - self.max_reprocess_tokens) / min_len) 892 | if (clen >= slen) and ((num_share / slen) > threshold): 893 | if cseq.length > to_update_len: 894 | to_update = cseq.cache_id 895 | to_update_len = cseq.length 896 | elif (slen > clen) and ((num_share / clen) > threshold): 897 | drop.append(cseq.cache_id) 898 | 899 | if to_update: 900 | update.append(to_update) 901 | else: 902 | keep.append(seq.index) 903 | 904 | drop = list(set(drop)) 905 | update = list(set(update)) 906 | end = time.perf_counter() 907 | self.log(f'Found {len(keep)} new caches to save. Time taken: {end - start:.3f}s.') 908 | return keep, drop, update 909 | 910 | def save_cache(self, 911 | cache: List[Union["KVCache", "RotatingKVCache"]], 912 | token_ids: "array", 913 | offsets: "array", 914 | create_cache_fn: Callable[[], List[Union["KVCache", "RotatingKVCache"]]], 915 | images: Optional[List[Optional[List[str]]]] = None, 916 | image_diffs: Optional[List[List[int]]] = None 917 | ) -> None: 918 | """Saves the provided KV cache to disk for later reuse. 919 | 920 | This method takes a KV cache, token IDs, and offsets, and saves the relevant portions of the 921 | cache to disk. It splits the input KV cache into smaller caches, one for each sequence in the 922 | batch, and then determines which of those sequences should be saved to the cache based on factors 923 | such as minimum token length and similarity to existing cached prompts. The function then saves 924 | the selected caches to disk as safetensors files, along with metadata about the cached prompts. 925 | It also manages the cache's capacity, dropping least recently used (LRU) entries if necessary, 926 | and updating cache metadata. 927 | 928 | Args: 929 | cache (List[Union["KVCache", "RotatingKVCache"]]): The KV cache to save. This cache contains cached 930 | key/value states for a batch of sequences. 931 | token_ids ("array"): A 2D array of token IDs. Each row represents a sequence in the batch. Used 932 | to identify which sequences should be cached and to create cache metadata. 933 | offsets ("array"): A 1D array of integer offsets. Each offset indicates the starting index of the 934 | corresponding sequence within the `token_ids` array. Used for batched inputs. 935 | create_cache_fn (Callable[[], List[Union["KVCache", "RotatingKVCache"]]]): A function that creates a new, 936 | empty KV cache with the same structure (number of layers, dimensions) as the input `cache`. This is 937 | used when splitting the original cache into smaller caches for individual sequences. 938 | images (Optional[List[Optional[List[str]]]], optional): A list of lists of image identifiers. 939 | Only required for vision models. If the model uses images, each inner list corresponds to the 940 | images associated with a token sequence (or None if there are no images for that sequence). 941 | Defaults to None. 942 | 943 | Raises: 944 | ValueError: If the number of tokens in the cache does not match the length of the corresponding 945 | token ID sequence. 946 | ValueError: If the number of caches does not match the number of token ID lists. 947 | ValueError: If the number of caches does not match the number of offsets. 948 | """ 949 | import time 950 | start = time.perf_counter() 951 | 952 | from datetime import datetime 953 | from mlx.core import clear_cache 954 | import os 955 | 956 | if cache[0].state is None: 957 | end = time.perf_counter() 958 | self.log(f'No cache saved. Time taken: {end - start:.3f}s.') 959 | return 960 | 961 | B, kv_heads, num_tokens, embed_size = cache[0].keys[..., :cache[0].offset, :].shape 962 | if (image_diffs is not None) and (image_diffs[0]): 963 | num_tokens += sum(image_diffs[0]) 964 | 965 | if num_tokens != token_ids.shape[1]: 966 | error = 'Number of tokens and token ids mismatch while saving cache.' 967 | self.log(msg=error, level='error') 968 | raise ValueError(error) 969 | 970 | if B != token_ids.shape[0]: 971 | error = 'Number of cache and number of token id lists mismatch while saving cache.' 972 | self.log(msg=error, level='error') 973 | raise ValueError(error) 974 | 975 | if offsets.shape[0] != B: 976 | error = 'Number of cache and number of offsets mismatch while saving cache.' 977 | self.log(msg=error, level='error') 978 | raise ValueError(error) 979 | 980 | token_seqs = [t[o:] for t, o in zip(token_ids.tolist(), offsets.tolist())] 981 | caches = self.split_cache(cache=cache, create_cache_fn=create_cache_fn, offsets=offsets) 982 | seq_lens = [len(t) for t in token_seqs] 983 | token_seqs = [t for t, l in zip(token_seqs, seq_lens) if l >= self.min_tokens] 984 | caches = [c for c, l in zip(caches, seq_lens) if l >= self.min_tokens] 985 | 986 | if len(token_seqs) == 0: 987 | del caches 988 | clear_cache() 989 | end = time.perf_counter() 990 | self.log(f'No cache saved. Time taken: {end - start:.3f}s.') 991 | return 992 | 993 | keep, drop, update = self.find_seq_to_keep_drop_update(token_seqs, images=images) 994 | 995 | if update: 996 | ts = datetime.now().timestamp() 997 | for cid in update: 998 | [sr for sr in self.cache_info if sr.cache_id == cid][0].last_modified = ts 999 | self.save_cache_info() 1000 | 1001 | if keep: 1002 | token_seqs = [ts for i, ts in enumerate(token_seqs) if i in keep] 1003 | caches = [c for i, c in enumerate(caches) if i in keep] 1004 | extra_to_drop = len(self.cache_info) + len(token_seqs) - len(drop) - self.max_capacity 1005 | if extra_to_drop > 0: 1006 | drop = drop + [cf.cache_id for cf in self.cache_info if cf.cache_id not in drop][-extra_to_drop:] 1007 | 1008 | if drop: 1009 | self.drop_cache_by_id(drop) 1010 | 1011 | if keep: 1012 | images = [None] * len(token_seqs) if images == None else images 1013 | image_diffs = [[]] * len(token_seqs) if image_diffs == None else image_diffs 1014 | ts = datetime.now().timestamp() 1015 | new_ids = self.get_new_cache_id(num=len(keep)) 1016 | new_cache_infos = [CacheInfo(cache_id=cid, token_ids=tids, images=img, image_diffs=imd, last_modified=ts) for tids, img, imd, cid in zip(token_seqs, images, image_diffs, new_ids)] 1017 | new_files = [os.path.join(self.cache_dir, f'{cid}.safetensors') for cid in new_ids] 1018 | for c, nf in zip(caches, new_files): 1019 | save_cache(cache=c, file=nf, model_name=self.model_name, logger=self._logger) 1020 | self._cache_info.extend(new_cache_infos) 1021 | self.save_cache_info() 1022 | 1023 | del caches 1024 | clear_cache() 1025 | end = time.perf_counter() 1026 | self.log(f'Save cache processed done. Total time taken: {end - start:.3f}s.') 1027 | 1028 | def clear(self) -> None: 1029 | cache_ids = [c.cache_id for c in self.cache_info] 1030 | self.drop_cache_by_id(cache_ids=cache_ids) 1031 | --------------------------------------------------------------------------------