├── .gitignore ├── requirements.txt ├── run.sh ├── setup.sh ├── README.md ├── qwen3_coder_chat_template.jinja └── qwen_server_with_tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx-lm>=0.26.0 2 | mlx>=0.17.0 3 | transformers>=4.45.0 4 | huggingface-hub>=0.24.0 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # If no arguments provided, use default port 1234 4 | if [ $# -eq 0 ]; then 5 | uv run python qwen_server_with_tools.py --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit --chat-template "$(cat qwen3_coder_chat_template.jinja)" --port 1234 6 | else 7 | # Pass all command line arguments 8 | uv run python qwen_server_with_tools.py --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit --chat-template "$(cat qwen3_coder_chat_template.jinja)" "$@" 9 | fi 10 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if uv is installed 4 | if ! command -v uv &> /dev/null; then 5 | echo "Error: uv is not installed." 6 | echo "" 7 | echo "To install uv, run one of the following:" 8 | echo " curl -LsSf https://astral.sh/uv/install.sh | sh" 9 | echo " brew install uv" 10 | echo " pip install uv" 11 | echo "" 12 | echo "For more options, visit: https://docs.astral.sh/uv/installation/" 13 | exit 1 14 | fi 15 | 16 | uv venv 17 | source .venv/bin/activate 18 | uv pip install -r requirements.txt 19 | deactivate 20 | 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Qwen MLX Server with Tool Support 2 | 3 | An enhanced MLX server that provides OpenAI-compatible tool calling for Qwen3 models by parsing their native XML format and converting it to OpenAI JSON format. 4 | 5 | ## Features 6 | 7 | - ✅ **XML to JSON conversion**: Automatically converts Qwen3's `` XML format to OpenAI JSON 8 | - ✅ **OpenAI compatibility**: Drop-in replacement for OpenAI's chat completions API 9 | - ✅ **Streaming support**: Proper streaming with XML filtering to prevent raw XML in output 10 | - ✅ **Robust parsing**: Handles incomplete and malformed XML gracefully 11 | - ✅ **vLLM compliance**: Based on official vLLM Qwen3XMLToolParser implementation 12 | 13 | ## Installation 14 | 15 | ```bash 16 | git clone https://github.com/yourusername/qwen-mlx-server.git 17 | cd qwen-mlx-server 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Quick Start 22 | 23 | ```bash 24 | # Basic usage with default template 25 | python qwen_server_with_tools.py --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit 26 | 27 | # With custom chat template for better tool calling 28 | python qwen_server_with_tools.py \ 29 | --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit \ 30 | --chat-template "$(cat qwen3_coder_chat_template.jinja)" 31 | 32 | # With different log level (WARNING for production, DEBUG for development) 33 | python qwen_server_with_tools.py \ 34 | --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit \ 35 | --log-level WARNING 36 | 37 | # With existing LM Studio download 38 | python qwen_server_with_tools.py \ 39 | --model ~/.cache/lm-studio/models/mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit \ 40 | --chat-template "$(cat qwen3_coder_chat_template.jinja)" 41 | # Note, when entering API details into a tool such as Qwen Code, the model name should be "default_model" 42 | # to avoid a redownload of the model. 43 | ``` 44 | 45 | ## Usage Example 46 | 47 | ### Tool Calling Request 48 | 49 | ```bash 50 | curl -X POST http://127.0.0.1:8080/v1/chat/completions \ 51 | -H "Content-Type: application/json" \ 52 | -d '{ 53 | "model": "mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit", 54 | "messages": [ 55 | { 56 | "role": "user", 57 | "content": "Calculate 15 * 7" 58 | } 59 | ], 60 | "tools": [ 61 | { 62 | "type": "function", 63 | "function": { 64 | "name": "calculate", 65 | "description": "Perform mathematical calculations", 66 | "parameters": { 67 | "type": "object", 68 | "properties": { 69 | "expression": { 70 | "type": "string", 71 | "description": "Mathematical expression to evaluate" 72 | } 73 | }, 74 | "required": ["expression"] 75 | } 76 | } 77 | } 78 | ], 79 | "stream": false 80 | }' 81 | ``` 82 | 83 | ### How It Works 84 | 85 | The server automatically converts Qwen3's native XML output: 86 | ```xml 87 | 88 | 89 | 90 | 15 * 7 91 | 92 | 93 | 94 | ``` 95 | 96 | To OpenAI-compatible JSON: 97 | ```json 98 | { 99 | "choices": [{ 100 | "message": { 101 | "role": "assistant", 102 | "content": "", 103 | "tool_calls": [{ 104 | "type": "function", 105 | "id": "call_12345", 106 | "function": { 107 | "name": "calculate", 108 | "arguments": "{\"expression\": \"15 * 7\"}" 109 | } 110 | }] 111 | }, 112 | "finish_reason": "tool_calls" 113 | }] 114 | } 115 | ``` 116 | 117 | ## Command Line Options 118 | 119 | | Option | Description | Default | 120 | |--------|-------------|---------| 121 | | `--model` | HuggingFace model path | `mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit` | 122 | | `--host` | Server host | `127.0.0.1` | 123 | | `--port` | Server port | `8080` | 124 | | `--chat-template` | Custom chat template file | `""` (uses model default) | 125 | | `--use-default-chat-template` | Force use of model's default template | `False` | 126 | | `--log-level` | Logging verbosity | `INFO` | 127 | | `--max-tokens` | Default max tokens to generate | `512` | 128 | 129 | ## Logging Levels 130 | 131 | - `DEBUG`: Shows detailed XML parsing and conversion steps (useful for development) 132 | - `INFO`: Standard operational messages (default) 133 | - `WARNING`: Only warnings and errors (recommended for production) 134 | - `ERROR`: Only errors 135 | 136 | ## Files 137 | 138 | - `qwen_server_with_tools.py` - Main server with XML→JSON tool parsing 139 | - `qwen3_coder_chat_template.jinja` - Optimized Qwen3-Coder chat template 140 | - `requirements.txt` - Python dependencies 141 | - `README.md` - This documentation 142 | 143 | ## Supported Models 144 | 145 | Designed for Qwen3-Coder models but should work with any Qwen model that outputs XML tool calls: 146 | 147 | - `mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit` 148 | - `mlx-community/Qwen3-Coder-7B-A3B-Instruct-4bit` 149 | - Other Qwen3 variants 150 | 151 | ## Development 152 | 153 | To see detailed XML parsing logs: 154 | ```bash 155 | python qwen_server_with_tools.py \ 156 | --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit \ 157 | --log-level DEBUG 158 | ``` 159 | 160 | ## Implementation Notes 161 | 162 | - Based on vLLM's `Qwen3XMLToolParser` for maximum compatibility 163 | - Handles both streaming and non-streaming requests correctly 164 | - Gracefully handles incomplete XML during token-by-token generation 165 | - Maintains full OpenAI Chat Completions API compatibility 166 | - Supports parameter type conversion and validation 167 | - Filters XML from streaming output to prevent malformed responses 168 | 169 | ## License 170 | 171 | MIT License - feel free to use this in your projects! 172 | -------------------------------------------------------------------------------- /qwen3_coder_chat_template.jinja: -------------------------------------------------------------------------------- 1 | {% macro render_item_list(item_list, tag_name='required') %} 2 | {%- if item_list is defined and item_list is iterable and item_list | length > 0 %} 3 | {%- if tag_name %}{{- '\n<' ~ tag_name ~ '>' -}}{% endif %} 4 | {{- '[' }} 5 | {%- for item in item_list -%} 6 | {%- if loop.index > 1 %}{{- ", "}}{% endif -%} 7 | {%- if item is string -%} 8 | {{ "`" ~ item ~ "`" }} 9 | {%- else -%} 10 | {{ item }} 11 | {%- endif -%} 12 | {%- endfor -%} 13 | {{- ']' }} 14 | {%- if tag_name %}{{- '' -}}{% endif %} 15 | {%- endif %} 16 | {% endmacro %} 17 | 18 | {%- if messages[0]["role"] == "system" %} 19 | {%- set system_message = messages[0]["content"] %} 20 | {%- set loop_messages = messages[1:] %} 21 | {%- else %} 22 | {%- set loop_messages = messages %} 23 | {%- endif %} 24 | 25 | {%- if not tools is defined %} 26 | {%- set tools = [] %} 27 | {%- endif %} 28 | 29 | {%- if system_message is defined %} 30 | {{- "<|im_start|>system\n" + system_message }} 31 | {%- else %} 32 | {%- if tools is iterable and tools | length > 0 %} 33 | {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} 34 | {%- endif %} 35 | {%- endif %} 36 | {%- if tools is iterable and tools | length > 0 %} 37 | {{- "\n\nYou have access to the following functions:\n\n" }} 38 | {{- "" }} 39 | {%- for tool in tools %} 40 | {%- if tool.function is defined %} 41 | {%- set tool = tool.function %} 42 | {%- endif %} 43 | {{- "\n\n" ~ tool.name ~ "" }} 44 | {{- '\n' ~ (tool.description | trim) ~ '' }} 45 | {{- '\n' }} 46 | {%- for param_name, param_fields in tool.parameters.properties|items %} 47 | {{- '\n' }} 48 | {{- '\n' ~ param_name ~ '' }} 49 | {%- if param_fields.type is defined %} 50 | {{- '\n' ~ (param_fields.type | string) ~ '' }} 51 | {%- endif %} 52 | {%- if param_fields.description is defined %} 53 | {{- '\n' ~ (param_fields.description | trim) ~ '' }} 54 | {%- endif %} 55 | {{- render_item_list(param_fields.enum, 'enum') }} 56 | {%- set handled_keys = ['type', 'description', 'enum', 'required'] %} 57 | {%- for json_key in param_fields.keys() | reject("in", handled_keys) %} 58 | {%- set normed_json_key = json_key | replace("-", "_") | replace(" ", "_") | replace("$", "") %} 59 | {%- if param_fields[json_key] is mapping %} 60 | {{- '\n<' ~ normed_json_key ~ '>' ~ (param_fields[json_key] | tojson | safe) ~ '' }} 61 | {%- else %} 62 | {{-'\n<' ~ normed_json_key ~ '>' ~ (param_fields[json_key] | string) ~ '' }} 63 | {%- endif %} 64 | {%- endfor %} 65 | {{- render_item_list(param_fields.required, 'required') }} 66 | {{- '\n' }} 67 | {%- endfor %} 68 | {{- render_item_list(tool.parameters.required, 'required') }} 69 | {{- '\n' }} 70 | {%- if tool.return is defined %} 71 | {%- if tool.return is mapping %} 72 | {{- '\n' ~ (tool.return | tojson | safe) ~ '' }} 73 | {%- else %} 74 | {{- '\n' ~ (tool.return | string) ~ '' }} 75 | {%- endif %} 76 | {%- endif %} 77 | {{- '\n' }} 78 | {%- endfor %} 79 | {{- "\n" }} 80 | {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} 81 | {%- endif %} 82 | {%- if system_message is defined %} 83 | {{- '<|im_end|>\n' }} 84 | {%- else %} 85 | {%- if tools is iterable and tools | length > 0 %} 86 | {{- '<|im_end|>\n' }} 87 | {%- endif %} 88 | {%- endif %} 89 | {%- for message in loop_messages %} 90 | {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} 91 | {{- '<|im_start|>' + message.role }} 92 | {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} 93 | {{- '\n' + message.content | trim + '\n' }} 94 | {%- endif %} 95 | {%- for tool_call in message.tool_calls %} 96 | {%- if tool_call.function is defined %} 97 | {%- set tool_call = tool_call.function %} 98 | {%- endif %} 99 | {{- '\n\n\n' }} 100 | {%- if tool_call.arguments is defined %} 101 | {%- if tool_call.arguments is string %} 102 | {# Handle JSON string format like '{"location": "Paris"}' #} 103 | {%- if tool_call.arguments.startswith('{') and tool_call.arguments.endswith('}') %} 104 | {%- set json_content = tool_call.arguments[1:-1] %} 105 | {%- set pairs = json_content.split(',') %} 106 | {%- for pair in pairs %} 107 | {%- set colon_pos = pair.find(':') %} 108 | {%- if colon_pos > 0 %} 109 | {%- set key = pair[:colon_pos].strip().strip('"') %} 110 | {%- set value = pair[colon_pos+1:].strip().strip('"') %} 111 | {{- '\n' }} 112 | {{- value }} 113 | {{- '\n\n' }} 114 | {%- endif %} 115 | {%- endfor %} 116 | {%- else %} 117 | {# Fallback: treat as single unnamed parameter #} 118 | {{- '\n' }} 119 | {{- tool_call.arguments }} 120 | {{- '\n\n' }} 121 | {%- endif %} 122 | {%- else %} 123 | {# Handle dict format #} 124 | {%- for args_name, args_value in tool_call.arguments|items %} 125 | {{- '\n' }} 126 | {%- set args_value = args_value if args_value is string else args_value | string %} 127 | {{- args_value }} 128 | {{- '\n\n' }} 129 | {%- endfor %} 130 | {%- endif %} 131 | {%- endif %} 132 | {{- '\n' }} 133 | {%- endfor %} 134 | {{- '<|im_end|>\n' }} 135 | {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} 136 | {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} 137 | {%- elif message.role == "tool" %} 138 | {%- if loop.previtem and loop.previtem.role != "tool" %} 139 | {{- '<|im_start|>user\n' }} 140 | {%- endif %} 141 | {{- '\n' }} 142 | {{- message.content }} 143 | {{- '\n\n' }} 144 | {%- if not loop.last and loop.nextitem.role != "tool" %} 145 | {{- '<|im_end|>\n' }} 146 | {%- elif loop.last %} 147 | {{- '<|im_end|>\n' }} 148 | {%- endif %} 149 | {%- else %} 150 | {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} 151 | {%- endif %} 152 | {%- endfor %} 153 | {%- if add_generation_prompt %} 154 | {{- '<|im_start|>assistant\n' }} 155 | {%- endif %} -------------------------------------------------------------------------------- /qwen_server_with_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | MLX Server with enhanced Qwen3 XML tool parsing support. 5 | 6 | This script extends the MLX server to properly parse Qwen3's XML-style tool calls 7 | and convert them to OpenAI-compatible JSON format responses. 8 | 9 | Based on vLLM's Qwen3XMLToolParser implementation. 10 | """ 11 | 12 | import json 13 | import re 14 | import uuid 15 | import argparse 16 | import logging 17 | from typing import Dict, List, Optional, Any, Union 18 | from collections.abc import Sequence 19 | 20 | # Import MLX server components 21 | from mlx_lm.server import APIHandler, ModelProvider, PromptCache, run 22 | from mlx_lm.utils import load 23 | 24 | 25 | class Qwen3ToolParser: 26 | """Tool parser for Qwen3's XML format that converts to OpenAI JSON format.""" 27 | 28 | def __init__(self, tokenizer): 29 | self.tokenizer = tokenizer 30 | 31 | # Track tool calls for finish_reason handling (like vLLM) 32 | self.prev_tool_call_arr = [] 33 | 34 | # XML parsing patterns (matching vLLM exactly) 35 | self.tool_call_complete_regex = re.compile( 36 | r"(.*?)", re.DOTALL 37 | ) 38 | self.tool_call_regex = re.compile( 39 | r"(.*?)|(.*?)$", re.DOTALL 40 | ) 41 | self.tool_call_function_regex = re.compile( 42 | r"|| Any: 53 | """Convert parameter value based on its expected type.""" 54 | # Handle null value for any type 55 | if param_value.lower() == "null": 56 | return None 57 | 58 | if param_name not in param_config: 59 | if param_config != {}: 60 | logging.warning( 61 | f"Parsed parameter '{param_name}' is not defined in the tool " 62 | f"parameters for tool '{func_name}', directly returning the string value." 63 | ) 64 | return param_value 65 | 66 | if ( 67 | isinstance(param_config[param_name], dict) 68 | and "type" in param_config[param_name] 69 | ): 70 | param_type = str(param_config[param_name]["type"]).strip().lower() 71 | else: 72 | param_type = "string" 73 | 74 | if param_type in ["string", "str", "text", "varchar", "char", "enum"]: 75 | return param_value 76 | elif ( 77 | param_type.startswith("int") 78 | or param_type.startswith("uint") 79 | or param_type.startswith("long") 80 | or param_type.startswith("short") 81 | or param_type.startswith("unsigned") 82 | ): 83 | try: 84 | param_value = int(param_value) 85 | except: 86 | logging.warning( 87 | f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " 88 | f"'{func_name}', degenerating to string." 89 | ) 90 | return param_value 91 | elif param_type.startswith("num") or param_type.startswith("float"): 92 | try: 93 | float_param_value = float(param_value) 94 | param_value = float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value) 95 | except: 96 | logging.warning( 97 | f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " 98 | f"'{func_name}', degenerating to string." 99 | ) 100 | return param_value 101 | elif param_type in ["boolean", "bool", "binary"]: 102 | param_value = param_value.lower() 103 | if param_value not in ["true", "false"]: 104 | logging.warning( 105 | f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." 106 | ) 107 | return param_value == "true" 108 | else: 109 | if param_type == "object" or param_type.startswith("dict"): 110 | try: 111 | param_value = json.loads(param_value) 112 | return param_value 113 | except: 114 | logging.warning( 115 | f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid JSON object in tool " 116 | f"'{func_name}', will try other methods to parse it." 117 | ) 118 | try: 119 | param_value = eval(param_value) 120 | except: 121 | logging.warning( 122 | f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `eval()` in tool '{func_name}', degenerating to string." 123 | ) 124 | return param_value 125 | 126 | def _get_arguments_config(self, func_name: str, tools: Optional[List[Dict]]) -> dict: 127 | """Get parameter configuration for a function from tools list.""" 128 | if tools is None: 129 | return {} 130 | for config in tools: 131 | if not isinstance(config, dict): 132 | continue 133 | if config.get("type") == "function" and isinstance(config.get("function"), dict): 134 | if config["function"].get("name") == func_name: 135 | params = config["function"].get("parameters", {}) 136 | if isinstance(params, dict) and "properties" in params: 137 | return params["properties"] 138 | elif isinstance(params, dict): 139 | return params 140 | else: 141 | return {} 142 | logging.warning(f"Tool '{func_name}' is not defined in the tools list.") 143 | return {} 144 | 145 | def _parse_xml_function_call(self, function_call_str: str, tools: Optional[List[Dict]]) -> Optional[Dict]: 146 | """Parse XML function call format to OpenAI JSON format.""" 147 | try: 148 | # Handle incomplete XML gracefully 149 | if ">" not in function_call_str: 150 | logging.warning(f"Incomplete XML function call: {function_call_str[:100]}...") 151 | return None 152 | 153 | # Extract function name 154 | end_index = function_call_str.index(">") 155 | function_name = function_call_str[:end_index] 156 | param_config = self._get_arguments_config(function_name, tools) 157 | parameters = function_call_str[end_index + 1:] 158 | 159 | param_dict = {} 160 | 161 | # Handle incomplete parameters more gracefully 162 | parameter_matches = self.tool_call_parameter_regex.findall(parameters) 163 | for match in parameter_matches: 164 | try: 165 | match_text = match[0] if match[0] else match[1] 166 | if ">" not in match_text: 167 | logging.warning(f"Incomplete parameter in XML: {match_text[:50]}...") 168 | continue 169 | 170 | idx = match_text.index(">") 171 | param_name = match_text[:idx] 172 | param_value = str(match_text[idx + 1:]) 173 | 174 | # Remove prefix and trailing \n 175 | if param_value.startswith("\n"): 176 | param_value = param_value[1:] 177 | if param_value.endswith("\n"): 178 | param_value = param_value[:-1] 179 | 180 | param_dict[param_name] = self._convert_param_value( 181 | param_value, param_name, param_config, function_name 182 | ) 183 | except Exception as param_e: 184 | logging.warning(f"Error parsing parameter {match}: {param_e}") 185 | continue 186 | 187 | return { 188 | "type": "function", 189 | "id": f"call_{uuid.uuid4().hex[:24]}", 190 | "function": { 191 | "name": function_name, 192 | "arguments": json.dumps(param_dict, ensure_ascii=False) 193 | } 194 | } 195 | except Exception as e: 196 | logging.error(f"Error parsing XML function call '{function_call_str[:100]}...': {e}") 197 | return None 198 | 199 | def _get_function_calls(self, model_output: str) -> List[str]: 200 | """Extract function calls from model output (matching vLLM implementation).""" 201 | # Find all tool calls 202 | matched_ranges = self.tool_call_regex.findall(model_output) 203 | raw_tool_calls = [ 204 | match[0] if match[0] else match[1] for match in matched_ranges 205 | ] 206 | 207 | # Back-off strategy if no tool_call tags found (like vLLM) 208 | if len(raw_tool_calls) == 0: 209 | raw_tool_calls = [model_output] 210 | 211 | raw_function_calls = [] 212 | for tool_call in raw_tool_calls: 213 | raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) 214 | 215 | function_calls = [ 216 | match[0] if match[0] else match[1] for match in raw_function_calls 217 | ] 218 | return function_calls 219 | 220 | def extract_tool_calls(self, model_output: str, tools: Optional[List[Dict]] = None) -> Dict: 221 | """Extract tool calls from model output and return in OpenAI format.""" 222 | # Quick check to avoid unnecessary processing (like vLLM) 223 | if self.tool_call_prefix not in model_output: 224 | return { 225 | "tools_called": False, 226 | "tool_calls": [], 227 | "content": model_output 228 | } 229 | 230 | try: 231 | function_calls = self._get_function_calls(model_output) 232 | if len(function_calls) == 0: 233 | return { 234 | "tools_called": False, 235 | "tool_calls": [], 236 | "content": model_output 237 | } 238 | 239 | tool_calls = [] 240 | for function_call_str in function_calls: 241 | parsed_call = self._parse_xml_function_call(function_call_str, tools) 242 | if parsed_call: 243 | tool_calls.append(parsed_call) 244 | 245 | # Populate prev_tool_call_arr for serving layer to set finish_reason (like vLLM) 246 | self.prev_tool_call_arr.clear() # Clear previous calls 247 | for tool_call in tool_calls: 248 | if tool_call: 249 | self.prev_tool_call_arr.append({ 250 | "name": tool_call["function"]["name"], 251 | "arguments": tool_call["function"]["arguments"], 252 | }) 253 | 254 | # Extract content before tool calls (like vLLM - no rstrip) 255 | content_index = model_output.find(self.tool_call_start_token) 256 | content_index = ( 257 | content_index 258 | if content_index >= 0 259 | else model_output.find(self.tool_call_prefix) 260 | ) 261 | content = model_output[:content_index] if content_index > 0 else "" 262 | 263 | return { 264 | "tools_called": len(tool_calls) > 0, 265 | "tool_calls": tool_calls, 266 | "content": content if content else None 267 | } 268 | 269 | except Exception as e: 270 | logging.error(f"Error in extracting tool call from response: {e}") 271 | return { 272 | "tools_called": False, 273 | "tool_calls": [], 274 | "content": model_output 275 | } 276 | 277 | 278 | class EnhancedAPIHandler(APIHandler): 279 | """Enhanced API handler with Qwen3 XML tool parsing support.""" 280 | 281 | def __init__( 282 | self, 283 | model_provider: ModelProvider, 284 | *args, 285 | prompt_cache: Optional[PromptCache] = None, 286 | system_fingerprint: Optional[str] = None, 287 | **kwargs, 288 | ): 289 | self.tool_parser = None 290 | self.full_generated_text = "" # Store full text including tool calls 291 | self.in_tool_call = False # Track if we're currently in a tool call during streaming 292 | self.tool_text = "" # For original MLX server compatibility 293 | super().__init__(model_provider, *args, prompt_cache=prompt_cache, system_fingerprint=system_fingerprint, **kwargs) 294 | # Initialize tool parser after parent initialization 295 | if hasattr(self, 'tokenizer'): 296 | self.tool_parser = Qwen3ToolParser(self.tokenizer) 297 | 298 | def handle_completion(self, prompt, stop_id_sequences): 299 | """Override to capture full generated text including XML tool calls.""" 300 | from mlx_lm.generate import stream_generate 301 | from mlx_lm.sample_utils import make_sampler, make_logits_processors 302 | import mlx.core as mx 303 | import time 304 | 305 | # Reset streaming state for new request 306 | self.in_tool_call = False 307 | self.tool_text = "" 308 | 309 | tokens = [] 310 | finish_reason = "length" 311 | stop_sequence_suffix = None 312 | if self.stream: 313 | self.end_headers() 314 | logging.debug(f"Starting stream:") 315 | else: 316 | logging.debug(f"Starting completion:") 317 | token_logprobs = [] 318 | top_tokens = [] 319 | 320 | # Debug: Log the request body 321 | logging.debug(f"REQUEST BODY: {json.dumps(self.body, indent=2)}") 322 | 323 | prompt = self.get_prompt_cache(prompt) 324 | 325 | text = "" 326 | full_text = "" # Capture ALL generated text including XML 327 | tic = time.perf_counter() 328 | sampler = make_sampler( 329 | self.temperature, 330 | top_p=self.top_p, 331 | top_k=self.top_k, 332 | min_p=self.min_p, 333 | xtc_probability=self.xtc_probability, 334 | xtc_threshold=self.xtc_threshold, 335 | xtc_special_tokens=[ 336 | self.tokenizer.eos_token_id, 337 | self.tokenizer.encode("\n"), 338 | ], 339 | ) 340 | logits_processors = make_logits_processors( 341 | self.logit_bias, 342 | self.repetition_penalty, 343 | self.repetition_context_size, 344 | ) 345 | 346 | tool_calls = [] 347 | segment = "" 348 | 349 | for gen_response in stream_generate( 350 | model=self.model, 351 | tokenizer=self.tokenizer, 352 | prompt=prompt, 353 | max_tokens=self.max_tokens, 354 | sampler=sampler, 355 | logits_processors=logits_processors, 356 | prompt_cache=self.prompt_cache.cache, 357 | draft_model=self.model_provider.draft_model, 358 | num_draft_tokens=self.num_draft_tokens, 359 | ): 360 | logging.debug(gen_response.text) 361 | 362 | # Capture ALL text for XML parsing 363 | full_text += gen_response.text 364 | 365 | # Use original logic for tool calling detection 366 | # For Qwen models, tokenizer.has_tool_calling is False, so all text goes to 'text' 367 | if ( 368 | self.tokenizer.has_tool_calling 369 | and gen_response.text == self.tokenizer.tool_call_start 370 | ): 371 | in_tool_call = True 372 | elif hasattr(self, 'in_tool_call') and self.in_tool_call: 373 | if gen_response.text == self.tokenizer.tool_call_end: 374 | tool_calls.append(self.tool_text) 375 | self.tool_text = "" 376 | self.in_tool_call = False 377 | else: 378 | self.tool_text += gen_response.text 379 | else: 380 | text += gen_response.text 381 | segment += gen_response.text 382 | 383 | token = gen_response.token 384 | logprobs = gen_response.logprobs 385 | tokens.append(token) 386 | 387 | if self.logprobs > 0: 388 | sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1) 389 | top_indices = sorted_indices[: self.logprobs] 390 | top_logprobs = logprobs[top_indices] 391 | top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) 392 | top_tokens.append(tuple(top_token_info)) 393 | 394 | token_logprobs.append(logprobs[token].item()) 395 | 396 | from mlx_lm.server import stopping_criteria 397 | stop_condition = stopping_criteria( 398 | tokens, stop_id_sequences, self.tokenizer.eos_token_id 399 | ) 400 | if stop_condition.stop_met: 401 | finish_reason = "stop" 402 | if stop_condition.trim_length: 403 | stop_sequence_suffix = self.tokenizer.decode( 404 | tokens[-stop_condition.trim_length :] 405 | ) 406 | text = text[: -len(stop_sequence_suffix)] 407 | full_text = full_text[: -len(stop_sequence_suffix)] 408 | segment = "" 409 | break 410 | 411 | if self.stream: 412 | from mlx_lm.server import sequence_overlap 413 | if any( 414 | ( 415 | sequence_overlap(tokens, sequence) 416 | for sequence in stop_id_sequences 417 | ) 418 | ): 419 | continue 420 | elif segment: 421 | try: 422 | # Simple approach: stop streaming as soon as we see < character in accumulated text 423 | if not self.in_tool_call: 424 | # Look for any < that could start XML in the accumulated text 425 | bracket_pos = text.find("<") 426 | if bracket_pos >= 0: 427 | self.in_tool_call = True 428 | # Calculate what part of this segment comes before the < 429 | text_before_segment = text[:-len(segment)] if len(segment) <= len(text) else "" 430 | 431 | if bracket_pos >= len(text_before_segment): 432 | # The < is in this segment 433 | chars_before_bracket = bracket_pos - len(text_before_segment) 434 | filtered_segment = segment[:chars_before_bracket] 435 | else: 436 | # The < was in previous segments, don't send anything 437 | filtered_segment = "" 438 | else: 439 | # No < found yet, send the segment 440 | filtered_segment = segment 441 | else: 442 | # Already detected <, don't send anything more 443 | filtered_segment = "" 444 | 445 | if filtered_segment: 446 | delta_response = { 447 | "id": self.request_id, 448 | "object": "chat.completion.chunk", 449 | "created": self.created, 450 | "model": self.requested_model, 451 | "system_fingerprint": self.system_fingerprint, 452 | "choices": [{ 453 | "index": 0, 454 | "delta": { 455 | "role": "assistant", 456 | "content": filtered_segment 457 | }, 458 | "finish_reason": None 459 | }] 460 | } 461 | self.wfile.write(f"data: {json.dumps(delta_response)}\n\n".encode()) 462 | self.wfile.flush() 463 | segment = "" 464 | except (BrokenPipeError, ConnectionResetError): 465 | logging.warning("Client disconnected during streaming") 466 | break 467 | except Exception as e: 468 | logging.error(f"Error sending streaming chunk: {e}") 469 | break 470 | 471 | self.prompt_cache.tokens.extend(tokens) 472 | 473 | if gen_response.finish_reason is not None: 474 | finish_reason = gen_response.finish_reason 475 | 476 | logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") 477 | logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") 478 | logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") 479 | 480 | logging.debug(f"FULL GENERATED TEXT: {repr(full_text)}") 481 | 482 | # Check if we have XML tool calls in the FULL text 483 | if "" in full_text or " 0: 504 | finish_reason = "tool_calls" 505 | 506 | if self.stream: 507 | try: 508 | final_delta = {"finish_reason": finish_reason} 509 | 510 | if tool_calls: 511 | # Convert back to OpenAI format for streaming 512 | openai_tool_calls = [] 513 | for i, tool_call_json in enumerate(tool_calls): 514 | tc_data = json.loads(tool_call_json) 515 | openai_tool_calls.append({ 516 | "index": i, 517 | "type": "function", 518 | "id": f"call_{uuid.uuid4().hex[:24]}", 519 | "function": { 520 | "name": tc_data["name"], 521 | "arguments": json.dumps(tc_data["arguments"]) 522 | } 523 | }) 524 | final_delta["tool_calls"] = openai_tool_calls 525 | 526 | final_response = { 527 | "id": self.request_id, 528 | "object": "chat.completion.chunk", 529 | "created": self.created, 530 | "model": self.requested_model, 531 | "system_fingerprint": self.system_fingerprint, 532 | "choices": [{ 533 | "index": 0, 534 | "delta": final_delta, 535 | "finish_reason": finish_reason 536 | }] 537 | } 538 | 539 | self.wfile.write(f"data: {json.dumps(final_response)}\n\n".encode()) 540 | self.wfile.flush() 541 | 542 | if self.stream_options is not None and self.stream_options["include_usage"]: 543 | usage_response = self.completion_usage_response(len(prompt), len(tokens)) 544 | self.wfile.write(f"data: {json.dumps(usage_response)}\n\n".encode()) 545 | self.wfile.flush() 546 | 547 | self.wfile.write("data: [DONE]\n\n".encode()) 548 | self.wfile.flush() 549 | except (BrokenPipeError, ConnectionResetError) as e: 550 | logging.warning(f"Client disconnected during streaming: {e}") 551 | except Exception as e: 552 | logging.error(f"Error during streaming response: {e}") 553 | try: 554 | self.wfile.write("data: [DONE]\n\n".encode()) 555 | self.wfile.flush() 556 | except: 557 | pass 558 | else: 559 | response = self.generate_response( 560 | text, 561 | finish_reason, 562 | len(prompt), 563 | len(tokens), 564 | token_logprobs=token_logprobs, 565 | top_tokens=top_tokens, 566 | tokens=tokens, 567 | tool_calls=tool_calls, 568 | ) 569 | response_json = json.dumps(response).encode() 570 | indent = "\t" # Backslashes can't be inside of f-strings 571 | logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") 572 | 573 | # Send an additional Content-Length header when it is known 574 | self.send_header("Content-Length", str(len(response_json))) 575 | self.end_headers() 576 | self.wfile.write(response_json) 577 | self.wfile.flush() 578 | def generate_response( 579 | self, 580 | text: str, 581 | finish_reason: Union[str, None], 582 | prompt_token_count: Optional[int] = None, 583 | completion_token_count: Optional[int] = None, 584 | token_logprobs: Optional[List[float]] = None, 585 | top_tokens: Optional[List[Dict[int, float]]] = None, 586 | tokens: Optional[List[int]] = None, 587 | tool_calls: Optional[List[str]] = None, 588 | ) -> dict: 589 | """Enhanced response generation with XML tool parsing.""" 590 | 591 | logging.debug(f"generate_response called with text: {repr(text[:100])}") 592 | logging.debug(f"finish_reason: {finish_reason}") 593 | 594 | # Initialize tool parser if needed 595 | if self.tool_parser is None and hasattr(self, 'tokenizer'): 596 | self.tool_parser = Qwen3ToolParser(self.tokenizer) 597 | 598 | # Parse tool calls from the text if any XML format is detected 599 | parsed_tools = None 600 | final_content = text 601 | 602 | if tool_calls or ("" in text or " OpenAI JSON format") 837 | 838 | # Use our enhanced handler 839 | run(args.host, args.port, ModelProvider(args), handler_class=EnhancedAPIHandler) 840 | 841 | 842 | if __name__ == "__main__": 843 | main() --------------------------------------------------------------------------------