├── mlx_lm ├── py.typed ├── models │ ├── __init__.py │ ├── mistral3.py │ ├── pixtral.py │ ├── gemma3.py │ ├── smollm3.py │ ├── kimi_vl.py │ ├── base.py │ ├── qwen.py │ ├── olmo.py │ ├── bitlinear_layers.py │ ├── ernie4_5.py │ ├── starcoder2.py │ ├── exaone.py │ ├── gemma.py │ ├── phi.py │ ├── gpt_bigcode.py │ ├── glm4.py │ ├── helium.py │ ├── qwen3.py │ ├── cohere.py │ ├── mimo.py │ ├── switch_layers.py │ ├── phixtral.py │ └── qwen2.py ├── _version.py ├── tuner │ ├── __init__.py │ └── callbacks.py ├── __init__.py ├── examples │ ├── merge_config.yaml │ ├── generate_response.py │ ├── chat.py │ ├── openai_tool_use.py │ ├── tool_use.py │ ├── lora_config.yaml │ └── pipeline_generate.py ├── README.md ├── MANAGE.md ├── upload.py ├── UPLOAD.md ├── __main__.py ├── quant │ └── utils.py ├── fuse.py ├── chat.py ├── manage.py ├── cache_prompt.py ├── SERVER.md └── LEARNED_QUANTS.md ├── MANIFEST.in ├── requirements.txt ├── .pre-commit-config.yaml ├── LICENSE ├── ACKNOWLEDGMENTS.md ├── tests ├── test_utils_load_model.py ├── test_gguf.py ├── test_tuner_utils.py ├── test_utils.py ├── test_datsets.py ├── test_tokenizers.py ├── test_sample_utils.py └── test_generate.py ├── .circleci └── config.yml ├── setup.py ├── .gitignore ├── CONTRIBUTING.md └── CODE_OF_CONDUCT.md /mlx_lm/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mlx_lm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | recursive-include mlx_lm/ *.py 3 | -------------------------------------------------------------------------------- /mlx_lm/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | __version__ = "0.25.3" 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.25.0 2 | numpy 3 | transformers[sentencepiece]>=4.39.3 4 | protobuf 5 | pyyaml 6 | jinja2 7 | -------------------------------------------------------------------------------- /mlx_lm/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import TrainingArgs, evaluate, train 2 | from .utils import linear_to_lora_layers 3 | -------------------------------------------------------------------------------- /mlx_lm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import os 4 | 5 | from ._version import __version__ 6 | 7 | os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" 8 | 9 | from .convert import convert 10 | from .generate import generate, stream_generate 11 | from .utils import load 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 25.1.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 6.0.0 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black 12 | -------------------------------------------------------------------------------- /mlx_lm/examples/merge_config.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | - OpenPipe/mistral-ft-optimized-1218 3 | - mlabonne/NeuralHermes-2.5-Mistral-7B 4 | method: slerp 5 | parameters: 6 | t: 7 | - filter: self_attn 8 | value: [0, 0.5, 0.3, 0.7, 1] 9 | - filter: mlp 10 | value: [1, 0.5, 0.7, 0.3, 0] 11 | - value: 0.5 12 | -------------------------------------------------------------------------------- /mlx_lm/README.md: -------------------------------------------------------------------------------- 1 | ## Generate Text with MLX and :hugs: Hugging Face 2 | 3 | This an example of large language model text generation that can pull models from 4 | the Hugging Face Hub. 5 | 6 | For more information on this example, see the [README](../README.md) in the 7 | parent directory. 8 | 9 | This package also supports fine tuning with LoRA or QLoRA. For more information 10 | see the [LoRA documentation](LORA.md). 11 | -------------------------------------------------------------------------------- /mlx_lm/MANAGE.md: -------------------------------------------------------------------------------- 1 | # Managing Models 2 | 3 | You can use `mlx-lm` to manage models downloaded locally in your machine. They 4 | are stored in the Hugging Face cache. 5 | 6 | Scan models: 7 | 8 | ```shell 9 | mlx_lm.manage --scan 10 | ``` 11 | 12 | Specify a `--pattern` to get info on a single or specific set of models: 13 | 14 | ```shell 15 | mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit 16 | ``` 17 | 18 | To delete a model (or multiple models): 19 | 20 | ```shell 21 | mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit 22 | ``` 23 | -------------------------------------------------------------------------------- /mlx_lm/upload.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | import argparse 4 | 5 | from .utils import upload_to_hub 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser( 10 | description="Upload a model to the Hugging Face Hub" 11 | ) 12 | 13 | parser.add_argument( 14 | "--path", type=str, default="mlx_model", help="Path to the MLX model." 15 | ) 16 | parser.add_argument( 17 | "--upload-repo", 18 | help="The Hugging Face repo to upload the model to.", 19 | type=str, 20 | ) 21 | args = parser.parse_args() 22 | upload_to_hub(args.path, args.upload_repo) 23 | -------------------------------------------------------------------------------- /mlx_lm/UPLOAD.md: -------------------------------------------------------------------------------- 1 | ### Packaging for PyPI 2 | 3 | Install `build` and `twine`: 4 | 5 | ``` 6 | pip install --user --upgrade build 7 | pip install --user --upgrade twine 8 | ``` 9 | 10 | Generate the source distribution and wheel: 11 | 12 | ``` 13 | python -m build 14 | ``` 15 | 16 | > [!warning] 17 | > Use a test server first 18 | 19 | #### Test Upload 20 | 21 | Upload to test server: 22 | 23 | ``` 24 | python -m twine upload --repository testpypi dist/* 25 | ``` 26 | 27 | Install from test server and check that it works: 28 | 29 | ``` 30 | python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm 31 | ``` 32 | 33 | #### Upload 34 | 35 | ``` 36 | python -m twine upload dist/* 37 | ``` 38 | -------------------------------------------------------------------------------- /mlx_lm/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | import importlib 4 | import sys 5 | 6 | if __name__ == "__main__": 7 | subcommands = { 8 | "quant.awq", 9 | "quant.dwq", 10 | "quant.dynamic_quant", 11 | "cache_prompt", 12 | "chat", 13 | "convert", 14 | "evaluate", 15 | "fuse", 16 | "generate", 17 | "lora", 18 | "server", 19 | "manage", 20 | "upload", 21 | } 22 | if len(sys.argv) < 2: 23 | raise ValueError(f"CLI requires a subcommand in {subcommands}") 24 | subcommand = sys.argv.pop(1) 25 | if subcommand not in subcommands: 26 | raise ValueError(f"CLI requires a subcommand in {subcommands}") 27 | submodule = importlib.import_module(f"mlx_lm.{subcommand}") 28 | submodule.main() 29 | -------------------------------------------------------------------------------- /mlx_lm/examples/generate_response.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | from mlx_lm import generate, load 4 | 5 | # Specify the checkpoint 6 | checkpoint = "mistralai/Mistral-7B-Instruct-v0.3" 7 | 8 | # Load the corresponding model and tokenizer 9 | model, tokenizer = load(path_or_hf_repo=checkpoint) 10 | 11 | # Specify the prompt and conversation history 12 | prompt = "Why is the sky blue?" 13 | conversation = [{"role": "user", "content": prompt}] 14 | 15 | # Transform the prompt into the chat template 16 | prompt = tokenizer.apply_chat_template( 17 | conversation=conversation, add_generation_prompt=True 18 | ) 19 | 20 | # Specify the maximum number of tokens 21 | max_tokens = 1_000 22 | 23 | # Specify if tokens and timing information will be printed 24 | verbose = True 25 | 26 | # Generate a response with the specified settings 27 | response = generate( 28 | model=model, 29 | tokenizer=tokenizer, 30 | prompt=prompt, 31 | max_tokens=max_tokens, 32 | verbose=verbose, 33 | ) 34 | -------------------------------------------------------------------------------- /mlx_lm/quant/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from pathlib import Path 4 | 5 | import mlx.core as mx 6 | 7 | 8 | def load_data(tokenizer, num_samples: int, sequence_length: int) -> mx.array: 9 | save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt" 10 | if not save_dir.exists(): 11 | from urllib import request 12 | 13 | save_dir.parent.mkdir(parents=True, exist_ok=True) 14 | url = "https://gist.githubusercontent.com/tristandruyen/9e207a95c7d75ddf37525d353e00659c/raw/571fda718462de863e5a0171078c175420c7649a/calibration_data_v5_rc.txt" 15 | request.urlretrieve(url, save_dir) 16 | with open(save_dir) as fid: 17 | texts = fid.read() 18 | tokens = tokenizer.encode(texts, return_tensors="mlx")[0] 19 | 20 | # select random non-overlapping chunks 21 | tokens = tokens[: (tokens.size // sequence_length) * sequence_length] 22 | tokens = tokens.reshape(-1, sequence_length) 23 | segments = mx.random.permutation(tokens.shape[0]) 24 | if num_samples > 0: 25 | segments = segments[:num_samples] 26 | return tokens[segments] 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS.md: -------------------------------------------------------------------------------- 1 | # Individual Contributors 2 | 3 | If you wish to be acknowledged for your contributions, please list your name 4 | with a short description of your contribution(s) below. For example: 5 | 6 | - Jane Smith: Added the `foo` example. 7 | 8 | MLX LM was developed with contributions from the following individuals: 9 | 10 | - Shunta Saito: Added support for PLaMo models. 11 | - Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, and Allenai's `OLMoE`; Added support for the following training algorithms: `full-fine-tuning`; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`. 12 | - Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`. 13 | -------------------------------------------------------------------------------- /mlx_lm/examples/chat.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """ 4 | An example of a multi-turn chat with prompt caching. 5 | """ 6 | 7 | from mlx_lm import generate, load 8 | from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache 9 | 10 | model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") 11 | 12 | # Make the initial prompt cache for the model 13 | prompt_cache = make_prompt_cache(model) 14 | 15 | # User turn 16 | prompt = "Hi my name is ." 17 | messages = [{"role": "user", "content": prompt}] 18 | prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) 19 | 20 | # Assistant response 21 | response = generate( 22 | model, 23 | tokenizer, 24 | prompt=prompt, 25 | verbose=True, 26 | prompt_cache=prompt_cache, 27 | ) 28 | 29 | # User turn 30 | prompt = "What's my name?" 31 | messages = [{"role": "user", "content": prompt}] 32 | prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) 33 | 34 | # Assistant response 35 | response = generate( 36 | model, 37 | tokenizer, 38 | prompt=prompt, 39 | verbose=True, 40 | prompt_cache=prompt_cache, 41 | ) 42 | 43 | # Save the prompt cache to disk to reuse it at a later time 44 | save_prompt_cache("mistral_prompt.safetensors", prompt_cache) 45 | 46 | # Load the prompt cache from disk 47 | prompt_cache = load_prompt_cache("mistral_prompt.safetensors") 48 | -------------------------------------------------------------------------------- /mlx_lm/models/mistral3.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx.utils import tree_flatten, tree_unflatten 9 | 10 | from . import llama 11 | from .base import BaseModelArgs 12 | 13 | 14 | @dataclass 15 | class ModelArgs(BaseModelArgs): 16 | model_type: str 17 | text_config: dict 18 | 19 | def __post_init__(self): 20 | self.text_config["tie_word_embeddings"] = False 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, args: ModelArgs): 25 | super().__init__() 26 | self.args = args 27 | self.model_type = args.model_type 28 | self.language_model = llama.Model(llama.ModelArgs.from_dict(args.text_config)) 29 | 30 | def __call__( 31 | self, 32 | inputs: mx.array, 33 | cache=None, 34 | mask: Optional[mx.array] = None, 35 | input_embeddings: Optional[mx.array] = None, 36 | ): 37 | return self.language_model( 38 | inputs, cache=cache, mask=mask, input_embeddings=input_embeddings 39 | ) 40 | 41 | def sanitize(self, weights): 42 | weights = tree_unflatten(list(weights.items())) 43 | weights.pop("vision_tower", None) 44 | weights.pop("multi_modal_projector", None) 45 | return dict(tree_flatten(weights)) 46 | 47 | @property 48 | def layers(self): 49 | return self.language_model.model.layers 50 | -------------------------------------------------------------------------------- /mlx_lm/tuner/callbacks.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | try: 4 | import wandb 5 | except ImportError: 6 | wandb = None 7 | 8 | 9 | class TrainingCallback: 10 | 11 | def on_train_loss_report(self, train_info: dict): 12 | """Called to report training loss at specified intervals.""" 13 | pass 14 | 15 | def on_val_loss_report(self, val_info: dict): 16 | """Called to report validation loss at specified intervals or the beginning.""" 17 | pass 18 | 19 | 20 | class WandBCallback(TrainingCallback): 21 | def __init__( 22 | self, 23 | project_name: str, 24 | log_dir: str, 25 | config: dict, 26 | wrapped_callback: TrainingCallback = None, 27 | ): 28 | if wandb is None: 29 | raise ImportError( 30 | "wandb is not installed. Please install it to use WandBCallback." 31 | ) 32 | self.wrapped_callback = wrapped_callback 33 | wandb.init(project=project_name, dir=log_dir, config=config) 34 | 35 | def on_train_loss_report(self, train_info: dict): 36 | wandb.log(train_info, step=train_info.get("iteration")) 37 | if self.wrapped_callback: 38 | self.wrapped_callback.on_train_loss_report(train_info) 39 | 40 | def on_val_loss_report(self, val_info: dict): 41 | wandb.log(val_info, step=val_info.get("iteration")) 42 | if self.wrapped_callback: 43 | self.wrapped_callback.on_val_loss_report(val_info) 44 | -------------------------------------------------------------------------------- /mlx_lm/models/pixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx.utils import tree_flatten, tree_unflatten 9 | 10 | from . import llama 11 | from .base import BaseModelArgs 12 | 13 | 14 | @dataclass 15 | class ModelArgs(BaseModelArgs): 16 | model_type: str 17 | text_config: dict 18 | 19 | def __post_init__(self): 20 | self.text_config["tie_word_embeddings"] = False 21 | self.text_config["num_attention_heads"] = self.text_config.get( 22 | "num_attention_heads", 32 23 | ) 24 | 25 | 26 | class Model(nn.Module): 27 | def __init__(self, args: ModelArgs): 28 | super().__init__() 29 | self.args = args 30 | self.model_type = args.model_type 31 | self.language_model = llama.Model(llama.ModelArgs.from_dict(args.text_config)) 32 | 33 | def __call__( 34 | self, 35 | inputs: mx.array, 36 | cache=None, 37 | mask: Optional[mx.array] = None, 38 | input_embeddings: Optional[mx.array] = None, 39 | ): 40 | return self.language_model( 41 | inputs, cache=cache, mask=mask, input_embeddings=input_embeddings 42 | ) 43 | 44 | def sanitize(self, weights): 45 | weights = tree_unflatten(list(weights.items())) 46 | weights.pop("vision_tower", None) 47 | weights.pop("multi_modal_projector", None) 48 | return dict(tree_flatten(weights)) 49 | 50 | @property 51 | def layers(self): 52 | return self.language_model.model.layers 53 | -------------------------------------------------------------------------------- /tests/test_utils_load_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | 4 | import mlx.nn as nn 5 | 6 | from mlx_lm.models.qwen2 import Model as Qwen2Model 7 | from mlx_lm.utils import get_model_path, load_model 8 | 9 | HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" 10 | 11 | 12 | class TestLoadModelCustomGetClasses(unittest.TestCase): 13 | 14 | def test_load_model_with_custom_get_classes(self): 15 | class CustomQwenModel(nn.Module): 16 | def __init__(self, args): 17 | super().__init__() 18 | self.config = args 19 | self.custom_attribute = "This is a custom model" 20 | 21 | def load_weights(self, weights, **kwargs): 22 | self.qwenWeights = weights 23 | 24 | class CustomQwenConfig: 25 | @classmethod 26 | def from_dict(cls, config): 27 | instance = cls() 28 | for k, v in config.items(): 29 | setattr(instance, k, v) 30 | return instance 31 | 32 | def custom_get_classes(config): 33 | return CustomQwenModel, CustomQwenConfig 34 | 35 | model_path, _ = get_model_path(HF_MODEL_PATH) 36 | model, _ = load_model(model_path, get_model_classes=custom_get_classes) 37 | 38 | self.assertIsInstance(model, CustomQwenModel) 39 | self.assertTrue(hasattr(model, "custom_attribute")) 40 | self.assertEqual(model.custom_attribute, "This is a custom model") 41 | self.assertTrue(hasattr(model, "qwenWeights")) 42 | 43 | def test_load_model_with_default_get_classes(self): 44 | model_path, _ = get_model_path(HF_MODEL_PATH) 45 | model, _ = load_model(model_path) 46 | 47 | self.assertIsInstance(model, Qwen2Model) 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | apple: ml-explore/pr-approval@0.1.0 5 | 6 | jobs: 7 | linux_build_and_test: 8 | docker: 9 | - image: cimg/python:3.9 10 | 11 | steps: 12 | - checkout 13 | - run: 14 | name: Run style checks 15 | command: | 16 | pip install pre-commit 17 | pre-commit run --all 18 | if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi 19 | 20 | mlx_lm_build_and_test: 21 | macos: 22 | xcode: "15.2.0" 23 | resource_class: macos.m1.large.gen1 24 | steps: 25 | - checkout 26 | - run: 27 | name: Install dependencies 28 | command: | 29 | brew install python@3.9 30 | python3.9 -m venv env 31 | source env/bin/activate 32 | pip install --upgrade pip 33 | pip install unittest-xml-reporting 34 | pip install -e ".[test]" 35 | - run: 36 | name: Run Python tests 37 | command: | 38 | source env/bin/activate 39 | python -m xmlrunner discover -v tests -o test-results/ 40 | - store_test_results: 41 | path: test-results 42 | 43 | workflows: 44 | build_and_test: 45 | when: 46 | matches: 47 | pattern: "^(?!pull/)[-\\w]+$" 48 | value: << pipeline.git.branch >> 49 | jobs: 50 | - mlx_lm_build_and_test 51 | - linux_build_and_test 52 | 53 | prb: 54 | when: 55 | matches: 56 | pattern: "^pull/\\d+(/head)?$" 57 | value: << pipeline.git.branch >> 58 | jobs: 59 | - hold: 60 | type: approval 61 | - apple/authenticate: 62 | context: pr-approval 63 | - mlx_lm_build_and_test: 64 | requires: [ hold ] 65 | - linux_build_and_test: 66 | requires: [ hold ] 67 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from setuptools import setup 7 | 8 | package_dir = Path(__file__).parent / "mlx_lm" 9 | with open("requirements.txt") as fid: 10 | requirements = [l.strip() for l in fid.readlines()] 11 | 12 | sys.path.append(str(package_dir)) 13 | from _version import __version__ 14 | 15 | setup( 16 | name="mlx-lm", 17 | version=__version__, 18 | description="LLMs on Apple silicon with MLX and the Hugging Face Hub", 19 | long_description=open("README.md", encoding="utf-8").read(), 20 | long_description_content_type="text/markdown", 21 | readme="README.md", 22 | author_email="mlx@group.apple.com", 23 | author="MLX Contributors", 24 | url="https://github.com/ml-explore/mlx-lm", 25 | license="MIT", 26 | install_requires=requirements, 27 | packages=["mlx_lm", "mlx_lm.models", "mlx_lm.quant", "mlx_lm.tuner"], 28 | python_requires=">=3.8", 29 | extras_require={ 30 | "test": ["datasets"], 31 | "evaluate": ["lm-eval", "tqdm"], 32 | "quant": ["datasets", "tqdm"], 33 | }, 34 | entry_points={ 35 | "console_scripts": [ 36 | "mlx_lm.awq = mlx_lm.quant.awq:main", 37 | "mlx_lm.dwq = mlx_lm.quant.dwq:main", 38 | "mlx_lm.dynamic_quant = mlx_lm.quant.dynamic_quant:main", 39 | "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", 40 | "mlx_lm.chat = mlx_lm.chat:main", 41 | "mlx_lm.convert = mlx_lm.convert:main", 42 | "mlx_lm.evaluate = mlx_lm.evaluate:main", 43 | "mlx_lm.fuse = mlx_lm.fuse:main", 44 | "mlx_lm.generate = mlx_lm.generate:main", 45 | "mlx_lm.lora = mlx_lm.lora:main", 46 | "mlx_lm.server = mlx_lm.server:main", 47 | "mlx_lm.manage = mlx_lm.manage:main", 48 | "mlx_lm.upload = mlx_lm.upload:main", 49 | ] 50 | }, 51 | ) 52 | -------------------------------------------------------------------------------- /mlx_lm/examples/openai_tool_use.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | """ 3 | This is an example of tool use with mlx_lm and the OpenAI client. 4 | 5 | To run, first start the server: 6 | 7 | >>> mlx_lm.server 8 | 9 | Then run this script. 10 | """ 11 | from openai import OpenAI 12 | 13 | client = OpenAI(base_url="http://localhost:8080/v1", api_key="not-needed") 14 | 15 | model = "mlx-community/qwen3-4b-4bit-DWQ" 16 | messages = [{"role": "user", "content": "What's the weather in Boston?"}] 17 | 18 | tools = [ 19 | { 20 | "type": "function", 21 | "function": { 22 | "name": "get_current_weather", 23 | "description": "Get the current weather in a given location", 24 | "parameters": { 25 | "type": "object", 26 | "properties": { 27 | "location": { 28 | "type": "string", 29 | "description": "The city and state, e.g. San Francisco, CA", 30 | }, 31 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, 32 | }, 33 | "required": ["location"], 34 | }, 35 | }, 36 | } 37 | ] 38 | 39 | 40 | def get_current_weather(**kwargs): 41 | return "51 Farenheit, clear skies" 42 | 43 | 44 | functions = {"get_current_weather": get_current_weather} 45 | 46 | # The first query generates a tool call 47 | response = client.chat.completions.create( 48 | model=model, 49 | messages=messages, 50 | tools=tools, 51 | ) 52 | 53 | # Call the function 54 | function = response.choices[0].message.tool_calls[0].function 55 | tool_result = functions[function.name](**json.loads(function.arguments)) 56 | 57 | # Put the result of the function in the messages and generate the final 58 | # response: 59 | messages.append({"role": "tool", "name": function.name, "content": tool_result}) 60 | response = client.chat.completions.create( 61 | model=model, 62 | messages=messages, 63 | tools=tools, 64 | ) 65 | print(response.choices[0].message.content) 66 | -------------------------------------------------------------------------------- /tests/test_gguf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | from pathlib import Path 5 | from unittest.mock import MagicMock, patch 6 | 7 | import mlx.core as mx 8 | 9 | from mlx_lm.gguf import convert_to_gguf 10 | 11 | 12 | class TestConvertToGGUFWithoutMocks(unittest.TestCase): 13 | @classmethod 14 | def setUpClass(cls): 15 | cls.test_dir_fid = tempfile.TemporaryDirectory() 16 | cls.test_dir = cls.test_dir_fid.name 17 | cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json") 18 | with open(cls.tokenizer_file_path, "w") as f: 19 | f.write("{}") 20 | 21 | @classmethod 22 | def tearDownClass(cls): 23 | cls.test_dir_fid.cleanup() 24 | 25 | @patch("transformers.AutoTokenizer.from_pretrained") 26 | @patch("mlx.core.save_gguf") 27 | def test_convert_to_gguf( 28 | self, 29 | mock_save_gguf, 30 | mock_from_pretrained, 31 | ): 32 | mock_tokenizer = MagicMock() 33 | mock_tokenizer.vocab_size = 3 34 | mock_tokenizer.get_added_vocab.return_value = {} 35 | mock_tokenizer.get_vocab.return_value = {"": 0, "hello": 1, "world": 2} 36 | mock_tokenizer.all_special_tokens = [""] 37 | mock_tokenizer.all_special_ids = [0] 38 | mock_from_pretrained.return_value = mock_tokenizer 39 | 40 | model_path = Path(self.test_dir) 41 | weights = { 42 | "self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]), 43 | } 44 | config = { 45 | "num_attention_heads": 1, 46 | "num_hidden_layers": 1, 47 | "hidden_size": 768, 48 | "intermediate_size": 3072, 49 | "_name_or_path": "test-llama", 50 | } 51 | output_file_path = "/fake/output/path/gguf_model.gguf" 52 | 53 | convert_to_gguf(model_path, weights, config, output_file_path) 54 | called_args, _ = mock_save_gguf.call_args 55 | self.assertEqual(called_args[0], output_file_path) 56 | 57 | 58 | if __name__ == "__main__": 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /mlx_lm/models/gemma3.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx.utils import tree_flatten, tree_unflatten 9 | 10 | from . import gemma3_text 11 | from .base import BaseModelArgs 12 | 13 | 14 | @dataclass 15 | class ModelArgs(BaseModelArgs): 16 | model_type: str 17 | text_config: dict 18 | vocab_size: int = 262208 19 | 20 | def __post_init__(self): 21 | self.text_config["vocab_size"] = self.vocab_size 22 | self.text_config["num_attention_heads"] = self.text_config.get( 23 | "num_attention_heads", 8 24 | ) 25 | self.text_config["num_key_value_heads"] = self.text_config.get( 26 | "num_key_value_heads", 4 27 | ) 28 | 29 | 30 | class Model(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | self.args = args 34 | self.model_type = args.model_type 35 | self.language_model = gemma3_text.Model( 36 | gemma3_text.ModelArgs.from_dict(args.text_config) 37 | ) 38 | 39 | def __call__( 40 | self, 41 | inputs: mx.array, 42 | cache=None, 43 | mask: Optional[mx.array] = None, 44 | input_embeddings: Optional[mx.array] = None, 45 | ): 46 | return self.language_model( 47 | inputs, cache=cache, mask=mask, input_embeddings=input_embeddings 48 | ) 49 | 50 | def sanitize(self, weights): 51 | weights = tree_unflatten(list(weights.items())) 52 | weights.pop("vision_tower", None) 53 | weights.pop("multi_modal_projector", None) 54 | lm_weights = dict(tree_flatten(weights["language_model"])) 55 | lm_weights = self.language_model.sanitize(lm_weights) 56 | weights["language_model"] = tree_unflatten(list(lm_weights.items())) 57 | return dict(tree_flatten(weights)) 58 | 59 | @property 60 | def layers(self): 61 | return self.language_model.layers 62 | 63 | def make_cache(self): 64 | return self.language_model.make_cache() 65 | -------------------------------------------------------------------------------- /mlx_lm/examples/tool_use.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | import json 4 | 5 | from mlx_lm import generate, load 6 | from mlx_lm.models.cache import make_prompt_cache 7 | 8 | # Specify the checkpoint 9 | checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit" 10 | 11 | # Load the corresponding model and tokenizer 12 | model, tokenizer = load(path_or_hf_repo=checkpoint) 13 | 14 | 15 | # An example tool, make sure to include a docstring and type hints 16 | def multiply(a: float, b: float): 17 | """ 18 | A function that multiplies two numbers 19 | 20 | Args: 21 | a: The first number to multiply 22 | b: The second number to multiply 23 | """ 24 | return a * b 25 | 26 | 27 | tools = {"multiply": multiply} 28 | 29 | # Specify the prompt and conversation history 30 | prompt = "Multiply 12234585 and 48838483920." 31 | messages = [{"role": "user", "content": prompt}] 32 | 33 | prompt = tokenizer.apply_chat_template( 34 | messages, add_generation_prompt=True, tools=list(tools.values()) 35 | ) 36 | 37 | prompt_cache = make_prompt_cache(model) 38 | 39 | # Generate the initial tool call: 40 | response = generate( 41 | model=model, 42 | tokenizer=tokenizer, 43 | prompt=prompt, 44 | max_tokens=2048, 45 | verbose=True, 46 | prompt_cache=prompt_cache, 47 | ) 48 | 49 | # Parse the tool call: 50 | # (Note, the tool call format is model specific) 51 | tool_open = "" 52 | tool_close = "" 53 | start_tool = response.find(tool_open) + len(tool_open) 54 | end_tool = response.find(tool_close) 55 | tool_call = json.loads(response[start_tool:end_tool].strip()) 56 | tool_result = tools[tool_call["name"]](**tool_call["arguments"]) 57 | 58 | # Put the tool result in the prompt 59 | messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}] 60 | prompt = tokenizer.apply_chat_template( 61 | messages, 62 | add_generation_prompt=True, 63 | ) 64 | 65 | # Generate the final response: 66 | response = generate( 67 | model=model, 68 | tokenizer=tokenizer, 69 | prompt=prompt, 70 | max_tokens=2048, 71 | verbose=True, 72 | prompt_cache=prompt_cache, 73 | ) 74 | -------------------------------------------------------------------------------- /mlx_lm/examples/lora_config.yaml: -------------------------------------------------------------------------------- 1 | # The path to the local model directory or Hugging Face repo. 2 | model: "mlx-community/Llama-3.2-1B-Instruct" 3 | 4 | # Whether or not to train (boolean) 5 | train: true 6 | 7 | # The fine-tuning method: "lora", "dora", or "full". 8 | fine_tune_type: lora 9 | 10 | # The Optimizer with its possible inputs 11 | optimizer: adamw 12 | # optimizer_config: 13 | # adamw: 14 | # betas: [0.9, 0.98] 15 | # eps: 1e-6 16 | # weight_decay: 0.05 17 | # bias_correction: true 18 | 19 | # Directory with {train, valid, test}.jsonl files 20 | data: "mlx-community/WikiSQL" 21 | 22 | # The PRNG seed 23 | seed: 0 24 | 25 | # Number of layers to fine-tune 26 | num_layers: 16 27 | 28 | # Minibatch size. 29 | batch_size: 4 30 | 31 | # Iterations to train for. 32 | iters: 1000 33 | 34 | # Number of validation batches, -1 uses the entire validation set. 35 | val_batches: 25 36 | 37 | # Adam learning rate. 38 | learning_rate: 1e-5 39 | 40 | # Whether to report the logs to WandB 41 | # wand: "wandb-project" 42 | 43 | # Number of training steps between loss reporting. 44 | steps_per_report: 10 45 | 46 | # Number of training steps between validations. 47 | steps_per_eval: 200 48 | 49 | # Load path to resume training with the given adapter weights. 50 | resume_adapter_file: null 51 | 52 | # Save/load path for the trained adapter weights. 53 | adapter_path: "adapters" 54 | 55 | # Save the model every N iterations. 56 | save_every: 100 57 | 58 | # Evaluate on the test set after training 59 | test: false 60 | 61 | # Number of test set batches, -1 uses the entire test set. 62 | test_batches: 100 63 | 64 | # Maximum sequence length. 65 | max_seq_length: 2048 66 | 67 | # Use gradient checkpointing to reduce memory use. 68 | grad_checkpoint: false 69 | 70 | # LoRA parameters can only be specified in a config file 71 | lora_parameters: 72 | # The layer keys to apply LoRA to. 73 | # These will be applied for the last lora_layers 74 | keys: ["self_attn.q_proj", "self_attn.v_proj"] 75 | rank: 8 76 | scale: 20.0 77 | dropout: 0.0 78 | 79 | # Schedule can only be specified in a config file, uncomment to use. 80 | #lr_schedule: 81 | # name: cosine_decay 82 | # warmup: 100 # 0 for no warmup 83 | # warmup_init: 1e-7 # 0 if not specified 84 | # arguments: [1e-5, 1000, 1e-7] # passed to scheduler 85 | 86 | #hf_dataset: 87 | # path: "billsum" 88 | # train_split: "train[:1000]" 89 | # valid_split: "train[-100:]" 90 | # prompt_feature: "text" 91 | # completion_feature: "summary" 92 | 93 | -------------------------------------------------------------------------------- /mlx_lm/models/smollm3.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from . import llama 10 | 11 | 12 | @dataclass 13 | class ModelArgs(llama.ModelArgs): 14 | model_type: str 15 | no_rope_layer_interval: int = 4 16 | no_rope_layers: Optional[list[int]] = None 17 | 18 | def __post_init__(self): 19 | super().__post_init__() 20 | if self.no_rope_layers is None: 21 | self.no_rope_layers = [ 22 | int((i + 1) % self.no_rope_layer_interval != 0) 23 | for i in range(self.num_hidden_layers) 24 | ] 25 | elif len(self.no_rope_layers) != self.num_hidden_layers: 26 | raise ValueError("`no_rope_layers` length mismatch") 27 | 28 | 29 | class NoPE(nn.Module): 30 | """No-op used to disable rotary embeddings in selected layers.""" 31 | 32 | def __call__(self, x, offset: int = 0): 33 | return x 34 | 35 | 36 | class Model(nn.Module): 37 | """Wrapper around Llama that respects NoPE layers in SmolLM-3.""" 38 | 39 | def __init__(self, args: ModelArgs): 40 | super().__init__() 41 | self.args = args 42 | self.model_type: str = args.model_type 43 | 44 | self.model = llama.LlamaModel(args) 45 | if not args.tie_word_embeddings: 46 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 47 | 48 | for idx, use_rope in enumerate(args.no_rope_layers): 49 | if not use_rope: 50 | self.model.layers[idx].self_attn.rope = NoPE() 51 | 52 | def __call__( 53 | self, 54 | inputs: mx.array, 55 | mask: Optional[mx.array] = None, 56 | cache=None, 57 | input_embeddings: Optional[mx.array] = None, 58 | ): 59 | out = self.model(inputs, mask, cache, input_embeddings) 60 | if self.args.tie_word_embeddings: 61 | out = self.model.embed_tokens.as_linear(out) 62 | else: 63 | out = self.lm_head(out) 64 | return out 65 | 66 | @property 67 | def layers(self): 68 | return self.model.layers 69 | 70 | def sanitize(self, weights: dict): 71 | weights = { 72 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 73 | } 74 | if self.args.tie_word_embeddings: 75 | weights.pop("lm_head.weight", None) 76 | return weights 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Vim 10 | *.swp 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # IDE files 135 | .idea/ 136 | .vscode/ 137 | 138 | # .DS_Store files 139 | .DS_Store 140 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX LM 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. Every PR should have passing tests and at least one review. 11 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 12 | This should install hooks for running `black` and `clang-format` to ensure 13 | consistent style for C++ and python code. 14 | 15 | You can also run the formatters manually as follows on individual files: 16 | 17 | ```bash 18 | clang-format -i file.cpp 19 | ``` 20 | 21 | ```bash 22 | black file.py 23 | ``` 24 | 25 | or, 26 | 27 | ```bash 28 | # single file 29 | pre-commit run --files file1.py 30 | 31 | # specific files 32 | pre-commit run --files file1.py file2.py 33 | ``` 34 | 35 | or run `pre-commit run --all-files` to check all files in the repo. 36 | 37 | ## Issues 38 | 39 | We use GitHub issues to track public bugs. Please ensure your description is 40 | clear and has sufficient instructions to be able to reproduce the issue. 41 | 42 | ## License 43 | 44 | By contributing to mlx-lm, you agree that your contributions will be licensed 45 | under the LICENSE file in the root directory of this source tree. 46 | 47 | ## Adding New Models 48 | 49 | Below are some tips to port LLMs available on Hugging Face to MLX. 50 | 51 | From this directory, do an editable install: 52 | 53 | ```shell 54 | pip install -e . 55 | ``` 56 | 57 | Then check if the model has weights in the 58 | [safetensors](https://huggingface.co/docs/safetensors/index) format. If not 59 | [follow instructions](https://huggingface.co/spaces/safetensors/convert) to 60 | convert it. 61 | 62 | After that, add the model file to the 63 | [`mlx_lm/models`](https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models) 64 | directory. You can see other examples there. We recommend starting from a model 65 | that is similar to the model you are porting. 66 | 67 | Make sure the name of the new model file is the same as the `model_type` in the 68 | `config.json`, for example 69 | [starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17). 70 | 71 | To determine the model layer names, we suggest either: 72 | 73 | - Refer to the Transformers implementation if you are familiar with the 74 | codebase. 75 | - Load the model weights and check the weight names which will tell you about 76 | the model structure. 77 | - Look at the names of the weights by inspecting `model.safetensors.index.json` 78 | in the Hugging Face repo. 79 | 80 | To add LoRA support edit 81 | [`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tuner/utils.py#L27-L60) 82 | 83 | Finally, add a test for the new modle type to the [model 84 | tests](https://github.com/ml-explore/mlx-lm/blob/main/tests/test_models.py). 85 | 86 | You can run the tests with: 87 | 88 | ```shell 89 | python -m unittest discover tests/ 90 | ``` 91 | -------------------------------------------------------------------------------- /tests/test_tuner_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import sys 4 | import unittest 5 | from io import StringIO 6 | from unittest.mock import MagicMock 7 | 8 | import mlx.nn as nn 9 | 10 | from mlx_lm.tuner.lora import LoRALinear 11 | from mlx_lm.tuner.utils import print_trainable_parameters 12 | 13 | 14 | class TestTunerUtils(unittest.TestCase): 15 | def setUp(self): 16 | self.capturedOutput = StringIO() 17 | sys.stdout = self.capturedOutput 18 | 19 | def tearDown(self): 20 | sys.stdout = sys.__stdout__ 21 | 22 | def test_quantized_print_trainable_parameters(self): 23 | model = MagicMock() 24 | quantized_linear = MagicMock(spec=nn.QuantizedLinear) 25 | quantized_linear.weight = MagicMock(size=1e6) 26 | quantized_linear.bits = 8 27 | lora_linear = MagicMock(spec=LoRALinear) 28 | lora_linear.weight = MagicMock(size=2e6) 29 | lora_linear.parameters.return_value = [lora_linear.weight] 30 | 31 | linear = MagicMock(spec=nn.Linear) 32 | linear.weight = MagicMock(size=3e6) 33 | linear.parameters.return_value = [linear.weight] 34 | 35 | model.leaf_modules.return_value = { 36 | "quantized_linear": quantized_linear, 37 | "lora_linear": lora_linear, 38 | "linear": linear, 39 | } 40 | 41 | model.trainable_parameters.return_value = { 42 | "layer1.weight": MagicMock(size=1e6), 43 | "layer3.weight": MagicMock(size=2e6), 44 | } 45 | expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n" 46 | print_trainable_parameters(model) 47 | self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits) 48 | self.capturedOutput.truncate(0) 49 | self.capturedOutput.seek(0) 50 | 51 | quantized_linear.weight = MagicMock(size=1e6) 52 | quantized_linear.bits = 4 53 | expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n" 54 | print_trainable_parameters(model) 55 | self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits) 56 | self.capturedOutput.truncate(0) 57 | self.capturedOutput.seek(0) 58 | 59 | def test_print_trainable_parameters(self): 60 | model = MagicMock() 61 | linear1 = MagicMock(spec=nn.Linear) 62 | linear1.weight = MagicMock(size=1e6) 63 | linear1.parameters.return_value = [linear1.weight] 64 | linear2 = MagicMock(spec=nn.Linear) 65 | linear2.weight = MagicMock(size=2e6) 66 | linear2.parameters.return_value = [linear2.weight] 67 | lora_linear = MagicMock(spec=LoRALinear) 68 | lora_linear.weight = MagicMock(size=3e6) 69 | lora_linear.parameters.return_value = [lora_linear.weight] 70 | model.leaf_modules.return_value = { 71 | "linear1": linear1, 72 | "linear2": linear2, 73 | "lora_linear": lora_linear, 74 | } 75 | 76 | model.trainable_parameters.return_value = { 77 | "layer1.weight": MagicMock(size=1e6), 78 | "layer3.weight": MagicMock(size=2e6), 79 | } 80 | expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n" 81 | print_trainable_parameters(model) 82 | self.assertEqual(self.capturedOutput.getvalue(), expected_output) 83 | 84 | 85 | if __name__ == "__main__": 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import os 4 | import tempfile 5 | import unittest 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | from mlx.utils import tree_flatten 10 | 11 | from mlx_lm import convert, utils 12 | 13 | HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" 14 | 15 | 16 | class TestUtils(unittest.TestCase): 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | cls.test_dir_fid = tempfile.TemporaryDirectory() 21 | cls.test_dir = cls.test_dir_fid.name 22 | if not os.path.isdir(cls.test_dir): 23 | os.mkdir(cls.test_dir_fid.name) 24 | 25 | @classmethod 26 | def tearDownClass(cls): 27 | cls.test_dir_fid.cleanup() 28 | 29 | def test_load(self): 30 | model, _ = utils.load(HF_MODEL_PATH) 31 | 32 | model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True) 33 | 34 | mx.eval(model_lazy.parameters()) 35 | 36 | p1 = model.layers[0].mlp.up_proj.weight 37 | p2 = model_lazy.layers[0].mlp.up_proj.weight 38 | self.assertTrue(mx.allclose(p1, p2)) 39 | 40 | def test_make_shards(self): 41 | from mlx_lm.models import llama 42 | 43 | args = llama.ModelArgs( 44 | model_type="llama", 45 | hidden_size=2048, 46 | num_hidden_layers=32, 47 | intermediate_size=4096, 48 | num_attention_heads=32, 49 | rms_norm_eps=1e-5, 50 | vocab_size=30_000, 51 | ) 52 | model = llama.Model(args) 53 | weights = tree_flatten(model.parameters()) 54 | gb = sum(p.nbytes for _, p in weights) // 2**30 55 | shards = utils.make_shards(dict(weights), 1) 56 | self.assertTrue(gb <= len(shards) <= gb + 1) 57 | 58 | def test_quantize(self): 59 | from mlx_lm.models import llama 60 | 61 | args = llama.ModelArgs( 62 | model_type="llama", 63 | hidden_size=1024, 64 | num_hidden_layers=4, 65 | intermediate_size=2048, 66 | num_attention_heads=4, 67 | rms_norm_eps=1e-5, 68 | vocab_size=10_000, 69 | ) 70 | model = llama.Model(args) 71 | model, config = utils.quantize_model(model, {}, 64, 4) 72 | weights = dict(tree_flatten(model.parameters())) 73 | self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights) 74 | self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights) 75 | self.assertEqual(config["quantization"]["group_size"], 64) 76 | self.assertEqual(config["quantization"]["bits"], 4) 77 | 78 | def test_convert(self): 79 | mlx_path = os.path.join(self.test_dir, "mlx_model") 80 | 81 | convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=False) 82 | model, _ = utils.load(mlx_path) 83 | self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear)) 84 | self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear)) 85 | 86 | # Check model weights have right type 87 | mlx_path = os.path.join(self.test_dir, "mlx_model_bf16") 88 | convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16") 89 | model, _ = utils.load(mlx_path) 90 | 91 | self.assertEqual(model.layers[0].mlp.up_proj.scales.dtype, mx.bfloat16) 92 | self.assertEqual(model.layers[-1].mlp.up_proj.scales.dtype, mx.bfloat16) 93 | 94 | 95 | if __name__ == "__main__": 96 | unittest.main() 97 | -------------------------------------------------------------------------------- /mlx_lm/fuse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from mlx.utils import tree_flatten, tree_unflatten 5 | 6 | from .gguf import convert_to_gguf 7 | from .tuner.utils import dequantize, load_adapters 8 | from .utils import ( 9 | fetch_from_hub, 10 | get_model_path, 11 | save, 12 | upload_to_hub, 13 | ) 14 | 15 | 16 | def parse_arguments() -> argparse.Namespace: 17 | parser = argparse.ArgumentParser( 18 | description="Fuse fine-tuned adapters into the base model." 19 | ) 20 | parser.add_argument( 21 | "--model", 22 | default="mlx_model", 23 | help="The path to the local model directory or Hugging Face repo.", 24 | ) 25 | parser.add_argument( 26 | "--save-path", 27 | default="fused_model", 28 | help="The path to save the fused model.", 29 | ) 30 | parser.add_argument( 31 | "--adapter-path", 32 | type=str, 33 | default="adapters", 34 | help="Path to the trained adapter weights and config.", 35 | ) 36 | parser.add_argument( 37 | "--upload-repo", 38 | help="The Hugging Face repo to upload the model to.", 39 | type=str, 40 | default=None, 41 | ) 42 | parser.add_argument( 43 | "--de-quantize", 44 | help="Generate a de-quantized model.", 45 | action="store_true", 46 | ) 47 | parser.add_argument( 48 | "--export-gguf", 49 | help="Export model weights in GGUF format.", 50 | action="store_true", 51 | ) 52 | parser.add_argument( 53 | "--gguf-path", 54 | help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.", 55 | default="ggml-model-f16.gguf", 56 | type=str, 57 | ) 58 | return parser.parse_args() 59 | 60 | 61 | def main() -> None: 62 | print("Loading pretrained model") 63 | args = parse_arguments() 64 | 65 | model_path, hf_path = get_model_path(args.model) 66 | model, config, tokenizer = fetch_from_hub(model_path) 67 | 68 | model.freeze() 69 | model = load_adapters(model, args.adapter_path) 70 | 71 | fused_linears = [ 72 | (n, m.fuse(de_quantize=args.de_quantize)) 73 | for n, m in model.named_modules() 74 | if hasattr(m, "fuse") 75 | ] 76 | 77 | if fused_linears: 78 | model.update_modules(tree_unflatten(fused_linears)) 79 | 80 | if args.de_quantize: 81 | print("De-quantizing model") 82 | model = dequantize(model) 83 | config.pop("quantization", None) 84 | 85 | save_path = Path(args.save_path) 86 | save( 87 | save_path, 88 | model_path, 89 | model, 90 | tokenizer, 91 | config, 92 | hf_repo=hf_path, 93 | donate_model=False, 94 | ) 95 | 96 | if args.export_gguf: 97 | model_type = config["model_type"] 98 | if model_type not in ["llama", "mixtral", "mistral"]: 99 | raise ValueError( 100 | f"Model type {model_type} not supported for GGUF conversion." 101 | ) 102 | weights = dict(tree_flatten(model.parameters())) 103 | convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path)) 104 | 105 | if args.upload_repo is not None: 106 | if hf_path is None: 107 | raise ValueError( 108 | "Must provide original Hugging Face repo to upload local model." 109 | ) 110 | upload_to_hub(args.save_path, args.upload_repo) 111 | 112 | 113 | if __name__ == "__main__": 114 | print( 115 | "Calling `python -m mlx_lm.fuse...` directly is deprecated." 116 | " Use `mlx_lm.fuse...` or `python -m mlx_lm fuse ...` instead." 117 | ) 118 | main() 119 | -------------------------------------------------------------------------------- /mlx_lm/models/kimi_vl.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs 10 | from .deepseek_v3 import DeepseekV3Model 11 | 12 | 13 | @dataclass 14 | class TextArgs(BaseModelArgs): 15 | vocab_size: int = 102400 16 | hidden_size: int = 4096 17 | intermediate_size: int = 11008 18 | moe_intermediate_size: int = 1407 19 | num_hidden_layers: int = 30 20 | num_attention_heads: int = 32 21 | num_key_value_heads: int = 32 22 | n_shared_experts: Optional[int] = None 23 | n_routed_experts: Optional[int] = None 24 | routed_scaling_factor: float = 1.0 25 | kv_lora_rank: int = 512 26 | q_lora_rank: int = 1536 27 | qk_rope_head_dim: int = 64 28 | v_head_dim: int = 128 29 | qk_nope_head_dim: int = 128 30 | topk_method: str = "noaux_tc" 31 | scoring_func: str = "sigmoid" 32 | norm_topk_prob: bool = True 33 | n_group: Optional[int] = None 34 | topk_group: Optional[int] = None 35 | num_experts_per_tok: Optional[int] = None 36 | moe_layer_freq: int = 1 37 | first_k_dense_replace: int = 0 38 | max_position_embeddings: int = 2048 39 | rms_norm_eps: float = 1e-6 40 | rope_theta: float = 10000.0 41 | rope_scaling: Dict = None 42 | attention_bias: bool = False 43 | 44 | 45 | @dataclass 46 | class ModelArgs(BaseModelArgs): 47 | text_config: Union[TextArgs, dict] 48 | model_type: str 49 | 50 | def __post_init__(self): 51 | self.text_config = TextArgs.from_dict(self.text_config) 52 | 53 | 54 | class LanguageModel(nn.Module): 55 | def __init__(self, config: TextArgs): 56 | super().__init__() 57 | self.args = config 58 | self.model = DeepseekV3Model(config) 59 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 60 | 61 | def __call__( 62 | self, 63 | inputs: mx.array, 64 | cache: Optional[Any] = None, 65 | mask: Optional[mx.array] = None, 66 | ): 67 | out = self.model(inputs, cache, mask) 68 | return self.lm_head(out) 69 | 70 | 71 | class Model(nn.Module): 72 | def __init__(self, config: ModelArgs): 73 | super().__init__() 74 | self.args = config 75 | self.model_type = config.model_type 76 | self.language_model = LanguageModel(config.text_config) 77 | 78 | def __call__( 79 | self, 80 | inputs: mx.array, 81 | cache: Optional[Any] = None, 82 | mask: Optional[mx.array] = None, 83 | ): 84 | return self.language_model(inputs, cache, mask) 85 | 86 | def sanitize(self, weights): 87 | def keep(key): 88 | return ( 89 | "vision_tower" not in key 90 | and "rotary_emb" not in key 91 | and "multi_modal_projector" not in key 92 | ) 93 | 94 | weights = {k: v for k, v in weights.items() if keep(k)} 95 | # Stack experts 96 | for l in range(self.args.text_config.num_hidden_layers): 97 | prefix = f"language_model.model.layers.{l}" 98 | for m in [("gate_proj"), ("down_proj"), ("up_proj")]: 99 | for k in ["weight", "scales", "biases"]: 100 | if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: 101 | to_join = [ 102 | weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") 103 | for e in range(self.args.text_config.n_routed_experts) 104 | ] 105 | weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) 106 | 107 | return weights 108 | 109 | @property 110 | def layers(self): 111 | return self.language_model.model.layers 112 | 113 | @property 114 | def cast_predicate(self): 115 | def predicate(k): 116 | return "e_score_correction_bias" not in k 117 | 118 | return predicate 119 | -------------------------------------------------------------------------------- /mlx_lm/examples/pipeline_generate.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """ 4 | Run with: 5 | 6 | ``` 7 | mlx.launch \ 8 | --hostfile /path/to/hosts.json \ 9 | /path/to/pipeline_generate.py \ 10 | --prompt "hello world" 11 | ``` 12 | 13 | Make sure you can run MLX over MPI on two hosts. For more information see the 14 | documentation: 15 | 16 | https://ml-explore.github.io/mlx/build/html/usage/distributed.html). 17 | """ 18 | 19 | import argparse 20 | import json 21 | import resource 22 | from pathlib import Path 23 | 24 | import mlx.core as mx 25 | from huggingface_hub import snapshot_download 26 | from mlx.utils import tree_flatten 27 | 28 | from mlx_lm import load, stream_generate 29 | from mlx_lm.utils import load_model, load_tokenizer 30 | 31 | # Needed for 8 bit model 32 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) 33 | 34 | 35 | def download(repo: str, allow_patterns: list[str]) -> Path: 36 | return Path( 37 | snapshot_download( 38 | repo, 39 | allow_patterns=allow_patterns, 40 | ) 41 | ) 42 | 43 | 44 | def shard_and_load(repo): 45 | # Get model path with everything but weight safetensors 46 | model_path = download( 47 | args.model, 48 | allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], 49 | ) 50 | 51 | # Lazy load and shard model to figure out 52 | # which weights we need 53 | model, _ = load_model(model_path, lazy=True, strict=False) 54 | 55 | group = mx.distributed.init() 56 | rank = group.rank() 57 | model.model.pipeline(group) 58 | 59 | # Figure out which files we need for the local shard 60 | with open(model_path / "model.safetensors.index.json", "r") as fid: 61 | weight_index = json.load(fid)["weight_map"] 62 | 63 | local_files = set() 64 | for k, _ in tree_flatten(model.parameters()): 65 | local_files.add(weight_index[k]) 66 | 67 | # Download weights for local shard 68 | download(args.model, allow_patterns=local_files) 69 | 70 | # Load and shard the model, and load the weights 71 | tokenizer = load_tokenizer(model_path) 72 | model, _ = load_model(model_path, lazy=True, strict=False) 73 | model.model.pipeline(group) 74 | mx.eval(model.parameters()) 75 | 76 | # Synchronize processes before generation to avoid timeout if downloading 77 | # model for the first time. 78 | mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) 79 | return model, tokenizer 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser(description="LLM pipelined inference example") 84 | parser.add_argument( 85 | "--model", 86 | default="mlx-community/DeepSeek-R1-3bit", 87 | help="HF repo or path to local model.", 88 | ) 89 | parser.add_argument( 90 | "--prompt", 91 | "-p", 92 | default="Write a quicksort in C++.", 93 | help="Message to be processed by the model ('-' reads from stdin)", 94 | ) 95 | parser.add_argument( 96 | "--max-tokens", 97 | "-m", 98 | type=int, 99 | default=256, 100 | help="Maximum number of tokens to generate", 101 | ) 102 | args = parser.parse_args() 103 | 104 | group = mx.distributed.init() 105 | rank = group.rank() 106 | 107 | def rprint(*args, **kwargs): 108 | if rank == 0: 109 | print(*args, **kwargs) 110 | 111 | model, tokenizer = shard_and_load(args.model) 112 | 113 | messages = [{"role": "user", "content": args.prompt}] 114 | prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) 115 | 116 | for response in stream_generate( 117 | model, tokenizer, prompt, max_tokens=args.max_tokens 118 | ): 119 | rprint(response.text, end="", flush=True) 120 | 121 | rprint() 122 | rprint("=" * 10) 123 | rprint( 124 | f"Prompt: {response.prompt_tokens} tokens, " 125 | f"{response.prompt_tps:.3f} tokens-per-sec" 126 | ) 127 | rprint( 128 | f"Generation: {response.generation_tokens} tokens, " 129 | f"{response.generation_tps:.3f} tokens-per-sec" 130 | ) 131 | rprint(f"Peak memory: {response.peak_memory:.3f} GB") 132 | -------------------------------------------------------------------------------- /mlx_lm/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import inspect 4 | from dataclasses import dataclass 5 | from typing import Any, Optional 6 | 7 | import mlx.core as mx 8 | from mlx.utils import tree_map 9 | 10 | from .cache import QuantizedKVCache 11 | 12 | 13 | @dataclass 14 | class BaseModelArgs: 15 | @classmethod 16 | def from_dict(cls, params): 17 | return cls( 18 | **{ 19 | k: v 20 | for k, v in params.items() 21 | if k in inspect.signature(cls).parameters 22 | } 23 | ) 24 | 25 | 26 | def create_causal_mask( 27 | N: int, 28 | offset: int = 0, 29 | window_size: Optional[int] = None, 30 | lengths: Optional[mx.array] = None, 31 | ): 32 | rinds = mx.arange(offset + N) 33 | linds = mx.arange(offset, offset + N) if offset else rinds 34 | linds = linds[:, None] 35 | rinds = rinds[None] 36 | mask = linds >= rinds 37 | if window_size is not None: 38 | mask = mask & (linds <= rinds + window_size) 39 | if lengths is not None: 40 | lengths = lengths[:, None, None, None] 41 | mask = mask & (rinds < lengths) 42 | return mask 43 | 44 | 45 | def create_attention_mask( 46 | h: mx.array, cache: Optional[Any] = None, return_array: bool = False 47 | ): 48 | T = h.shape[1] 49 | if T > 1: 50 | offset = 0 51 | window_size = None 52 | if cache is not None and cache[0] is not None: 53 | c = cache[0] 54 | offset = c.offset 55 | if hasattr(c, "max_size"): 56 | window_size = c.max_size 57 | offset = min(window_size, offset) 58 | return_array = return_array or offset + T > window_size 59 | if return_array: 60 | return create_causal_mask(T, offset, window_size=window_size) 61 | else: 62 | return "causal" 63 | else: 64 | mask = None 65 | return mask 66 | 67 | 68 | def quantized_scaled_dot_product_attention( 69 | queries: mx.array, 70 | q_keys: tuple[mx.array, mx.array, mx.array], 71 | q_values: tuple[mx.array, mx.array, mx.array], 72 | scale: float, 73 | mask: Optional[mx.array], 74 | group_size: int = 64, 75 | bits: int = 8, 76 | ) -> mx.array: 77 | B, n_q_heads, L, D = queries.shape 78 | n_kv_heads = q_keys[0].shape[-3] 79 | n_repeats = n_q_heads // n_kv_heads 80 | 81 | queries *= scale 82 | 83 | if n_repeats > 1: 84 | queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) 85 | q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) 86 | q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) 87 | 88 | scores = mx.quantized_matmul( 89 | queries, *q_keys, transpose=True, group_size=group_size, bits=bits 90 | ) 91 | if mask is not None: 92 | if isinstance(mask, str): 93 | qL, kL = scores.shape[-2:] 94 | q_indices = mx.arange(kL - qL, kL) 95 | k_indices = mx.arange(kL) 96 | mask = q_indices[:, None] >= k_indices[None] 97 | if mask.dtype == mx.bool_: 98 | scores = mx.where(mask, scores, mx.finfo(scores.dtype).min) 99 | else: 100 | scores += mask 101 | scores = mx.softmax(scores, axis=-1, precise=True) 102 | out = mx.quantized_matmul( 103 | scores, *q_values, transpose=False, group_size=group_size, bits=bits 104 | ) 105 | 106 | if n_repeats > 1: 107 | out = mx.reshape(out, (B, n_q_heads, L, D)) 108 | 109 | return out 110 | 111 | 112 | def scaled_dot_product_attention( 113 | queries, 114 | keys, 115 | values, 116 | cache, 117 | scale: float, 118 | mask: Optional[mx.array], 119 | ) -> mx.array: 120 | if isinstance(cache, QuantizedKVCache): 121 | return quantized_scaled_dot_product_attention( 122 | queries, 123 | keys, 124 | values, 125 | scale=scale, 126 | mask=mask, 127 | group_size=cache.group_size, 128 | bits=cache.bits, 129 | ) 130 | else: 131 | return mx.fast.scaled_dot_product_attention( 132 | queries, keys, values, scale=scale, mask=mask 133 | ) 134 | -------------------------------------------------------------------------------- /mlx_lm/chat.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | 5 | import mlx.core as mx 6 | 7 | from .generate import stream_generate 8 | from .models.cache import make_prompt_cache 9 | from .sample_utils import make_sampler 10 | from .utils import load 11 | 12 | DEFAULT_TEMP = 0.0 13 | DEFAULT_TOP_P = 1.0 14 | DEFAULT_XTC_PROBABILITY = 0.0 15 | DEFAULT_XTC_THRESHOLD = 0.0 16 | DEFAULT_SEED = None 17 | DEFAULT_MAX_TOKENS = 256 18 | DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" 19 | 20 | 21 | def setup_arg_parser(): 22 | """Set up and return the argument parser.""" 23 | parser = argparse.ArgumentParser(description="Chat with an LLM") 24 | parser.add_argument( 25 | "--model", 26 | type=str, 27 | help="The path to the local model directory or Hugging Face repo.", 28 | default=DEFAULT_MODEL, 29 | ) 30 | parser.add_argument( 31 | "--adapter-path", 32 | type=str, 33 | help="Optional path for the trained adapter weights and config.", 34 | ) 35 | parser.add_argument( 36 | "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" 37 | ) 38 | parser.add_argument( 39 | "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" 40 | ) 41 | parser.add_argument( 42 | "--xtc-probability", 43 | type=float, 44 | default=DEFAULT_XTC_PROBABILITY, 45 | help="Probability of XTC sampling to happen each next token", 46 | ) 47 | parser.add_argument( 48 | "--xtc-threshold", 49 | type=float, 50 | default=0.0, 51 | help="Thresold the probs of each next token candidate to be sampled by XTC", 52 | ) 53 | parser.add_argument( 54 | "--seed", 55 | type=int, 56 | default=DEFAULT_SEED, 57 | help="PRNG seed", 58 | ) 59 | parser.add_argument( 60 | "--max-kv-size", 61 | type=int, 62 | help="Set the maximum key-value cache size", 63 | default=None, 64 | ) 65 | parser.add_argument( 66 | "--max-tokens", 67 | "-m", 68 | type=int, 69 | default=DEFAULT_MAX_TOKENS, 70 | help="Maximum number of tokens to generate", 71 | ) 72 | return parser 73 | 74 | 75 | def main(): 76 | parser = setup_arg_parser() 77 | args = parser.parse_args() 78 | 79 | if args.seed is not None: 80 | mx.random.seed(args.seed) 81 | 82 | model, tokenizer = load( 83 | args.model, 84 | adapter_path=args.adapter_path, 85 | tokenizer_config={"trust_remote_code": True}, 86 | ) 87 | 88 | def print_help(): 89 | print("The command list:") 90 | print("- 'q' to exit") 91 | print("- 'r' to reset the chat") 92 | print("- 'h' to display these commands") 93 | 94 | print(f"[INFO] Starting chat session with {args.model}.") 95 | print_help() 96 | prompt_cache = make_prompt_cache(model, args.max_kv_size) 97 | while True: 98 | query = input(">> ") 99 | if query == "q": 100 | break 101 | if query == "r": 102 | prompt_cache = make_prompt_cache(model, args.max_kv_size) 103 | continue 104 | if query == "h": 105 | print_help() 106 | continue 107 | messages = [{"role": "user", "content": query}] 108 | prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) 109 | for response in stream_generate( 110 | model, 111 | tokenizer, 112 | prompt, 113 | max_tokens=args.max_tokens, 114 | sampler=make_sampler( 115 | args.temp, 116 | args.top_p, 117 | xtc_threshold=args.xtc_threshold, 118 | xtc_probability=args.xtc_probability, 119 | xtc_special_tokens=( 120 | tokenizer.encode("\n") + list(tokenizer.eos_token_ids) 121 | ), 122 | ), 123 | prompt_cache=prompt_cache, 124 | ): 125 | print(response.text, flush=True, end="") 126 | print() 127 | 128 | 129 | if __name__ == "__main__": 130 | print( 131 | "Calling `python -m mlx_lm.chat...` directly is deprecated." 132 | " Use `mlx_lm.chat...` or `python -m mlx_lm chat ...` instead." 133 | ) 134 | main() 135 | -------------------------------------------------------------------------------- /tests/test_datsets.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import json 4 | import os 5 | import tempfile 6 | import types 7 | import unittest 8 | 9 | from transformers import AutoTokenizer 10 | 11 | from mlx_lm.tuner import datasets 12 | 13 | HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" 14 | 15 | 16 | class TestDatasets(unittest.TestCase): 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | cls.test_dir_fid = tempfile.TemporaryDirectory() 21 | cls.test_dir = cls.test_dir_fid.name 22 | if not os.path.isdir(cls.test_dir): 23 | os.mkdir(cls.test_dir_fid.name) 24 | 25 | @classmethod 26 | def tearDownClass(cls): 27 | cls.test_dir_fid.cleanup() 28 | 29 | def save_data(self, data): 30 | for ds in ["train", "valid"]: 31 | with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid: 32 | for l in data: 33 | json.dump(l, fid) 34 | fid.write("\n") 35 | 36 | def test_text(self): 37 | data = {"text": "This is an example for the model."} 38 | self.save_data(4 * [data]) 39 | args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) 40 | tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) 41 | train, valid, test = datasets.load_dataset(args, tokenizer) 42 | self.assertEqual(len(train), 4) 43 | self.assertEqual(len(valid), 4) 44 | self.assertEqual(len(test), 0) 45 | self.assertTrue(len(train[0]) > 0) 46 | self.assertTrue(len(valid[0]) > 0) 47 | self.assertTrue(isinstance(train, datasets.TextDataset)) 48 | 49 | def test_completions(self): 50 | data = {"prompt": "What is the capital of France?", "completion": "Paris."} 51 | self.save_data(4 * [data]) 52 | args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) 53 | tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) 54 | train, valid, test = datasets.load_dataset(args, tokenizer) 55 | self.assertEqual(len(train), 4) 56 | self.assertEqual(len(valid), 4) 57 | self.assertEqual(len(test), 0) 58 | self.assertTrue(len(train[0]) > 0) 59 | self.assertTrue(len(valid[0]) > 0) 60 | self.assertTrue(isinstance(train, datasets.CompletionsDataset)) 61 | 62 | def test_chat(self): 63 | data = { 64 | "messages": [ 65 | {"role": "system", "content": "You are a helpful assistant."}, 66 | {"role": "user", "content": "Hello."}, 67 | {"role": "assistant", "content": "How can I assistant you today."}, 68 | ] 69 | } 70 | self.save_data(4 * [data]) 71 | args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) 72 | tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) 73 | train, valid, test = datasets.load_dataset(args, tokenizer) 74 | self.assertEqual(len(train), 4) 75 | self.assertEqual(len(valid), 4) 76 | self.assertEqual(len(test), 0) 77 | self.assertTrue(len(train[0]) > 0) 78 | self.assertTrue(len(valid[0]) > 0) 79 | self.assertTrue(isinstance(train, datasets.ChatDataset)) 80 | 81 | def test_hf(self): 82 | hf_args = { 83 | "path": "billsum", 84 | "prompt_feature": "text", 85 | "completion_feature": "summary", 86 | "train_split": "train[:2%]", 87 | "valid_split": "train[-2%:]", 88 | } 89 | args = types.SimpleNamespace( 90 | hf_dataset=hf_args, 91 | test=False, 92 | train=True, 93 | ) 94 | tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) 95 | train, valid, test = datasets.load_dataset(args, tokenizer) 96 | self.assertTrue(len(train) > 0) 97 | self.assertTrue(len(train[0]) > 0) 98 | self.assertTrue(len(valid) > 0) 99 | self.assertTrue(len(valid[0]) > 0) 100 | self.assertEqual(len(test), 0) 101 | 102 | args = types.SimpleNamespace( 103 | hf_dataset=[hf_args, hf_args], 104 | test=False, 105 | train=True, 106 | ) 107 | train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) 108 | self.assertEqual(2 * len(train), len(train_double)) 109 | self.assertEqual(2 * len(valid), len(valid_double)) 110 | self.assertEqual(2 * len(test), len(test_double)) 111 | 112 | 113 | if __name__ == "__main__": 114 | unittest.main() 115 | -------------------------------------------------------------------------------- /tests/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import unittest 4 | from pathlib import Path 5 | 6 | from huggingface_hub import snapshot_download 7 | 8 | from mlx_lm.tokenizer_utils import ( 9 | BPEStreamingDetokenizer, 10 | NaiveStreamingDetokenizer, 11 | SPMStreamingDetokenizer, 12 | load_tokenizer, 13 | ) 14 | 15 | 16 | class TestTokenizers(unittest.TestCase): 17 | 18 | def download_tokenizer(self, repo): 19 | path = Path( 20 | snapshot_download( 21 | repo_id=repo, 22 | allow_patterns=[ 23 | "tokenizer.json", 24 | "tokenizer_config.json", 25 | "special_tokens_map.json", 26 | "tokenizer.model", 27 | "chat_template.jinja", 28 | ], 29 | ) 30 | ) 31 | return load_tokenizer(path) 32 | 33 | def check_tokenizer(self, tokenizer): 34 | def check(tokens): 35 | expected_text = tokenizer.decode(tokens) 36 | detokenizer = tokenizer.detokenizer 37 | detokenizer.reset() 38 | text = "" 39 | for e, t in enumerate(tokens): 40 | detokenizer.add_token(t) 41 | seg = detokenizer.last_segment 42 | text += seg 43 | self.assertEqual(detokenizer.tokens, tokens[: e + 1]) 44 | detokenizer.finalize() 45 | text += detokenizer.last_segment 46 | self.assertEqual(text, expected_text) 47 | 48 | tokens = tokenizer.encode("こんにちは!私の名前はAI") 49 | check(tokens) 50 | 51 | tokens = tokenizer.encode("a ,b") 52 | check(tokens) 53 | 54 | tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') 55 | check(tokens) 56 | 57 | tokens = tokenizer.encode("3 3") 58 | check(tokens) 59 | 60 | tokens = tokenizer.encode("import 'package:flutter/material.dart';") 61 | check(tokens) 62 | 63 | tokens = tokenizer.encode("hello\nworld") 64 | check(tokens) 65 | 66 | def test_tokenizers(self): 67 | tokenizer_repos = [ 68 | ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), 69 | ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), 70 | ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), 71 | ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), 72 | ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), 73 | ("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer), 74 | ] 75 | for tokenizer_repo, expected_detokenizer in tokenizer_repos: 76 | with self.subTest(tokenizer=tokenizer_repo): 77 | tokenizer = self.download_tokenizer(tokenizer_repo) 78 | tokenizer.decode([0, 1, 2]) 79 | self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) 80 | self.check_tokenizer(tokenizer) 81 | 82 | # Try one with a naive detokenizer 83 | tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") 84 | tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) 85 | self.check_tokenizer(tokenizer) 86 | 87 | def test_special_tokens(self): 88 | tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" 89 | tokenizer = self.download_tokenizer(tokenizer_repo) 90 | 91 | detokenizer = tokenizer.detokenizer 92 | detokenizer.reset() 93 | detokenizer.add_token(tokenizer.eos_token_id) 94 | detokenizer.finalize() 95 | 96 | self.assertEqual(detokenizer.last_segment, tokenizer.eos_token) 97 | 98 | def test_tool_calling(self): 99 | tokenizer_repo = "mlx-community/Qwen3-4B-4bit" 100 | tokenizer = self.download_tokenizer(tokenizer_repo) 101 | self.assertTrue(tokenizer.has_tool_calling) 102 | self.assertEqual(tokenizer.tool_call_start, "") 103 | self.assertEqual(tokenizer.tool_call_end, "") 104 | 105 | tokenizer_repo = "mlx-community/Llama-3.2-1B-Instruct-4bit" 106 | tokenizer = self.download_tokenizer(tokenizer_repo) 107 | self.assertTrue(tokenizer.has_tool_calling, False) 108 | 109 | def test_thinking(self): 110 | tokenizer_repo = "mlx-community/Qwen3-4B-4bit" 111 | tokenizer = self.download_tokenizer(tokenizer_repo) 112 | self.assertTrue(tokenizer.has_thinking) 113 | self.assertEqual(tokenizer.think_start, "") 114 | self.assertEqual(tokenizer.think_end, "") 115 | 116 | 117 | if __name__ == "__main__": 118 | unittest.main() 119 | -------------------------------------------------------------------------------- /tests/test_sample_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import mlx.core as mx 4 | 5 | from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p, apply_xtc 6 | 7 | 8 | class TestSampleUtils(unittest.TestCase): 9 | def test_apply_top_p(self): 10 | probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] 11 | logits = mx.log(probs) 12 | 13 | new_logits = apply_top_p(logits, 0.3) 14 | actual_probs = mx.softmax(new_logits.squeeze()) 15 | self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) 16 | 17 | new_logits = apply_top_p(logits, 0.95) 18 | actual_probs = mx.softmax(new_logits.squeeze()) 19 | self.assertTrue(mx.allclose(probs.squeeze(), actual_probs)) 20 | 21 | probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] 22 | logits = mx.log(probs) 23 | new_logits = apply_top_p(logits, 0.4) 24 | actual_probs = mx.softmax(new_logits.squeeze()) 25 | self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0]) 26 | 27 | new_logits = apply_top_p(logits, 0.6) 28 | actual_probs = mx.softmax(new_logits.squeeze()) 29 | self.assertEqual( 30 | [round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0] 31 | ) 32 | 33 | new_logits = apply_top_p(logits, 0.95) 34 | actual_probs = mx.softmax(new_logits.squeeze()) 35 | actual_rounded = [round(p, 4) for p in actual_probs.tolist()] 36 | expected_rounded = [0.0, 0.5, 0.4, 0.1] 37 | self.assertEqual(actual_rounded, expected_rounded) 38 | self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0) 39 | 40 | # Batch mode works 41 | probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]]) 42 | logits = mx.log(probs) 43 | new_logits = apply_top_p(logits, 0.5) 44 | actual_probs = mx.softmax(new_logits, axis=-1) 45 | self.assertEqual( 46 | actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] 47 | ) 48 | 49 | def test_apply_min_p(self): 50 | probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] 51 | logits = mx.log(probs) 52 | new_logits = apply_min_p(logits, 0.8) 53 | actual_probs = mx.softmax(new_logits.squeeze()) 54 | self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) 55 | 56 | probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] 57 | logits = mx.log(probs) 58 | new_logits = apply_min_p(logits, 0.05) 59 | actual_probs = mx.softmax(new_logits.squeeze()) 60 | self.assertTrue(mx.allclose(actual_probs, mx.squeeze(probs))) 61 | 62 | # Batch mode works 63 | probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) 64 | logits = mx.log(probs) 65 | new_logits = apply_min_p(logits, 0.7) 66 | actual_probs = mx.softmax(new_logits, axis=-1) 67 | self.assertEqual( 68 | actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] 69 | ) 70 | 71 | def test_apply_top_k(self): 72 | probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] 73 | logits = mx.log(probs) 74 | 75 | new_logits = apply_top_k(logits, 1) 76 | actual_probs = mx.softmax(new_logits.squeeze()) 77 | self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) 78 | 79 | probs = mx.array([0.6, 0.0, 0.1, 0.3])[None] 80 | logits = mx.log(probs) 81 | new_logits = apply_top_k(logits, 2) 82 | actual_probs = mx.softmax(new_logits.squeeze()) 83 | self.assertEqual( 84 | [round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333] 85 | ) 86 | 87 | # Batch mode works 88 | probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) 89 | logits = mx.log(probs) 90 | 91 | new_logits = apply_top_k(logits, 1) 92 | actual_probs = mx.softmax(new_logits, axis=-1) 93 | self.assertEqual( 94 | actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] 95 | ) 96 | 97 | def test_apply_xtc(self): 98 | # Test the threshold 99 | probs = mx.array([[0.4, 0.3, 0.15, 0.15]]) 100 | new_probs = mx.softmax(apply_xtc(mx.log(probs), 1, 0.2, []), -1) 101 | expected = mx.array([[0, 0.5, 0.25, 0.25]]) 102 | self.assertTrue(mx.allclose(new_probs, expected)) 103 | probs = mx.array([[0.4, 0.3, 0.15, 0.15]]) 104 | new_probs = mx.softmax(apply_xtc(mx.log(probs), 1, 0.1, []), -1) 105 | expected = mx.array([[0, 0.0, 0.5, 0.5]]) 106 | self.assertTrue(mx.allclose(new_probs, expected)) 107 | 108 | # Test the special tokens 109 | probs = mx.array([[0.4, 0.3, 0.15, 0.15]]) 110 | new_probs = mx.softmax(apply_xtc(mx.log(probs), 1, 0.1, [0]), -1) 111 | expected = mx.array([[4 / 7, 0.0, 1.5 / 7, 1.5 / 7]]) 112 | self.assertTrue(mx.allclose(new_probs, expected)) 113 | 114 | # Test that with probability 0 the probs don't change 115 | probs = mx.array([[0.4, 0.3, 0.15, 0.15]]) 116 | new_probs = mx.softmax(apply_xtc(mx.log(probs), 0, 0.1, [0]), -1) 117 | self.assertTrue(mx.allclose(new_probs, probs)) 118 | 119 | 120 | if __name__ == "__main__": 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /mlx_lm/manage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List, Union 3 | 4 | from huggingface_hub import scan_cache_dir 5 | 6 | 7 | def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: 8 | """ 9 | Inspired by: 10 | - stackoverflow.com/a/8356620/593036 11 | - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data 12 | """ 13 | col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] 14 | row_format = ("{{:{}}} " * len(headers)).format(*col_widths) 15 | lines = [] 16 | lines.append(row_format.format(*headers)) 17 | lines.append(row_format.format(*["-" * w for w in col_widths])) 18 | for row in rows: 19 | lines.append(row_format.format(*row)) 20 | return "\n".join(lines) 21 | 22 | 23 | def ask_for_confirmation(message: str) -> bool: 24 | """Ask user for confirmation with Y/N prompt. 25 | Returns True for Y/yes, False for N/no/empty.""" 26 | y = ("y", "yes", "1") 27 | n = ("n", "no", "0", "") 28 | full_message = f"{message} (y/n) " 29 | while True: 30 | answer = input(full_message).lower() 31 | if answer in y: 32 | return True 33 | if answer in n: 34 | return False 35 | print(f"Invalid input. Must be one of: yes/no/y/n or empty for no") 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser(description="MLX Model Cache.") 40 | parser.add_argument( 41 | "--scan", 42 | action="store_true", 43 | help="Scan Hugging Face cache for mlx models.", 44 | ) 45 | parser.add_argument( 46 | "--delete", 47 | action="store_true", 48 | help="Delete models matching the given pattern.", 49 | ) 50 | parser.add_argument( 51 | "--pattern", 52 | type=str, 53 | help="Model repos contain the pattern.", 54 | default="mlx", 55 | ) 56 | 57 | args = parser.parse_args() 58 | 59 | if args.scan: 60 | print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".') 61 | hf_cache_info = scan_cache_dir() 62 | print( 63 | tabulate( 64 | rows=[ 65 | [ 66 | repo.repo_id, 67 | repo.repo_type, 68 | "{:>12}".format(repo.size_on_disk_str), 69 | repo.nb_files, 70 | repo.last_accessed_str, 71 | repo.last_modified_str, 72 | str(repo.repo_path), 73 | ] 74 | for repo in sorted( 75 | hf_cache_info.repos, key=lambda repo: repo.repo_path 76 | ) 77 | if args.pattern in repo.repo_id 78 | ], 79 | headers=[ 80 | "REPO ID", 81 | "REPO TYPE", 82 | "SIZE ON DISK", 83 | "NB FILES", 84 | "LAST_ACCESSED", 85 | "LAST_MODIFIED", 86 | "LOCAL PATH", 87 | ], 88 | ) 89 | ) 90 | 91 | if args.delete: 92 | print(f'Deleting models matching pattern "{args.pattern}"') 93 | hf_cache_info = scan_cache_dir() 94 | 95 | repos = [ 96 | repo 97 | for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) 98 | if args.pattern in repo.repo_id 99 | ] 100 | if repos: 101 | print("\nFound the following models:") 102 | print( 103 | tabulate( 104 | rows=[ 105 | [ 106 | repo.repo_id, 107 | repo.size_on_disk_str, # Added size information 108 | str(repo.repo_path), 109 | ] 110 | for repo in repos 111 | ], 112 | headers=[ 113 | "REPO ID", 114 | "SIZE", # Added size header 115 | "LOCAL PATH", 116 | ], 117 | ) 118 | ) 119 | 120 | confirmed = ask_for_confirmation( 121 | "\nAre you sure you want to delete these models?" 122 | ) 123 | if confirmed: 124 | for model_info in repos: 125 | print(f"\nDeleting {model_info.repo_id}...") 126 | for revision in sorted( 127 | model_info.revisions, key=lambda revision: revision.commit_hash 128 | ): 129 | strategy = hf_cache_info.delete_revisions(revision.commit_hash) 130 | strategy.execute() 131 | print("\nModel(s) deleted successfully.") 132 | else: 133 | print("\nDeletion cancelled - no changes made.") 134 | else: 135 | print(f'No models found matching pattern "{args.pattern}"') 136 | 137 | 138 | if __name__ == "__main__": 139 | print( 140 | "Calling `python -m mlx_lm.manage...` directly is deprecated." 141 | " Use `mlx_lm.manage...` or `python -m mlx_lm manage ...` instead." 142 | ) 143 | main() 144 | -------------------------------------------------------------------------------- /mlx_lm/models/qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 9 | 10 | 11 | @dataclass 12 | class ModelArgs(BaseModelArgs): 13 | model_type: str 14 | hidden_size: int = 2048 15 | num_attention_heads: int = 16 16 | num_hidden_layers: int = 24 17 | kv_channels: int = 128 18 | max_position_embeddings: int = 8192 19 | layer_norm_epsilon: float = 1e-6 20 | intermediate_size: int = 11008 21 | no_bias: bool = True 22 | vocab_size: int = 151936 23 | num_key_value_heads = None 24 | 25 | def __post_init__(self): 26 | if self.num_key_value_heads is None: 27 | self.num_key_value_heads = self.num_attention_heads 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | 34 | hidden_size = args.hidden_size 35 | self.num_attention_heads = args.num_attention_heads 36 | 37 | hidden_size_per_attention_head = hidden_size // self.num_attention_heads 38 | 39 | self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False) 40 | 41 | proj_size = args.kv_channels * self.num_attention_heads 42 | 43 | self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True) 44 | self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias) 45 | 46 | self.scale = hidden_size_per_attention_head**-0.5 47 | 48 | def __call__(self, x, mask=None, cache=None): 49 | qkv = self.c_attn(x) 50 | 51 | q, k, v = mx.split(qkv, 3, axis=-1) 52 | 53 | B, L, _ = q.shape 54 | 55 | queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) 56 | keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) 57 | values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) 58 | 59 | if cache is not None: 60 | queries = self.rotary_emb(queries, offset=cache.offset) 61 | keys = self.rotary_emb(keys, offset=cache.offset) 62 | keys, values = cache.update_and_fetch(keys, values) 63 | else: 64 | queries = self.rotary_emb(queries) 65 | keys = self.rotary_emb(keys) 66 | 67 | output = scaled_dot_product_attention( 68 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 69 | ) 70 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 71 | 72 | return self.c_proj(output) 73 | 74 | 75 | class MLP(nn.Module): 76 | def __init__(self, args: ModelArgs): 77 | super().__init__() 78 | 79 | self.w1 = nn.Linear( 80 | args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias 81 | ) 82 | self.w2 = nn.Linear( 83 | args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias 84 | ) 85 | self.c_proj = nn.Linear( 86 | args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias 87 | ) 88 | 89 | def __call__(self, x): 90 | a1 = self.w1(x) 91 | a2 = self.w2(x) 92 | return self.c_proj(a1 * nn.silu(a2)) 93 | 94 | 95 | class TransformerBlock(nn.Module): 96 | def __init__(self, args: ModelArgs): 97 | super().__init__() 98 | 99 | self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) 100 | self.attn = Attention(args) 101 | self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) 102 | self.mlp = MLP(args) 103 | 104 | def __call__(self, x, mask=None, cache=None): 105 | residual = x 106 | x = self.ln_1(x) 107 | x = self.attn(x, mask=mask, cache=cache) 108 | residual = x + residual 109 | x = self.ln_2(residual) 110 | x = self.mlp(x) 111 | x = x + residual 112 | 113 | return x 114 | 115 | 116 | class QwenModel(nn.Module): 117 | def __init__(self, args: ModelArgs): 118 | super().__init__() 119 | self.wte = nn.Embedding(args.vocab_size, args.hidden_size) 120 | self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] 121 | self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) 122 | 123 | def __call__(self, inputs, mask=None, cache=None): 124 | x = self.wte(inputs) 125 | 126 | if mask is None: 127 | mask = create_attention_mask(x, cache) 128 | 129 | if cache is None: 130 | cache = [None] * len(self.h) 131 | 132 | for layer, c in zip(self.h, cache): 133 | x = layer(x, mask, c) 134 | 135 | return self.ln_f(x) 136 | 137 | 138 | class Model(nn.Module): 139 | def __init__(self, config: ModelArgs): 140 | super().__init__() 141 | self.model_type = config.model_type 142 | self.transformer = QwenModel(config) 143 | self.lm_head = nn.Linear( 144 | config.hidden_size, config.vocab_size, bias=not config.no_bias 145 | ) 146 | self.args = config 147 | 148 | def __call__( 149 | self, 150 | x: mx.array, 151 | mask: mx.array = None, 152 | cache=None, 153 | ) -> mx.array: 154 | y = self.transformer(x, mask, cache) 155 | return self.lm_head(y) 156 | 157 | @property 158 | def layers(self): 159 | return self.transformer.h 160 | -------------------------------------------------------------------------------- /mlx_lm/cache_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import argparse 4 | import json 5 | import sys 6 | import time 7 | 8 | import mlx.core as mx 9 | 10 | from .generate import generate_step 11 | from .models.cache import make_prompt_cache, save_prompt_cache 12 | from .utils import load 13 | 14 | DEFAULT_QUANTIZED_KV_START = 5000 15 | 16 | 17 | def setup_arg_parser(): 18 | """Set up and return the argument parser.""" 19 | parser = argparse.ArgumentParser( 20 | description="Cache the state of a prompt to be reused with mlx_lm.generate" 21 | ) 22 | parser.add_argument( 23 | "--model", 24 | type=str, 25 | default="mlx_model", 26 | help="The path to the local model directory or Hugging Face repo.", 27 | ) 28 | parser.add_argument( 29 | "--adapter-path", 30 | type=str, 31 | help="Optional path for the trained adapter weights and config.", 32 | ) 33 | parser.add_argument( 34 | "--trust-remote-code", 35 | action="store_true", 36 | help="Enable trusting remote code for tokenizer", 37 | ) 38 | parser.add_argument( 39 | "--eos-token", 40 | type=str, 41 | default=None, 42 | help="End of sequence token for tokenizer", 43 | ) 44 | parser.add_argument( 45 | "--ignore-chat-template", 46 | action="store_true", 47 | help="Use the raw prompt without the tokenizer's chat template.", 48 | ) 49 | parser.add_argument( 50 | "--use-default-chat-template", 51 | action="store_true", 52 | help="Use the default chat template", 53 | ) 54 | parser.add_argument( 55 | "--max-kv-size", 56 | type=int, 57 | default=None, 58 | help="Set the maximum key-value cache size", 59 | ) 60 | parser.add_argument( 61 | "--prompt-cache-file", 62 | help="The file to save the prompt cache in", 63 | required=True, 64 | ) 65 | parser.add_argument( 66 | "--prompt", 67 | required=True, 68 | help="Message to be processed by the model ('-' reads from stdin)", 69 | ) 70 | parser.add_argument( 71 | "--kv-bits", 72 | type=int, 73 | help="Number of bits for KV cache quantization. " 74 | "Defaults to no quantization.", 75 | default=None, 76 | ) 77 | parser.add_argument( 78 | "--kv-group-size", 79 | type=int, 80 | help="Group size for KV cache quantization.", 81 | default=64, 82 | ) 83 | parser.add_argument( 84 | "--quantized-kv-start", 85 | help="When --kv-bits is set, start quantizing the KV cache " 86 | "from this step onwards.", 87 | type=int, 88 | default=DEFAULT_QUANTIZED_KV_START, 89 | ) 90 | return parser 91 | 92 | 93 | def main(): 94 | parser = setup_arg_parser() 95 | args = parser.parse_args() 96 | 97 | # Building tokenizer_config 98 | tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} 99 | if args.eos_token is not None: 100 | tokenizer_config["eos_token"] = args.eos_token 101 | 102 | model, tokenizer = load( 103 | args.model, 104 | adapter_path=args.adapter_path, 105 | tokenizer_config=tokenizer_config, 106 | ) 107 | 108 | args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt 109 | 110 | if args.use_default_chat_template: 111 | if tokenizer.chat_template is None: 112 | tokenizer.chat_template = tokenizer.default_chat_template 113 | 114 | if not args.ignore_chat_template and tokenizer.chat_template is not None: 115 | messages = [{"role": "user", "content": args.prompt}] 116 | prompt = tokenizer.apply_chat_template( 117 | messages, add_generation_prompt=False, continue_final_message=True 118 | ) 119 | 120 | else: 121 | prompt = tokenizer.encode(args.prompt) 122 | 123 | cache = make_prompt_cache(model, args.max_kv_size) 124 | y = mx.array(prompt) 125 | 126 | # Process the prompt 127 | start = time.time() 128 | max_msg_len = 0 129 | 130 | def callback(processed, total_tokens): 131 | current = time.time() 132 | speed = processed / (current - start) 133 | msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" 134 | nonlocal max_msg_len 135 | max_msg_len = max(max_msg_len, len(msg)) 136 | print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) 137 | 138 | for _ in generate_step( 139 | y, 140 | model, 141 | max_tokens=0, 142 | prompt_cache=cache, 143 | kv_bits=args.kv_bits, 144 | kv_group_size=args.kv_group_size, 145 | quantized_kv_start=args.quantized_kv_start, 146 | prompt_progress_callback=callback, 147 | ): 148 | pass 149 | 150 | print() 151 | print(f"Peak memory: {mx.get_peak_memory() / 1e9:.3f} GB") 152 | 153 | print("Saving...") 154 | metadata = {} 155 | metadata["model"] = args.model 156 | metadata["chat_template"] = json.dumps(tokenizer.chat_template) 157 | metadata["tokenizer_config"] = json.dumps(tokenizer_config) 158 | save_prompt_cache(args.prompt_cache_file, cache, metadata) 159 | 160 | 161 | if __name__ == "__main__": 162 | print( 163 | "Calling `python -m mlx_lm.cache_prompt...` directly is deprecated." 164 | " Use `mlx_lm.cache_prompt...` or `python -m mlx_lm cache_prompt ...` instead." 165 | ) 166 | main() 167 | -------------------------------------------------------------------------------- /mlx_lm/SERVER.md: -------------------------------------------------------------------------------- 1 | # HTTP Model Server 2 | 3 | You use `mlx-lm` to make an HTTP API for generating text with any supported 4 | model. The HTTP API is intended to be similar to the [OpenAI chat 5 | API](https://platform.openai.com/docs/api-reference). 6 | 7 | > [!NOTE] 8 | > The MLX LM server is not recommended for production as it only implements 9 | > basic security checks. 10 | 11 | Start the server with: 12 | 13 | ```shell 14 | mlx_lm.server --model 15 | ``` 16 | 17 | For example: 18 | 19 | ```shell 20 | mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit 21 | ``` 22 | 23 | This will start a text generation server on port `8080` of the `localhost` 24 | using Mistral 7B instruct. The model will be downloaded from the provided 25 | Hugging Face repo if it is not already in the local cache. 26 | 27 | To see a full list of options run: 28 | 29 | ```shell 30 | mlx_lm.server --help 31 | ``` 32 | 33 | You can make a request to the model by running: 34 | 35 | ```shell 36 | curl localhost:8080/v1/chat/completions \ 37 | -H "Content-Type: application/json" \ 38 | -d '{ 39 | "messages": [{"role": "user", "content": "Say this is a test!"}], 40 | "temperature": 0.7 41 | }' 42 | ``` 43 | 44 | ### Request Fields 45 | 46 | - `messages`: An array of message objects representing the conversation 47 | history. Each message object should have a role (e.g. user, assistant) and 48 | content (the message text). 49 | 50 | - `role_mapping`: (Optional) A dictionary to customize the role prefixes in 51 | the generated prompt. If not provided, the default mappings are used. 52 | 53 | - `stop`: (Optional) An array of strings or a single string. These are 54 | sequences of tokens on which the generation should stop. 55 | 56 | - `max_tokens`: (Optional) An integer specifying the maximum number of tokens 57 | to generate. Defaults to `512`. 58 | 59 | - `stream`: (Optional) A boolean indicating if the response should be 60 | streamed. If true, responses are sent as they are generated. Defaults to 61 | false. 62 | 63 | - `temperature`: (Optional) A float specifying the sampling temperature. 64 | Defaults to `0.0`. 65 | 66 | - `top_p`: (Optional) A float specifying the nucleus sampling parameter. 67 | Defaults to `1.0`. 68 | 69 | - `top_k`: (Optional) An integer specifying the top-k sampling parameter. 70 | Defaults to `0` (disabled). 71 | 72 | - `min_p`: (Optional) A float specifying the min-p sampling parameter. 73 | Defaults to `0.0` (disabled). 74 | 75 | - `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. 76 | Defaults to `1.0`. 77 | 78 | - `repetition_context_size`: (Optional) The size of the context window for 79 | applying repetition penalty. Defaults to `20`. 80 | 81 | - `logit_bias`: (Optional) A dictionary mapping token IDs to their bias 82 | values. Defaults to `None`. 83 | 84 | - `logprobs`: (Optional) An integer specifying the number of top tokens and 85 | corresponding log probabilities to return for each output in the generated 86 | sequence. If set, this can be any value between 1 and 10, inclusive. 87 | 88 | - `model`: (Optional) A string path to a local model or Hugging Face repo id. 89 | If the path is local is must be relative to the directory the server was 90 | started in. 91 | 92 | - `adapters`: (Optional) A string path to low-rank adapters. The path must be 93 | relative to the directory the server was started in. 94 | 95 | - `draft_model`: (Optional) Specifies a smaller model to use for speculative 96 | decoding. Set to `null` to unload. 97 | 98 | - `num_draft_tokens`: (Optional) The number of draft tokens the draft model 99 | should predict at once. Defaults to `3`. 100 | 101 | ### Response Fields 102 | 103 | - `id`: A unique identifier for the chat. 104 | 105 | - `system_fingerprint`: A unique identifier for the system. 106 | 107 | - `object`: Any of "chat.completion", "chat.completion.chunk" (for 108 | streaming), or "text.completion". 109 | 110 | - `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). 111 | 112 | - `created`: A time-stamp for when the request was processed. 113 | 114 | - `choices`: A list of outputs. Each output is a dictionary containing the fields: 115 | - `index`: The index in the list. 116 | - `logprobs`: A dictionary containing the fields: 117 | - `token_logprobs`: A list of the log probabilities for the generated 118 | tokens. 119 | - `tokens`: A list of the generated token ids. 120 | - `top_logprobs`: A list of lists. Each list contains the `logprobs` 121 | top tokens (if requested) with their corresponding probabilities. 122 | - `finish_reason`: The reason the completion ended. This can be either of 123 | `"stop"` or `"length"`. 124 | - `message`: The text response from the model. 125 | 126 | - `usage`: A dictionary containing the fields: 127 | - `prompt_tokens`: The number of prompt tokens processed. 128 | - `completion_tokens`: The number of tokens generated. 129 | - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. 130 | 131 | ### List Models 132 | 133 | Use the `v1/models` endpoint to list available models: 134 | 135 | ```shell 136 | curl localhost:8080/v1/models -H "Content-Type: application/json" 137 | ``` 138 | 139 | This will return a list of locally available models where each model in the 140 | list contains the following fields: 141 | 142 | - `id`: The Hugging Face repo id. 143 | - `created`: A time-stamp representing the model creation time. 144 | -------------------------------------------------------------------------------- /tests/test_generate.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import unittest 4 | from typing import List 5 | 6 | from mlx_lm.generate import ( 7 | GenerationResponse, 8 | generate, 9 | stream_generate, 10 | ) 11 | from mlx_lm.sample_utils import make_logits_processors, make_sampler 12 | from mlx_lm.utils import load 13 | 14 | 15 | class TestGenerate(unittest.TestCase): 16 | 17 | @classmethod 18 | def setUpClass(cls): 19 | cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" 20 | cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) 21 | 22 | def test_generate(self): 23 | # Simple test that generation runs 24 | text = generate( 25 | self.model, self.tokenizer, "hello", max_tokens=5, verbose=False 26 | ) 27 | 28 | def test_generate_with_logit_bias(self): 29 | logit_bias = {0: 2000.0, 1: -20.0} 30 | text = generate( 31 | self.model, 32 | self.tokenizer, 33 | "hello", 34 | max_tokens=5, 35 | logits_processors=make_logits_processors(logit_bias), 36 | verbose=False, 37 | ) 38 | self.assertEqual(text, "!!!!!") 39 | 40 | def test_generate_with_processor(self): 41 | init_toks = self.tokenizer.encode("hello") 42 | 43 | all_toks = None 44 | 45 | def logits_processor(toks, logits): 46 | nonlocal all_toks 47 | all_toks = toks 48 | return logits 49 | 50 | generate( 51 | self.model, 52 | self.tokenizer, 53 | "hello", 54 | max_tokens=5, 55 | verbose=False, 56 | logits_processors=[logits_processor], 57 | ) 58 | self.assertEqual(len(all_toks), len(init_toks) + 5) 59 | 60 | def test_stream_generate_speculative(self): 61 | # Use same model as draft model, this is not a speed test 62 | draft_model, _ = load(self.HF_MODEL_PATH) 63 | 64 | results: List[GenerationResponse] = [] 65 | drafted: List[bool] = [] 66 | 67 | # make a determinate sampler 68 | sampler = make_sampler(temp=0.0) 69 | messages = [{"role": "user", "content": "hello"}] 70 | prompt = self.tokenizer.apply_chat_template( 71 | messages, add_generation_prompt=True 72 | ) 73 | 74 | for generation_result in stream_generate( 75 | model=self.model, 76 | tokenizer=self.tokenizer, 77 | prompt=prompt, 78 | max_tokens=5, 79 | draft_model=draft_model, 80 | num_draft_tokens=2, 81 | sampler=sampler, 82 | ): 83 | drafted.append(generation_result.from_draft) 84 | results.append(generation_result) 85 | 86 | self.assertEqual(len(results), 6) 87 | drafted.pop() 88 | # since num_draft_tokens is 2 and draft model is the same, the 89 | # first 2 generations should be drafts, the third should come 90 | # from the target model, and last two should be drafts 91 | self.assertEqual(drafted, [True, True, False, True, True]) 92 | 93 | def test_stream_generate_input_embeddings(self): 94 | sampler = make_sampler(temp=0.0) # determinate sampler 95 | 96 | # get prompt embeddings 97 | messages = [{"role": "user", "content": "Say 'TEST' and nothing else"}] 98 | prompt = self.tokenizer.apply_chat_template( 99 | messages, add_generation_prompt=True 100 | ) 101 | prompt_embeddings = self.model.model.embed_tokens(prompt) 102 | 103 | response = "" 104 | for generation_result in stream_generate( 105 | model=self.model, 106 | tokenizer=self.tokenizer, 107 | prompt=prompt, 108 | max_tokens=5, 109 | sampler=sampler, 110 | input_embeddings=prompt_embeddings, 111 | ): 112 | response += generation_result.text 113 | 114 | self.assertEqual("TEST", response) 115 | 116 | def test_stream_generate_input_embeddings_prefill(self): 117 | sampler = make_sampler(temp=0.0) # determinate sampler 118 | 119 | # get prompt embeddings 120 | messages = [{"role": "user", "content": "Say 'TEST' and nothing else"}] 121 | prompt = self.tokenizer.apply_chat_template( 122 | messages, add_generation_prompt=True 123 | ) 124 | prompt_embeddings = self.model.model.embed_tokens(prompt) 125 | 126 | # setup prompt progress callback to track batched prefill 127 | num_prompt_processing_callbacks = 0 128 | 129 | def progress_callback(processed: int, total: int) -> None: 130 | nonlocal num_prompt_processing_callbacks 131 | num_prompt_processing_callbacks += 1 132 | 133 | # generate 134 | prefill_step_size = 5 135 | response = "" 136 | for generation_result in stream_generate( 137 | model=self.model, 138 | tokenizer=self.tokenizer, 139 | prompt=prompt, 140 | max_tokens=5, 141 | sampler=sampler, 142 | input_embeddings=prompt_embeddings, 143 | prefill_step_size=prefill_step_size, 144 | prompt_progress_callback=progress_callback, 145 | ): 146 | response += generation_result.text 147 | 148 | self.assertEqual("TEST", response) 149 | num_embeddings = prompt_embeddings.shape[0] 150 | self.assertEqual( 151 | num_embeddings / prefill_step_size, num_prompt_processing_callbacks 152 | ) 153 | 154 | 155 | if __name__ == "__main__": 156 | unittest.main() 157 | -------------------------------------------------------------------------------- /mlx_lm/LEARNED_QUANTS.md: -------------------------------------------------------------------------------- 1 | # Learned Quantization 2 | 3 | To reduce the quality loss from quantization MLX LM has several options: 4 | 5 | - Distilled Weight Quantization (DWQ) 6 | - Activation-aware Weight Quantization (AWQ)[^1] 7 | - Dynamic quantization 8 | 9 | All methods use calibration data to tune parameters or hyper-parameters of the 10 | model. DWQ fine-tunes non-quantized parameters (including quantization scales 11 | and biases) using the non-quantized model as a teacher. AWQ scales and clips 12 | the weights prior to quantization. Dynamic quantization estimates the 13 | sensitivity of a model's outputs to each layer and uses a higher precision for 14 | layers which have higher sensitivity. 15 | 16 | Dynamic quantization is the fastest to run. DWQ takes longer but typically 17 | yields better results. You can also cascade methods. For example a dynamically 18 | quantized model can be further refined with DWQ. 19 | 20 | To get started, first install the requirements: 21 | 22 | ``` 23 | pip install mlx-lm[quant] 24 | ``` 25 | 26 | ### DWQ 27 | 28 | Use `mlx_lm.dwq` to run DWQ on a given model. For example: 29 | 30 | ```bash 31 | mlx_lm.dwq --model mistralai/Mistral-7B-Instruct-v0.3 32 | ``` 33 | 34 | Some important options, along with their default values are: 35 | 36 | - `--mlx-path mlx_model`: The location to save the DWQ model. 37 | - `--bits 4`: Precision of the quantization. 38 | - `--num-samples 1024`: Number of samples to use. Using more samples can lead to 39 | better results but takes longer. 40 | - `--batch-size 8`: Use a smaller batch size to reduce the memory footprint. 41 | 42 | For a full list of options run: 43 | 44 | ```bash 45 | mlx_lm.dwq --help 46 | ``` 47 | 48 | #### Tips 49 | 50 | - DWQ works best distilling to lower precision, anywhere from 2-bit to 4-bit 51 | models. 52 | - Distilling 16-bit precision to 8-bit and even 6-bit often doesn't work well. 53 | The loss starts out so low that it's difficult to reduce further. 54 | - Decreasing the quantization group size (e.g. `--group-size 32`) doubles the 55 | number of tunable parameters and can work much better. 56 | - If the loss is oscillating and not going down consistently, try reducing the 57 | learning rate. If it is decreasing but slowly, try increasing the learning 58 | rate. 59 | - As a rule of thumb, lower precision can benefit from a higher learning rate 60 | since the loss starts out higher. Conversely, higher precision needs a lower 61 | learning rate. 62 | 63 | 64 | #### Memory Use 65 | 66 | A few options to reduce memory use for DWQ: 67 | 68 | - Distill from an 8-bit model instead of a 16-bit model. The 8-bit 69 | models are usually as good as 16-bit precision models. 70 | - Use a shorter maximum sequence length. The default is 2048. Using 71 | `--max-seq-length 512` reduces the memory and still gets good results. 72 | - Use a smaller batch size, e.g. `--batch-size 1` 73 | 74 | ### Dynamic Quantization 75 | 76 | Use `mlx_lm.dynamic_quant` to generate a dynamic quantization of given model. 77 | For example: 78 | 79 | ```bash 80 | mlx_lm.dynamic_quant --model mistralai/Mistral-7B-Instruct-v0.3 81 | ``` 82 | 83 | The script will estimate the sensitivity for each quantizable layer in the 84 | model. It will then quantize the model using higher precision (default 5 bits) 85 | for the more sensitive layers and lower precision (default 4 bits) for the 86 | rest. The script also saves a JSON file with each layer's sensitivities which 87 | saves needing to compute it multiple times to make different precision quants 88 | of the same model. 89 | 90 | Some important options are: 91 | 92 | - `--target-bpw`: The target bits-per-weight. For a given set of quantization 93 | parameters only certain ranges are possible. For example, with the default 94 | parameters a BPW in the range `[4.5, 5.5]` is achievable. 95 | - `--sensitivities`: A path to a precomputed sensitivities file. 96 | - `--low-bits`: The number of bits to use for the less sensitive layers. 97 | - `--high-bits`: The number of bits to use for the more sensitive layers. 98 | 99 | ### AWQ 100 | 101 | Use `mlx_lm.awq` to run AWQ on a given model. For example: 102 | 103 | ```bash 104 | mlx_lm.awq --model mistralai/Mistral-7B-Instruct-v0.3 105 | ``` 106 | 107 | The script can take anywhere form a few minutes to several hours to run 108 | depending on the model size and the number of samples. 109 | 110 | Some important options, along with their default values, are: 111 | 112 | - `--mlx-path mlx_model`: The location to save the AWQ model. 113 | - `--bits 4`: Precision of the quantization. 114 | - `--num-samples 32`: Number of samples to use. Using more samples can lead to 115 | better results but takes longer. 116 | - `--n-grid 10`: The granularity of the AWQ search. A larger grid can lead to 117 | better results but takes longer. 118 | 119 | For a full list of options run: 120 | 121 | ```bash 122 | mlx_lm.awq --help 123 | ``` 124 | 125 | ### Evaluate 126 | 127 | Once the training script finishes, you can evaluate the quality of the model 128 | on downstream tasks using `mlx_lm.evaluate`. For example: 129 | 130 | ```bash 131 | mlx_lm.evaluate \ 132 | --model mlx_model \ 133 | --tasks winogrande boolq arc_challenge arc_easy hellaswag openbookqa piqa social_iqa 134 | ``` 135 | 136 | ### Upload to Hugging Face 137 | 138 | Use `mlx_lm.upload` to upload the quantized model to the Hugging Face Hub. For 139 | example: 140 | 141 | ```bash 142 | mlx_lm.upload \ 143 | --path mlx_model \ 144 | --upload-repo mlx-community/Mistral-7B-Instruct-v0.3-3bit-DWQ 145 | ``` 146 | 147 | [^1]: Refer to the [paper](https://arxiv.org/abs/2306.00978) 148 | and [github repository](https://github.com/mit-han-lab/llm-awq) for more 149 | details. 150 | -------------------------------------------------------------------------------- /mlx_lm/models/olmo.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import sys 4 | from dataclasses import dataclass 5 | from typing import Any, Optional 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | 10 | from .base import BaseModelArgs, create_attention_mask 11 | 12 | try: 13 | import hf_olmo 14 | except ImportError: 15 | print("To run olmo install ai2-olmo: pip install ai2-olmo") 16 | sys.exit(1) 17 | 18 | 19 | @dataclass 20 | class ModelArgs(BaseModelArgs): 21 | model_type: str 22 | d_model: int 23 | n_layers: int 24 | mlp_hidden_size: int 25 | n_heads: int 26 | vocab_size: int 27 | embedding_size: int 28 | rope_theta: float = 10000 29 | rope_traditional: bool = False 30 | mlp_ratio: int = 4 31 | weight_tying: bool = False 32 | 33 | def __post_init__(self): 34 | self.mlp_hidden_size = ( 35 | self.mlp_hidden_size 36 | if self.mlp_hidden_size is not None 37 | else self.mlp_ratio * self.d_model 38 | ) 39 | 40 | 41 | class TransformerBlock(nn.Module): 42 | def __init__(self, args: ModelArgs): 43 | super().__init__() 44 | self.n_heads = args.n_heads 45 | dim = args.d_model 46 | 47 | self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False) 48 | self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False) 49 | 50 | self.att_norm = nn.LayerNorm(dim, affine=False) 51 | self.ff_norm = nn.LayerNorm(dim, affine=False) 52 | 53 | head_dim = dim // self.n_heads 54 | self.scale = head_dim**-0.5 55 | 56 | self.att_proj = nn.Linear(dim, 3 * dim, bias=False) 57 | self.attn_out = nn.Linear(dim, dim, bias=False) 58 | 59 | self.rope = nn.RoPE( 60 | head_dim, 61 | traditional=args.rope_traditional, 62 | base=args.rope_theta, 63 | ) 64 | 65 | self.args = args 66 | 67 | def attend( 68 | self, 69 | x: mx.array, 70 | mask: Optional[mx.array] = None, 71 | cache: Optional[Any] = None, 72 | ) -> mx.array: 73 | B, L, D = x.shape 74 | 75 | queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1) 76 | 77 | # Prepare the queries, keys and values for the attention computation 78 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 79 | keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 80 | values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 81 | 82 | if cache is not None: 83 | queries = self.rope(queries, offset=cache.offset) 84 | keys = self.rope(keys, offset=cache.offset) 85 | keys, values = cache.update_and_fetch(keys, values) 86 | else: 87 | queries = self.rope(queries) 88 | keys = self.rope(keys) 89 | 90 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) 91 | if mask is not None: 92 | scores += mask 93 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 94 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 95 | return self.attn_out(output) 96 | 97 | def __call__( 98 | self, 99 | x: mx.array, 100 | mask: Optional[mx.array] = None, 101 | cache: Optional[Any] = None, 102 | ) -> mx.array: 103 | r = self.attend(self.att_norm(x), mask, cache) 104 | h = x + r 105 | 106 | x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1) 107 | 108 | out = h + self.ff_out(nn.silu(x2) * x1) 109 | return out 110 | 111 | 112 | class Transformer(nn.Module): 113 | def __init__(self, args: ModelArgs): 114 | super().__init__() 115 | self.n_layers = args.n_layers 116 | self.weight_tying = args.weight_tying 117 | 118 | self.wte = nn.Embedding(args.embedding_size, args.d_model) 119 | self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] 120 | if not self.weight_tying: 121 | self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) 122 | self.norm = nn.LayerNorm(args.d_model, affine=False) 123 | 124 | def __call__( 125 | self, 126 | inputs: mx.array, 127 | mask: mx.array = None, 128 | cache=None, 129 | ): 130 | h = self.wte(inputs) 131 | 132 | if mask is None: 133 | mask = create_attention_mask(h, cache) 134 | 135 | if cache is None: 136 | cache = [None] * len(self.blocks) 137 | 138 | for block, c in zip(self.blocks, cache): 139 | h = block(h, mask, c) 140 | 141 | h = self.norm(h) 142 | 143 | if self.weight_tying: 144 | return self.wte.as_linear(h), cache 145 | 146 | return self.ff_out(h) 147 | 148 | 149 | class OlmoModel(nn.Module): 150 | def __init__(self, args: ModelArgs): 151 | super().__init__() 152 | self.transformer = Transformer(args) 153 | 154 | def __call__( 155 | self, 156 | inputs: mx.array, 157 | mask: mx.array = None, 158 | cache=None, 159 | ): 160 | return self.transformer(inputs, mask, cache) 161 | 162 | 163 | class Model(nn.Module): 164 | def __init__(self, args: ModelArgs): 165 | super().__init__() 166 | self.model_type = args.model_type 167 | self.model = OlmoModel(args) 168 | self.args = args 169 | 170 | def __call__( 171 | self, 172 | inputs: mx.array, 173 | mask: mx.array = None, 174 | cache=None, 175 | ): 176 | return self.model(inputs, mask, cache) 177 | 178 | @property 179 | def layers(self): 180 | return self.model.transformer.blocks 181 | -------------------------------------------------------------------------------- /mlx_lm/models/bitlinear_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | from mlx.nn.layers.quantized import QuantizedLinear 6 | from mlx.utils import tree_flatten, tree_unflatten 7 | 8 | 9 | def bitnet_quantize(model, quantization_config: dict): 10 | quantize_layers = [] 11 | modules_to_not_convert = quantization_config.get("modules_to_not_convert", []) 12 | invert_weight_scales = ( 13 | quantization_config.get("linear_class", "") != "autobitlinear" 14 | ) 15 | 16 | for name, module in tree_flatten(model.leaf_modules(), is_leaf=nn.Module.is_module): 17 | 18 | # Replace nn.Linear layers, but skip any layer from the `modules_to_not_convert` list 19 | if name not in modules_to_not_convert and isinstance(module, nn.Linear): 20 | old_weight = module.weight 21 | out_features, in_features = old_weight.shape 22 | bias = "bias" in module 23 | new_layer = BitLinear( 24 | in_features, 25 | out_features, 26 | bias=bias, 27 | invert_weight_scales=invert_weight_scales, 28 | ) 29 | quantize_layers.append((name, new_layer)) 30 | if len(quantize_layers) > 0: 31 | model.update_modules(tree_unflatten(quantize_layers)) 32 | return model 33 | 34 | 35 | def make_bitlinear_kernel(): 36 | """ 37 | Custom Metal kernel that performs matrix multiplication directly on 38 | packed weights and scales the output. This eliminates the need to 39 | store unpacked weights in memory. 40 | """ 41 | source = """ 42 | constexpr int M = 4; 43 | constexpr int BLOCK = 32; 44 | 45 | uint tid = thread_position_in_grid.y; 46 | uint in_offset = thread_position_in_grid.x; 47 | 48 | uint batch_idx = tid / (out_features / 4); 49 | uint row_idx = tid % (out_features / 4); 50 | 51 | float sum[4] = {0.0}; 52 | 53 | for (uint i = in_offset * M; i < in_features; i += BLOCK * M) { 54 | float v[M]; 55 | for (int j=0; j> 2) & 3) - 1); 63 | sum[2] += v[j] * (((w >> 4) & 3) - 1); 64 | sum[3] += v[j] * (((w >> 6) & 3) - 1); 65 | } 66 | } 67 | 68 | for (int j=0; j<4; j++) { 69 | sum[j] = simd_sum(sum[j]); 70 | } 71 | 72 | // Apply weight scaling by diving them or multiplying them 73 | if (in_offset == 0) { 74 | float scale = invert_weight_scales ? 1 / weight_scale[0] : weight_scale[0]; 75 | for (int i=0; i<4; i++) { 76 | out[batch_idx * out_features + row_idx + i * (out_features/4)] = static_cast(sum[i] * scale); 77 | } 78 | } 79 | """ 80 | 81 | return mx.fast.metal_kernel( 82 | name="bitlinear_matmul", 83 | input_names=["x", "packed_weights", "weight_scale"], 84 | output_names=["out"], 85 | source=source, 86 | ) 87 | 88 | 89 | _bitlinear_kernel = make_bitlinear_kernel() 90 | 91 | 92 | class BitLinear(nn.Module): 93 | """ 94 | BitLinear module with memory-efficient weight handling. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | in_features, 100 | out_features, 101 | bias=True, 102 | invert_weight_scales=False, 103 | ): 104 | super().__init__() 105 | self.in_features = in_features 106 | self.out_features = out_features 107 | # Calculate packed dimensions - the first dimension gets packed 4:1 108 | # The weights are ternary so can be represented with 2 bits, and they 109 | # are packed in uint8 tensors, hence the number of values per item is 4 110 | packed_out_features = (out_features + 3) // 4 111 | self.weight = mx.zeros((packed_out_features, in_features), dtype=mx.uint8) 112 | 113 | self.invert_weight_scales = invert_weight_scales 114 | self.weight_scale = mx.array([1.0]) 115 | 116 | if bias: 117 | self.bias = mx.zeros((out_features,)) 118 | else: 119 | self.bias = None 120 | 121 | def execute_matmul_kernel(self, x, packed_weights): 122 | original_shape = x.shape 123 | if len(original_shape) > 2: 124 | x = x.reshape(-1, original_shape[-1]) 125 | total_batch_elements, in_features = x.shape 126 | 127 | out_features = self.out_features 128 | 129 | dtype = self.weight_scale.dtype 130 | assert x.dtype == dtype, "Wrong type for input." 131 | out = _bitlinear_kernel( 132 | inputs=[ 133 | x, 134 | packed_weights, 135 | self.weight_scale, 136 | ], 137 | template=[ 138 | ("T", dtype), 139 | ("invert_weight_scales", self.invert_weight_scales), 140 | ("in_features", in_features), 141 | ("out_features", out_features), 142 | ], 143 | grid=(32, total_batch_elements * out_features // 4, 1), 144 | threadgroup=(32, 1, 1), # SIMD width is 32 threads 145 | output_shapes=[(total_batch_elements, out_features)], 146 | output_dtypes=[dtype], 147 | )[0] 148 | 149 | if len(original_shape) > 2: 150 | out = out.reshape(*original_shape[:-1], out_features) 151 | return out 152 | 153 | def __call__(self, x): 154 | y = self.execute_matmul_kernel(x, self.weight) 155 | 156 | if self.bias is not None: 157 | y = mx.add(y, self.bias) 158 | return y 159 | -------------------------------------------------------------------------------- /mlx_lm/models/ernie4_5.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | from .rope_utils import initialize_rope 11 | 12 | 13 | @dataclass 14 | class ModelArgs(BaseModelArgs): 15 | hidden_size: int 16 | intermediate_size: int 17 | model_type: str 18 | max_position_embeddings: int 19 | num_attention_heads: int 20 | num_key_value_heads: int 21 | head_dim: Optional[int] 22 | num_hidden_layers: int 23 | rms_norm_eps: float 24 | vocab_size: int 25 | rope_theta: float 26 | use_bias: bool 27 | tie_word_embeddings: bool 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | 34 | dim = args.hidden_size 35 | self.n_heads = n_heads = args.num_attention_heads 36 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 37 | 38 | self.head_dim = head_dim = args.head_dim or dim // n_heads 39 | self.scale = head_dim**-0.5 40 | 41 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.use_bias) 42 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.use_bias) 43 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.use_bias) 44 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.use_bias) 45 | 46 | self.rope = initialize_rope( 47 | head_dim, 48 | base=args.rope_theta, 49 | traditional=True, 50 | max_position_embeddings=args.max_position_embeddings, 51 | ) 52 | 53 | def __call__( 54 | self, 55 | x: mx.array, 56 | mask: Optional[mx.array] = None, 57 | cache: Optional[Any] = None, 58 | ) -> mx.array: 59 | B, L, D = x.shape 60 | 61 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 62 | 63 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 64 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 65 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 66 | 67 | if cache is not None: 68 | queries = self.rope(queries, offset=cache.offset) 69 | keys = self.rope(keys, offset=cache.offset) 70 | keys, values = cache.update_and_fetch(keys, values) 71 | else: 72 | queries = self.rope(queries) 73 | keys = self.rope(keys) 74 | 75 | output = scaled_dot_product_attention( 76 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 77 | ) 78 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 79 | return self.o_proj(output) 80 | 81 | 82 | class MLP(nn.Module): 83 | def __init__(self, dim, hidden_dim, use_bias=False): 84 | super().__init__() 85 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=use_bias) 86 | self.down_proj = nn.Linear(hidden_dim, dim, bias=use_bias) 87 | self.up_proj = nn.Linear(dim, hidden_dim, bias=use_bias) 88 | 89 | def __call__(self, x) -> mx.array: 90 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 91 | 92 | 93 | class DecoderLayer(nn.Module): 94 | def __init__(self, args: ModelArgs): 95 | super().__init__() 96 | self.self_attn = Attention(args) 97 | self.mlp = MLP(args.hidden_size, args.intermediate_size, args.use_bias) 98 | 99 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 100 | self.post_attention_layernorm = nn.RMSNorm( 101 | args.hidden_size, eps=args.rms_norm_eps 102 | ) 103 | 104 | def __call__( 105 | self, 106 | x: mx.array, 107 | mask: Optional[mx.array] = None, 108 | cache: Optional[Any] = None, 109 | ) -> mx.array: 110 | r = self.self_attn(self.input_layernorm(x), mask, cache) 111 | h = x + r 112 | r = self.mlp(self.post_attention_layernorm(h)) 113 | return h + r 114 | 115 | 116 | class Ernie45Model(nn.Module): 117 | def __init__(self, args: ModelArgs): 118 | super().__init__() 119 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 120 | self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)] 121 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 122 | 123 | def __call__( 124 | self, 125 | inputs: mx.array, 126 | mask: mx.array = None, 127 | cache=None, 128 | ): 129 | h = self.embed_tokens(inputs) 130 | 131 | if mask is None: 132 | mask = create_attention_mask(h, cache) 133 | 134 | if cache is None: 135 | cache = [None] * len(self.layers) 136 | 137 | for layer, c in zip(self.layers, cache): 138 | h = layer(h, mask, c) 139 | 140 | return self.norm(h) 141 | 142 | 143 | class Model(nn.Module): 144 | def __init__(self, args: ModelArgs): 145 | super().__init__() 146 | self.args = args 147 | self.model_type = args.model_type 148 | self.model = Ernie45Model(args) 149 | if not args.tie_word_embeddings: 150 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 151 | 152 | def __call__( 153 | self, 154 | inputs: mx.array, 155 | mask: mx.array = None, 156 | cache=None, 157 | ): 158 | out = self.model(inputs, mask, cache) 159 | if self.args.tie_word_embeddings: 160 | out = self.model.embed_tokens.as_linear(out) 161 | else: 162 | out = self.lm_head(out) 163 | return out 164 | 165 | @property 166 | def layers(self): 167 | return self.model.layers 168 | -------------------------------------------------------------------------------- /mlx_lm/models/starcoder2.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | 11 | 12 | @dataclass 13 | class ModelArgs(BaseModelArgs): 14 | model_type: str 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | num_key_value_heads: int 20 | norm_epsilon: float = 1e-5 21 | vocab_size: int = 49152 22 | rope_theta: float = 100000 23 | tie_word_embeddings: bool = True 24 | 25 | 26 | class Attention(nn.Module): 27 | def __init__(self, args: ModelArgs): 28 | super().__init__() 29 | self.args = args 30 | 31 | dim = args.hidden_size 32 | self.n_heads = n_heads = args.num_attention_heads 33 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 34 | 35 | head_dim = args.hidden_size // args.num_attention_heads 36 | self.scale = head_dim**-0.5 37 | 38 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) 39 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 40 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 41 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) 42 | self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta) 43 | 44 | def __call__( 45 | self, 46 | x: mx.array, 47 | mask: Optional[mx.array] = None, 48 | cache: Optional[Any] = None, 49 | ) -> mx.array: 50 | B, L, D = x.shape 51 | 52 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 53 | 54 | # Prepare the queries, keys and values for the attention computation 55 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 56 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 57 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 58 | 59 | if cache is not None: 60 | queries = self.rope(queries, offset=cache.offset) 61 | keys = self.rope(keys, offset=cache.offset) 62 | keys, values = cache.update_and_fetch(keys, values) 63 | else: 64 | queries = self.rope(queries) 65 | keys = self.rope(keys) 66 | 67 | output = scaled_dot_product_attention( 68 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 69 | ) 70 | 71 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 72 | return self.o_proj(output) 73 | 74 | 75 | class MLP(nn.Module): 76 | def __init__(self, dim, hidden_dim): 77 | super().__init__() 78 | self.c_fc = nn.Linear(dim, hidden_dim, bias=True) 79 | self.c_proj = nn.Linear(hidden_dim, dim, bias=True) 80 | 81 | def __call__(self, x): 82 | return self.c_proj(nn.gelu(self.c_fc(x))) 83 | 84 | 85 | class TransformerBlock(nn.Module): 86 | def __init__(self, args: ModelArgs): 87 | super().__init__() 88 | self.hidden_size = args.hidden_size 89 | self.n_heads = args.num_attention_heads 90 | 91 | self.self_attn = Attention(args) 92 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 93 | self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) 94 | self.post_attention_layernorm = nn.LayerNorm( 95 | args.hidden_size, eps=args.norm_epsilon 96 | ) 97 | self.args = args 98 | 99 | def __call__( 100 | self, 101 | x: mx.array, 102 | mask: Optional[mx.array] = None, 103 | cache: Optional[Any] = None, 104 | ) -> mx.array: 105 | r = self.self_attn(self.input_layernorm(x), mask, cache) 106 | h = x + r 107 | r = self.mlp(self.post_attention_layernorm(h)) 108 | out = h + r 109 | return out 110 | 111 | 112 | class Starcoder2Model(nn.Module): 113 | def __init__(self, args: ModelArgs): 114 | super().__init__() 115 | self.args = args 116 | self.vocab_size = args.vocab_size 117 | self.num_hidden_layers = args.num_hidden_layers 118 | assert self.vocab_size > 0 119 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 120 | self.layers = [ 121 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 122 | ] 123 | self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) 124 | 125 | def __call__( 126 | self, 127 | inputs: mx.array, 128 | mask: mx.array = None, 129 | cache=None, 130 | ): 131 | h = self.embed_tokens(inputs) 132 | 133 | if mask is None: 134 | mask = create_attention_mask(h, cache) 135 | 136 | if cache is None: 137 | cache = [None] * len(self.layers) 138 | 139 | for layer, c in zip(self.layers, cache): 140 | h = layer(h, mask, c) 141 | 142 | return self.norm(h) 143 | 144 | 145 | class Model(nn.Module): 146 | def __init__(self, args: ModelArgs): 147 | super().__init__() 148 | self.args = args 149 | self.model_type = args.model_type 150 | self.model = Starcoder2Model(args) 151 | if not args.tie_word_embeddings: 152 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 153 | 154 | def __call__( 155 | self, 156 | inputs: mx.array, 157 | mask: mx.array = None, 158 | cache=None, 159 | ): 160 | out = self.model(inputs, mask, cache) 161 | if self.args.tie_word_embeddings: 162 | out = self.model.embed_tokens.as_linear(out) 163 | else: 164 | out = self.lm_head(out) 165 | return out 166 | 167 | @property 168 | def layers(self): 169 | return self.model.layers 170 | -------------------------------------------------------------------------------- /mlx_lm/models/exaone.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | from .rope_utils import initialize_rope 11 | 12 | 13 | @dataclass 14 | class ModelArgs(BaseModelArgs): 15 | model_type: str 16 | hidden_size: int 17 | num_layers: int 18 | intermediate_size: int 19 | num_attention_heads: int 20 | vocab_size: int 21 | rope_theta: float 22 | layer_norm_epsilon: float 23 | num_key_value_heads: int 24 | head_dim: Optional[int] = None 25 | max_position_embeddings: Optional[int] = None 26 | rope_traditional: bool = False 27 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 28 | tie_word_embeddings: bool = True 29 | attention_bias: bool = False 30 | mlp_bias: bool = False 31 | 32 | 33 | class AttentionModule(nn.Module): 34 | def __init__(self, args: ModelArgs): 35 | super().__init__() 36 | dim = args.hidden_size 37 | self.n_heads = n_heads = args.num_attention_heads 38 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 39 | self.head_dim = head_dim = args.head_dim or (dim // n_heads) 40 | self.scale = head_dim**-0.5 41 | 42 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) 43 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) 44 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) 45 | self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) 46 | 47 | self.rope = initialize_rope( 48 | self.head_dim, 49 | args.rope_theta, 50 | args.rope_traditional, 51 | args.rope_scaling, 52 | args.max_position_embeddings, 53 | ) 54 | 55 | def __call__( 56 | self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None 57 | ) -> mx.array: 58 | B, L, D = x.shape 59 | q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 60 | k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 61 | v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 62 | 63 | if cache is not None: 64 | q = self.rope(q, offset=cache.offset) 65 | k = self.rope(k, offset=cache.offset) 66 | k, v = cache.update_and_fetch(k, v) 67 | else: 68 | q = self.rope(q) 69 | k = self.rope(k) 70 | 71 | out = scaled_dot_product_attention( 72 | q, k, v, cache=cache, scale=self.scale, mask=mask 73 | ) 74 | out = out.transpose(0, 2, 1, 3).reshape(B, L, D) 75 | return self.out_proj(out) 76 | 77 | 78 | class Attention(nn.Module): 79 | def __init__(self, args: ModelArgs): 80 | super().__init__() 81 | self.attention = AttentionModule(args) 82 | 83 | 84 | class MLP(nn.Module): 85 | def __init__(self, args: ModelArgs): 86 | super().__init__() 87 | dim = args.hidden_size 88 | hidden_dim = args.intermediate_size 89 | self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) 90 | self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) 91 | self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) 92 | 93 | def __call__(self, x: mx.array) -> mx.array: 94 | return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) 95 | 96 | 97 | class TransformerBlock(nn.Module): 98 | def __init__(self, args: ModelArgs): 99 | super().__init__() 100 | self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) 101 | self.attn = Attention(args) 102 | self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) 103 | self.mlp = MLP(args) 104 | 105 | def __call__( 106 | self, 107 | x: mx.array, 108 | mask: Optional[mx.array] = None, 109 | cache: Optional[Any] = None, 110 | ) -> mx.array: 111 | h = x + self.attn.attention(self.ln_1(x), mask, cache) 112 | out = h + self.mlp(self.ln_2(h)) 113 | return out 114 | 115 | 116 | class ExaoneModel(nn.Module): 117 | def __init__(self, args: ModelArgs): 118 | super().__init__() 119 | self.wte = nn.Embedding(args.vocab_size, args.hidden_size) 120 | self.h = [TransformerBlock(args) for _ in range(args.num_layers)] 121 | self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) 122 | 123 | def __call__( 124 | self, 125 | inputs: mx.array, 126 | mask: mx.array = None, 127 | cache=None, 128 | ): 129 | h = self.wte(inputs) 130 | if mask is None: 131 | mask = create_attention_mask(h, cache) 132 | 133 | if cache is None: 134 | cache = [None] * len(self.h) 135 | 136 | for layer, c in zip(self.h, cache): 137 | h = layer(h, mask, cache=c) 138 | 139 | return self.ln_f(h) 140 | 141 | 142 | class Model(nn.Module): 143 | def __init__(self, args: ModelArgs): 144 | super().__init__() 145 | self.args = args 146 | self.model_type = args.model_type 147 | self.transformer = ExaoneModel(args) 148 | if not args.tie_word_embeddings: 149 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 150 | 151 | def __call__( 152 | self, 153 | inputs: mx.array, 154 | mask: mx.array = None, 155 | cache=None, 156 | ): 157 | out = self.transformer(inputs, mask, cache) 158 | if self.args.tie_word_embeddings: 159 | out = self.transformer.wte.as_linear(out) 160 | else: 161 | out = self.lm_head(out) 162 | return out 163 | 164 | @property 165 | def layers(self): 166 | return self.transformer.h 167 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /mlx_lm/models/gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | 11 | 12 | @dataclass 13 | class ModelArgs(BaseModelArgs): 14 | model_type: str 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | head_dim: int 20 | rms_norm_eps: float 21 | vocab_size: int 22 | num_key_value_heads: int 23 | rope_theta: float = 10000 24 | rope_traditional: bool = False 25 | 26 | 27 | class RMSNorm(nn.Module): 28 | def __init__(self, dims: int, eps: float = 1e-5): 29 | super().__init__() 30 | self.weight = mx.ones((dims,)) 31 | self.eps = eps 32 | 33 | def __call__(self, x): 34 | return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, args: ModelArgs): 39 | super().__init__() 40 | 41 | dim = args.hidden_size 42 | self.n_heads = n_heads = args.num_attention_heads 43 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 44 | self.head_dim = head_dim = args.head_dim 45 | 46 | self.scale = head_dim**-0.5 47 | 48 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 49 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 50 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 51 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 52 | 53 | self.rope = nn.RoPE( 54 | head_dim, 55 | traditional=args.rope_traditional, 56 | base=args.rope_theta, 57 | ) 58 | 59 | def __call__( 60 | self, 61 | x: mx.array, 62 | mask: Optional[mx.array] = None, 63 | cache: Optional[Any] = None, 64 | ) -> mx.array: 65 | B, L, D = x.shape 66 | 67 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 68 | 69 | # Prepare the queries, keys and values for the attention computation 70 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 71 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 72 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 73 | 74 | if cache is not None: 75 | queries = self.rope(queries, offset=cache.offset) 76 | keys = self.rope(keys, offset=cache.offset) 77 | keys, values = cache.update_and_fetch(keys, values) 78 | else: 79 | queries = self.rope(queries) 80 | keys = self.rope(keys) 81 | 82 | output = scaled_dot_product_attention( 83 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 84 | ) 85 | 86 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 87 | return self.o_proj(output) 88 | 89 | 90 | class MLP(nn.Module): 91 | def __init__(self, dim, hidden_dim): 92 | super().__init__() 93 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 94 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 95 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 96 | 97 | def __call__(self, x) -> mx.array: 98 | return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) 99 | 100 | 101 | class TransformerBlock(nn.Module): 102 | def __init__(self, args: ModelArgs): 103 | super().__init__() 104 | self.num_attention_heads = args.num_attention_heads 105 | self.hidden_size = args.hidden_size 106 | self.self_attn = Attention(args) 107 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 108 | self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 109 | self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 110 | self.args = args 111 | 112 | def __call__( 113 | self, 114 | x: mx.array, 115 | mask: Optional[mx.array] = None, 116 | cache: Optional[Any] = None, 117 | ) -> mx.array: 118 | r = self.self_attn(self.input_layernorm(x), mask, cache) 119 | h = x + r 120 | r = self.mlp(self.post_attention_layernorm(h)) 121 | out = h + r 122 | return out 123 | 124 | 125 | class GemmaModel(nn.Module): 126 | def __init__(self, args: ModelArgs): 127 | super().__init__() 128 | self.args = args 129 | self.vocab_size = args.vocab_size 130 | self.num_hidden_layers = args.num_hidden_layers 131 | assert self.vocab_size > 0 132 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 133 | self.layers = [ 134 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 135 | ] 136 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 137 | 138 | def __call__( 139 | self, 140 | inputs: mx.array, 141 | mask: mx.array = None, 142 | cache=None, 143 | ): 144 | h = self.embed_tokens(inputs) 145 | h = h * (self.args.hidden_size**0.5) 146 | 147 | if mask is None: 148 | mask = create_attention_mask(h, cache) 149 | 150 | if cache is None: 151 | cache = [None] * len(self.layers) 152 | 153 | for layer, c in zip(self.layers, cache): 154 | h = layer(h, mask, c) 155 | 156 | return self.norm(h) 157 | 158 | 159 | class Model(nn.Module): 160 | def __init__(self, args: ModelArgs): 161 | super().__init__() 162 | self.model_type = args.model_type 163 | self.model = GemmaModel(args) 164 | self.args = args 165 | 166 | def __call__( 167 | self, 168 | inputs: mx.array, 169 | mask: mx.array = None, 170 | cache=None, 171 | ): 172 | out = self.model(inputs, mask, cache) 173 | out = self.model.embed_tokens.as_linear(out) 174 | return out 175 | 176 | @property 177 | def layers(self): 178 | return self.model.layers 179 | -------------------------------------------------------------------------------- /mlx_lm/models/phi.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import math 4 | from dataclasses import dataclass 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | 11 | 12 | @dataclass 13 | class ModelArgs(BaseModelArgs): 14 | model_type: str = "phi" 15 | max_position_embeddings: int = 2048 16 | vocab_size: int = 51200 17 | hidden_size: int = 2560 18 | num_attention_heads: int = 32 19 | num_hidden_layers: int = 32 20 | num_key_value_heads: int = 32 21 | partial_rotary_factor: float = 0.4 22 | intermediate_size: int = 10240 23 | layer_norm_eps: float = 1e-5 24 | rope_theta: float = 10000.0 25 | 26 | def __post_init__(self): 27 | if self.num_key_value_heads is None: 28 | self.num_key_value_heads = self.num_attention_heads 29 | 30 | 31 | class PhiAttention(nn.Module): 32 | def __init__(self, config: ModelArgs): 33 | super().__init__() 34 | 35 | self.hidden_size = config.hidden_size 36 | self.num_heads = config.num_attention_heads 37 | self.head_dim = self.hidden_size // self.num_heads 38 | self.num_key_value_heads = config.num_key_value_heads 39 | self.repeats = self.num_heads // self.num_key_value_heads 40 | self.rope_theta = config.rope_theta 41 | self.partial_rotary_factor = config.partial_rotary_factor 42 | 43 | if (self.head_dim * self.num_heads) != self.hidden_size: 44 | raise ValueError( 45 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 46 | f" and `num_heads`: {self.num_heads})." 47 | ) 48 | 49 | self.q_proj = nn.Linear( 50 | self.hidden_size, self.num_heads * self.head_dim, bias=True 51 | ) 52 | self.k_proj = nn.Linear( 53 | self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True 54 | ) 55 | self.v_proj = nn.Linear( 56 | self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True 57 | ) 58 | self.dense = nn.Linear( 59 | self.num_heads * self.head_dim, self.hidden_size, bias=True 60 | ) 61 | 62 | self.rope = nn.RoPE( 63 | int(self.partial_rotary_factor * self.head_dim), 64 | traditional=False, 65 | base=self.rope_theta, 66 | ) 67 | 68 | def __call__(self, x, mask=None, cache=None): 69 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 70 | 71 | # Extract some shapes 72 | B, L, D = queries.shape 73 | n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads 74 | 75 | # Prepare the queries, keys and values for the attention computation 76 | queries = queries.reshape( 77 | B, 78 | L, 79 | n_heads, 80 | -1, 81 | ).moveaxis(1, 2) 82 | keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2) 83 | values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2) 84 | 85 | # Add RoPE to the queries and keys and combine them with the cache 86 | if cache is not None: 87 | queries = self.rope(queries, offset=cache.offset) 88 | keys = self.rope(keys, offset=cache.offset) 89 | keys, values = cache.update_and_fetch(keys, values) 90 | else: 91 | queries = self.rope(queries) 92 | keys = self.rope(keys) 93 | 94 | scale = math.sqrt(1 / queries.shape[-1]) 95 | output = scaled_dot_product_attention( 96 | queries.astype(mx.float32), 97 | keys, 98 | values, 99 | cache=cache, 100 | scale=scale, 101 | mask=mask, 102 | ).astype(values.dtype) 103 | 104 | output = output.moveaxis(2, 1).reshape(B, L, -1) 105 | 106 | return self.dense(output) 107 | 108 | 109 | class PhiMLP(nn.Module): 110 | def __init__(self, config: ModelArgs): 111 | super().__init__() 112 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 113 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 114 | 115 | def __call__(self, x) -> mx.array: 116 | return self.fc2(nn.gelu_approx(self.fc1(x))) 117 | 118 | 119 | class PhiDecoderLayer(nn.Module): 120 | def __init__(self, config: ModelArgs): 121 | super().__init__() 122 | self.self_attn = PhiAttention(config=config) 123 | self.input_layernorm = nn.LayerNorm( 124 | config.hidden_size, eps=config.layer_norm_eps 125 | ) 126 | self.mlp = PhiMLP(config) 127 | 128 | def __call__(self, x, mask, cache): 129 | h = self.input_layernorm(x) 130 | attn_h = self.self_attn(h, mask, cache) 131 | ff_h = self.mlp(h) 132 | return attn_h + ff_h + x 133 | 134 | 135 | class PhiModel(nn.Module): 136 | def __init__(self, config: ModelArgs): 137 | super().__init__() 138 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 139 | self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)] 140 | self.final_layernorm = nn.LayerNorm( 141 | config.hidden_size, eps=config.layer_norm_eps 142 | ) 143 | 144 | def __call__(self, x, mask, cache): 145 | x = self.embed_tokens(x) 146 | 147 | if mask is None: 148 | mask = create_attention_mask(x, cache) 149 | 150 | if cache is None: 151 | cache = [None] * len(self.layers) 152 | 153 | for layer, c in zip(self.layers, cache): 154 | x = layer(x, mask, c) 155 | return self.final_layernorm(x) 156 | 157 | 158 | class Model(nn.Module): 159 | def __init__(self, config: ModelArgs): 160 | super().__init__() 161 | self.model_type = config.model_type 162 | self.model = PhiModel(config) 163 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) 164 | self.args = config 165 | 166 | def __call__( 167 | self, 168 | x: mx.array, 169 | mask: mx.array = None, 170 | cache=None, 171 | ) -> mx.array: 172 | y = self.model(x, mask, cache) 173 | return self.lm_head(y) 174 | 175 | @property 176 | def layers(self): 177 | return self.model.layers 178 | -------------------------------------------------------------------------------- /mlx_lm/models/gpt_bigcode.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import numpy as np 9 | 10 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 11 | 12 | 13 | @dataclass 14 | class ModelArgs(BaseModelArgs): 15 | model_type: str 16 | n_embd: int 17 | n_layer: int 18 | n_inner: int 19 | n_head: int 20 | n_positions: int 21 | layer_norm_epsilon: float 22 | vocab_size: int 23 | num_key_value_heads: int = None 24 | multi_query: bool = True 25 | attention_bias: bool = True 26 | mlp_bias: bool = True 27 | tie_word_embeddings: bool = True 28 | 29 | def __post_init__(self): 30 | if self.num_key_value_heads is None: 31 | self.num_key_value_heads = 1 if self.multi_query else self.n_head 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, args: ModelArgs): 36 | super().__init__() 37 | 38 | self.dim = dim = args.n_embd 39 | self.n_heads = n_heads = args.n_head 40 | self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head 41 | 42 | self.head_dim = head_dim = dim // n_heads 43 | 44 | self.kv_dim = n_kv_heads * head_dim 45 | 46 | self.scale = head_dim**-0.5 47 | 48 | if hasattr(args, "attention_bias"): 49 | attention_bias = args.attention_bias 50 | else: 51 | attention_bias = False 52 | 53 | self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias) 54 | self.c_proj = nn.Linear(dim, dim, bias=attention_bias) 55 | 56 | def __call__( 57 | self, 58 | x: mx.array, 59 | mask: Optional[mx.array] = None, 60 | cache: Optional[Any] = None, 61 | ) -> mx.array: 62 | B, L, D = x.shape 63 | 64 | qkv = self.c_attn(x) 65 | queries, keys, values = mx.split( 66 | qkv, [self.dim, self.dim + self.kv_dim], axis=-1 67 | ) 68 | 69 | # Prepare the queries, keys and values for the attention computation 70 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 71 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 72 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 73 | 74 | if cache is not None: 75 | keys, values = cache.update_and_fetch(keys, values) 76 | 77 | output = scaled_dot_product_attention( 78 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 79 | ) 80 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 81 | return self.c_proj(output) 82 | 83 | 84 | class MLP(nn.Module): 85 | def __init__(self, args: ModelArgs): 86 | super().__init__() 87 | 88 | dim = args.n_embd 89 | hidden_dim = args.n_inner 90 | if hasattr(args, "mlp_bias"): 91 | mlp_bias = args.mlp_bias 92 | else: 93 | mlp_bias = False 94 | 95 | self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias) 96 | self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) 97 | 98 | def __call__(self, x) -> mx.array: 99 | return self.c_proj(nn.gelu(self.c_fc(x))) 100 | 101 | 102 | class TransformerBlock(nn.Module): 103 | def __init__(self, args: ModelArgs): 104 | super().__init__() 105 | self.n_head = args.n_head 106 | self.n_embd = args.n_embd 107 | self.attn = Attention(args) 108 | self.mlp = MLP(args) 109 | self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) 110 | self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) 111 | self.args = args 112 | 113 | def __call__( 114 | self, 115 | x: mx.array, 116 | mask: Optional[mx.array] = None, 117 | cache: Optional[Any] = None, 118 | ) -> mx.array: 119 | r = self.attn(self.ln_1(x), mask, cache) 120 | h = x + r 121 | r = self.mlp(self.ln_2(h)) 122 | out = h + r 123 | return out 124 | 125 | 126 | class GPTBigCodeModel(nn.Module): 127 | def __init__(self, args: ModelArgs): 128 | super().__init__() 129 | self.args = args 130 | self.vocab_size = args.vocab_size 131 | assert self.vocab_size > 0 132 | self.wte = nn.Embedding(args.vocab_size, args.n_embd) 133 | self.wpe = nn.Embedding(args.n_positions, args.n_embd) 134 | self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)] 135 | self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) 136 | 137 | def __call__( 138 | self, 139 | inputs: mx.array, 140 | mask: mx.array = None, 141 | cache=None, 142 | ): 143 | B, L = inputs.shape 144 | 145 | hidden_states = self.wte(inputs) 146 | 147 | mask = None 148 | if mask is not None and hidden_states.shape[1] > 1: 149 | mask = create_attention_mask(hidden_states, cache) 150 | 151 | if cache is None: 152 | cache = [None] * len(self.h) 153 | position_ids = mx.array(np.arange(L)) 154 | else: 155 | position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) 156 | 157 | hidden_states += self.wpe(position_ids) 158 | 159 | for layer, c in zip(self.h, cache): 160 | hidden_states = layer(hidden_states, mask, cache=c) 161 | 162 | return self.ln_f(hidden_states) 163 | 164 | 165 | class Model(nn.Module): 166 | def __init__(self, args: ModelArgs): 167 | super().__init__() 168 | self.args = args 169 | self.model_type = args.model_type 170 | self.transformer = GPTBigCodeModel(args) 171 | if not args.tie_word_embeddings: 172 | self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False) 173 | 174 | def __call__( 175 | self, 176 | inputs: mx.array, 177 | mask: mx.array = None, 178 | cache=None, 179 | ): 180 | out = self.transformer(inputs, mask, cache) 181 | if self.args.tie_word_embeddings: 182 | out = self.transformer.wte.as_linear(out) 183 | else: 184 | out = self.lm_head(out) 185 | return out 186 | 187 | @property 188 | def layers(self): 189 | return self.transformer.h 190 | -------------------------------------------------------------------------------- /mlx_lm/models/glm4.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | 11 | 12 | @dataclass 13 | class ModelArgs(BaseModelArgs): 14 | model_type: str 15 | hidden_size: int 16 | num_hidden_layers: int 17 | intermediate_size: int 18 | num_attention_heads: int 19 | attention_bias: bool 20 | head_dim: int 21 | rms_norm_eps: float 22 | vocab_size: int 23 | num_key_value_heads: int 24 | partial_rotary_factor: float 25 | rope_theta: float 26 | rope_traditional: bool = True 27 | max_position_embeddings: int = 32768 28 | 29 | 30 | class Glm4MLP(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | self.gate_up_proj = nn.Linear( 34 | args.hidden_size, 2 * args.intermediate_size, bias=False 35 | ) 36 | self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) 37 | 38 | def __call__(self, x) -> mx.array: 39 | x = self.gate_up_proj(x) 40 | gate, up_states = mx.split(x, 2, axis=-1) 41 | return self.down_proj(nn.silu(gate) * up_states) 42 | 43 | 44 | class Glm4Attention(nn.Module): 45 | def __init__(self, args: ModelArgs): 46 | super().__init__() 47 | self.head_dim = getattr( 48 | args, "head_dim", args.hidden_size // args.num_attention_heads 49 | ) 50 | self.n_heads = args.num_attention_heads 51 | self.n_kv_heads = args.num_key_value_heads 52 | self.scale = self.head_dim**-0.5 53 | 54 | self.q_proj = nn.Linear( 55 | args.hidden_size, 56 | args.num_attention_heads * self.head_dim, 57 | bias=args.attention_bias, 58 | ) 59 | self.k_proj = nn.Linear( 60 | args.hidden_size, 61 | args.num_key_value_heads * self.head_dim, 62 | bias=args.attention_bias, 63 | ) 64 | self.v_proj = nn.Linear( 65 | args.hidden_size, 66 | args.num_key_value_heads * self.head_dim, 67 | bias=args.attention_bias, 68 | ) 69 | self.o_proj = nn.Linear( 70 | args.num_attention_heads * self.head_dim, args.hidden_size, bias=False 71 | ) 72 | 73 | self.rope = nn.RoPE( 74 | dims=int(self.head_dim * args.partial_rotary_factor), 75 | base=args.rope_theta, 76 | traditional=args.rope_traditional, 77 | ) 78 | 79 | def __call__( 80 | self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None 81 | ) -> mx.array: 82 | B, L, D = x.shape 83 | 84 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 85 | 86 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 87 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 88 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 89 | 90 | if cache is not None: 91 | queries = self.rope(queries, offset=cache.offset) 92 | keys = self.rope(keys, offset=cache.offset) 93 | keys, values = cache.update_and_fetch(keys, values) 94 | else: 95 | queries = self.rope(queries) 96 | keys = self.rope(keys) 97 | 98 | output = scaled_dot_product_attention( 99 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 100 | ) 101 | 102 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 103 | return self.o_proj(output) 104 | 105 | 106 | class Glm4DecoderLayer(nn.Module): 107 | def __init__(self, args: ModelArgs): 108 | super().__init__() 109 | self.self_attn = Glm4Attention(args=args) 110 | 111 | self.mlp = Glm4MLP(args) 112 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 113 | self.post_attention_layernorm = nn.RMSNorm( 114 | args.hidden_size, eps=args.rms_norm_eps 115 | ) 116 | self.post_self_attn_layernorm = nn.RMSNorm( 117 | args.hidden_size, eps=args.rms_norm_eps 118 | ) 119 | self.post_mlp_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 120 | 121 | def __call__( 122 | self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None 123 | ) -> mx.array: 124 | x = x + self.post_self_attn_layernorm( 125 | self.self_attn(self.input_layernorm(x), mask, cache) 126 | ) 127 | residual = x 128 | x = ( 129 | self.post_mlp_layernorm(self.mlp(self.post_attention_layernorm(x))) 130 | + residual 131 | ) 132 | return x 133 | 134 | 135 | class Glm4Model(nn.Module): 136 | def __init__(self, args: ModelArgs): 137 | super().__init__() 138 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 139 | self.layers = [ 140 | Glm4DecoderLayer(args=args) for _ in range(args.num_hidden_layers) 141 | ] 142 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 143 | 144 | def __call__( 145 | self, 146 | inputs: mx.array, 147 | mask: Optional[mx.array] = None, 148 | cache: Optional[Any] = None, 149 | ): 150 | h = self.embed_tokens(inputs) 151 | 152 | if mask is None: 153 | mask = create_attention_mask(h, cache) 154 | 155 | if cache is None: 156 | cache = [None] * len(self.layers) 157 | 158 | for layer, c in zip(self.layers, cache): 159 | h = layer(h, mask, cache=c) 160 | 161 | return self.norm(h) 162 | 163 | 164 | class Model(nn.Module): 165 | def __init__(self, args: ModelArgs): 166 | super().__init__() 167 | self.args = args 168 | self.model_type = args.model_type 169 | self.model = Glm4Model(args) 170 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 171 | 172 | def __call__( 173 | self, 174 | inputs: mx.array, 175 | mask: Optional[mx.array] = None, 176 | cache: Optional[Any] = None, 177 | ): 178 | out = self.model(inputs, mask, cache) 179 | return self.lm_head(out) 180 | 181 | @property 182 | def layers(self): 183 | return self.model.layers 184 | -------------------------------------------------------------------------------- /mlx_lm/models/helium.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | 11 | 12 | @dataclass 13 | class ModelArgs(BaseModelArgs): 14 | hidden_size: int 15 | num_hidden_layers: int 16 | intermediate_size: int 17 | num_attention_heads: int 18 | num_key_value_heads: int 19 | rms_norm_eps: float 20 | vocab_size: int 21 | attention_bias: bool 22 | head_dim: int 23 | max_position_embeddings: int 24 | mlp_bias: bool 25 | model_type: str 26 | rope_theta: float 27 | tie_word_embeddings: bool 28 | 29 | 30 | class HeliumAttention(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | 34 | dim = args.hidden_size 35 | self.n_heads = n_heads = args.num_attention_heads 36 | assert args.num_key_value_heads is not None 37 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 38 | 39 | head_dim = args.hidden_size // n_heads 40 | self.scale = head_dim**-0.5 41 | 42 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) 43 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) 44 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) 45 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 46 | self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) 47 | 48 | def __call__( 49 | self, 50 | x: mx.array, 51 | mask: Optional[mx.array] = None, 52 | cache: Optional[Any] = None, 53 | ) -> mx.array: 54 | B, L, D = x.shape 55 | 56 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 57 | 58 | # Prepare the queries, keys and values for the attention computation 59 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 60 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 61 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 62 | 63 | if cache is not None: 64 | queries = self.rope(queries, offset=cache.offset) 65 | keys = self.rope(keys, offset=cache.offset) 66 | keys, values = cache.update_and_fetch(keys, values) 67 | else: 68 | queries = self.rope(queries) 69 | keys = self.rope(keys) 70 | 71 | output = scaled_dot_product_attention( 72 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 73 | ) 74 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 75 | return self.o_proj(output) 76 | 77 | 78 | class HeliumMLP(nn.Module): 79 | def __init__(self, args: ModelArgs): 80 | super().__init__() 81 | self.hidden_size = args.hidden_size 82 | self.intermediate_size = args.intermediate_size 83 | 84 | self.gate_proj = nn.Linear( 85 | self.hidden_size, self.intermediate_size, bias=args.mlp_bias 86 | ) 87 | self.up_proj = nn.Linear( 88 | self.hidden_size, self.intermediate_size, bias=args.mlp_bias 89 | ) 90 | self.down_proj = nn.Linear( 91 | self.intermediate_size, self.hidden_size, bias=args.mlp_bias 92 | ) 93 | 94 | def __call__(self, x: mx.array) -> mx.array: 95 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 96 | 97 | 98 | class HeliumDecoderLayer(nn.Module): 99 | def __init__(self, args: ModelArgs): 100 | super().__init__() 101 | self.hidden_size = args.hidden_size 102 | 103 | self.self_attn = HeliumAttention(args) 104 | self.mlp = HeliumMLP(args) 105 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 106 | self.post_attention_layernorm = nn.RMSNorm( 107 | args.hidden_size, eps=args.rms_norm_eps 108 | ) 109 | 110 | def __call__( 111 | self, 112 | x: mx.array, 113 | mask: Optional[mx.array] = None, 114 | cache: Optional[Any] = None, 115 | ) -> mx.array: 116 | r = self.self_attn(self.input_layernorm(x), mask, cache) 117 | h = x + r 118 | r = self.mlp(self.post_attention_layernorm(h)) 119 | out = h + r 120 | return out 121 | 122 | 123 | class HeliumModel(nn.Module): 124 | def __init__(self, args: ModelArgs): 125 | super().__init__() 126 | self.num_hidden_layers = args.num_hidden_layers 127 | self.vocab_size = args.vocab_size 128 | 129 | assert self.vocab_size > 0 130 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 131 | 132 | self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)] 133 | 134 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 135 | 136 | def __call__( 137 | self, 138 | inputs: mx.array, 139 | mask: mx.array = None, 140 | cache=None, 141 | ) -> mx.array: 142 | h = self.embed_tokens(inputs) 143 | 144 | if mask is None: 145 | mask = create_attention_mask(h, cache) 146 | 147 | if cache is None: 148 | cache = [None] * len(self.layers) 149 | 150 | for layer, c in zip(self.layers, cache): 151 | h = layer(h, mask, c) 152 | 153 | return self.norm(h) 154 | 155 | 156 | class Model(nn.Module): 157 | def __init__(self, args: ModelArgs): 158 | super().__init__() 159 | self.args = args 160 | self.model_type = args.model_type 161 | 162 | self.model = HeliumModel(args) 163 | 164 | self.vocab_size = args.vocab_size 165 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 166 | 167 | if not args.tie_word_embeddings: 168 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 169 | 170 | def __call__( 171 | self, 172 | inputs: mx.array, 173 | mask: mx.array = None, 174 | cache=None, 175 | ) -> mx.array: 176 | out = self.model(inputs, mask, cache) 177 | if self.args.tie_word_embeddings: 178 | out = self.model.embed_tokens.as_linear(out) 179 | else: 180 | out = self.lm_head(out) 181 | return out 182 | 183 | @property 184 | def layers(self): 185 | return self.model.layers 186 | -------------------------------------------------------------------------------- /mlx_lm/models/qwen3.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | from .rope_utils import initialize_rope 11 | 12 | 13 | @dataclass 14 | class ModelArgs(BaseModelArgs): 15 | model_type: str 16 | hidden_size: int 17 | num_hidden_layers: int 18 | intermediate_size: int 19 | num_attention_heads: int 20 | rms_norm_eps: float 21 | vocab_size: int 22 | num_key_value_heads: int 23 | max_position_embeddings: int 24 | rope_theta: float 25 | head_dim: int 26 | tie_word_embeddings: bool 27 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | 34 | dim = args.hidden_size 35 | self.n_heads = n_heads = args.num_attention_heads 36 | assert args.num_key_value_heads is not None 37 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 38 | 39 | head_dim = args.head_dim 40 | self.scale = head_dim**-0.5 41 | 42 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 43 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 44 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 45 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 46 | 47 | self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) 48 | self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) 49 | self.rope = initialize_rope( 50 | head_dim, 51 | base=args.rope_theta, 52 | traditional=False, 53 | scaling_config=args.rope_scaling, 54 | max_position_embeddings=args.max_position_embeddings, 55 | ) 56 | 57 | def __call__( 58 | self, 59 | x: mx.array, 60 | mask: Optional[mx.array] = None, 61 | cache: Optional[Any] = None, 62 | ) -> mx.array: 63 | B, L, D = x.shape 64 | 65 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 66 | 67 | queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose( 68 | 0, 2, 1, 3 69 | ) 70 | keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose( 71 | 0, 2, 1, 3 72 | ) 73 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 74 | 75 | if cache is not None: 76 | queries = self.rope(queries, offset=cache.offset) 77 | keys = self.rope(keys, offset=cache.offset) 78 | keys, values = cache.update_and_fetch(keys, values) 79 | else: 80 | queries = self.rope(queries) 81 | keys = self.rope(keys) 82 | 83 | output = scaled_dot_product_attention( 84 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 85 | ) 86 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 87 | return self.o_proj(output) 88 | 89 | 90 | class MLP(nn.Module): 91 | def __init__(self, dim, hidden_dim): 92 | super().__init__() 93 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 94 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 95 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 96 | 97 | def __call__(self, x) -> mx.array: 98 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 99 | 100 | 101 | class TransformerBlock(nn.Module): 102 | def __init__(self, args: ModelArgs): 103 | super().__init__() 104 | self.num_attention_heads = args.num_attention_heads 105 | self.hidden_size = args.hidden_size 106 | self.self_attn = Attention(args) 107 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 108 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 109 | self.post_attention_layernorm = nn.RMSNorm( 110 | args.hidden_size, eps=args.rms_norm_eps 111 | ) 112 | self.args = args 113 | 114 | def __call__( 115 | self, 116 | x: mx.array, 117 | mask: Optional[mx.array] = None, 118 | cache: Optional[Any] = None, 119 | ) -> mx.array: 120 | r = self.self_attn(self.input_layernorm(x), mask, cache) 121 | h = x + r 122 | r = self.mlp(self.post_attention_layernorm(h)) 123 | out = h + r 124 | return out 125 | 126 | 127 | class Qwen3Model(nn.Module): 128 | def __init__(self, args: ModelArgs): 129 | super().__init__() 130 | self.args = args 131 | self.vocab_size = args.vocab_size 132 | self.num_hidden_layers = args.num_hidden_layers 133 | assert self.vocab_size > 0 134 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 135 | self.layers = [ 136 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 137 | ] 138 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 139 | 140 | def __call__( 141 | self, 142 | inputs: mx.array, 143 | mask: mx.array = None, 144 | cache=None, 145 | ): 146 | h = self.embed_tokens(inputs) 147 | 148 | if mask is None: 149 | mask = create_attention_mask(h, cache) 150 | 151 | if cache is None: 152 | cache = [None] * len(self.layers) 153 | 154 | for layer, c in zip(self.layers, cache): 155 | h = layer(h, mask, c) 156 | 157 | return self.norm(h) 158 | 159 | 160 | class Model(nn.Module): 161 | def __init__(self, args: ModelArgs): 162 | super().__init__() 163 | self.args = args 164 | self.model_type = args.model_type 165 | self.model = Qwen3Model(args) 166 | if not args.tie_word_embeddings: 167 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 168 | 169 | def __call__( 170 | self, 171 | inputs: mx.array, 172 | mask: mx.array = None, 173 | cache=None, 174 | ): 175 | out = self.model(inputs, mask, cache) 176 | if self.args.tie_word_embeddings: 177 | out = self.model.embed_tokens.as_linear(out) 178 | else: 179 | out = self.lm_head(out) 180 | return out 181 | 182 | def sanitize(self, weights): 183 | if self.args.tie_word_embeddings: 184 | weights.pop("lm_head.weight", None) 185 | return weights 186 | 187 | @property 188 | def layers(self): 189 | return self.model.layers 190 | -------------------------------------------------------------------------------- /mlx_lm/models/cohere.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | 11 | 12 | @dataclass 13 | class ModelArgs(BaseModelArgs): 14 | model_type: str 15 | hidden_size: int = 8192 16 | num_hidden_layers: int = 40 17 | intermediate_size: int = 22528 18 | num_attention_heads: int = 64 19 | num_key_value_heads: int = 64 20 | rope_theta: float = 8000000.0 21 | vocab_size: int = 256000 22 | layer_norm_eps: float = 1e-05 23 | logit_scale: float = 0.0625 24 | attention_bias: bool = False 25 | layer_norm_bias: bool = False 26 | use_qk_norm: bool = False 27 | 28 | 29 | class LayerNorm2D(nn.Module): 30 | 31 | def __init__(self, d1, d2, eps): 32 | super().__init__() 33 | self.weight = mx.zeros((d1, d2)) 34 | self.eps = eps 35 | 36 | def __call__(self, x): 37 | return self.weight * mx.fast.layer_norm(x, None, None, self.eps) 38 | 39 | 40 | class Attention(nn.Module): 41 | def __init__(self, args: ModelArgs): 42 | super().__init__() 43 | self.args = args 44 | 45 | dim = args.hidden_size 46 | self.n_heads = n_heads = args.num_attention_heads 47 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 48 | 49 | head_dim = args.hidden_size // args.num_attention_heads 50 | self.scale = head_dim**-0.5 51 | 52 | attetion_bias = args.attention_bias 53 | 54 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) 55 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) 56 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) 57 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) 58 | 59 | self.use_qk_norm = args.use_qk_norm 60 | if self.use_qk_norm: 61 | self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps) 62 | self.k_norm = LayerNorm2D( 63 | self.n_kv_heads, head_dim, eps=args.layer_norm_eps 64 | ) 65 | 66 | self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) 67 | 68 | def __call__( 69 | self, 70 | x: mx.array, 71 | mask: Optional[mx.array] = None, 72 | cache: Optional[Any] = None, 73 | ) -> mx.array: 74 | B, L, D = x.shape 75 | 76 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 77 | 78 | queries = queries.reshape(B, L, self.n_heads, -1) 79 | keys = keys.reshape(B, L, self.n_kv_heads, -1) 80 | if self.use_qk_norm: 81 | queries = self.q_norm(queries) 82 | keys = self.k_norm(keys) 83 | 84 | queries = queries.transpose(0, 2, 1, 3) 85 | keys = keys.transpose(0, 2, 1, 3) 86 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 87 | 88 | if cache is not None: 89 | queries = self.rope(queries, offset=cache.offset) 90 | keys = self.rope(keys, offset=cache.offset) 91 | keys, values = cache.update_and_fetch(keys, values) 92 | else: 93 | queries = self.rope(queries) 94 | keys = self.rope(keys) 95 | 96 | output = scaled_dot_product_attention( 97 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 98 | ) 99 | 100 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 101 | return self.o_proj(output) 102 | 103 | 104 | class MLP(nn.Module): 105 | def __init__(self, dim, hidden_dim): 106 | super().__init__() 107 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 108 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 109 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 110 | 111 | def __call__(self, x): 112 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 113 | 114 | 115 | class TransformerBlock(nn.Module): 116 | def __init__(self, args: ModelArgs): 117 | super().__init__() 118 | self.hidden_size = args.hidden_size 119 | self.n_heads = args.num_attention_heads 120 | 121 | self.self_attn = Attention(args) 122 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 123 | self.input_layernorm = nn.LayerNorm( 124 | args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias 125 | ) 126 | self.args = args 127 | 128 | def __call__( 129 | self, 130 | x: mx.array, 131 | mask: Optional[mx.array] = None, 132 | cache: Optional[Any] = None, 133 | ) -> mx.array: 134 | h = self.input_layernorm(x) 135 | attn_h = self.self_attn(h, mask, cache) 136 | ff_h = self.mlp(h) 137 | return attn_h + ff_h + x 138 | 139 | 140 | class CohereModel(nn.Module): 141 | def __init__(self, args: ModelArgs): 142 | super().__init__() 143 | self.args = args 144 | self.vocab_size = args.vocab_size 145 | self.num_hidden_layers = args.num_hidden_layers 146 | assert self.vocab_size > 0 147 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 148 | self.layers = [ 149 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 150 | ] 151 | self.norm = nn.LayerNorm( 152 | args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias 153 | ) 154 | 155 | def __call__( 156 | self, 157 | inputs: mx.array, 158 | mask: mx.array = None, 159 | cache=None, 160 | ): 161 | h = self.embed_tokens(inputs) 162 | 163 | if mask is None: 164 | mask = create_attention_mask(h, cache) 165 | 166 | if cache is None: 167 | cache = [None] * len(self.layers) 168 | 169 | for layer, c in zip(self.layers, cache): 170 | h = layer(h, mask, c) 171 | 172 | return self.norm(h) 173 | 174 | 175 | class Model(nn.Module): 176 | def __init__(self, args: ModelArgs): 177 | super().__init__() 178 | self.model_type = args.model_type 179 | self.model = CohereModel(args) 180 | self.args = args 181 | 182 | def __call__( 183 | self, 184 | inputs: mx.array, 185 | mask: mx.array = None, 186 | cache=None, 187 | ): 188 | out = self.model(inputs, mask, cache) 189 | out = self.model.embed_tokens.as_linear(out) 190 | out = out * self.model.args.logit_scale 191 | return out 192 | 193 | @property 194 | def layers(self): 195 | return self.model.layers 196 | -------------------------------------------------------------------------------- /mlx_lm/models/mimo.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2025 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | from .rope_utils import initialize_rope 11 | 12 | 13 | @dataclass 14 | class ModelArgs(BaseModelArgs): 15 | model_type: str 16 | hidden_size: int 17 | num_hidden_layers: int 18 | intermediate_size: int 19 | num_attention_heads: int 20 | rms_norm_eps: float 21 | vocab_size: int 22 | num_key_value_heads: int 23 | max_position_embeddings: int = 32768 24 | rope_theta: float = 10000.0 25 | rope_traditional: bool = False 26 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 27 | tie_word_embeddings: bool = False 28 | num_nextn_predict_layers: int = 2 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, args: ModelArgs): 33 | super().__init__() 34 | 35 | dim = args.hidden_size 36 | self.n_heads = n_heads = args.num_attention_heads 37 | assert args.num_key_value_heads is not None 38 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 39 | 40 | head_dim = args.hidden_size // n_heads 41 | self.scale = head_dim**-0.5 42 | 43 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) 44 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 45 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 46 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 47 | 48 | self.rope = initialize_rope( 49 | head_dim, 50 | base=args.rope_theta, 51 | traditional=args.rope_traditional, 52 | scaling_config=args.rope_scaling, 53 | max_position_embeddings=args.max_position_embeddings, 54 | ) 55 | 56 | def __call__( 57 | self, 58 | x: mx.array, 59 | mask: Optional[mx.array] = None, 60 | cache: Optional[Any] = None, 61 | ) -> mx.array: 62 | B, L, D = x.shape 63 | 64 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 65 | 66 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 67 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 68 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 69 | 70 | if cache is not None: 71 | queries = self.rope(queries, offset=cache.offset) 72 | keys = self.rope(keys, offset=cache.offset) 73 | keys, values = cache.update_and_fetch(keys, values) 74 | else: 75 | queries = self.rope(queries) 76 | keys = self.rope(keys) 77 | 78 | output = scaled_dot_product_attention( 79 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 80 | ) 81 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 82 | return self.o_proj(output) 83 | 84 | 85 | class MLP(nn.Module): 86 | def __init__(self, dim, hidden_dim): 87 | super().__init__() 88 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 89 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 90 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 91 | 92 | def __call__(self, x) -> mx.array: 93 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 94 | 95 | 96 | class TransformerBlock(nn.Module): 97 | def __init__(self, args: ModelArgs): 98 | super().__init__() 99 | self.num_attention_heads = args.num_attention_heads 100 | self.hidden_size = args.hidden_size 101 | self.self_attn = Attention(args) 102 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 103 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 104 | self.post_attention_layernorm = nn.RMSNorm( 105 | args.hidden_size, eps=args.rms_norm_eps 106 | ) 107 | self.args = args 108 | 109 | def __call__( 110 | self, 111 | x: mx.array, 112 | mask: Optional[mx.array] = None, 113 | cache: Optional[Any] = None, 114 | ) -> mx.array: 115 | r = self.self_attn(self.input_layernorm(x), mask, cache) 116 | h = x + r 117 | r = self.mlp(self.post_attention_layernorm(h)) 118 | out = h + r 119 | return out 120 | 121 | 122 | class MiMoModel(nn.Module): 123 | def __init__(self, args: ModelArgs): 124 | super().__init__() 125 | self.args = args 126 | self.vocab_size = args.vocab_size 127 | self.num_hidden_layers = args.num_hidden_layers 128 | self.num_nextn_predict_layers = args.num_nextn_predict_layers 129 | 130 | assert self.vocab_size > 0 131 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 132 | self.layers = [ 133 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 134 | ] 135 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 136 | 137 | def __call__( 138 | self, 139 | inputs: mx.array, 140 | mask: mx.array = None, 141 | cache=None, 142 | ): 143 | h = self.embed_tokens(inputs) 144 | 145 | if mask is None: 146 | mask = create_attention_mask(h, cache) 147 | 148 | if cache is None: 149 | cache = [None] * len(self.layers) 150 | 151 | for layer, c in zip(self.layers, cache): 152 | h = layer(h, mask, c) 153 | 154 | h = self.norm(h) 155 | 156 | return h 157 | 158 | 159 | class Model(nn.Module): 160 | def __init__(self, args: ModelArgs): 161 | super().__init__() 162 | self.args = args 163 | self.model_type = args.model_type 164 | self.model = MiMoModel(args) 165 | if not args.tie_word_embeddings: 166 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 167 | 168 | def __call__( 169 | self, 170 | inputs: mx.array, 171 | mask: mx.array = None, 172 | cache=None, 173 | ): 174 | out = self.model(inputs, mask, cache) 175 | 176 | if self.args.tie_word_embeddings: 177 | out = self.model.embed_tokens.as_linear(out) 178 | else: 179 | out = self.lm_head(out) 180 | 181 | return out 182 | 183 | def sanitize(self, weights): 184 | if self.args.tie_word_embeddings: 185 | weights.pop("lm_head.weight", None) 186 | 187 | return { 188 | k: v 189 | for k, v in weights.items() 190 | if "self_attn.rotary_emb.inv_freq" not in k 191 | and not k.startswith("model.mtp_layers.") 192 | } 193 | 194 | @property 195 | def layers(self): 196 | return self.model.layers 197 | -------------------------------------------------------------------------------- /mlx_lm/models/switch_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import math 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | def _gather_sort(x, indices): 10 | *_, M = indices.shape 11 | indices = indices.flatten() 12 | order = mx.argsort(indices) 13 | inv_order = mx.argsort(order) 14 | return x.flatten(0, -3)[order // M], indices[order], inv_order 15 | 16 | 17 | def _scatter_unsort(x, inv_order, shape=None): 18 | x = x[inv_order] 19 | if shape is not None: 20 | x = mx.unflatten(x, 0, shape) 21 | return x 22 | 23 | 24 | class QuantizedSwitchLinear(nn.Module): 25 | def __init__( 26 | self, 27 | input_dims: int, 28 | output_dims: int, 29 | num_experts: int, 30 | bias: bool = True, 31 | group_size: int = 64, 32 | bits: int = 4, 33 | ): 34 | super().__init__() 35 | 36 | scale = math.sqrt(1 / input_dims) 37 | self.weight, self.scales, self.biases = mx.quantize( 38 | mx.random.uniform( 39 | low=-scale, 40 | high=scale, 41 | shape=(num_experts, output_dims, input_dims), 42 | ), 43 | group_size=group_size, 44 | bits=bits, 45 | ) 46 | 47 | if bias: 48 | self.bias = mx.zeros((num_experts, output_dims)) 49 | 50 | self.group_size = group_size 51 | self.bits = bits 52 | 53 | # Freeze this model's parameters 54 | self.freeze() 55 | 56 | def unfreeze(self, *args, **kwargs): 57 | """Wrap unfreeze so that we unfreeze any layers we might contain but 58 | our parameters will remain frozen.""" 59 | super().unfreeze(*args, **kwargs) 60 | self.freeze(recurse=False) 61 | 62 | @property 63 | def input_dims(self): 64 | return self.scales.shape[2] * self.group_size 65 | 66 | @property 67 | def output_dims(self): 68 | return self.weight.shape[1] 69 | 70 | @property 71 | def num_experts(self): 72 | return self.weight.shape[0] 73 | 74 | def __call__(self, x, indices, sorted_indices=False): 75 | x = mx.gather_qmm( 76 | x, 77 | self["weight"], 78 | self["scales"], 79 | self["biases"], 80 | rhs_indices=indices, 81 | transpose=True, 82 | group_size=self.group_size, 83 | bits=self.bits, 84 | sorted_indices=sorted_indices, 85 | ) 86 | if "bias" in self: 87 | x = x + mx.expand_dims(self["bias"][indices], -2) 88 | return x 89 | 90 | 91 | class SwitchLinear(nn.Module): 92 | def __init__( 93 | self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True 94 | ): 95 | super().__init__() 96 | scale = math.sqrt(1 / input_dims) 97 | self.weight = mx.random.uniform( 98 | low=-scale, 99 | high=scale, 100 | shape=(num_experts, output_dims, input_dims), 101 | ) 102 | 103 | if bias: 104 | self.bias = mx.zeros((num_experts, output_dims)) 105 | 106 | @property 107 | def input_dims(self): 108 | return self.weight.shape[2] 109 | 110 | @property 111 | def output_dims(self): 112 | return self.weight.shape[1] 113 | 114 | @property 115 | def num_experts(self): 116 | return self.weight.shape[0] 117 | 118 | def __call__(self, x, indices, sorted_indices=False): 119 | x = mx.gather_mm( 120 | x, 121 | self["weight"].swapaxes(-1, -2), 122 | rhs_indices=indices, 123 | sorted_indices=sorted_indices, 124 | ) 125 | if "bias" in self: 126 | x = x + mx.expand_dims(self["bias"][indices], -2) 127 | return x 128 | 129 | def to_quantized(self, group_size: int = 64, bits: int = 4): 130 | num_experts, output_dims, input_dims = self.weight.shape 131 | ql = QuantizedSwitchLinear( 132 | input_dims, output_dims, num_experts, False, group_size, bits 133 | ) 134 | ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits) 135 | if "bias" in self: 136 | ql.bias = self.bias 137 | return ql 138 | 139 | 140 | class SwitchGLU(nn.Module): 141 | def __init__( 142 | self, 143 | input_dims: int, 144 | hidden_dims: int, 145 | num_experts: int, 146 | activation=nn.SiLU(), 147 | bias: bool = False, 148 | ): 149 | super().__init__() 150 | 151 | self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) 152 | self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) 153 | self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) 154 | self.activation = activation 155 | 156 | def __call__(self, x, indices) -> mx.array: 157 | x = mx.expand_dims(x, (-2, -3)) 158 | 159 | # When we have many tokens, then sort them to make sure that the access 160 | # of different experts is in order. 161 | do_sort = indices.size >= 64 162 | idx = indices 163 | inv_order = None 164 | if do_sort: 165 | x, idx, inv_order = _gather_sort(x, indices) 166 | 167 | x_up = self.up_proj(x, idx, sorted_indices=do_sort) 168 | x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) 169 | x = self.down_proj( 170 | self.activation(x_gate) * x_up, 171 | idx, 172 | sorted_indices=do_sort, 173 | ) 174 | 175 | if do_sort: 176 | x = _scatter_unsort(x, inv_order, indices.shape) 177 | 178 | return x.squeeze(-2) 179 | 180 | 181 | class SwitchMLP(nn.Module): 182 | def __init__( 183 | self, 184 | input_dims: int, 185 | hidden_dims: int, 186 | num_experts: int, 187 | activation=nn.GELU(approx="precise"), 188 | bias: bool = False, 189 | ): 190 | super().__init__() 191 | 192 | self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) 193 | self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) 194 | self.activation = activation 195 | 196 | def __call__(self, x, indices) -> mx.array: 197 | x = mx.expand_dims(x, (-2, -3)) 198 | 199 | # When we have many tokens, then sort them to make sure that the access 200 | # of different experts is in order. 201 | do_sort = indices.size >= 64 202 | idx = indices 203 | inv_order = None 204 | if do_sort: 205 | x, idx, inv_order = _gather_sort(x, indices) 206 | 207 | x = self.fc1(x, idx, sorted_indices=do_sort) 208 | x = self.activation(x) 209 | x = self.fc2(x, idx, sorted_indices=do_sort) 210 | 211 | if do_sort: 212 | x = _scatter_unsort(x, inv_order, indices.shape) 213 | 214 | return x.squeeze(-2) 215 | -------------------------------------------------------------------------------- /mlx_lm/models/phixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import inspect 4 | import math 5 | from dataclasses import dataclass 6 | from typing import Tuple 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | 11 | from .base import create_attention_mask, scaled_dot_product_attention 12 | from .switch_layers import SwitchMLP 13 | 14 | 15 | @dataclass 16 | class ModelArgs: 17 | model_type: str 18 | num_vocab: int = 51200 19 | model_dim: int = 2560 20 | num_heads: int = 32 21 | num_layers: int = 32 22 | rotary_dim: int = 32 23 | num_experts_per_tok: int = 2 24 | num_local_experts: int = 4 25 | 26 | @classmethod 27 | def from_dict(cls, params): 28 | return cls( 29 | **{ 30 | k: v 31 | for k, v in params.items() 32 | if k in inspect.signature(cls).parameters 33 | } 34 | ) 35 | 36 | 37 | class RoPEAttention(nn.Module): 38 | def __init__(self, dims: int, num_heads: int, rotary_dim: int): 39 | super().__init__() 40 | 41 | self.num_heads = num_heads 42 | 43 | self.rope = nn.RoPE(rotary_dim, traditional=False) 44 | self.Wqkv = nn.Linear(dims, 3 * dims) 45 | self.out_proj = nn.Linear(dims, dims) 46 | 47 | def __call__(self, x, mask=None, cache=None): 48 | qkv = self.Wqkv(x) 49 | queries, keys, values = mx.split(qkv, 3, axis=-1) 50 | 51 | # Extract some shapes 52 | num_heads = self.num_heads 53 | B, L, D = queries.shape 54 | 55 | # Prepare the queries, keys and values for the attention computation 56 | queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 57 | keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 58 | values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) 59 | 60 | # Add RoPE to the queries and keys and combine them with the cache 61 | if cache is not None: 62 | queries = self.rope(queries, offset=cache.offset) 63 | keys = self.rope(keys, offset=cache.offset) 64 | keys, values = cache.update_and_fetch(keys, values) 65 | else: 66 | queries = self.rope(queries) 67 | keys = self.rope(keys) 68 | 69 | queries = queries.astype(mx.float32) 70 | 71 | # Finally perform the attention computation 72 | scale = math.sqrt(1 / queries.shape[-1]) 73 | 74 | output = scaled_dot_product_attention( 75 | queries.astype(mx.float32), 76 | keys, 77 | values, 78 | cache=cache, 79 | scale=scale, 80 | mask=mask, 81 | ).astype(values.dtype) 82 | output = output.moveaxis(2, 1).reshape(B, L, -1) 83 | 84 | return self.out_proj(output) 85 | 86 | 87 | class MOE(nn.Module): 88 | def __init__(self, args: ModelArgs, dim: int, hidden_dim: int): 89 | super().__init__() 90 | self.dim = dim 91 | self.hidden_dim = hidden_dim 92 | self.num_experts = args.num_local_experts 93 | self.num_experts_per_tok = args.num_experts_per_tok 94 | self.switch_mlp = SwitchMLP( 95 | self.dim, self.hidden_dim, self.num_experts, bias=True 96 | ) 97 | self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) 98 | 99 | def __call__(self, x: mx.array) -> mx.array: 100 | gates = self.gate(x) 101 | 102 | k = self.num_experts_per_tok 103 | inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k] 104 | scores = mx.take_along_axis(gates, inds, axis=-1) 105 | scores = mx.softmax(scores, axis=-1, precise=True) 106 | 107 | y = self.switch_mlp(x, inds) 108 | y = (y * scores[..., None]).sum(axis=-2) 109 | 110 | return y 111 | 112 | 113 | class ParallelBlock(nn.Module): 114 | def __init__(self, config: ModelArgs): 115 | super().__init__() 116 | dims = config.model_dim 117 | mlp_dims = dims * 4 118 | self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) 119 | self.ln = nn.LayerNorm(dims) 120 | self.moe = MOE(config, dims, mlp_dims) 121 | 122 | def __call__(self, x, mask, cache): 123 | h = self.ln(x) 124 | attn_h = self.mixer(h, mask, cache) 125 | ff_h = self.moe(h) 126 | return attn_h + ff_h + x 127 | 128 | 129 | class TransformerDecoder(nn.Module): 130 | def __init__(self, config: ModelArgs): 131 | super().__init__() 132 | self.embd = Embd(config) 133 | self.h = [ParallelBlock(config) for i in range(config.num_layers)] 134 | 135 | def __call__(self, x, mask, cache): 136 | x = self.embd(x) 137 | if cache is None: 138 | cache = [None] * len(self.h) 139 | 140 | for layer, c in zip(self.h, cache): 141 | x = layer(x, mask, c) 142 | return x 143 | 144 | 145 | class Embd(nn.Module): 146 | def __init__(self, config: ModelArgs): 147 | super().__init__() 148 | self.wte = nn.Embedding(config.num_vocab, config.model_dim) 149 | 150 | def __call__(self, x): 151 | return self.wte(x) 152 | 153 | 154 | class OutputHead(nn.Module): 155 | def __init__(self, config: ModelArgs) -> None: 156 | super().__init__() 157 | self.ln = nn.LayerNorm(config.model_dim) 158 | self.linear = nn.Linear(config.model_dim, config.num_vocab) 159 | 160 | def __call__(self, inputs): 161 | return self.linear(self.ln(inputs)) 162 | 163 | 164 | class Model(nn.Module): 165 | def __init__(self, config: ModelArgs): 166 | super().__init__() 167 | self.model_type = config.model_type 168 | self.transformer = TransformerDecoder(config) 169 | self.lm_head = OutputHead(config) 170 | self.args = config 171 | 172 | def __call__( 173 | self, 174 | x: mx.array, 175 | mask: mx.array = None, 176 | cache=None, 177 | ) -> mx.array: 178 | 179 | if mask is None: 180 | mask = create_attention_mask(x, cache) 181 | 182 | y = self.transformer(x, mask, cache) 183 | return self.lm_head(y) 184 | 185 | def sanitize(self, weights): 186 | if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights: 187 | return weights 188 | for l in range(self.args.num_layers): 189 | prefix = f"transformer.h.{l}" 190 | for n in ["fc1", "fc2"]: 191 | for k in ["weight", "scales", "biases", "bias"]: 192 | if f"{prefix}.moe.mlp.0.{n}.{k}" in weights: 193 | to_join = [ 194 | weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") 195 | for e in range(self.args.num_local_experts) 196 | ] 197 | weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join) 198 | return weights 199 | 200 | @property 201 | def layers(self): 202 | return self.transformer.h 203 | -------------------------------------------------------------------------------- /mlx_lm/models/qwen2.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention 10 | from .rope_utils import initialize_rope 11 | 12 | 13 | @dataclass 14 | class ModelArgs(BaseModelArgs): 15 | model_type: str 16 | hidden_size: int 17 | num_hidden_layers: int 18 | intermediate_size: int 19 | num_attention_heads: int 20 | rms_norm_eps: float 21 | vocab_size: int 22 | num_key_value_heads: int 23 | max_position_embeddings: int = 32768 24 | rope_theta: float = 1000000 25 | rope_traditional: bool = False 26 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 27 | tie_word_embeddings: bool = True 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | 34 | dim = args.hidden_size 35 | self.n_heads = n_heads = args.num_attention_heads 36 | assert args.num_key_value_heads is not None 37 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 38 | 39 | head_dim = args.hidden_size // n_heads 40 | self.scale = head_dim**-0.5 41 | 42 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) 43 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 44 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) 45 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 46 | 47 | self.rope = initialize_rope( 48 | head_dim, 49 | base=args.rope_theta, 50 | traditional=args.rope_traditional, 51 | scaling_config=args.rope_scaling, 52 | max_position_embeddings=args.max_position_embeddings, 53 | ) 54 | 55 | def __call__( 56 | self, 57 | x: mx.array, 58 | mask: Optional[mx.array] = None, 59 | cache: Optional[Any] = None, 60 | ) -> mx.array: 61 | B, L, D = x.shape 62 | 63 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 64 | 65 | # Prepare the queries, keys and values for the attention computation 66 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 67 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 68 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 69 | 70 | if cache is not None: 71 | queries = self.rope(queries, offset=cache.offset) 72 | keys = self.rope(keys, offset=cache.offset) 73 | keys, values = cache.update_and_fetch(keys, values) 74 | else: 75 | queries = self.rope(queries) 76 | keys = self.rope(keys) 77 | 78 | output = scaled_dot_product_attention( 79 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 80 | ) 81 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 82 | return self.o_proj(output) 83 | 84 | 85 | class MLP(nn.Module): 86 | def __init__(self, dim, hidden_dim): 87 | super().__init__() 88 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 89 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 90 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 91 | 92 | def __call__(self, x) -> mx.array: 93 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 94 | 95 | 96 | class TransformerBlock(nn.Module): 97 | def __init__(self, args: ModelArgs): 98 | super().__init__() 99 | self.num_attention_heads = args.num_attention_heads 100 | self.hidden_size = args.hidden_size 101 | self.self_attn = Attention(args) 102 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 103 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 104 | self.post_attention_layernorm = nn.RMSNorm( 105 | args.hidden_size, eps=args.rms_norm_eps 106 | ) 107 | self.args = args 108 | 109 | def __call__( 110 | self, 111 | x: mx.array, 112 | mask: Optional[mx.array] = None, 113 | cache: Optional[Any] = None, 114 | ) -> mx.array: 115 | r = self.self_attn(self.input_layernorm(x), mask, cache) 116 | h = x + r 117 | r = self.mlp(self.post_attention_layernorm(h)) 118 | out = h + r 119 | return out 120 | 121 | 122 | class Qwen2Model(nn.Module): 123 | def __init__(self, args: ModelArgs): 124 | super().__init__() 125 | self.args = args 126 | self.vocab_size = args.vocab_size 127 | self.num_hidden_layers = args.num_hidden_layers 128 | assert self.vocab_size > 0 129 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 130 | self.layers = [ 131 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 132 | ] 133 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 134 | 135 | def __call__( 136 | self, 137 | inputs: mx.array, 138 | mask: mx.array = None, 139 | cache=None, 140 | input_embeddings: Optional[mx.array] = None, 141 | ): 142 | if input_embeddings is not None: 143 | h = input_embeddings 144 | else: 145 | h = self.embed_tokens(inputs) 146 | 147 | if mask is None: 148 | mask = create_attention_mask(h, cache) 149 | 150 | if cache is None: 151 | cache = [None] * len(self.layers) 152 | 153 | for layer, c in zip(self.layers, cache): 154 | h = layer(h, mask, c) 155 | 156 | return self.norm(h) 157 | 158 | 159 | class Model(nn.Module): 160 | def __init__(self, args: ModelArgs): 161 | super().__init__() 162 | self.args = args 163 | self.model_type = args.model_type 164 | self.model = Qwen2Model(args) 165 | if not args.tie_word_embeddings: 166 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 167 | 168 | def __call__( 169 | self, 170 | inputs: mx.array, 171 | mask: mx.array = None, 172 | cache=None, 173 | input_embeddings: Optional[mx.array] = None, 174 | ): 175 | out = self.model(inputs, mask, cache, input_embeddings) 176 | if self.args.tie_word_embeddings: 177 | out = self.model.embed_tokens.as_linear(out) 178 | else: 179 | out = self.lm_head(out) 180 | return out 181 | 182 | def sanitize(self, weights): 183 | if self.args.tie_word_embeddings: 184 | weights.pop("lm_head.weight", None) 185 | # Remove unused precomputed rotary freqs 186 | return { 187 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 188 | } 189 | 190 | @property 191 | def layers(self): 192 | return self.model.layers 193 | --------------------------------------------------------------------------------