├── docs ├── authors.rst ├── endpoints.md ├── homehub.png ├── mlxchat.png ├── overrides │ └── main.html ├── models.md ├── community_projects.md ├── changelog.md ├── installation.md ├── examples │ ├── function_calling.md │ └── chatbot.md ├── cli_reference.md ├── index.md ├── contributing.md └── usage.md ├── tests ├── __init__.py └── test_fastmlx.py ├── .gitignore ├── requirements.txt ├── MANIFEST.in ├── fastmlx ├── types │ ├── model.py │ └── chat │ │ └── chat_completion.py ├── __init__.py ├── tools │ ├── arcee_agent.j2 │ ├── config.json │ ├── xlam.j2 │ ├── command-r-plus.j2 │ └── llama-3_1.j2 ├── fastmlx.py └── utils.py ├── AUTHORS.rst ├── .pre-commit-config.yaml ├── LICENSE ├── .github └── workflows │ ├── tests.yml │ ├── deploy-docs.yml │ ├── python-publish.yml │ └── update_changelog.yml ├── pyproject.toml ├── mkdocs.yml ├── update_changelog.py └── README.md /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for fastmlx.""" 2 | -------------------------------------------------------------------------------- /docs/endpoints.md: -------------------------------------------------------------------------------- 1 | 2 | # Endpoints 3 | 4 | ::: fastmlx 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | *.egg-info 4 | env/ 5 | venv/* 6 | -------------------------------------------------------------------------------- /docs/homehub.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/fastmlx/HEAD/docs/homehub.png -------------------------------------------------------------------------------- /docs/mlxchat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/fastmlx/HEAD/docs/mlxchat.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.15 2 | mlx-lm>=0.15.2 3 | mlx-vlm>=0.0.12 4 | fastapi>=0.111.0 5 | jinja2 -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] 7 | 8 | -------------------------------------------------------------------------------- /fastmlx/types/model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class SupportedModels(BaseModel): 7 | vlm: List[str] 8 | lm: List[str] 9 | -------------------------------------------------------------------------------- /fastmlx/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for fastmlx.""" 2 | 3 | __author__ = """Prince Canuma""" 4 | __email__ = "prince.gdt@gmail.com" 5 | __version__ = "0.2.1" 6 | 7 | from .fastmlx import * 8 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Prince Canuma 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 | {% if page.nb_url %} 5 | 6 | {% include ".icons/material/download.svg" %} 7 | 8 | {% endif %} 9 | 10 | {{ super() }} 11 | {% endblock content %} 12 | -------------------------------------------------------------------------------- /fastmlx/tools/arcee_agent.j2: -------------------------------------------------------------------------------- 1 | {% if tools %} 2 | In this environment, you have access to a set of tools you can use to answer the user's question. 3 | 4 | You may call them like this: 5 | 6 | 7 | $TOOL_NAME 8 | 9 | <$PARAMETER_NAME>$PARAMETER_VALUE 10 | ... 11 | 12 | 13 | 14 | 15 | Here are the tools available: 16 | {{ tools }} 17 | 18 | {% endif %} 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache Software License 2.0 2 | 3 | Copyright (c) 2024, Prince Canuma 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /fastmlx/tools/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": { 3 | "arcee-agent": { 4 | "prompt_template": "arcee_agent.j2", 5 | "parallel_tool_calling": true, 6 | "tool_role": "tool" 7 | }, 8 | "command-r-plus": { 9 | "prompt_template": "command-r-plus.j2", 10 | "parallel_tool_calling": true, 11 | "tool_role": "tool" 12 | }, 13 | "llama-3_1": { 14 | "prompt_template": "llama-3_1.j2", 15 | "parallel_tool_calling": true, 16 | "eom_token": ["<|eom_id|>"], 17 | "tool_role": "ipython" 18 | }, 19 | "xlam": { 20 | "prompt_template": "xlam.j2", 21 | "parallel_tool_calling": true, 22 | "tool_role": "tool" 23 | }, 24 | "default": { 25 | "prompt_template": "llama-3_1.j2", 26 | "parallel_tool_calling": false 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /docs/models.md: -------------------------------------------------------------------------------- 1 | # Managing Models 2 | 3 | ## Listing Supported Models 4 | 5 | To see all vision and language models supported by MLX: 6 | 7 | ```python 8 | import requests 9 | 10 | url = "http://localhost:8000/v1/supported_models" 11 | response = requests.get(url) 12 | print(response.json()) 13 | ``` 14 | 15 | ## Listing Available Models 16 | 17 | To see all available models: 18 | 19 | ```python 20 | import requests 21 | 22 | url = "http://localhost:8000/v1/models" 23 | response = requests.get(url) 24 | print(response.json()) 25 | ``` 26 | 27 | ### Deleting Models 28 | 29 | To remove any models loaded to memory: 30 | 31 | ```python 32 | import requests 33 | 34 | url = "http://localhost:8000/v1/models" 35 | params = { 36 | "model_name": "hf-repo-or-path", 37 | } 38 | response = requests.delete(url, params=params) 39 | print(response) 40 | ``` 41 | -------------------------------------------------------------------------------- /docs/community_projects.md: -------------------------------------------------------------------------------- 1 | Here are some projects built by the community that use FastMLX: 2 | 3 | 1. FastMLX-MineCraft by Mathieu 4 | 2. MLX Chat by Nils Durner 5 | 3. AI Home Hub by Prince Canuma 6 | 7 | 8 | ### PROJECTS IN DETAIL 9 | #### [FastMLX-MineCraft](https://x.com/mwrites__/status/1837465176582353080) by [Mathieu](https://x.com/mwrites__) 10 | 11 | Remote image 12 | 13 | ####[MLX Chat](https://github.com/ndurner/mlx_chat) by [Nils Durner](https://github.com/ndurner) 14 | Chat interface for MLX for on-device Language Model use on Apple Silicon. Built on FastMLX. 15 | 16 | ![MLX Chat](./mlxchat.png) 17 | 18 | ####[Home Hub](https://x.com/Prince_Canuma/status/1813689110089101623) by [Prince Canuma](https://x.com/Prince_Canuma) 19 | Turning your Mac into an AI home server. 20 | 21 | ![AI Home Hub](./homehub.png) -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [v0.1.0] - 11 July 2024 4 | 5 | 6 | **What's Changed** 7 | 8 | - Add support for token streaming and custom CORS by [@Blaizzy](https://github.com/Blaizzy) 9 | - Add support for Parallel calls by [@Blaizzy](https://github.com/Blaizzy) 10 | - Add Parallel calls usage by [@Blaizzy](https://github.com/Blaizzy) 11 | 12 | **Fixes :** 13 | 14 | - Cross origin Support [#2](https://github.com/Blaizzy/fastmlx/issues/2) 15 | - Max tokens not overriding [#5](https://github.com/Blaizzy/fastmlx/issues/5) 16 | 17 | ## [v0.0.1] - 09 July 2024 18 | 19 | 20 | **What's Changed** 21 | 22 | - Setup FastMLX by [@Blaizzy](https://github.com/Blaizzy) 23 | - Add support for VLMs by [@Blaizzy](https://github.com/Blaizzy) 24 | - Add support for LMs by by [@Blaizzy](https://github.com/Blaizzy) 25 | 26 | **New Contributors** 27 | 28 | - [@Blaizzy](https://github.com/Blaizzy) made their first contribution in [https://github.com/Blaizzy/fastmlx/pull/1](https://github.com/Blaizzy/fastmlx/pull/1) 29 | 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test PRs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: macos-14 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install MLX 22 | run: | 23 | pip install mlx>=0.15 24 | 25 | - name: Install pre-commit 26 | run: | 27 | python -m pip install pre-commit 28 | pre-commit run --all 29 | if ! git diff --quiet; then 30 | echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change' 31 | exit 1 32 | fi 33 | 34 | - name: Install package and dependencies 35 | run: | 36 | python -m pip install pytest 37 | python -m pip install -e . 38 | 39 | - name: Run tests 40 | run: | 41 | pytest -s . -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main # or your default branch name 7 | pull_request: 8 | branches: 9 | - main # or your default branch name 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install mkdocs mkdocs-material mkdocstrings[python] mkdocs-autorefs mkdocs-git-revision-date-localized-plugin mkdocs-jupyter 26 | 27 | - name: Build documentation 28 | run: mkdocs build 29 | 30 | - name: Deploy to GitHub Pages 31 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 32 | uses: peaceiris/actions-gh-pages@v3 33 | with: 34 | github_token: ${{ secrets.GITHUB_TOKEN }} 35 | publish_dir: ./site -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | deploy: 15 | 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: '3.10' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install build 28 | - name: Build package 29 | run: python -m build 30 | - name: Publish package 31 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 32 | with: 33 | user: __token__ 34 | password: ${{ secrets.PYPI_API_TOKEN }} 35 | packages_dir: dist 36 | -------------------------------------------------------------------------------- /fastmlx/tools/xlam.j2: -------------------------------------------------------------------------------- 1 | {% if tools %} 2 | [BEGIN OF TASK INSTRUCTION] 3 | You are an expert in composing functions. You are given a question and a set of possible functions. 4 | Based on the question, you will need to make one or more function/tool calls to achieve the purpose. 5 | If none of the functions can be used, point it out and refuse to answer. 6 | If the given question lacks the parameters required by the function, also point it out. 7 | [END OF TASK INSTRUCTION] 8 | 9 | [BEGIN OF AVAILABLE TOOLS] 10 | {{ tools }} 11 | [END OF AVAILABLE TOOLS] 12 | 13 | [BEGIN OF FORMAT INSTRUCTION] 14 | The output MUST strictly adhere to the following JSON format, and NO other text MUST be included. 15 | The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]' 16 | ``` 17 | { 18 | "tool_calls": [ 19 | {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, 20 | ... (more tool calls as required) 21 | ] 22 | } 23 | [END OF FORMAT INSTRUCTION] 24 | 25 | [BEGIN OF QUERY] 26 | {{query}} 27 | [END OF QUERY] 28 | 29 | {% endif %} 30 | -------------------------------------------------------------------------------- /fastmlx/tools/command-r-plus.j2: -------------------------------------------------------------------------------- 1 | {% if tools %} 2 | 3 | You are a helpful assistant with access to functions. Use them if required. 4 | In addition to plain text responses, you can chose to call one or more of the provided functions. 5 | 6 | Use the following rule to decide when to call a function: 7 | * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so 8 | * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls 9 | 10 | If you decide to call functions: 11 | * prefix function calls with functools marker (no closing marker required) 12 | * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] 13 | * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples 14 | * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 15 | * make sure you pick the right functions that match the user intent 16 | 17 | Available functions as JSON spec: 18 | {{tools}} 19 | Today is {{ current_date }}. 20 | {% endif %} 21 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Stable release 4 | 5 | To install the latest stable release of FastMLX, use the following command: 6 | 7 | ``` 8 | pip install -U fastmlx 9 | ``` 10 | 11 | This is the recommended method to install **FastMLX**, as it will always install the most recent stable release. 12 | 13 | If [pip](https://pip.pypa.io) isn't installed, you can follow the [Python installation guide](http://docs.python-guide.org/en/latest/starting/installation/) to set it up. 14 | 15 | ## Installation from Sources 16 | 17 | To install **FastMLX** directly from the source code, run this command in your terminal: 18 | 19 | ``` 20 | pip install git+https://github.com/Blaizzy/fastmlx 21 | ``` 22 | ## Running the Server 23 | 24 | There are two ways to start the FastMLX server: 25 | 26 | Using the `fastmlx` command: 27 | 28 | ```bash 29 | fastmlx 30 | ``` 31 | 32 | or 33 | 34 | Using `uvicorn` directly: 35 | 36 | ```bash 37 | uvicorn fastmlx:app --reload --workers 0 38 | ``` 39 | 40 | > WARNING: The `--reload` flag should not be used in production. It is only intended for development purposes. 41 | 42 | ### Additional Notes 43 | 44 | - **Dependencies**: Ensure that you have the required dependencies installed. FastMLX relies on several libraries, which `pip` will handle automatically. -------------------------------------------------------------------------------- /fastmlx/types/chat/chat_completion.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Literal, Optional, Union 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class FunctionCall(BaseModel): 7 | name: str 8 | arguments: str 9 | 10 | 11 | class ToolCall(BaseModel): 12 | id: str 13 | type: str = "function" 14 | function: FunctionCall 15 | 16 | 17 | class Function(BaseModel): 18 | name: str 19 | description: str 20 | parameters: Dict[str, Any] 21 | 22 | 23 | class ChatCompletionContentPartParam(BaseModel): 24 | type: Literal["text", "image_url"] 25 | text: str = None 26 | image_url: dict = None 27 | 28 | 29 | class ChatMessage(BaseModel): 30 | role: str 31 | content: Union[str, List[ChatCompletionContentPartParam]] 32 | 33 | 34 | class Usage(BaseModel): 35 | prompt_tokens: int 36 | completion_tokens: int 37 | total_tokens: int 38 | 39 | 40 | class ChatCompletionRequest(BaseModel): 41 | model: str 42 | messages: List[ChatMessage] 43 | image: Optional[str] = Field(default=None) 44 | max_tokens: Optional[int] = Field(default=100) 45 | stream: Optional[bool] = Field(default=False) 46 | temperature: Optional[float] = Field(default=0.2) 47 | tools: Optional[List[Function]] = Field(default=None) 48 | tool_choice: Optional[str] = Field(default=None) 49 | stream_options: Optional[Dict[str, Any]] = Field(default=None) 50 | 51 | 52 | class ChatCompletionResponse(BaseModel): 53 | id: str 54 | object: str = "chat.completion" 55 | created: int 56 | model: str 57 | usage: Usage 58 | choices: List[dict] 59 | tool_calls: Optional[List[ToolCall]] = None 60 | 61 | 62 | class ChatCompletionChunk(BaseModel): 63 | id: str 64 | object: str = "chat.completion.chunk" 65 | created: int 66 | model: str 67 | choices: List[Dict[str, Any]] 68 | usage: Optional[Usage] = None 69 | -------------------------------------------------------------------------------- /fastmlx/tools/llama-3_1.j2: -------------------------------------------------------------------------------- 1 | {% if tools %} 2 | Environment: ipython 3 | Tools: brave_search, wolfram_alpha 4 | 5 | Cutting Knowledge Date: December 2023 6 | Today Date: {{ current_date }} 7 | 8 | # Tool Instructions 9 | - Always execute python code in messages that you share. 10 | - When looking for real time information use relevant functions if available else fallback to brave_search 11 | 12 | You have access to the following functions: 13 | Available tools: {{ tools }} 14 | 15 | If you choose to call a function ONLY reply in the following format: 16 | parameters 17 | where 18 | 19 | start_tag => ` a JSON dict with the function argument name as key and function argument value as value. 21 | end_tag => `` 22 | 23 | {% if parallel_tool_calling %} 24 | To call functions, use the following format: 25 | 26 | 27 | 28 | { 29 | "param1": "value1", 30 | "param2": "value2" 31 | } 32 | 33 | 34 | 35 | { 36 | "param1": "value1", 37 | "param2": "value2" 38 | } 39 | 40 | 41 | 42 | Reminders: 43 | - Enclose all function calls within tags. 44 | - Each function call should be in its own tag. 45 | - Function parameters should be in valid JSON format. 46 | - You can call multiple functions in parallel by including multiple tags. 47 | - Required parameters MUST be specified for each function call. 48 | {% else %} 49 | Here is an example of how to call a function: 50 | {"example_name": "example_value"} 51 | 52 | Reminders: 53 | - Function calls MUST follow the specified format 54 | - Required parameters MUST be specified 55 | - Only call one function at a time 56 | - Put the entire function call reply on one line 57 | 58 | {% endif %} 59 | {% endif %} 60 | 61 | You are a helpful Assistant.{% if parallel_tool_calling %} Respond to queries efficiently, using parallel function calls when appropriate to gather information or perform tasks simultaneously{% endif %}. -------------------------------------------------------------------------------- /docs/examples/function_calling.md: -------------------------------------------------------------------------------- 1 | ## Function Calling 2 | 3 | FastMLX now supports tool calling in accordance with the OpenAI API specification. This feature is available for the following models: 4 | 5 | - Llama 3.1 6 | - Arcee Agent 7 | - C4ai-Command-R-Plus 8 | - Firefunction 9 | - xLAM 10 | 11 | Supported modes: 12 | 13 | - Without Streaming 14 | - Parallel Tool Calling 15 | 16 | > Note: Tool choice and OpenAI-compliant streaming for function calling are currently under development. 17 | 18 | This example demonstrates how to use the `get_current_weather` tool with the `Llama 3.1` model. The API will process the user's question and use the provided tool to fetch the required information. 19 | 20 | 21 | ```python 22 | import requests 23 | import json 24 | 25 | url = "http://localhost:8000/v1/chat/completions" 26 | headers = {"Content-Type": "application/json"} 27 | data = { 28 | "model": "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit", 29 | "messages": [ 30 | { 31 | "role": "user", 32 | "content": "What's the weather like in San Francisco and Washington?" 33 | } 34 | ], 35 | "tools": [ 36 | { 37 | "name": "get_current_weather", 38 | "description": "Get the current weather", 39 | "parameters": { 40 | "type": "object", 41 | "properties": { 42 | "location": { 43 | "type": "string", 44 | "description": "The city and state, e.g. San Francisco, CA" 45 | }, 46 | "format": { 47 | "type": "string", 48 | "enum": ["celsius", "fahrenheit"], 49 | "description": "The temperature unit to use. Infer this from the user's location." 50 | } 51 | }, 52 | "required": ["location", "format"] 53 | } 54 | } 55 | ], 56 | "max_tokens": 150, 57 | "temperature": 0.7, 58 | "stream": False, 59 | } 60 | 61 | response = requests.post(url, headers=headers, data=json.dumps(data)) 62 | print(response.json()) 63 | ``` 64 | 65 | > Note: Streaming is available for regular text generation, but the streaming implementation for function calling is still in development and does not yet fully comply with the OpenAI specification. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fastmlx" 3 | version = "0.2.1" 4 | dynamic = [ 5 | "dependencies", 6 | ] 7 | description = "FastMLX is a high performance production ready API to host MLX models." 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | keywords = [ 11 | "fastmlx", 12 | "MLX", 13 | "Apple MLX", 14 | "vision language models", 15 | "VLMs", 16 | "large language models", 17 | "LLMs", 18 | ] 19 | license = {text = "Apache Software License 2.0"} 20 | authors = [ 21 | {name = "Prince Canuma", email = "prince.gdt@gmail.com"}, 22 | ] 23 | classifiers = [ 24 | "Intended Audience :: Developers", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Natural Language :: English", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3.11", 31 | "Programming Language :: Python :: 3.12", 32 | ] 33 | 34 | [project.entry-points."console_scripts"] 35 | fastmlx = "fastmlx.fastmlx:run" 36 | 37 | [project.optional-dependencies] 38 | all = [ 39 | "fastmlx[extra]", 40 | ] 41 | 42 | extra = [] 43 | 44 | 45 | [tool] 46 | [tool.setuptools.packages.find] 47 | include = ["fastmlx*"] 48 | exclude = ["docs*"] 49 | 50 | [tool.setuptools.dynamic] 51 | dependencies = {file = ["requirements.txt"]} 52 | 53 | 54 | [tool.distutils.bdist_wheel] 55 | universal = true 56 | 57 | 58 | [tool.bumpversion] 59 | current_version = "0.2.1" 60 | commit = true 61 | tag = true 62 | 63 | [[tool.bumpversion.files]] 64 | filename = "pyproject.toml" 65 | search = 'version = "{current_version}"' 66 | replace = 'version = "{new_version}"' 67 | 68 | [[tool.bumpversion.files]] 69 | filename = "fastmlx/__init__.py" 70 | search = '__version__ = "{current_version}"' 71 | replace = '__version__ = "{new_version}"' 72 | 73 | 74 | [tool.flake8] 75 | exclude = [ 76 | "docs", 77 | ] 78 | max-line-length = 88 79 | 80 | 81 | [project.urls] 82 | Homepage = "https://github.com/Blaizzy/fastmlx" 83 | 84 | [build-system] 85 | requires = ["setuptools>=64", "setuptools_scm>=8"] 86 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /.github/workflows/update_changelog.yml: -------------------------------------------------------------------------------- 1 | name: Update Changelog 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | update-changelog: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install requests python-dotenv 20 | - name: Update Changelog 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 23 | run: python update_changelog.py 24 | - name: Commit changes 25 | run: | 26 | set -x # Enable verbose output 27 | 28 | echo "Configuring Git..." 29 | git config --local user.email "action@github.com" 30 | git config --local user.name "GitHub Action" 31 | 32 | echo "GitHub event name: ${{ github.event_name }}" 33 | echo "GitHub head ref: ${{ github.head_ref }}" 34 | echo "GitHub ref name: ${{ github.ref_name }}" 35 | 36 | echo "Fetching latest changes..." 37 | git fetch origin 38 | 39 | echo "Checking out and updating branch..." 40 | if [ "${{ github.event_name }}" = "pull_request" ]; then 41 | git checkout -B "${{ github.head_ref }}" "origin/${{ github.head_ref }}" 42 | git pull origin "${{ github.head_ref }}" 43 | else 44 | git checkout -B "${{ github.ref_name }}" "origin/${{ github.ref_name }}" 45 | git pull origin "${{ github.ref_name }}" 46 | fi 47 | 48 | echo "Current branch after checkout:" 49 | git branch 50 | 51 | echo "Running update script..." 52 | python update_changelog.py 53 | 54 | echo "Checking for changes..." 55 | git add docs/changelog.md 56 | git pull 57 | if git diff --staged --quiet; then 58 | echo "No changes to commit" 59 | else 60 | echo "Changes detected, committing..." 61 | git commit -m "Update changelog for latest release" 62 | echo "Pushing changes..." 63 | git push origin HEAD:"${{ github.head_ref || github.ref_name }}" || echo "Failed to push changes" 64 | fi 65 | 66 | echo "Final Git status:" 67 | git status 68 | -------------------------------------------------------------------------------- /docs/cli_reference.md: -------------------------------------------------------------------------------- 1 | # CLI Reference 2 | 3 | The **FastMLX** API server can be configured using various command-line arguments. Here is a detailed reference for each available option. 4 | 5 | ## Usage 6 | 7 | ``` 8 | fastmlx [OPTIONS] 9 | ``` 10 | 11 | ## Options 12 | 13 | ### `--allowed-origins` 14 | 15 | - **Type**: List of strings 16 | - **Default**: `["*"]` 17 | - **Description**: List of allowed origins for CORS (Cross-Origin Resource Sharing). 18 | 19 | ### `--host` 20 | 21 | - **Type**: String 22 | - **Default**: `"0.0.0.0"` 23 | - **Description**: Host to run the server on. 24 | 25 | ### `--port` 26 | 27 | - **Type**: Integer 28 | - **Default**: `8000` 29 | - **Description**: Port to run the server on. 30 | 31 | ### `--reload` 32 | 33 | - **Type**: Boolean 34 | - **Default**: `False` 35 | - **Description**: Enable auto-reload of the server. Only works when 'workers' is set to None. 36 | 37 | ### `--workers` 38 | 39 | - **Type**: Integer or Float 40 | - **Default**: Calculated based on `FASTMLX_NUM_WORKERS` environment variable or 2 if not set. 41 | - **Description**: Number of workers. This option overrides the `FASTMLX_NUM_WORKERS` environment variable. 42 | 43 | - If an integer, it specifies the exact number of workers to use. 44 | - If a float, it represents the fraction of available CPU cores to use (minimum 1 worker). 45 | - To use all available CPU cores, set it to 1.0. 46 | 47 | **Examples**: 48 | - `--workers 1`: Use 1 worker 49 | - `--workers 1.0`: Use all available CPU cores 50 | - `--workers 0.5`: Use half of the available CPU cores 51 | - `--workers 0.0`: Use 1 worker 52 | 53 | ## Environment Variables 54 | 55 | - `FASTMLX_NUM_WORKERS`: Sets the default number of workers if not specified via the `--workers` argument. 56 | 57 | ## Examples 58 | 59 | 1. Run the server on localhost with default settings: 60 | ``` 61 | fastmlx 62 | ``` 63 | 64 | 2. Run the server on a specific host and port: 65 | ``` 66 | fastmlx --host 127.0.0.1 --port 5000 67 | ``` 68 | 69 | 3. Run the server with 4 workers: 70 | ``` 71 | fastmlx --workers 4 72 | ``` 73 | 74 | 4. Run the server using half of the available CPU cores: 75 | ``` 76 | fastmlx --workers 0.5 77 | ``` 78 | 79 | 5. Enable auto-reload (for development): 80 | ``` 81 | fastmlx --reload 82 | ``` 83 | 84 | Remember that the `--reload` option is intended for development purposes and should not be used in production environments. -------------------------------------------------------------------------------- /docs/examples/chatbot.md: -------------------------------------------------------------------------------- 1 | This example demonstrates how to create a chatbot application using FastMLX with a Gradio interface. 2 | 3 | ```python 4 | 5 | import argparse 6 | import gradio as gr 7 | import requests 8 | import json 9 | 10 | import asyncio 11 | 12 | async def process_sse_stream(url, headers, data): 13 | response = requests.post(url, headers=headers, json=data, stream=True) 14 | if response.status_code != 200: 15 | raise gr.Error(f"Error: Received status code {response.status_code}") 16 | full_content = "" 17 | for line in response.iter_lines(): 18 | if line: 19 | line = line.decode('utf-8') 20 | if line.startswith('data: '): 21 | event_data = line[6:] # Remove 'data: ' prefix 22 | if event_data == '[DONE]': 23 | break 24 | try: 25 | chunk_data = json.loads(event_data) 26 | content = chunk_data['choices'][0]['delta']['content'] 27 | yield str(content) 28 | except (json.JSONDecodeError, KeyError): 29 | continue 30 | 31 | async def chat(message, history, temperature, max_tokens): 32 | 33 | url = "http://localhost:8000/v1/chat/completions" 34 | headers = {"Content-Type": "application/json"} 35 | data = { 36 | "model": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", 37 | "messages": [{"role": "user", "content": message['text']}], 38 | "max_tokens": max_tokens, 39 | "temperature": temperature, 40 | "stream": True 41 | } 42 | 43 | if len(message['files']) > 0: 44 | data["model"] = "mlx-community/nanoLLaVA-1.5-8bit" 45 | data["image"] = message['files'][-1]["path"] 46 | 47 | response = requests.post(url, headers=headers, json=data, stream=True) 48 | if response.status_code != 200: 49 | raise gr.Error(f"Error: Received status code {response.status_code}") 50 | 51 | full_content = "" 52 | for line in response.iter_lines(): 53 | if line: 54 | line = line.decode('utf-8') 55 | if line.startswith('data: '): 56 | event_data = line[6:] # Remove 'data: ' prefix 57 | if event_data == '[DONE]': 58 | break 59 | try: 60 | chunk_data = json.loads(event_data) 61 | content = chunk_data['choices'][0]['delta']['content'] 62 | full_content += content 63 | yield full_content 64 | except (json.JSONDecodeError, KeyError): 65 | continue 66 | 67 | demo = gr.ChatInterface( 68 | fn=chat, 69 | title="FastMLX Chat UI", 70 | additional_inputs_accordion=gr.Accordion( 71 | label="⚙️ Parameters", open=False, render=False 72 | ), 73 | additional_inputs=[ 74 | gr.Slider( 75 | minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", render=False 76 | ), 77 | gr.Slider( 78 | minimum=128, 79 | maximum=4096, 80 | step=1, 81 | value=200, 82 | label="Max new tokens", 83 | render=False 84 | ), 85 | ], 86 | multimodal=True, 87 | ) 88 | 89 | demo.launch(inbrowser=True) 90 | ``` -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # FastMLX 2 | 3 | [![PyPI version](https://img.shields.io/pypi/v/fastmlx.svg)](https://pypi.python.org/pypi/fastmlx) 4 | [![Conda version](https://img.shields.io/conda/vn/conda-forge/fastmlx.svg)](https://anaconda.org/conda-forge/fastmlx) 5 | [![Updates](https://pyup.io/repos/github/Blaizzy/fastmlx/shield.svg)](https://pyup.io/repos/github/Blaizzy/fastmlx) 6 | 7 | **FastMLX** is a high-performance, production-ready API for hosting MLX models, including Vision Language Models (VLMs) and Language Models (LMs). It provides an easy-to-use interface for integrating powerful machine learning capabilities into your applications. 8 | 9 | 10 | ## Key Features 11 | 12 | - **OpenAI-compatible API**: Easily integrate with existing applications that use OpenAI's API. 13 | - **Dynamic Model Loading**: Load MLX models on-the-fly or use pre-loaded models for better performance. 14 | - **Support for Multiple Model Types**: Compatible with various MLX model architectures. 15 | - **Image Processing Capabilities**: Handle both text and image inputs for versatile model interactions. 16 | - **Efficient Resource Management**: Optimized for high-performance and scalability. 17 | - **Error Handling**: Robust error management for production environments. 18 | - **Customizable**: Easily extendable to accommodate specific use cases and model types. 19 | 20 | ## Quick Start 21 | 22 | [Get started with FastMLX](installation.md): Learn how to install and set up FastMLX in your environment. 23 | 24 | Explore Examples: Hands-on guides, such as: 25 | 26 | - [Chatbot application](examples/chatbot.md) 27 | - [Function calling](examples/function_calling.md) 28 | 29 | ### Installation 30 | 31 | Install **FastMLX** on your system by running the following command: 32 | 33 | ``` 34 | pip install -U fastmlx 35 | ``` 36 | 37 | ### Running the Server 38 | 39 | Start the **FastMLX** server using the following command: 40 | 41 | ```bash 42 | fastmlx 43 | ``` 44 | 45 | or with multiple workers for improved performance: 46 | 47 | ```bash 48 | fastmlx --workers 4 49 | ``` 50 | 51 | ### Making API Calls 52 | 53 | Once the server is running, you can interact with the API. Here's an example using a Vision Language Model: 54 | 55 | ```python 56 | import requests 57 | import json 58 | 59 | url = "http://localhost:8000/v1/chat/completions" 60 | headers = {"Content-Type": "application/json"} 61 | data = { 62 | "model": "mlx-community/nanoLLaVA-1.5-4bit", 63 | "image": "http://images.cocodataset.org/val2017/000000039769.jpg", 64 | "messages": [{"role": "user", "content": "What are these"}], 65 | "max_tokens": 100 66 | } 67 | 68 | response = requests.post(url, headers=headers, data=json.dumps(data)) 69 | print(response.json()) 70 | ``` 71 | 72 | ## What's Next? 73 | 74 | - Check out the [Installation](installation.md) guide for detailed setup instructions. 75 | - Learn more about the API usage in the [Usage](usage.md) section. 76 | - Explore advanced features and configurations in the [API Reference](endpoints.md). 77 | - If you're interested in contributing, see our [Contributing](contributing.md) guidelines. 78 | 79 | ## License 80 | 81 | FastMLX is free software, licensed under the Apache Software License 2.0. 82 | 83 | For more detailed information and advanced usage, please explore the rest of our documentation. If you encounter any issues or have questions, don't hesitate to [report an issue](https://github.com/Blaizzy/fastmlx/issues) on our GitHub repository. 84 | 85 | Happy coding with FastMLX! -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: FastMLX 2 | site_description: FastMLX is a high performance production ready API to host MLX models. 3 | site_author: Blaizzy 4 | repo_name: Blaizzy/fastmlx 5 | site_url: https://Blaizzy.github.io/fastmlx 6 | repo_url: https://github.com/Blaizzy/fastmlx 7 | 8 | copyright: "Copyright © 2024 - 2024 Prince Canuma" 9 | 10 | theme: 11 | palette: 12 | - scheme: default 13 | primary: black 14 | toggle: 15 | icon: material/toggle-switch-off-outline 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: indigo 20 | toggle: 21 | icon: material/toggle-switch 22 | name: Switch to light mode 23 | name: material 24 | icon: 25 | repo: fontawesome/brands/github 26 | # logo: assets/logo.png 27 | # favicon: assets/favicon.png 28 | features: 29 | - navigation.instant 30 | - navigation.tracking 31 | - navigation.top 32 | - navigation.footer # Adds a "Next" and "Previous" navigation to the footer 33 | - search.highlight 34 | - search.share 35 | - content.code.copy # Adds a copy button to code blocks 36 | 37 | custom_dir: "docs/overrides" 38 | font: 39 | text: Google Sans 40 | code: Regular 41 | 42 | plugins: 43 | - search 44 | - mkdocstrings 45 | # - pdf-export 46 | - mkdocs-jupyter: 47 | include_source: True 48 | ignore_h1_titles: True 49 | execute: True 50 | allow_errors: false 51 | ignore: ["conf.py"] 52 | execute_ignore: ["*ignore.ipynb"] 53 | 54 | markdown_extensions: 55 | - admonition 56 | - abbr 57 | - attr_list 58 | - def_list 59 | - footnotes 60 | - meta 61 | - md_in_html 62 | - pymdownx.superfences 63 | - pymdownx.highlight: 64 | linenums: true 65 | - toc: 66 | permalink: true 67 | 68 | extra: 69 | social: 70 | - icon: fontawesome/brands/github 71 | link: https://github.com/Blaizzy 72 | - icon: fontawesome/brands/twitter 73 | link: https://twitter.com/Prince_Canuma 74 | version: 75 | provider: mike # Uncomment if you decide to use mike for versioning 76 | consent: 77 | title: Cookie consent 78 | description: >- 79 | We use cookies to recognize your repeated visits and preferences, as well 80 | as to measure the effectiveness of our documentation and whether users 81 | find what they're searching for. With your consent, you're helping us to 82 | make our documentation better. 83 | 84 | extra_css: 85 | - stylesheets/extra.css 86 | 87 | # extra: 88 | # analytics: 89 | # provider: google 90 | # property: UA-XXXXXXXXX-X 91 | 92 | nav: 93 | - Home: index.md 94 | - Installation: installation.md 95 | - Usage: usage.md 96 | - CLI Reference: cli_reference.md 97 | - Examples: 98 | - Multi-Modal Chatbot: examples/chatbot.md 99 | - Function Calling: examples/function_calling.md 100 | 101 | - Managing Models: models.md 102 | - API Reference: endpoints.md 103 | - Contributing: contributing.md 104 | - Community Projects: community_projects.md 105 | - Report Issues: https://github.com/Blaizzy/fastmlx/issues 106 | - Changelog: changelog.md 107 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Join us in making a difference! 2 | 3 | Your contributions are always welcome and we would love to see how you can make our project even better. Your input is invaluable to us, and we ensure that all contributors receive recognition for their efforts. 4 | 5 | ## Ways to contribute 6 | 7 | Here’s how you can get involved: 8 | 9 | ### Report Bugs 10 | 11 | Report bugs at . 12 | 13 | If you are reporting a bug, please include: 14 | 15 | - Your operating system name and version. 16 | - Any details about your local setup that might be helpful in troubleshooting. 17 | - Detailed steps to reproduce the bug. 18 | 19 | ### Fix Bugs 20 | 21 | Look through the GitHub issues for bugs. Anything tagged with `bug` and `help wanted` is open to whoever wants to implement it. 22 | 23 | ### Implement Features 24 | 25 | Look through the GitHub issues for features. If anything tagged `enhancement` and `help wanted` catches your eye, dive in and start coding. Your ideas can become a reality in FastMLX! 26 | 27 | ### Write Documentation 28 | 29 | We’re always in need of more documentation, whether it’s for our official docs, adding helpful comments in the code, or writing blog posts and articles. Clear and comprehensive documentation empowers the community, and your contributions are crucial! 30 | 31 | ### Submit Feedback 32 | 33 | The best way to share your thoughts is by filing an issue on our GitHub page: . Whether you’re suggesting a new feature or sharing your experience, we want to hear from you! 34 | 35 | Proposing a feature? 36 | 37 | - Describe in detail how it should work. 38 | - Keep it focused and manageable to make implementation smoother. 39 | - Remember, this is a volunteer-driven project, and your contributions are always appreciated! 40 | 41 | ## How to get Started! 42 | 43 | Ready to contribute? Follow these simple steps to set up FastMLX for local development and start making a difference. 44 | 45 | 1. Fork the repository. 46 | - Head over to the [fastmlx GitHub repo]() and click the Fork button to create your copy of the repository. 47 | 48 | 2. Clone your fork locally 49 | - Open your terminal and run the following command to clone your forked repository: 50 | 51 | ```shell 52 | $ git clone git@github.com:your_name_here/fastmlx.git 53 | ``` 54 | 55 | 3. Set Up Your Development Environment 56 | - Install your local copy of FastMLX into a virtual environment. If you’re using `virtualenvwrapper`, follow these steps: 57 | 58 | ```shell 59 | $ mkvirtualenv fastmlx 60 | $ cd fastmlx/ 61 | $ python setup.py develop 62 | ``` 63 | 64 | Tip: If you don’t have `virtualenvwrapper` installed, you can install it with `pip install virtualenvwrapper`. 65 | 66 | 4. Create a Development Branch 67 | - Create a new branch to work on your bugfix or feature: 68 | 69 | ```shell 70 | $ git checkout -b name-of-your-bugfix-or-feature 71 | ``` 72 | 73 | Now you’re ready to make changes! 74 | 75 | 5. Run Tests and Code Checks 76 | 77 | - When you're done making changes, check that your changes pass flake8 78 | and the tests, including testing other Python versions with tox: 79 | 80 | ```shell 81 | $ flake8 fastmlx tests 82 | $ pytest . 83 | ``` 84 | 85 | - To install flake8 and tox, simply run: 86 | ``` 87 | pip install flake8 tox 88 | ``` 89 | 90 | 6. Commit and Push Your Changes 91 | - Once everything looks good, commit your changes with a descriptive message: 92 | 93 | ```shell 94 | $ git add . 95 | $ git commit -m "Your detailed description of your changes." 96 | $ git push origin name-of-your-bugfix-or-feature 97 | ``` 98 | 99 | 7. Submit a Pull Request 100 | - Head back to the FastMLX GitHub repo and open a pull request. We’ll review your changes, provide feedback, and merge them once everything is ready. 101 | 102 | ## Pull Request Guidelines 103 | 104 | Before you submit a pull request, check that it meets these guidelines: 105 | 106 | 1. The pull request should include tests. 107 | 2. If the pull request adds functionality, the docs should be updated. 108 | Put your new functionality into a function with a docstring, and add 109 | the feature to the list in README.rst. 110 | 3. The pull request should work for Python 3.8 and later, and 111 | for PyPy. Check and make sure that the tests pass for all 112 | supported Python versions. 113 | -------------------------------------------------------------------------------- /update_changelog.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to automatically update the CHANGELOG.md file based on GitHub releases. 3 | 4 | This script fetches release information from the GitHub API and updates 5 | the CHANGELOG.md file with the latest release notes. 6 | 7 | Usage: 8 | python update_changelog.py 9 | 10 | Requirements: 11 | - requests 12 | - python-dotenv 13 | 14 | Make sure to set up a .env file with your GitHub token: 15 | GITHUB_TOKEN=your_token_here 16 | """ 17 | 18 | import os 19 | import re 20 | from datetime import datetime 21 | 22 | import requests 23 | from dotenv import load_dotenv 24 | 25 | # Load environment variables 26 | load_dotenv() 27 | 28 | # GitHub repository details 29 | REPO_OWNER = "Blaizzy" 30 | REPO_NAME = "fastmlx" 31 | GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") 32 | 33 | # File paths 34 | CHANGELOG_PATH = "docs/changelog.md" 35 | 36 | 37 | def get_releases(): 38 | """Fetch all releases information from GitHub API.""" 39 | url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/releases" 40 | headers = { 41 | "Authorization": f"token {GITHUB_TOKEN}", 42 | "Accept": "application/vnd.github.v3+json", 43 | } 44 | response = requests.get(url, headers=headers) 45 | response.raise_for_status() 46 | return response.json() 47 | 48 | 49 | def parse_version(version_string): 50 | """Parse version string to tuple, handling 'v' prefix.""" 51 | return tuple(map(int, version_string.lstrip("v").split("."))) 52 | 53 | 54 | def compare_versions(v1, v2): 55 | """Compare two version tuples.""" 56 | return (v1 > v2) - (v1 < v2) 57 | 58 | 59 | def create_issue_link(issue_number): 60 | """Create a clickable link for an issue.""" 61 | return f"[#{issue_number}](https://github.com/{REPO_OWNER}/{REPO_NAME}/issues/{issue_number})" 62 | 63 | 64 | def create_contributor_link(username): 65 | """Create a clickable link for a contributor.""" 66 | return f"[@{username}](https://github.com/{username})" 67 | 68 | 69 | def format_release_notes(body): 70 | """Format the release notes in a cleaner structure.""" 71 | formatted_notes = [] 72 | current_section = None 73 | 74 | for line in body.split("\n"): 75 | line = line.strip() 76 | 77 | if not line: 78 | continue 79 | 80 | if line.startswith("##"): 81 | current_section = line.lstrip("#").strip() 82 | formatted_notes.append(f"\n**{current_section}**\n") 83 | elif line.startswith("**Full Changelog**"): 84 | # Skip the full changelog link 85 | continue 86 | elif line.startswith("*"): 87 | if line.startswith("* @"): 88 | # Handle new contributors 89 | match = re.match( 90 | r"\* @(\w+) made their first contribution in (https://.*)", line 91 | ) 92 | if match: 93 | username, url = match.groups() 94 | formatted_notes.append( 95 | f"- {create_contributor_link(username)} made their first contribution in [{url}]({url})" 96 | ) 97 | else: 98 | # Clean up bullet points 99 | cleaned_line = re.sub( 100 | r"by @(\w+) in (https://.*)", 101 | lambda m: f"by {create_contributor_link(m.group(1))}", 102 | line, 103 | ) 104 | cleaned_line = cleaned_line.replace("* ", "- ") 105 | formatted_notes.append(cleaned_line) 106 | else: 107 | formatted_notes.append(line) 108 | 109 | # Replace issue numbers with clickable links 110 | formatted_notes = [ 111 | re.sub(r"#(\d+)", lambda m: create_issue_link(m.group(1)), item) 112 | for item in formatted_notes 113 | ] 114 | 115 | return "\n".join(formatted_notes) 116 | 117 | 118 | def update_changelog(releases): 119 | """Update the CHANGELOG.md file with the release information.""" 120 | with open(CHANGELOG_PATH, "r") as file: 121 | content = file.read() 122 | 123 | # Extract existing versions from changelog 124 | existing_versions = re.findall(r"##\s*\[?v?([\d.]+)", content) 125 | existing_versions = [parse_version(v) for v in existing_versions] 126 | 127 | # Sort releases by version (newest first) 128 | releases.sort(key=lambda r: parse_version(r["tag_name"]), reverse=True) 129 | 130 | new_content = "" 131 | added_versions = [] 132 | 133 | for release in releases: 134 | version = parse_version(release["tag_name"]) 135 | 136 | # Skip if this version is already in the changelog 137 | if version in existing_versions: 138 | continue 139 | 140 | print(f"Adding new version: {'.'.join(map(str, version))}") # Debug print 141 | 142 | release_date = datetime.strptime( 143 | release["published_at"], "%Y-%m-%dT%H:%M:%SZ" 144 | ).strftime("%d %B %Y") 145 | new_content += f"## [{release['tag_name']}] - {release_date}\n\n" 146 | new_content += format_release_notes(release["body"]) 147 | new_content += "\n\n" 148 | added_versions.append(version) 149 | 150 | if new_content: 151 | # Find the position to insert new content (after the header and introduction) 152 | header_end = content.find("\n\n", content.find("# Changelog")) 153 | if header_end == -1: 154 | header_end = content.find("\n", content.find("# Changelog")) 155 | 156 | if header_end != -1: 157 | updated_content = ( 158 | content[: header_end + 2] + new_content + content[header_end + 2 :] 159 | ) 160 | else: 161 | # If we can't find the proper position, just prepend the new content 162 | updated_content = "# Changelog\n\n" + new_content + content 163 | 164 | # Write updated changelog 165 | with open(CHANGELOG_PATH, "w") as file: 166 | file.write(updated_content) 167 | else: 168 | print("No new versions to add") 169 | 170 | 171 | def main(): 172 | try: 173 | releases = get_releases() 174 | update_changelog(releases) 175 | print("Changelog updated and formatted successfully") 176 | except requests.RequestException as e: 177 | print(f"Error fetching release information: {e}") 178 | except IOError as e: 179 | print(f"Error updating changelog file: {e}") 180 | except Exception as e: 181 | print(f"An unexpected error occurred: {e}") 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | This guide covers the server setup, and usage of FastMLX, including making API calls and managing models. 4 | 5 | ## 1. Installation 6 | Follow the [installation guide](installation.md) to install FastMLX. 7 | 8 | ## 2. Running the server 9 | Start the FastMLX server with the following command: 10 | 11 | ```bash 12 | fastmlx 13 | ``` 14 | 15 | or 16 | 17 | Using `uvicorn` directly: 18 | 19 | ```bash 20 | uvicorn fastmlx:app --reload --workers 0 21 | ``` 22 | 23 | > [!WARNING] 24 | > The `--reload` flag should not be used in production. It is only intended for development purposes. 25 | 26 | ### Running with Multiple Workers (Parallel Processing) 27 | 28 | For improved performance and parallel processing capabilities, you can specify either the absolute number of worker processes or the fraction of CPU cores to use. 29 | 30 | You can set the number of workers in three ways (listed in order of precedence): 31 | 32 | 1. Command-line argument: 33 | ```bash 34 | fastmlx --workers 4 35 | ``` 36 | or 37 | ```bash 38 | uvicorn fastmlx:app --workers 4 39 | ``` 40 | 41 | 2. Environment variable: 42 | ```bash 43 | export FASTMLX_NUM_WORKERS=4 44 | fastmlx 45 | ``` 46 | 47 | 3. Default value (2 workers) 48 | 49 | To use all available CPU cores, set the value to 1.0: 50 | 51 | ```bash 52 | fastmlx --workers 1.0 53 | ``` 54 | 55 | > [!NOTE] 56 | > - The `--reload` flag is not compatible with multiple workers. 57 | > - The number of workers should typically not exceed the number of CPU cores available on your machine for optimal performance. 58 | 59 | ### Considerations for Multi-Worker Setup 60 | 61 | 1. **Stateless Application**: Ensure your FastMLX application is stateless, as each worker process operates independently. 62 | 2. **Database Connections**: If your app uses a database, make sure your connection pooling is configured to handle multiple workers. 63 | 3. **Resource Usage**: Monitor your system's resource usage to find the optimal number of workers for your specific hardware and application needs. 64 | 4. **Load Balancing**: When running with multiple workers, incoming requests are automatically load-balanced across the worker processes. 65 | 66 | ## 3. Making API Calls 67 | 68 | Use the API similar to OpenAI's chat completions: 69 | 70 | ### Vision Language Model 71 | 72 | #### Without Streaming 73 | Here's an example of how to use a Vision Language Model: 74 | 75 | ```python 76 | import requests 77 | import json 78 | 79 | url = "http://localhost:8000/v1/chat/completions" 80 | headers = {"Content-Type": "application/json"} 81 | data = { 82 | "model": "mlx-community/nanoLLaVA-1.5-4bit", 83 | "image": "http://images.cocodataset.org/val2017/000000039769.jpg", 84 | "messages": [{"role": "user", "content": "What are these"}], 85 | "max_tokens": 100 86 | } 87 | 88 | response = requests.post(url, headers=headers, data=json.dumps(data)) 89 | print(response.json()) 90 | ``` 91 | 92 | #### Without Streaming 93 | ```python 94 | import requests 95 | import json 96 | 97 | def process_sse_stream(url, headers, data): 98 | response = requests.post(url, headers=headers, json=data, stream=True) 99 | 100 | if response.status_code != 200: 101 | print(f"Error: Received status code {response.status_code}") 102 | print(response.text) 103 | return 104 | 105 | full_content = "" 106 | 107 | try: 108 | for line in response.iter_lines(): 109 | if line: 110 | line = line.decode('utf-8') 111 | if line.startswith('data: '): 112 | event_data = line[6:] # Remove 'data: ' prefix 113 | if event_data == '[DONE]': 114 | print("\nStream finished. ✅") 115 | break 116 | try: 117 | chunk_data = json.loads(event_data) 118 | content = chunk_data['choices'][0]['delta']['content'] 119 | full_content += content 120 | print(content, end='', flush=True) 121 | except json.JSONDecodeError: 122 | print(f"\nFailed to decode JSON: {event_data}") 123 | except KeyError: 124 | print(f"\nUnexpected data structure: {chunk_data}") 125 | 126 | except KeyboardInterrupt: 127 | print("\nStream interrupted by user.") 128 | except requests.exceptions.RequestException as e: 129 | print(f"\nAn error occurred: {e}") 130 | 131 | if __name__ == "__main__": 132 | url = "http://localhost:8000/v1/chat/completions" 133 | headers = {"Content-Type": "application/json"} 134 | data = { 135 | "model": "mlx-community/nanoLLaVA-1.5-4bit", 136 | "image": "http://images.cocodataset.org/val2017/000000039769.jpg", 137 | "messages": [{"role": "user", "content": "What are these?"}], 138 | "max_tokens": 500, 139 | "stream": True 140 | } 141 | process_sse_stream(url, headers, data) 142 | ``` 143 | ### Language Model 144 | 145 | #### Without Streaming 146 | 147 | Here's an example of how to use a Language Model: 148 | 149 | ```python 150 | import requests 151 | import json 152 | 153 | url = "http://localhost:8000/v1/chat/completions" 154 | headers = {"Content-Type": "application/json"} 155 | data = { 156 | "model": "mlx-community/gemma-2-9b-it-4bit", 157 | "messages": [{"role": "user", "content": "What is the capital of France?"}], 158 | "max_tokens": 100 159 | } 160 | 161 | response = requests.post(url, headers=headers, data=json.dumps(data)) 162 | print(response.json()) 163 | ``` 164 | 165 | #### With Streaming 166 | 167 | ```python 168 | import requests 169 | import json 170 | 171 | def process_sse_stream(url, headers, data): 172 | response = requests.post(url, headers=headers, json=data, stream=True) 173 | 174 | if response.status_code != 200: 175 | print(f"Error: Received status code {response.status_code}") 176 | print(response.text) 177 | return 178 | 179 | full_content = "" 180 | 181 | try: 182 | for line in response.iter_lines(): 183 | if line: 184 | line = line.decode('utf-8') 185 | if line.startswith('data: '): 186 | event_data = line[6:] # Remove 'data: ' prefix 187 | if event_data == '[DONE]': 188 | print("\nStream finished. ✅") 189 | break 190 | try: 191 | chunk_data = json.loads(event_data) 192 | content = chunk_data['choices'][0]['delta']['content'] 193 | full_content += content 194 | print(content, end='', flush=True) 195 | except json.JSONDecodeError: 196 | print(f"\nFailed to decode JSON: {event_data}") 197 | except KeyError: 198 | print(f"\nUnexpected data structure: {chunk_data}") 199 | 200 | except KeyboardInterrupt: 201 | print("\nStream interrupted by user.") 202 | except requests.exceptions.RequestException as e: 203 | print(f"\nAn error occurred: {e}") 204 | 205 | if __name__ == "__main__": 206 | url = "http://localhost:8000/v1/chat/completions" 207 | headers = {"Content-Type": "application/json"} 208 | data = { 209 | "model": "mlx-community/gemma-2-9b-it-4bit", 210 | "messages": [{"role": "user", "content": "Hi, how are you?"}], 211 | "max_tokens": 500, 212 | "stream": True 213 | } 214 | process_sse_stream(url, headers, data) 215 | ``` 216 | 217 | For more detailed API documentation, please refer to the [API Reference](endpoints.md) section. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastMLX 2 | 3 | [![image](https://img.shields.io/pypi/v/fastmlx.svg)](https://pypi.python.org/pypi/fastmlx) 4 | [![image](https://img.shields.io/conda/vn/conda-forge/fastmlx.svg)](https://anaconda.org/conda-forge/fastmlx) 5 | [![image](https://pyup.io/repos/github/Blaizzy/fastmlx/shield.svg)](https://pyup.io/repos/github/Blaizzy/fastmlx) 6 | 7 | **FastMLX is a high performance production ready API to host MLX models, including Vision Language Models (VLMs) and Language Models (LMs).** 8 | 9 | - Free software: Apache Software License 2.0 10 | - Documentation: https://Blaizzy.github.io/fastmlx 11 | 12 | ## Features 13 | 14 | - **OpenAI-compatible API**: Easily integrate with existing applications that use OpenAI's API. 15 | - **Dynamic Model Loading**: Load MLX models on-the-fly or use pre-loaded models for better performance. 16 | - **Support for Multiple Model Types**: Compatible with various MLX model architectures. 17 | - **Image Processing Capabilities**: Handle both text and image inputs for versatile model interactions. 18 | - **Efficient Resource Management**: Optimized for high-performance and scalability. 19 | - **Error Handling**: Robust error management for production environments. 20 | - **Customizable**: Easily extendable to accommodate specific use cases and model types. 21 | 22 | ## Usage 23 | 24 | 1. **Installation** 25 | 26 | ```bash 27 | pip install fastmlx 28 | ``` 29 | 30 | 2. **Running the Server** 31 | 32 | Start the FastMLX server: 33 | ```bash 34 | fastmlx 35 | ``` 36 | or 37 | 38 | ```bash 39 | uvicorn fastmlx:app --reload --workers 0 40 | ``` 41 | 42 | > [!WARNING] 43 | > The `--reload` flag should not be used in production. It is only intended for development purposes. 44 | 45 | ### Running with Multiple Workers (Parallel Processing) 46 | 47 | For improved performance and parallel processing capabilities, you can specify either the absolute number of worker processes or the fraction of CPU cores to use. This is particularly useful for handling multiple requests simultaneously. 48 | 49 | You can also set the `FASTMLX_NUM_WORKERS` environment variable to specify the number of workers or the fraction of CPU cores to use. `workers` defaults to 2 if not passed explicitly or set via the environment variable. 50 | 51 | In order of precedence (highest to lowest), the number of workers is determined by the following: 52 | - Explicitly passed as a command-line argument 53 | - `--workers 4` will set the number of workers to 4 54 | - `--workers 0.5` will set the number of workers to half the number of CPU cores available (minimum of 1) 55 | - Set via the `FASTMLX_NUM_WORKERS` environment variable 56 | - Default value of 2 57 | 58 | To use all available CPU cores, set the value to **1.0**. 59 | 60 | Example: 61 | ```bash 62 | fastmlx --workers 4 63 | ``` 64 | or 65 | 66 | ```bash 67 | uvicorn fastmlx:app --workers 4 68 | ``` 69 | 70 | > [!NOTE] 71 | > - `--reload` flag is not compatible with multiple workers 72 | > - The number of workers should typically not exceed the number of CPU cores available on your machine for optimal performance. 73 | 74 | ### Considerations for Multi-Worker Setup 75 | 76 | 1. **Stateless Application**: Ensure your FastMLX application is stateless, as each worker process operates independently. 77 | 2. **Database Connections**: If your app uses a database, make sure your connection pooling is configured to handle multiple workers. 78 | 3. **Resource Usage**: Monitor your system's resource usage to find the optimal number of workers for your specific hardware and application needs. Additionally, you can remove any unused models using the delete model endpoint. 79 | 4. **Load Balancing**: When running with multiple workers, incoming requests are automatically load-balanced across the worker processes. 80 | 81 | By leveraging multiple workers, you can significantly improve the throughput and responsiveness of your FastMLX application, especially under high load conditions. 82 | 83 | 3. **Making API Calls** 84 | 85 | Use the API similar to OpenAI's chat completions: 86 | 87 | **Vision Language Model** 88 | ```python 89 | import requests 90 | import json 91 | 92 | url = "http://localhost:8000/v1/chat/completions" 93 | headers = {"Content-Type": "application/json"} 94 | data = { 95 | "model": "mlx-community/nanoLLaVA-1.5-4bit", 96 | "image": "http://images.cocodataset.org/val2017/000000039769.jpg", 97 | "messages": [{"role": "user", "content": "What are these"}], 98 | "max_tokens": 100 99 | } 100 | 101 | response = requests.post(url, headers=headers, data=json.dumps(data)) 102 | print(response.json()) 103 | ``` 104 | 105 | With streaming: 106 | ```python 107 | import requests 108 | import json 109 | 110 | def process_sse_stream(url, headers, data): 111 | response = requests.post(url, headers=headers, json=data, stream=True) 112 | 113 | if response.status_code != 200: 114 | print(f"Error: Received status code {response.status_code}") 115 | print(response.text) 116 | return 117 | 118 | full_content = "" 119 | 120 | try: 121 | for line in response.iter_lines(): 122 | if line: 123 | line = line.decode('utf-8') 124 | if line.startswith('data: '): 125 | event_data = line[6:] # Remove 'data: ' prefix 126 | if event_data == '[DONE]': 127 | print("\nStream finished. ✅") 128 | break 129 | try: 130 | chunk_data = json.loads(event_data) 131 | content = chunk_data['choices'][0]['delta']['content'] 132 | full_content += content 133 | print(content, end='', flush=True) 134 | except json.JSONDecodeError: 135 | print(f"\nFailed to decode JSON: {event_data}") 136 | except KeyError: 137 | print(f"\nUnexpected data structure: {chunk_data}") 138 | 139 | except KeyboardInterrupt: 140 | print("\nStream interrupted by user.") 141 | except requests.exceptions.RequestException as e: 142 | print(f"\nAn error occurred: {e}") 143 | 144 | if __name__ == "__main__": 145 | url = "http://localhost:8000/v1/chat/completions" 146 | headers = {"Content-Type": "application/json"} 147 | data = { 148 | "model": "mlx-community/nanoLLaVA-1.5-4bit", 149 | "image": "http://images.cocodataset.org/val2017/000000039769.jpg", 150 | "messages": [{"role": "user", "content": "What are these?"}], 151 | "max_tokens": 500, 152 | "stream": True 153 | } 154 | process_sse_stream(url, headers, data) 155 | ``` 156 | 157 | **Language Model** 158 | ```python 159 | import requests 160 | import json 161 | 162 | url = "http://localhost:8000/v1/chat/completions" 163 | headers = {"Content-Type": "application/json"} 164 | data = { 165 | "model": "mlx-community/gemma-2-9b-it-4bit", 166 | "messages": [{"role": "user", "content": "What is the capital of France?"}], 167 | "max_tokens": 100 168 | } 169 | 170 | response = requests.post(url, headers=headers, data=json.dumps(data)) 171 | print(response.json()) 172 | ``` 173 | 174 | With streaming: 175 | ```python 176 | import requests 177 | import json 178 | 179 | def process_sse_stream(url, headers, data): 180 | response = requests.post(url, headers=headers, json=data, stream=True) 181 | 182 | if response.status_code != 200: 183 | print(f"Error: Received status code {response.status_code}") 184 | print(response.text) 185 | return 186 | 187 | full_content = "" 188 | 189 | try: 190 | for line in response.iter_lines(): 191 | if line: 192 | line = line.decode('utf-8') 193 | if line.startswith('data: '): 194 | event_data = line[6:] # Remove 'data: ' prefix 195 | if event_data == '[DONE]': 196 | print("\nStream finished. ✅") 197 | break 198 | try: 199 | chunk_data = json.loads(event_data) 200 | content = chunk_data['choices'][0]['delta']['content'] 201 | full_content += content 202 | print(content, end='', flush=True) 203 | except json.JSONDecodeError: 204 | print(f"\nFailed to decode JSON: {event_data}") 205 | except KeyError: 206 | print(f"\nUnexpected data structure: {chunk_data}") 207 | 208 | except KeyboardInterrupt: 209 | print("\nStream interrupted by user.") 210 | except requests.exceptions.RequestException as e: 211 | print(f"\nAn error occurred: {e}") 212 | 213 | if __name__ == "__main__": 214 | url = "http://localhost:8000/v1/chat/completions" 215 | headers = {"Content-Type": "application/json"} 216 | data = { 217 | "model": "mlx-community/gemma-2-9b-it-4bit", 218 | "messages": [{"role": "user", "content": "Hi, how are you?"}], 219 | "max_tokens": 500, 220 | "stream": True 221 | } 222 | process_sse_stream(url, headers, data) 223 | ``` 224 | 225 | 4. **Function Calling** 226 | 227 | FastMLX now supports tool calling in accordance with the OpenAI API specification. This feature is available for the following models: 228 | 229 | - Llama 3.1 230 | - Arcee Agent 231 | - C4ai-Command-R-Plus 232 | - Firefunction 233 | - xLAM 234 | 235 | Supported modes: 236 | - Without Streaming 237 | - Parallel Tool Calling 238 | 239 | > Note: Tool choice and OpenAI-compliant streaming for function calling are currently under development. 240 | 241 | Here's an example of how to use function calling with FastMLX: 242 | 243 | ```python 244 | import requests 245 | import json 246 | 247 | url = "http://localhost:8000/v1/chat/completions" 248 | headers = {"Content-Type": "application/json"} 249 | data = { 250 | "model": "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit", 251 | "messages": [ 252 | { 253 | "role": "user", 254 | "content": "What's the weather like in San Francisco and Washington?" 255 | } 256 | ], 257 | "tools": [ 258 | { 259 | "name": "get_current_weather", 260 | "description": "Get the current weather", 261 | "parameters": { 262 | "type": "object", 263 | "properties": { 264 | "location": { 265 | "type": "string", 266 | "description": "The city and state, e.g. San Francisco, CA" 267 | }, 268 | "format": { 269 | "type": "string", 270 | "enum": ["celsius", "fahrenheit"], 271 | "description": "The temperature unit to use. Infer this from the user's location." 272 | } 273 | }, 274 | "required": ["location", "format"] 275 | } 276 | } 277 | ], 278 | "max_tokens": 150, 279 | "temperature": 0.7, 280 | "stream": False, 281 | } 282 | 283 | response = requests.post(url, headers=headers, data=json.dumps(data)) 284 | print(response.json()) 285 | ``` 286 | 287 | This example demonstrates how to use the `get_current_weather` tool with the Llama 3.1 model. The API will process the user's question and use the provided tool to fetch the required information. 288 | 289 | Please note that while streaming is available for regular text generation, the streaming implementation for function calling is still in development and does not yet fully comply with the OpenAI specification. 290 | 291 | 5. **List Supported Models** 292 | 293 | To see all vision and language models supported by MLX: 294 | 295 | ```python 296 | import requests 297 | 298 | url = "http://localhost:8000/v1/supported_models" 299 | response = requests.get(url) 300 | print(response.json()) 301 | ``` 302 | 303 | 6. **Add Available Model** 304 | 305 | You can add new models to the API: 306 | 307 | ```python 308 | import requests 309 | 310 | url = "http://localhost:8000/v1/models" 311 | params = { 312 | "model_name": "hf-repo-or-path", 313 | } 314 | 315 | response = requests.post(url, params=params) 316 | print(response.json()) 317 | ``` 318 | 319 | 7. **List Available Models** 320 | 321 | Provides the list of available models that have been added in a OpenAI compliant format: 322 | 323 | ```python 324 | import requests 325 | 326 | url = "http://localhost:8000/v1/models" 327 | response = requests.get(url) 328 | print(response.json()) 329 | ``` 330 | 331 | 8. **Delete Models** 332 | 333 | To remove any models loaded to memory: 334 | 335 | ```python 336 | import requests 337 | 338 | url = "http://localhost:8000/v1/models" 339 | params = { 340 | "model_name": "hf-repo-or-path", 341 | } 342 | response = requests.delete(url, params=params) 343 | print(response) 344 | ``` 345 | 346 | For more detailed usage instructions and API documentation, please refer to the [full documentation](https://Blaizzy.github.io/fastmlx). 347 | -------------------------------------------------------------------------------- /fastmlx/fastmlx.py: -------------------------------------------------------------------------------- 1 | """Main module for FastMLX API server. 2 | 3 | This module provides a FastAPI-based server for hosting MLX models, 4 | including Vision Language Models (VLMs) and Language Models (LMs). 5 | It offers an OpenAI-compatible API for chat completions and model management. 6 | """ 7 | 8 | import argparse 9 | import asyncio 10 | import os 11 | import time 12 | from typing import Any, Dict, List 13 | from urllib.parse import unquote 14 | 15 | from fastapi import FastAPI, HTTPException, Response 16 | from fastapi.middleware.cors import CORSMiddleware 17 | from fastapi.responses import JSONResponse, StreamingResponse 18 | 19 | from .types.chat.chat_completion import ( 20 | ChatCompletionRequest, 21 | ChatCompletionResponse, 22 | ChatMessage, 23 | Usage, 24 | ) 25 | from .types.model import SupportedModels 26 | 27 | try: 28 | from mlx_vlm import generate as vlm_generate 29 | from mlx_vlm.prompt_utils import apply_chat_template as apply_vlm_chat_template 30 | from mlx_vlm.utils import load_config 31 | 32 | from .utils import ( 33 | MODEL_REMAPPING, 34 | MODELS, 35 | apply_lm_chat_template, 36 | get_eom_token, 37 | get_tool_prompt, 38 | handle_function_calls, 39 | lm_generate, 40 | lm_stream_generator, 41 | load_lm_model, 42 | load_vlm_model, 43 | vlm_stream_generator, 44 | ) 45 | 46 | MLX_AVAILABLE = True 47 | except ImportError: 48 | print("Warning: mlx or mlx_lm not available. Some functionality will be limited.") 49 | MLX_AVAILABLE = False 50 | 51 | 52 | class ModelProvider: 53 | def __init__(self): 54 | self.models: Dict[str, Dict[str, Any]] = {} 55 | self.lock = asyncio.Lock() 56 | 57 | def load_model(self, model_name: str): 58 | if model_name not in self.models: 59 | config = load_config(model_name) 60 | model_type = MODEL_REMAPPING.get(config["model_type"], config["model_type"]) 61 | if model_type in MODELS["vlm"]: 62 | self.models[model_name] = load_vlm_model(model_name, config) 63 | else: 64 | self.models[model_name] = load_lm_model(model_name, config) 65 | 66 | return self.models[model_name] 67 | 68 | async def remove_model(self, model_name: str) -> bool: 69 | async with self.lock: 70 | if model_name in self.models: 71 | del self.models[model_name] 72 | return True 73 | return False 74 | 75 | async def get_available_models(self): 76 | async with self.lock: 77 | return list(self.models.keys()) 78 | 79 | 80 | app = FastAPI() 81 | 82 | 83 | def int_or_float(value): 84 | 85 | try: 86 | return int(value) 87 | except ValueError: 88 | try: 89 | return float(value) 90 | except ValueError: 91 | raise argparse.ArgumentTypeError(f"{value} is not an int or float") 92 | 93 | 94 | def calculate_default_workers(workers: int = 2) -> int: 95 | if num_workers_env := os.getenv("FASTMLX_NUM_WORKERS"): 96 | try: 97 | workers = int(num_workers_env) 98 | except ValueError: 99 | workers = max(1, int(os.cpu_count() * float(num_workers_env))) 100 | return workers 101 | 102 | 103 | # Add CORS middleware 104 | def setup_cors(app: FastAPI, allowed_origins: List[str]): 105 | app.add_middleware( 106 | CORSMiddleware, 107 | allow_origins=allowed_origins, 108 | allow_credentials=True, 109 | allow_methods=["*"], 110 | allow_headers=["*"], 111 | ) 112 | 113 | 114 | # Initialize the ModelProvider 115 | model_provider = ModelProvider() 116 | 117 | 118 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 119 | async def chat_completion(request: ChatCompletionRequest): 120 | """ 121 | Handle chat completion requests for both VLM and LM models. 122 | 123 | Args: 124 | request (ChatCompletionRequest): The chat completion request. 125 | 126 | Returns: 127 | ChatCompletionResponse (ChatCompletionResponse): The generated chat completion response. 128 | 129 | Raises: 130 | HTTPException (str): If MLX library is not available. 131 | """ 132 | if not MLX_AVAILABLE: 133 | raise HTTPException(status_code=500, detail="MLX library not available") 134 | 135 | stream = request.stream 136 | model_data = model_provider.load_model(request.model) 137 | model = model_data["model"] 138 | config = model_data["config"] 139 | model_type = MODEL_REMAPPING.get(config["model_type"], config["model_type"]) 140 | stop_words = get_eom_token(request.model) 141 | 142 | if model_type in MODELS["vlm"]: 143 | processor = model_data["processor"] 144 | image_processor = model_data["image_processor"] 145 | 146 | image_url = None 147 | chat_messages = [] 148 | 149 | for msg in request.messages: 150 | if isinstance(msg.content, str): 151 | chat_messages.append({"role": msg.role, "content": msg.content}) 152 | elif isinstance(msg.content, list): 153 | text_content = "" 154 | for content_part in msg.content: 155 | if content_part.type == "text": 156 | text_content += content_part.text + " " 157 | elif content_part.type == "image_url": 158 | image_url = content_part.image_url["url"] 159 | chat_messages.append( 160 | {"role": msg.role, "content": text_content.strip()} 161 | ) 162 | 163 | if not image_url and model_type in MODELS["vlm"]: 164 | raise HTTPException( 165 | status_code=400, detail="Image URL not provided for VLM model" 166 | ) 167 | 168 | prompt = "" 169 | if model.config.model_type != "paligemma": 170 | prompt = apply_vlm_chat_template(processor, config, chat_messages) 171 | else: 172 | prompt = chat_messages[-1]["content"] 173 | 174 | if stream: 175 | return StreamingResponse( 176 | vlm_stream_generator( 177 | model, 178 | request.model, 179 | processor, 180 | image_url, 181 | prompt, 182 | image_processor, 183 | request.max_tokens, 184 | request.temperature, 185 | stream_options=request.stream_options, 186 | ), 187 | media_type="text/event-stream", 188 | ) 189 | else: 190 | # Generate the response 191 | output = vlm_generate( 192 | model, 193 | processor, 194 | image_url, 195 | prompt, 196 | image_processor, 197 | max_tokens=request.max_tokens, 198 | temp=request.temperature, 199 | verbose=False, 200 | ) 201 | 202 | else: 203 | # Add function calling information to the prompt 204 | if request.tools and "firefunction-v2" not in request.model: 205 | # Handle system prompt 206 | if request.messages and request.messages[0].role == "system": 207 | pass 208 | else: 209 | # Generate system prompt based on model and tools 210 | prompt, user_role = get_tool_prompt( 211 | request.model, 212 | [tool.model_dump() for tool in request.tools], 213 | request.messages[-1].content, 214 | ) 215 | 216 | if user_role: 217 | request.messages[-1].content = prompt 218 | else: 219 | # Insert the system prompt at the beginning of the messages 220 | request.messages.insert( 221 | 0, ChatMessage(role="system", content=prompt) 222 | ) 223 | 224 | tokenizer = model_data["tokenizer"] 225 | 226 | chat_messages = [ 227 | {"role": msg.role, "content": msg.content} for msg in request.messages 228 | ] 229 | prompt = apply_lm_chat_template(tokenizer, chat_messages, request) 230 | 231 | if stream: 232 | return StreamingResponse( 233 | lm_stream_generator( 234 | model, 235 | request.model, 236 | tokenizer, 237 | prompt, 238 | request.max_tokens, 239 | request.temperature, 240 | stop_words=stop_words, 241 | stream_options=request.stream_options, 242 | ), 243 | media_type="text/event-stream", 244 | ) 245 | else: 246 | output, token_length_info = lm_generate( 247 | model, 248 | tokenizer, 249 | prompt, 250 | request.max_tokens, 251 | temp=request.temperature, 252 | stop_words=stop_words, 253 | ) 254 | 255 | # Parse the output to check for function calls 256 | return handle_function_calls(output, request, token_length_info) 257 | 258 | 259 | @app.get("/v1/supported_models", response_model=SupportedModels) 260 | async def get_supported_models(): 261 | """ 262 | Get a list of supported model types for VLM and LM. 263 | 264 | Returns: 265 | JSONResponse (json): A JSON response containing the supported models. 266 | """ 267 | return JSONResponse(content=MODELS) 268 | 269 | 270 | @app.get("/v1/models") 271 | async def list_models(): 272 | """ 273 | Get list of models - provided in OpenAI API compliant format. 274 | """ 275 | models = await model_provider.get_available_models() 276 | models_data = [] 277 | for model in models: 278 | models_data.append( 279 | { 280 | "id": model, 281 | "object": "model", 282 | "created": int(time.time()), 283 | "owned_by": "system", 284 | } 285 | ) 286 | return {"object": "list", "data": models_data} 287 | 288 | 289 | @app.post("/v1/models") 290 | async def add_model(model_name: str): 291 | """ 292 | Add a new model to the API. 293 | 294 | Args: 295 | model_name (str): The name of the model to add. 296 | 297 | Returns: 298 | dict (dict): A dictionary containing the status of the operation. 299 | """ 300 | model_provider.load_model(model_name) 301 | return {"status": "success", "message": f"Model {model_name} added successfully"} 302 | 303 | 304 | @app.delete("/v1/models") 305 | async def remove_model(model_name: str): 306 | """ 307 | Remove a model from the API. 308 | 309 | Args: 310 | model_name (str): The name of the model to remove. 311 | 312 | Returns: 313 | Response (str): A 204 No Content response if successful. 314 | 315 | Raises: 316 | HTTPException (str): If the model is not found. 317 | """ 318 | model_name = unquote(model_name).strip('"') 319 | removed = await model_provider.remove_model(model_name) 320 | if removed: 321 | return Response(status_code=204) # 204 No Content - successful deletion 322 | else: 323 | raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") 324 | 325 | 326 | def run(): 327 | parser = argparse.ArgumentParser(description="FastMLX API server") 328 | parser.add_argument( 329 | "--allowed-origins", 330 | nargs="+", 331 | default=["*"], 332 | help="List of allowed origins for CORS", 333 | ) 334 | parser.add_argument( 335 | "--host", type=str, default="0.0.0.0", help="Host to run the server on" 336 | ) 337 | parser.add_argument( 338 | "--port", type=int, default=8000, help="Port to run the server on" 339 | ) 340 | parser.add_argument( 341 | "--reload", 342 | type=bool, 343 | default=False, 344 | help="Enable auto-reload of the server. Only works when 'workers' is set to None.", 345 | ) 346 | 347 | parser.add_argument( 348 | "--workers", 349 | type=int_or_float, 350 | default=calculate_default_workers(), 351 | help="""Number of workers. Overrides the `FASTMLX_NUM_WORKERS` env variable. 352 | Can be either an int or a float. 353 | If an int, it will be the number of workers to use. 354 | If a float, number of workers will be this fraction of the number of CPU cores available, with a minimum of 1. 355 | Defaults to the `FASTMLX_NUM_WORKERS` env variable if set and to 2 if not. 356 | To use all available CPU cores, set it to 1.0. 357 | 358 | Examples: 359 | --workers 1 (will use 1 worker) 360 | --workers 1.0 (will use all available CPU cores) 361 | --workers 0.5 (will use half the number of CPU cores available) 362 | --workers 0.0 (will use 1 worker)""", 363 | ) 364 | 365 | args = parser.parse_args() 366 | if isinstance(args.workers, float): 367 | args.workers = max(1, int(os.cpu_count() * args.workers)) 368 | 369 | setup_cors(app, args.allowed_origins) 370 | 371 | import uvicorn 372 | 373 | uvicorn.run( 374 | "fastmlx:app", 375 | host=args.host, 376 | port=args.port, 377 | reload=args.reload, 378 | workers=args.workers, 379 | loop="asyncio", 380 | ) 381 | 382 | 383 | if __name__ == "__main__": 384 | run() 385 | -------------------------------------------------------------------------------- /tests/test_fastmlx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for `fastmlx` package.""" 4 | 5 | import json 6 | from unittest.mock import MagicMock, patch 7 | 8 | import pytest 9 | from fastapi.testclient import TestClient 10 | 11 | # Import the actual classes and functions 12 | from fastmlx import ( 13 | ChatCompletionRequest, 14 | ChatCompletionResponse, 15 | ChatMessage, 16 | ModelProvider, 17 | Usage, 18 | app, 19 | handle_function_calls, 20 | ) 21 | 22 | 23 | # Create mock classes that inherit from the original classes 24 | class MockModelProvider(ModelProvider): 25 | def __init__(self): 26 | super().__init__() 27 | self.models = {} 28 | 29 | def load_model(self, model_name: str): 30 | if model_name not in self.models: 31 | model_type = "vlm" if "llava" in model_name.lower() else "lm" 32 | self.models[model_name] = { 33 | "model": MagicMock(), 34 | "processor": MagicMock(), 35 | "tokenizer": MagicMock(), 36 | "image_processor": MagicMock() if model_type == "vlm" else None, 37 | "config": {"model_type": model_type}, 38 | } 39 | return self.models[model_name] 40 | 41 | async def remove_model(self, model_name: str) -> bool: 42 | if model_name in self.models: 43 | del self.models[model_name] 44 | return True 45 | return False 46 | 47 | async def get_available_models(self): 48 | return list(self.models.keys()) 49 | 50 | 51 | # Mock MODELS dictionary 52 | MODELS = {"vlm": ["llava"], "lm": ["phi"]} 53 | 54 | 55 | # Mock functions 56 | def mock_generate(*args, **kwargs): 57 | return "generated response", { 58 | "prompt_tokens": 10, 59 | "completion_tokens": 20, 60 | "total_tokens": 30, 61 | } 62 | 63 | 64 | def mock_vlm_stream_generate(*args, **kwargs): 65 | yield "Hello" 66 | yield " world" 67 | yield "!" 68 | 69 | 70 | def mock_lm_stream_generate(*args, **kwargs): 71 | yield "Testing" 72 | yield " stream" 73 | yield " generation" 74 | 75 | 76 | @pytest.fixture(scope="module") 77 | def client(): 78 | # Apply patches 79 | with patch("fastmlx.fastmlx.model_provider", MockModelProvider()), patch( 80 | "fastmlx.fastmlx.vlm_generate", mock_generate 81 | ), patch("fastmlx.fastmlx.lm_generate", mock_generate), patch( 82 | "fastmlx.fastmlx.MODELS", MODELS 83 | ), patch( 84 | "fastmlx.utils.vlm_stream_generate", mock_vlm_stream_generate 85 | ), patch( 86 | "fastmlx.utils.lm_stream_generate", mock_lm_stream_generate 87 | ): 88 | yield TestClient(app) 89 | 90 | 91 | def test_chat_completion_vlm(client): 92 | request = ChatCompletionRequest( 93 | model="test_llava_model", 94 | messages=[ChatMessage(role="user", content="Hello")], 95 | image="test_image", 96 | ) 97 | response = client.post( 98 | "/v1/chat/completions", json=json.loads(request.model_dump_json()) 99 | ) 100 | 101 | assert response.status_code == 200 102 | assert "generated response" in response.json()["choices"][0]["message"]["content"] 103 | assert "usage" in response.json() 104 | usage = response.json()["usage"] 105 | assert "prompt_tokens" in usage 106 | assert "completion_tokens" in usage 107 | assert "total_tokens" in usage 108 | 109 | 110 | def test_chat_completion_lm(client): 111 | request = ChatCompletionRequest( 112 | model="test_phi_model", messages=[ChatMessage(role="user", content="Hello")] 113 | ) 114 | response = client.post( 115 | "/v1/chat/completions", json=json.loads(request.model_dump_json()) 116 | ) 117 | 118 | assert response.status_code == 200 119 | assert "generated response" in response.json()["choices"][0]["message"]["content"] 120 | assert "usage" in response.json() 121 | usage = response.json()["usage"] 122 | assert "prompt_tokens" in usage 123 | assert "completion_tokens" in usage 124 | assert "total_tokens" in usage 125 | 126 | 127 | @pytest.mark.asyncio 128 | async def test_vlm_streaming(client): 129 | 130 | # Mock the vlm_stream_generate function 131 | response = client.post( 132 | "/v1/chat/completions", 133 | json={ 134 | "model": "test_llava_model", 135 | "messages": [{"role": "user", "content": "Describe this image"}], 136 | "image": "base64_encoded_image_data", 137 | "stream": True, 138 | }, 139 | ) 140 | 141 | assert response.status_code == 200 142 | assert response.headers["content-type"].startswith("text/event-stream") 143 | 144 | chunks = list(response.iter_lines()) 145 | assert len(chunks) == 8 # 7 content chunks + [DONE] 146 | for chunk in chunks[:-2]: # Exclude the [DONE] message 147 | if chunk: 148 | chunk = chunk.split("data: ")[1] 149 | data = json.loads(chunk) 150 | assert "id" in data 151 | assert data["object"] == "chat.completion.chunk" 152 | assert "created" in data 153 | assert data["model"] == "test_llava_model" 154 | assert len(data["choices"]) == 1 155 | assert data["choices"][0]["index"] == 0 156 | assert "delta" in data["choices"][0] 157 | assert "role" in data["choices"][0]["delta"] 158 | assert "content" in data["choices"][0]["delta"] 159 | if "usage" in data: 160 | usage = data["usage"] 161 | assert "prompt_tokens" in usage 162 | assert "completion_tokens" in usage 163 | assert "total_tokens" in usage 164 | 165 | assert chunks[-2] == "data: [DONE]" 166 | 167 | 168 | @pytest.mark.asyncio 169 | async def test_lm_streaming(client): 170 | 171 | # Mock the lm_stream_generate function 172 | response = client.post( 173 | "/v1/chat/completions", 174 | json={ 175 | "model": "test_phi_model", 176 | "messages": [{"role": "user", "content": "Hello, how are you?"}], 177 | "stream": True, 178 | }, 179 | ) 180 | 181 | assert response.status_code == 200 182 | assert response.headers["content-type"].startswith("text/event-stream") 183 | 184 | chunks = list(response.iter_lines()) 185 | assert len(chunks) == 8 # 7 content chunks + [DONE] 186 | 187 | for chunk in chunks[:-2]: # Exclude the [DONE] message 188 | if chunk: 189 | chunk = chunk.split("data: ")[1] 190 | data = json.loads(chunk) 191 | assert "id" in data 192 | assert data["object"] == "chat.completion.chunk" 193 | assert "created" in data 194 | assert data["model"] == "test_phi_model" 195 | assert len(data["choices"]) == 1 196 | assert data["choices"][0]["index"] == 0 197 | assert "delta" in data["choices"][0] 198 | assert "role" in data["choices"][0]["delta"] 199 | assert "content" in data["choices"][0]["delta"] 200 | if "usage" in data: 201 | usage = data["usage"] 202 | assert "prompt_tokens" in usage 203 | assert "completion_tokens" in usage 204 | assert "total_tokens" in usage 205 | 206 | assert chunks[-2] == "data: [DONE]" 207 | 208 | 209 | def test_get_supported_models(client): 210 | response = client.get("/v1/supported_models") 211 | assert response.status_code == 200 212 | data = response.json() 213 | assert "vlm" in data 214 | assert "lm" in data 215 | assert data["vlm"] == ["llava"] 216 | assert data["lm"] == ["phi"] 217 | 218 | 219 | def test_list_models(client): 220 | client.post("/v1/models?model_name=test_llava_model") 221 | client.post("/v1/models?model_name=test_phi_model") 222 | 223 | response = client.get("/v1/models") 224 | 225 | assert response.status_code == 200 226 | response_data = response.json() 227 | assert response_data["object"] == "list" 228 | model_ids = {model["id"] for model in response_data["data"]} 229 | assert model_ids == {"test_llava_model", "test_phi_model"} 230 | 231 | 232 | def test_add_model(client): 233 | response = client.post("/v1/models?model_name=new_llava_model") 234 | 235 | assert response.status_code == 200 236 | assert response.json() == { 237 | "status": "success", 238 | "message": "Model new_llava_model added successfully", 239 | } 240 | 241 | 242 | def test_remove_model(client): 243 | # Add a model 244 | response = client.post("/v1/models?model_name=test_model") 245 | assert response.status_code == 200 246 | 247 | # Verify the model is added 248 | response = client.get("/v1/models") 249 | response_data = response.json() 250 | model_ids = {model["id"] for model in response_data["data"]} 251 | assert "test_model" in model_ids 252 | 253 | # Remove the model 254 | response = client.delete("/v1/models?model_name=test_model") 255 | assert response.status_code == 204 256 | 257 | # Verify the model is removed 258 | response = client.get("/v1/models") 259 | response_data = response.json() 260 | model_ids = {model["id"] for model in response_data["data"]} 261 | assert "test_model" not in model_ids 262 | 263 | # Try to remove a non-existent model 264 | response = client.delete("/v1/models?model_name=non_existent_model") 265 | assert response.status_code == 404 266 | assert "Model 'non_existent_model' not found" in response.json()["detail"] 267 | 268 | 269 | def test_handle_function_calls_json_format(): 270 | output = """Here's the weather forecast: 271 | {"tool_calls": [{"name": "get_weather", "arguments": {"location": "New York", "date": "2023-08-15"}}]} 272 | """ 273 | request = MagicMock() 274 | request.model = "test_model" 275 | token_info = MagicMock() 276 | token_info.prompt_tokens = 10 277 | token_info.completion_tokens = 20 278 | token_info.total_tokens = 30 279 | token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) 280 | 281 | response = handle_function_calls(output, request, token_info) 282 | 283 | assert isinstance(response, ChatCompletionResponse) 284 | assert len(response.tool_calls) == 1 285 | assert response.tool_calls[0].function.name == "get_weather" 286 | assert json.loads(response.tool_calls[0].function.arguments) == { 287 | "location": "New York", 288 | "date": "2023-08-15", 289 | } 290 | assert "Here's the weather forecast:" in response.choices[0]["message"]["content"] 291 | assert '{"tool_calls":' not in response.choices[0]["message"]["content"] 292 | assert response.usage 293 | usage = response.usage 294 | assert usage.prompt_tokens == 10 295 | assert usage.completion_tokens == 20 296 | assert usage.total_tokens == 30 297 | 298 | 299 | def test_handle_function_calls_xml_format_old(): 300 | output = """Let me check that for you. 301 | 302 | {"symbol": "AAPL"} 303 | 304 | """ 305 | request = MagicMock() 306 | request.model = "test_model" 307 | token_info = MagicMock() 308 | token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) 309 | 310 | response = handle_function_calls(output, request, token_info) 311 | 312 | assert isinstance(response, ChatCompletionResponse) 313 | assert len(response.tool_calls) == 1 314 | assert response.tool_calls[0].function.name == "get_stock_price" 315 | assert json.loads(response.tool_calls[0].function.arguments) == {"symbol": "AAPL"} 316 | assert "Let me check that for you." in response.choices[0]["message"]["content"] 317 | assert "" not in response.choices[0]["message"]["content"] 318 | assert response.usage 319 | usage = response.usage 320 | assert usage.prompt_tokens == 10 321 | assert usage.completion_tokens == 20 322 | assert usage.total_tokens == 30 323 | 324 | 325 | def test_handle_function_calls_xml_format_new(): 326 | output = """I'll get that information for you. 327 | 328 | 329 | search_database 330 | latest smartphones 331 | 5 332 | 333 | 334 | """ 335 | request = MagicMock() 336 | request.model = "test_model" 337 | token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) 338 | 339 | response = handle_function_calls(output, request, token_info) 340 | 341 | assert isinstance(response, ChatCompletionResponse) 342 | assert len(response.tool_calls) == 1 343 | assert response.tool_calls[0].function.name == "search_database" 344 | assert json.loads(response.tool_calls[0].function.arguments) == { 345 | "query": "latest smartphones", 346 | "limit": "5", 347 | } 348 | assert ( 349 | "I'll get that information for you." 350 | in response.choices[0]["message"]["content"] 351 | ) 352 | assert "" not in response.choices[0]["message"]["content"] 353 | assert response.usage 354 | usage = response.usage 355 | assert usage.prompt_tokens == 10 356 | assert usage.completion_tokens == 20 357 | assert usage.total_tokens == 30 358 | 359 | 360 | def test_handle_function_calls_functools_format(): 361 | output = """Here are the results: 362 | functools[{"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}] 363 | """ 364 | request = MagicMock() 365 | request.model = "test_model" 366 | token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) 367 | 368 | response = handle_function_calls(output, request, token_info) 369 | 370 | assert isinstance(response, ChatCompletionResponse) 371 | assert response.tool_calls is not None 372 | assert len(response.tool_calls) == 1 373 | assert response.tool_calls[0].function.name == "get_current_weather" 374 | assert json.loads(response.tool_calls[0].function.arguments) == { 375 | "location": "San Francisco, CA", 376 | "format": "fahrenheit", 377 | } 378 | assert "Here are the results:" in response.choices[0]["message"]["content"] 379 | assert "functools[" not in response.choices[0]["message"]["content"] 380 | assert response.usage 381 | usage = response.usage 382 | assert usage.prompt_tokens == 10 383 | assert usage.completion_tokens == 20 384 | assert usage.total_tokens == 30 385 | 386 | 387 | # Add a new test for multiple function calls in functools format 388 | def test_handle_function_calls_multiple_functools(): 389 | output = """Here are the results: 390 | functools[{"name": "get_weather", "arguments": {"location": "New York"}}, {"name": "get_time", "arguments": {"timezone": "EST"}}] 391 | """ 392 | request = MagicMock() 393 | request.model = "test_model" 394 | token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) 395 | 396 | response = handle_function_calls(output, request, token_info) 397 | assert isinstance(response, ChatCompletionResponse) 398 | assert response.tool_calls is not None 399 | assert len(response.tool_calls) == 2 400 | assert response.tool_calls[0].function.name == "get_weather" 401 | assert json.loads(response.tool_calls[0].function.arguments) == { 402 | "location": "New York" 403 | } 404 | assert response.tool_calls[1].function.name == "get_time" 405 | assert json.loads(response.tool_calls[1].function.arguments) == {"timezone": "EST"} 406 | assert "Here are the results:" in response.choices[0]["message"]["content"] 407 | assert "functools[" not in response.choices[0]["message"]["content"] 408 | assert response.usage 409 | usage = response.usage 410 | assert usage.prompt_tokens == 10 411 | assert usage.completion_tokens == 20 412 | assert usage.total_tokens == 30 413 | 414 | 415 | if __name__ == "__main__": 416 | pytest.main(["-v", __file__]) 417 | -------------------------------------------------------------------------------- /fastmlx/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import time 5 | from contextlib import contextmanager 6 | from datetime import datetime 7 | from typing import Any, Dict, Generator, List, Union 8 | 9 | from jinja2 import Environment, FileSystemLoader 10 | 11 | from .types.chat.chat_completion import ( 12 | ChatCompletionChunk, 13 | ChatCompletionRequest, 14 | ChatCompletionResponse, 15 | FunctionCall, 16 | ToolCall, 17 | Usage, 18 | ) 19 | 20 | # MLX Imports 21 | try: 22 | import mlx.core as mx 23 | from mlx_lm import load as lm_load 24 | from mlx_lm import models as lm_models 25 | from mlx_lm.sample_utils import make_sampler 26 | from mlx_lm.tokenizer_utils import TokenizerWrapper 27 | from mlx_lm.utils import generate_step 28 | from mlx_lm.utils import stream_generate as lm_stream_generate 29 | from mlx_vlm import load as vlm_load 30 | from mlx_vlm import models as vlm_models 31 | from mlx_vlm.utils import load_image_processor 32 | from mlx_vlm.utils import stream_generate as vlm_stream_generate 33 | except ImportError: 34 | print("Warning: mlx or mlx_lm not available. Some functionality will be limited.") 35 | 36 | 37 | TOOLS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "tools")) 38 | 39 | 40 | def get_model_type_list(models, type="vlm"): 41 | 42 | # Get the directory path of the models package 43 | models_dir = os.path.dirname(models.__file__) 44 | 45 | # List all items in the models directory 46 | all_items = os.listdir(models_dir) 47 | 48 | if type == "vlm": 49 | submodules = [ 50 | item 51 | for item in all_items 52 | if os.path.isdir(os.path.join(models_dir, item)) 53 | and not item.startswith(".") 54 | and item != "__pycache__" 55 | ] 56 | return submodules 57 | else: 58 | 59 | return [item for item in all_items if not item.startswith("__")] 60 | 61 | 62 | MODELS = { 63 | "vlm": get_model_type_list(vlm_models), 64 | "lm": get_model_type_list(lm_models, "lm"), 65 | } 66 | MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"} 67 | 68 | 69 | @contextmanager 70 | def working_directory(path): 71 | """A context manager to change the working directory temporarily.""" 72 | current_path = os.getcwd() 73 | try: 74 | os.chdir(path) 75 | yield 76 | finally: 77 | os.chdir(current_path) 78 | 79 | 80 | def load_tools_config(): 81 | with working_directory(TOOLS_PATH): 82 | with open("config.json", "r") as file: 83 | return json.load(file) 84 | 85 | 86 | def get_model_type(model_name, available_models): 87 | # Convert model name to lowercase for case-insensitive matching 88 | model_name_lower = model_name.lower().replace(".", "_") 89 | 90 | # Check if any of the available model types are in the model name 91 | for model_type in available_models: 92 | if model_type != "default" and model_type in model_name_lower: 93 | return model_type 94 | 95 | # If no match is found, return 'default' 96 | return "default" 97 | 98 | 99 | def get_tool_prompt(model_name, tools, prompt): 100 | tool_config = load_tools_config() 101 | available_models = tool_config["models"].keys() 102 | model_type = get_model_type(model_name, available_models) 103 | model_config = tool_config["models"].get( 104 | model_type, tool_config["models"]["default"] 105 | ) 106 | env = Environment(loader=FileSystemLoader(TOOLS_PATH)) 107 | template = env.get_template(model_config["prompt_template"]) 108 | if model_config.get("query", False): 109 | return ( 110 | template.render( 111 | tools=tools, 112 | parallel_tool_calling=model_config.get("parallel_tool_calling", False), 113 | current_date=datetime.now().strftime("%d %b %Y"), 114 | query=prompt, 115 | ), 116 | True, 117 | ) 118 | else: 119 | return ( 120 | template.render( 121 | tools=tools, 122 | parallel_tool_calling=model_config.get("parallel_tool_calling", False), 123 | current_date=datetime.now().strftime("%d %b %Y"), 124 | ), 125 | False, 126 | ) 127 | 128 | 129 | def get_eom_token(model_name): 130 | tool_config = load_tools_config() 131 | available_models = tool_config["models"].keys() 132 | model_type = get_model_type(model_name, available_models) 133 | model_config = tool_config["models"].get( 134 | model_type, tool_config["models"]["default"] 135 | ) 136 | eom_token = model_config.get("eom_token", None) 137 | return eom_token 138 | 139 | 140 | def apply_lm_chat_template( 141 | tokenizer: Any, chat_messages: List[Dict], request: ChatCompletionRequest 142 | ) -> str: 143 | if tokenizer.chat_template is not None and hasattr( 144 | tokenizer, "apply_chat_template" 145 | ): 146 | if "firefunction-v2" in request.model: 147 | now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 148 | return tokenizer.apply_chat_template( 149 | chat_messages, 150 | functions=json.dumps( 151 | [tool.model_dump() for tool in request.tools], indent=4 152 | ), 153 | datetime=now, 154 | tokenize=False, 155 | ) 156 | else: 157 | return tokenizer.apply_chat_template( 158 | chat_messages, 159 | tokenize=False, 160 | add_generation_prompt=True, 161 | ) 162 | else: 163 | return request.messages[-1].content 164 | 165 | 166 | def handle_function_calls( 167 | output: str, request: ChatCompletionRequest, token_info: Usage 168 | ) -> ChatCompletionResponse: 169 | tool_calls = [] 170 | 171 | # Check for JSON format tool calls 172 | json_match = re.search(r'\{.*"tool_calls":\s*\[.*\].*\}', output, re.DOTALL) 173 | if json_match: 174 | try: 175 | json_data = json.loads(json_match.group()) 176 | for call in json_data.get("tool_calls", []): 177 | tool_calls.append( 178 | ToolCall( 179 | id=f"call_{os.urandom(4).hex()}", 180 | function=FunctionCall( 181 | name=call["name"], arguments=json.dumps(call["arguments"]) 182 | ), 183 | ) 184 | ) 185 | # Remove the JSON from the output 186 | output = re.sub( 187 | r'\{.*"tool_calls":\s*\[.*\].*\}', "", output, flags=re.DOTALL 188 | ).strip() 189 | except json.JSONDecodeError as e: 190 | print(f"Error parsing JSON tool calls: {e}") 191 | 192 | # Check for XML-style function calls 193 | # Check for function calls in both old and new XML formats 194 | elif "" in output.lower(): 195 | try: 196 | # Try parsing old format 197 | function_calls = re.findall(r"\s*({[^<>]+})", output) 198 | for function_name, args_str in function_calls: 199 | args = json.loads(args_str) 200 | tool_calls.append( 201 | ToolCall( 202 | id=f"call_{os.urandom(4).hex()}", 203 | function=FunctionCall( 204 | name=function_name, arguments=json.dumps(args) 205 | ), 206 | ) 207 | ) 208 | 209 | # Try parsing new XML format 210 | invoke_blocks = re.findall( 211 | r"(.*?)", output, re.DOTALL | re.IGNORECASE 212 | ) 213 | for block in invoke_blocks: 214 | tool_name = re.search( 215 | r"(.*?)", block, re.IGNORECASE 216 | ) 217 | parameters = re.findall(r"<(\w+)>(.*?)", block, re.IGNORECASE) 218 | 219 | if tool_name: 220 | args = { 221 | param[0].lower(): param[1] 222 | for param in parameters 223 | if param[0].lower() != "tool_name" 224 | } 225 | tool_calls.append( 226 | ToolCall( 227 | id=f"call_{os.urandom(4).hex()}", 228 | function=FunctionCall( 229 | name=tool_name.group(1), arguments=json.dumps(args) 230 | ), 231 | ) 232 | ) 233 | 234 | # Remove the function calls from the output 235 | output = re.sub( 236 | r".*", 237 | "", 238 | output, 239 | flags=re.DOTALL | re.IGNORECASE, 240 | ).strip() 241 | except Exception as e: 242 | print(f"Error parsing function call: {e}") 243 | 244 | elif "functools[" in output: 245 | try: 246 | functools_match = re.search(r"functools\[(.*?)\]", output, re.DOTALL) 247 | if functools_match: 248 | functools_data = json.loads(f"[{functools_match.group(1)}]") 249 | for call in functools_data: 250 | tool_calls.append( 251 | ToolCall( 252 | id=f"call_{os.urandom(4).hex()}", 253 | function=FunctionCall( 254 | name=call["name"], 255 | arguments=json.dumps(call["arguments"]), 256 | ), 257 | ) 258 | ) 259 | # Remove the functools call from the output 260 | output = re.sub( 261 | r"functools\[.*?\]", "", output, flags=re.DOTALL 262 | ).strip() 263 | except Exception as e: 264 | print(f"Error parsing functools call: {e}") 265 | 266 | # Prepare the response 267 | response = ChatCompletionResponse( 268 | id=f"chatcmpl-{os.urandom(4).hex()}", 269 | created=int(time.time()), 270 | model=request.model, 271 | usage=token_info, 272 | choices=[ 273 | { 274 | "index": 0, 275 | "message": {"role": "assistant", "content": output}, 276 | "finish_reason": "stop" if not tool_calls else "tool_call", 277 | } 278 | ], 279 | tool_calls=tool_calls, 280 | ) 281 | 282 | return response 283 | 284 | 285 | # Model Loading and Generation Functions 286 | def load_vlm_model(model_name: str, config: Dict[str, Any]) -> Dict[str, Any]: 287 | model, processor = vlm_load(model_name, {"trust_remote_code": True}) 288 | image_processor = load_image_processor(model_name) 289 | return { 290 | "model": model, 291 | "processor": processor, 292 | "image_processor": image_processor, 293 | "config": config, 294 | } 295 | 296 | 297 | def load_lm_model(model_name: str, config: Dict[str, Any]) -> Dict[str, Any]: 298 | time_start = time.time() 299 | model, tokenizer = lm_load(model_name, model_config=config) 300 | print(f"Model loaded in {time.time() - time_start:.2f} seconds.") 301 | return {"model": model, "tokenizer": tokenizer, "config": config} 302 | 303 | 304 | def vlm_stream_generator( 305 | model, 306 | model_name, 307 | processor, 308 | image, 309 | prompt, 310 | image_processor, 311 | max_tokens, 312 | temperature, 313 | stream_options, 314 | ): 315 | INCLUDE_USAGE = ( 316 | False if stream_options == None else stream_options.get("include_usage", False) 317 | ) 318 | completion_tokens = 0 319 | prompt_tokens = len(mx.array(processor.encode(prompt))) if INCLUDE_USAGE else None 320 | empty_usage: Usage = None 321 | 322 | for token in vlm_stream_generate( 323 | model, 324 | processor, 325 | image, 326 | prompt, 327 | image_processor, 328 | max_tokens=max_tokens, 329 | temp=temperature, 330 | ): 331 | # Update token length info 332 | if INCLUDE_USAGE: 333 | completion_tokens += 1 334 | 335 | chunk = ChatCompletionChunk( 336 | id=f"chatcmpl-{os.urandom(4).hex()}", 337 | created=int(time.time()), 338 | model=model_name, 339 | usage=empty_usage, 340 | choices=[ 341 | { 342 | "index": 0, 343 | "delta": {"role": "assistant", "content": token}, 344 | "finish_reason": None, 345 | } 346 | ], 347 | ) 348 | yield f"data: {json.dumps(chunk.model_dump())}\n\n" 349 | 350 | if INCLUDE_USAGE: 351 | chunk = ChatCompletionChunk( 352 | id=f"chatcmpl-{os.urandom(4).hex()}", 353 | created=int(time.time()), 354 | model=model_name, 355 | choices=[], 356 | usage=Usage( 357 | prompt_tokens=prompt_tokens, 358 | completion_tokens=completion_tokens, 359 | total_tokens=prompt_tokens + completion_tokens, 360 | ), 361 | ) 362 | yield f"data: {json.dumps(chunk.model_dump())}\n\n" 363 | yield "data: [DONE]\n\n" 364 | 365 | 366 | def lm_generate( 367 | model, 368 | tokenizer, 369 | prompt: str, 370 | max_tokens: int = 100, 371 | temp: float = 0.0, 372 | **kwargs, 373 | ) -> Union[str, Generator[str, None, None]]: 374 | """ 375 | Generate a complete response from the model. 376 | 377 | Args: 378 | model (nn.Module): The language model. 379 | tokenizer (PreTrainedTokenizer): The tokenizer. 380 | prompt (str): The string prompt. 381 | max_tokens (int): The maximum number of tokens. Default: ``100``. 382 | verbose (bool): If ``True``, print tokens and timing information. 383 | Default: ``False``. 384 | formatter (Optional[Callable]): A function which takes a token and a 385 | probability and displays it. 386 | kwargs: The remaining options get passed to :func:`generate_step`. 387 | See :func:`generate_step` for more details. 388 | """ 389 | if not isinstance(tokenizer, TokenizerWrapper): 390 | tokenizer = TokenizerWrapper(tokenizer) 391 | 392 | stop_words = kwargs.pop("stop_words", []) 393 | 394 | stop_words_id = ( 395 | tokenizer._tokenizer(stop_words)["input_ids"][0] if stop_words else None 396 | ) 397 | 398 | prompt_tokens = mx.array(tokenizer.encode(prompt)) 399 | prompt_token_len = len(prompt_tokens) 400 | detokenizer = tokenizer.detokenizer 401 | 402 | detokenizer.reset() 403 | 404 | for (token, logprobs), n in zip( 405 | generate_step(prompt_tokens, model, sampler=make_sampler(temp=temp), **kwargs), 406 | range(max_tokens), 407 | ): 408 | if token == tokenizer.eos_token_id or ( 409 | stop_words_id and token in stop_words_id 410 | ): 411 | break 412 | 413 | detokenizer.add_token(token) 414 | 415 | detokenizer.finalize() 416 | 417 | _completion_tokens = len(detokenizer.tokens) 418 | token_length_info: Usage = Usage( 419 | prompt_tokens=prompt_token_len, 420 | completion_tokens=_completion_tokens, 421 | total_tokens=prompt_token_len + _completion_tokens, 422 | ) 423 | return detokenizer.text, token_length_info 424 | 425 | 426 | def lm_stream_generator( 427 | model, 428 | model_name, 429 | tokenizer, 430 | prompt, 431 | max_tokens, 432 | temperature, 433 | stream_options, 434 | **kwargs, 435 | ): 436 | stop_words = kwargs.pop("stop_words", []) 437 | INCLUDE_USAGE = ( 438 | False if stream_options == None else stream_options.get("include_usage", False) 439 | ) 440 | prompt_tokens = len(tokenizer.encode(prompt)) if INCLUDE_USAGE else None 441 | completion_tokens = 0 442 | empty_usage: Usage = None 443 | 444 | for token in lm_stream_generate( 445 | model, tokenizer, prompt, max_tokens=max_tokens, temp=temperature 446 | ): 447 | if stop_words and token in stop_words: 448 | break 449 | 450 | # Update token length info 451 | if INCLUDE_USAGE: 452 | completion_tokens += 1 453 | 454 | chunk = ChatCompletionChunk( 455 | id=f"chatcmpl-{os.urandom(4).hex()}", 456 | created=int(time.time()), 457 | model=model_name, 458 | usage=empty_usage, 459 | choices=[ 460 | { 461 | "index": 0, 462 | "delta": {"role": "assistant", "content": token}, 463 | "finish_reason": None, 464 | } 465 | ], 466 | ) 467 | yield f"data: {json.dumps(chunk.model_dump())}\n\n" 468 | 469 | if INCLUDE_USAGE: 470 | chunk = ChatCompletionChunk( 471 | id=f"chatcmpl-{os.urandom(4).hex()}", 472 | created=int(time.time()), 473 | model=model_name, 474 | choices=[], 475 | usage=Usage( 476 | prompt_tokens=prompt_tokens, 477 | completion_tokens=completion_tokens, 478 | total_tokens=prompt_tokens + completion_tokens, 479 | ), 480 | ) 481 | yield f"data: {json.dumps(chunk.model_dump())}\n\n" 482 | 483 | yield "data: [DONE]\n\n" 484 | --------------------------------------------------------------------------------