├── .gitignore ├── mlx_lm ├── models │ ├── __init__.py │ ├── base.py │ ├── phi2.py │ └── llama.py ├── requirements.txt ├── __init__.py ├── README.md ├── UPLOAD.md ├── generate.py ├── utils.py └── convert.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | mlx_model 2 | -------------------------------------------------------------------------------- /mlx_lm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mlx_lm/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx 2 | numpy 3 | transformers 4 | protobuf 5 | -------------------------------------------------------------------------------- /mlx_lm/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import convert 2 | from .utils import generate, load 3 | -------------------------------------------------------------------------------- /mlx_lm/README.md: -------------------------------------------------------------------------------- 1 | ## Generate Text with MLX and :hugs: Hugging Face 2 | 3 | This an example of large language model text generation that can pull models from 4 | the Hugging Face Hub. 5 | 6 | For more information on this example, see the 7 | [README](../README.md) in the parent directory. 8 | -------------------------------------------------------------------------------- /mlx_lm/models/base.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class BaseModelArgs: 7 | @classmethod 8 | def from_dict(cls, params): 9 | return cls( 10 | **{ 11 | k: v 12 | for k, v in params.items() 13 | if k in inspect.signature(cls).parameters 14 | } 15 | ) 16 | -------------------------------------------------------------------------------- /mlx_lm/UPLOAD.md: -------------------------------------------------------------------------------- 1 | ### Packaging for PyPI 2 | 3 | Install `build` and `twine`: 4 | 5 | ``` 6 | pip install --user --upgrade build 7 | pip install --user --upgrade twine 8 | ``` 9 | 10 | Generate the source distribution and wheel: 11 | 12 | ``` 13 | python -m build 14 | ``` 15 | 16 | > [!warning] 17 | > Use a test server first 18 | 19 | #### Test Upload 20 | 21 | Upload to test server: 22 | 23 | ``` 24 | python -m twine upload --repository testpypi dist/* 25 | ``` 26 | 27 | Install from test server and check that it works: 28 | 29 | ``` 30 | python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm 31 | ``` 32 | 33 | #### Upload 34 | 35 | ``` 36 | python -m twine upload dist/* 37 | ``` 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # First Token Cutoff LLM sampling 2 | 3 | This code implements the tokenizer described [in this blog post](http://antirez.com/news/142) using MLX. 4 | The implementation is just a modification of the mlx example for LLMs inference. 5 | 6 | ## Usage 7 | 8 | First, install MLX. Note that this is going to work only for Apple Silicon: 9 | 10 | $ pip install mlx 11 | 12 | Than convert the Pytorch model into MLX format: 13 | 14 | $ python3 -m mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.2 15 | 16 | Finally try the inference: 17 | 18 | $ python3 -m mlx_lm.generate --model mlx_model --prompt "[INST] Who was Leonardo Da Vinci? [/INST]" --max-tokens 500 19 | 20 | The default cutoff is 0.7 (tokens up to 70% worse than the best scoring token are accepted for sampling), but you can change this with the `--sampling-cutoff` option in the command line. A cutoff of 0 will make the generation deterministic, always selecting the first token. A cutoff of 1 will consider all the possible tokens and makes no sense. More interesting values are between 0.05 and 0.99, depending on the variability you want. 21 | 22 | ## Output colorization 23 | 24 | If you add the `--colorize` option in the generate command line above, the output of the LLM will be colorized based on the probability of the best token (regardless of the token that is sampled). These are the intervals used: 25 | 26 | ``` 27 | if t0 > 0.95: 28 | color = 'white' 29 | elif t0 > 0.70: 30 | color = 'green' 31 | elif t0 > 0.30: 32 | color = 'yellow' 33 | else: 34 | color = 'red' 35 | ``` 36 | 37 | First token strength is an interesting hint on the model internal state, especially if the model is outputting dates or other factual information: it is often possible to tell, in such cases, if the model is likely hallucinating or not. 38 | 39 | ## Sampling algorithm used 40 | 41 | This is how the algorithm works: 42 | 43 | * Compute softmax() of logits. 44 | * Sort tokens by probability. 45 | * Given T0, the probability of the best token, compute the ratio of all the other tokens as: 46 | 47 | r = 1 - (T[i] / T0) 48 | 49 | * Select only tokens for which r <= co 50 | * Perform weighted random pick among the selected tokens. 51 | 52 | Note that in this way, regardless of the fact that tokens may have a smooth monotonically decreasing value, there is a hard limit to the tokens we can include in the set of possibilities. Instead with other methods that try to identify high-score clusters, this is not the case. 53 | 54 | ## Hacking with the implementation 55 | 56 | The implementation of the sampler is contained in the `sample` function of `mlx_lm/utils.py`. 57 | Modifying it is straightforward. 58 | -------------------------------------------------------------------------------- /mlx_lm/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import mlx.core as mx 5 | 6 | from .utils import generate_step, load 7 | 8 | DEFAULT_MODEL_PATH = "mlx_model" 9 | DEFAULT_PROMPT = "hello" 10 | DEFAULT_MAX_TOKENS = 100 11 | DEFAULT_CO = 0.7 12 | DEFAULT_SEED = 0 13 | 14 | def setup_arg_parser(): 15 | """Set up and return the argument parser.""" 16 | parser = argparse.ArgumentParser(description="LLM inference script") 17 | parser.add_argument( 18 | "--model", 19 | type=str, 20 | default="mlx_model", 21 | help="The path to the local model directory or Hugging Face repo.", 22 | ) 23 | parser.add_argument( 24 | "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" 25 | ) 26 | parser.add_argument( 27 | "--max-tokens", 28 | "-m", 29 | type=int, 30 | default=DEFAULT_MAX_TOKENS, 31 | help="Maximum number of tokens to generate", 32 | ) 33 | parser.add_argument( 34 | "--sampling-cutoff", type=float, default=DEFAULT_CO, help="Sampling cutoff" 35 | ) 36 | parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") 37 | parser.add_argument("--colorize", 38 | action='store_true', 39 | help="Colorize output based on T[0] probability") 40 | return parser 41 | 42 | def colorprint(color, s): 43 | color_codes = { 44 | 'black': 30, 45 | 'red': 31, 46 | 'green': 32, 47 | 'yellow': 33, 48 | 'blue': 34, 49 | 'magenta': 35, 50 | 'cyan': 36, 51 | 'white': 39, 52 | } 53 | ccode = color_codes.get(color, 30) 54 | print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) 55 | 56 | def colorprint_by_t0(t0, s): 57 | if t0 > 0.95: 58 | color = 'white' 59 | elif t0 > 0.70: 60 | color = 'green' 61 | elif t0 > 0.30: 62 | color = 'yellow' 63 | else: 64 | color = 'red' 65 | colorprint(color,s) 66 | 67 | def main(args): 68 | mx.random.seed(args.seed) 69 | model, tokenizer = load(args.model) 70 | print("=" * 10) 71 | print("Prompt:", args.prompt) 72 | prompt = tokenizer.encode(args.prompt) 73 | prompt = mx.array(prompt) 74 | tic = time.time() 75 | tokens = [] 76 | skip = 0 77 | for token, n in zip( 78 | generate_step(prompt, model, args.sampling_cutoff), range(args.max_tokens) 79 | ): 80 | t0,token = token 81 | if token == tokenizer.eos_token_id: 82 | break 83 | if n == 0: 84 | prompt_time = time.time() - tic 85 | tic = time.time() 86 | tokens.append(token.item()) 87 | s = tokenizer.decode(tokens) 88 | 89 | if args.colorize: 90 | colorprint_by_t0(t0,s[skip:]) 91 | else: 92 | print(s[skip:], end="", flush=True) 93 | 94 | skip = len(s) 95 | 96 | print(tokenizer.decode(tokens)[skip:], flush=True) 97 | gen_time = time.time() - tic 98 | print("=" * 10) 99 | if len(tokens) == 0: 100 | print("No tokens generated for this prompt") 101 | return 102 | prompt_tps = prompt.size / prompt_time 103 | gen_tps = (len(tokens) - 1) / gen_time 104 | print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") 105 | print(f"Generation: {gen_tps:.3f} tokens-per-sec") 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = setup_arg_parser() 110 | args = parser.parse_args() 111 | main(args) 112 | -------------------------------------------------------------------------------- /mlx_lm/models/phi2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Tuple 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | from .base import BaseModelArgs 9 | 10 | 11 | @dataclass 12 | class ModelArgs(BaseModelArgs): 13 | n_positions: int = 2048 14 | vocab_size: int = 51200 15 | n_embd: int = 2560 16 | n_head: int = 32 17 | n_layer: int = 32 18 | rotary_dim: int = 32 19 | 20 | 21 | class LayerNorm(nn.LayerNorm): 22 | def __call__(self, x: mx.array) -> mx.array: 23 | return super().__call__(x.astype(mx.float32)).astype(x.dtype) 24 | 25 | 26 | class RoPEAttention(nn.Module): 27 | def __init__(self, dims: int, n_head: int, rotary_dim: int): 28 | super().__init__() 29 | 30 | self.n_head = n_head 31 | 32 | self.q_proj = nn.Linear(dims, dims) 33 | self.k_proj = nn.Linear(dims, dims) 34 | self.v_proj = nn.Linear(dims, dims) 35 | self.dense = nn.Linear(dims, dims) 36 | 37 | self.rope = nn.RoPE(rotary_dim, traditional=False) 38 | 39 | def __call__(self, x, mask=None, cache=None): 40 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 41 | 42 | # Extract some shapes 43 | n_head = self.n_head 44 | B, L, D = queries.shape 45 | 46 | # Prepare the queries, keys and values for the attention computation 47 | queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) 48 | keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) 49 | values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) 50 | 51 | # Add RoPE to the queries and keys and combine them with the cache 52 | if cache is not None: 53 | key_cache, value_cache = cache 54 | queries = self.rope(queries, offset=key_cache.shape[2]) 55 | keys = self.rope(keys, offset=key_cache.shape[2]) 56 | keys = mx.concatenate([key_cache, keys], axis=2) 57 | values = mx.concatenate([value_cache, values], axis=2) 58 | else: 59 | queries = self.rope(queries) 60 | keys = self.rope(keys) 61 | 62 | queries = queries.astype(mx.float32) 63 | keys = keys.astype(mx.float32) 64 | 65 | # Finally perform the attention computation 66 | scale = math.sqrt(1 / queries.shape[-1]) 67 | scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) 68 | if mask is not None: 69 | scores = scores + mask 70 | 71 | scores = mx.softmax(scores, axis=-1).astype(values.dtype) 72 | values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 73 | 74 | return self.dense(values_hat), (keys, values) 75 | 76 | 77 | class MLP(nn.Module): 78 | def __init__(self, dim, hidden_dim): 79 | super().__init__() 80 | self.fc1 = nn.Linear(dim, hidden_dim) 81 | self.fc2 = nn.Linear(hidden_dim, dim) 82 | self.act = nn.GELU(approx="precise") 83 | 84 | def __call__(self, x) -> mx.array: 85 | return self.fc2(self.act(self.fc1(x))) 86 | 87 | 88 | class ParallelBlock(nn.Module): 89 | def __init__(self, config: ModelArgs): 90 | super().__init__() 91 | dims = config.n_embd 92 | mlp_dims = dims * 4 93 | self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) 94 | self.input_layernorm = LayerNorm(dims) 95 | self.mlp = MLP(dims, mlp_dims) 96 | 97 | def __call__(self, x, mask, cache): 98 | h = self.input_layernorm(x) 99 | attn_h, cache = self.self_attn(h, mask, cache) 100 | ff_h = self.mlp(h) 101 | return attn_h + ff_h + x, cache 102 | 103 | 104 | class Transformer(nn.Module): 105 | def __init__(self, config: ModelArgs): 106 | super().__init__() 107 | self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) 108 | self.layers = [ParallelBlock(config) for i in range(config.n_layer)] 109 | self.final_layernorm = LayerNorm(config.n_embd) 110 | 111 | def __call__(self, x, mask, cache): 112 | x = self.embed_tokens(x) 113 | if cache is None: 114 | cache = [None] * len(self.layers) 115 | 116 | for e, layer in enumerate(self.layers): 117 | x, cache[e] = layer(x, mask, cache[e]) 118 | return self.final_layernorm(x), cache 119 | 120 | 121 | class Model(nn.Module): 122 | def __init__(self, config: ModelArgs): 123 | super().__init__() 124 | self.model = Transformer(config) 125 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size) 126 | 127 | def __call__( 128 | self, 129 | x: mx.array, 130 | mask: mx.array = None, 131 | cache: mx.array = None, 132 | ) -> Tuple[mx.array, mx.array]: 133 | mask = None 134 | if x.shape[1] > 1: 135 | mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) 136 | mask = mask.astype(x.dtype) 137 | 138 | y, cache = self.model(x, mask, cache) 139 | return self.lm_head(y), cache 140 | -------------------------------------------------------------------------------- /mlx_lm/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import logging 4 | from pathlib import Path 5 | from typing import Generator, Tuple 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | from huggingface_hub import snapshot_download 10 | from transformers import AutoTokenizer, PreTrainedTokenizer 11 | import numpy as np 12 | import random, time 13 | import matplotlib.pyplot as plt 14 | 15 | # Local imports 16 | from .models import llama, phi2 17 | from .models.base import BaseModelArgs 18 | 19 | # Constants 20 | MODEL_MAPPING = { 21 | "llama": llama, 22 | "mistral": llama, # mistral is compatible with llama 23 | "phi": phi2, 24 | } 25 | 26 | 27 | def _get_classes(config: dict): 28 | """ 29 | Retrieve the model and model args classes based on the configuration. 30 | 31 | Args: 32 | config (dict): The model configuration. 33 | 34 | Returns: 35 | A tuple containing the Model class and the ModelArgs class. 36 | """ 37 | model_type = config["model_type"] 38 | if model_type not in MODEL_MAPPING: 39 | msg = f"Model type {model_type} not supported." 40 | logging.error(msg) 41 | raise ValueError(msg) 42 | 43 | arch = MODEL_MAPPING[model_type] 44 | return arch.Model, arch.ModelArgs 45 | 46 | 47 | def get_model_path(path_or_hf_repo: str) -> Path: 48 | """ 49 | Ensures the model is available locally. If the path does not exist locally, 50 | it is downloaded from the Hugging Face Hub. 51 | 52 | Args: 53 | path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. 54 | 55 | Returns: 56 | Path: The path to the model. 57 | """ 58 | model_path = Path(path_or_hf_repo) 59 | if not model_path.exists(): 60 | model_path = Path( 61 | snapshot_download( 62 | repo_id=path_or_hf_repo, 63 | allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"], 64 | ) 65 | ) 66 | return model_path 67 | 68 | 69 | def generate_step( 70 | prompt: mx.array, model: nn.Module, cutoff: float = 7.0 71 | ) -> Generator[mx.array, None, None]: 72 | """ 73 | A generator producing text based on the given prompt from the model. 74 | 75 | Args: 76 | prompt (mx.array): The input prompt. 77 | model (nn.Module): The model to use for generation. 78 | temp (float): The temperature for sampling. If temp is 0, use max sampling. 79 | 80 | Yields: 81 | Generator[mx.array]: A generator producing one token per call. 82 | """ 83 | 84 | def sample(logits: mx.array) -> mx.array: 85 | random.seed(time.monotonic_ns()) 86 | mx.random.seed(time.monotonic_ns()) 87 | 88 | logits = mx.softmax(logits) 89 | np_logits = np.array(logits) # MX -> NumPy 90 | np_logits = np_logits.flatten() 91 | sorted_indices = np.argsort(np_logits) 92 | sorted_indices = sorted_indices[::-1] 93 | 94 | j = 1 95 | t0 = np_logits[sorted_indices[0]] 96 | while 1 - (np_logits[sorted_indices[j]] / t0) < cutoff and j < len(np_logits): 97 | j += 1 98 | accepted_logits = [] 99 | for i in range(0,j): 100 | accepted_logits.append(float(np_logits[sorted_indices[j]])) 101 | accepted_logits = mx.array(accepted_logits) 102 | idx = mx.random.categorical(accepted_logits) 103 | idx = int(np.array(idx)) # We can't convert zero-dim array without passing from numpy 104 | c = sorted_indices[idx] 105 | return (t0,mx.array(np.array([c]))) 106 | 107 | y = prompt 108 | cache = None 109 | while True: 110 | logits, cache = model(y[None], cache=cache) 111 | logits = logits[:, -1, :] 112 | t0,y = sample(logits) 113 | yield (t0,y) 114 | 115 | 116 | def generate( 117 | model: nn.Module, 118 | tokenizer: PreTrainedTokenizer, 119 | prompt: str, 120 | temp: float = 0.0, 121 | max_tokens: int = 100, 122 | verbose: bool = False, 123 | ) -> str: 124 | """ 125 | Generate text from the model. 126 | 127 | Args: 128 | model (nn.Module): The language model. 129 | tokenizer (PreTrainedTokenizer): The tokenizer. 130 | prompt (str): The string prompt. 131 | temp (float): The temperature for sampling (default 0). 132 | max_tokens (int): The maximum number of tokens (default 100). 133 | """ 134 | 135 | prompt = mx.array(tokenizer.encode(prompt)) 136 | 137 | tokens = [] 138 | skip = 0 139 | for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): 140 | if token == tokenizer.eos_token_id: 141 | break 142 | 143 | tokens.append(token.item()) 144 | 145 | if verbose: 146 | s = tokenizer.decode(tokens) 147 | print(s[skip:], end="", flush=True) 148 | skip = len(s) 149 | 150 | tokens = tokenizer.decode(tokens)[skip:] 151 | if verbose: 152 | print(tokens, flush=True) 153 | return tokens 154 | 155 | 156 | def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: 157 | """ 158 | Load the model from a given path or a huggingface repository. 159 | 160 | Args: 161 | path_or_hf_repo (str): The path or the huggingface repository to load the model from. 162 | 163 | Returns: 164 | Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer. 165 | 166 | Raises: 167 | FileNotFoundError: If config file or safetensors are not found. 168 | ValueError: If model class or args class are not found. 169 | """ 170 | model_path = get_model_path(path_or_hf_repo) 171 | 172 | try: 173 | with open(model_path / "config.json", "r") as f: 174 | config = json.load(f) 175 | quantization = config.get("quantization", None) 176 | except FileNotFoundError: 177 | logging.error(f"Config file not found in {model_path}") 178 | raise 179 | weight_files = glob.glob(str(model_path / "*.safetensors")) 180 | if not weight_files: 181 | logging.error(f"No safetensors found in {model_path}") 182 | raise FileNotFoundError(f"No safetensors found in {model_path}") 183 | weights = {} 184 | for wf in weight_files: 185 | weights.update(mx.load(wf)) 186 | 187 | model_class, model_args_class = _get_classes(config=config) 188 | 189 | model_args = model_args_class.from_dict(config) 190 | model = model_class(model_args) 191 | 192 | if quantization is not None: 193 | nn.QuantizedLinear.quantize_module(model, **quantization) 194 | 195 | model.load_weights(list(weights.items())) 196 | 197 | mx.eval(model.parameters()) 198 | tokenizer = AutoTokenizer.from_pretrained(model_path) 199 | return model, tokenizer 200 | -------------------------------------------------------------------------------- /mlx_lm/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import glob 4 | import json 5 | from pathlib import Path 6 | from typing import Dict, Tuple 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import transformers 11 | from mlx.utils import tree_flatten 12 | 13 | from .utils import get_model_path, load 14 | 15 | MAX_FILE_SIZE_GB = 15 16 | 17 | 18 | def configure_parser() -> argparse.ArgumentParser: 19 | """ 20 | Configures and returns the argument parser for the script. 21 | 22 | Returns: 23 | argparse.ArgumentParser: Configured argument parser. 24 | """ 25 | parser = argparse.ArgumentParser( 26 | description="Convert Hugging Face model to MLX format" 27 | ) 28 | 29 | parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") 30 | parser.add_argument( 31 | "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." 32 | ) 33 | parser.add_argument( 34 | "-q", "--quantize", help="Generate a quantized model.", action="store_true" 35 | ) 36 | parser.add_argument( 37 | "--q-group-size", help="Group size for quantization.", type=int, default=64 38 | ) 39 | parser.add_argument( 40 | "--q-bits", help="Bits per weight for quantization.", type=int, default=4 41 | ) 42 | parser.add_argument( 43 | "--dtype", 44 | help="Type to save the parameters, ignored if -q is given.", 45 | type=str, 46 | choices=["float16", "bfloat16", "float32"], 47 | default="float16", 48 | ) 49 | parser.add_argument( 50 | "--upload-repo", 51 | help="The Hugging Face repo to upload the model to.", 52 | type=str, 53 | default=None, 54 | ) 55 | return parser 56 | 57 | 58 | def fetch_from_hub( 59 | model_path: str, 60 | ) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]: 61 | model_path = get_model_path(model_path) 62 | 63 | weight_files = glob.glob(f"{model_path}/*.safetensors") 64 | if not weight_files: 65 | raise FileNotFoundError(f"No safetensors found in {model_path}") 66 | 67 | weights = {} 68 | for wf in weight_files: 69 | weights.update(mx.load(wf).items()) 70 | 71 | config = transformers.AutoConfig.from_pretrained(model_path) 72 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) 73 | 74 | return weights, config.to_dict(), tokenizer 75 | 76 | 77 | def quantize_model( 78 | weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int 79 | ) -> tuple: 80 | """ 81 | Applies quantization to the model weights. 82 | 83 | Args: 84 | weights (dict): Model weights. 85 | config (dict): Model configuration. 86 | hf_path (str): HF model path.. 87 | q_group_size (int): Group size for quantization. 88 | q_bits (int): Bits per weight for quantization. 89 | 90 | Returns: 91 | tuple: Tuple containing quantized weights and config. 92 | """ 93 | quantized_config = copy.deepcopy(config) 94 | model, _ = load(hf_path) 95 | model.load_weights(list(weights.items())) 96 | 97 | nn.QuantizedLinear.quantize_module(model, q_group_size, q_bits) 98 | quantized_config["quantization"] = { 99 | "group_size": q_group_size, 100 | "bits": q_bits, 101 | } 102 | quantized_weights = dict(tree_flatten(model.parameters())) 103 | 104 | return quantized_weights, quantized_config 105 | 106 | 107 | def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: 108 | """ 109 | Splits the weights into smaller shards. 110 | 111 | Args: 112 | weights (dict): Model weights. 113 | max_file_size_gb (int): Maximum size of each shard in gigabytes. 114 | 115 | Returns: 116 | list: List of weight shards. 117 | """ 118 | max_file_size_bytes = max_file_size_gb << 30 119 | shards = [] 120 | shard, shard_size = {}, 0 121 | for k, v in weights.items(): 122 | estimated_size = v.size * v.dtype.size 123 | if shard_size + estimated_size > max_file_size_bytes: 124 | shards.append(shard) 125 | shard, shard_size = {}, 0 126 | shard[k] = v 127 | shard_size += estimated_size 128 | shards.append(shard) 129 | return shards 130 | 131 | 132 | def upload_to_hub(path: str, upload_repo: str, hf_path: str): 133 | """ 134 | Uploads the model to Hugging Face hub. 135 | 136 | Args: 137 | path (str): Local path to the model. 138 | upload_repo (str): Name of the HF repo to upload to. 139 | hf_path (str): Path to the original Hugging Face model. 140 | """ 141 | import os 142 | 143 | from huggingface_hub import HfApi, ModelCard, logging 144 | 145 | card = ModelCard.load(hf_path) 146 | card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] 147 | card.text = f""" 148 | # {upload_repo} 149 | This model was converted to MLX format from [`{hf_path}`](). 150 | Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. 151 | ## Use with mlx 152 | 153 | ```bash 154 | pip install mlx-lm 155 | ``` 156 | 157 | ```python 158 | from mlx_lm import load, generate 159 | 160 | model, tokenizer = load("{upload_repo}") 161 | response = generate(model, tokenizer, prompt="hello", verbose=True) 162 | ``` 163 | """ 164 | card.save(os.path.join(path, "README.md")) 165 | 166 | logging.set_verbosity_info() 167 | 168 | api = HfApi() 169 | api.create_repo(repo_id=upload_repo, exist_ok=True) 170 | api.upload_folder( 171 | folder_path=path, 172 | repo_id=upload_repo, 173 | repo_type="model", 174 | ) 175 | 176 | 177 | def convert( 178 | hf_path: str, 179 | mlx_path: str = "mlx_model", 180 | quantize: bool = False, 181 | q_group_size: int = 64, 182 | q_bits: int = 4, 183 | dtype: str = "float16", 184 | upload_repo: str = None, 185 | ): 186 | print("[INFO] Loading") 187 | weights, config, tokenizer = fetch_from_hub(hf_path) 188 | dtype = mx.float16 if quantize else getattr(mx, dtype) 189 | weights = {k: v.astype(dtype) for k, v in weights.items()} 190 | if quantize: 191 | print("[INFO] Quantizing") 192 | weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits) 193 | 194 | mlx_path = Path(mlx_path) 195 | mlx_path.mkdir(parents=True, exist_ok=True) 196 | shards = make_shards(weights) 197 | for i, shard in enumerate(shards): 198 | mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard) 199 | tokenizer.save_pretrained(mlx_path) 200 | with open(mlx_path / "config.json", "w") as fid: 201 | json.dump(config, fid, indent=4) 202 | 203 | if upload_repo is not None: 204 | upload_to_hub(mlx_path, upload_repo, hf_path) 205 | 206 | 207 | if __name__ == "__main__": 208 | parser = configure_parser() 209 | args = parser.parse_args() 210 | convert(**vars(args)) 211 | -------------------------------------------------------------------------------- /mlx_lm/models/llama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from .base import BaseModelArgs 8 | 9 | 10 | @dataclass 11 | class ModelArgs(BaseModelArgs): 12 | hidden_size: int 13 | num_hidden_layers: int 14 | intermediate_size: int 15 | num_attention_heads: int 16 | rms_norm_eps: float 17 | vocab_size: int 18 | num_key_value_heads: int = None 19 | rope_theta: float = 10000 20 | rope_traditional: bool = False 21 | model_type: str = None 22 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 23 | 24 | def __post_init__(self): 25 | if self.num_key_value_heads is None: 26 | self.num_key_value_heads = self.num_attention_heads 27 | 28 | if self.rope_scaling: 29 | required_keys = {"factor", "type"} 30 | if not all(key in self.rope_scaling for key in required_keys): 31 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 32 | 33 | if self.rope_scaling["type"] != "linear": 34 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 35 | 36 | 37 | class RMSNorm(nn.Module): 38 | def __init__(self, dims: int, eps: float = 1e-5): 39 | super().__init__() 40 | self.weight = mx.ones((dims,)) 41 | self.eps = eps 42 | 43 | def _norm(self, x): 44 | return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) 45 | 46 | def __call__(self, x): 47 | output = self._norm(x.astype(mx.float32)).astype(x.dtype) 48 | return self.weight * output 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, args: ModelArgs): 53 | super().__init__() 54 | 55 | dim = args.hidden_size 56 | self.n_heads = n_heads = args.num_attention_heads 57 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 58 | 59 | self.repeats = n_heads // n_kv_heads 60 | 61 | head_dim = args.hidden_size // n_heads 62 | self.scale = head_dim**-0.5 63 | 64 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 65 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 66 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 67 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 68 | 69 | rope_scale = ( 70 | 1 / args.rope_scaling["factor"] 71 | if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" 72 | else 1 73 | ) 74 | self.rope = nn.RoPE( 75 | head_dim, 76 | traditional=args.rope_traditional, 77 | base=args.rope_theta, 78 | scale=rope_scale, 79 | ) 80 | 81 | def __call__( 82 | self, 83 | x: mx.array, 84 | mask: Optional[mx.array] = None, 85 | cache: Optional[Tuple[mx.array, mx.array]] = None, 86 | ) -> mx.array: 87 | B, L, D = x.shape 88 | 89 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 90 | 91 | # Prepare the queries, keys and values for the attention computation 92 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 93 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 94 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 95 | 96 | def repeat(a): 97 | a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) 98 | return a.reshape([B, self.n_heads, L, -1]) 99 | 100 | if self.repeats > 1: 101 | keys, values = map(repeat, (keys, values)) 102 | 103 | if cache is not None: 104 | key_cache, value_cache = cache 105 | queries = self.rope(queries, offset=key_cache.shape[2]) 106 | keys = self.rope(keys, offset=key_cache.shape[2]) 107 | keys = mx.concatenate([key_cache, keys], axis=2) 108 | values = mx.concatenate([value_cache, values], axis=2) 109 | else: 110 | queries = self.rope(queries) 111 | keys = self.rope(keys) 112 | 113 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) 114 | if mask is not None: 115 | scores += mask 116 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 117 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 118 | return self.o_proj(output), (keys, values) 119 | 120 | 121 | class MLP(nn.Module): 122 | def __init__(self, dim, hidden_dim): 123 | super().__init__() 124 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 125 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 126 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 127 | 128 | def __call__(self, x) -> mx.array: 129 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 130 | 131 | 132 | class TransformerBlock(nn.Module): 133 | def __init__(self, args: ModelArgs): 134 | super().__init__() 135 | self.num_attention_heads = args.num_attention_heads 136 | self.hidden_size = args.hidden_size 137 | self.self_attn = Attention(args) 138 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 139 | self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 140 | self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 141 | self.args = args 142 | 143 | def __call__( 144 | self, 145 | x: mx.array, 146 | mask: Optional[mx.array] = None, 147 | cache: Optional[Tuple[mx.array, mx.array]] = None, 148 | ) -> mx.array: 149 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 150 | h = x + r 151 | r = self.mlp(self.post_attention_layernorm(h)) 152 | out = h + r 153 | return out, cache 154 | 155 | 156 | class LlamaModel(nn.Module): 157 | def __init__(self, args: ModelArgs): 158 | super().__init__() 159 | self.args = args 160 | self.vocab_size = args.vocab_size 161 | self.num_hidden_layers = args.num_hidden_layers 162 | assert self.vocab_size > 0 163 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 164 | self.layers = [ 165 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 166 | ] 167 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 168 | 169 | def __call__( 170 | self, 171 | inputs: mx.array, 172 | cache=None, 173 | ): 174 | h = self.embed_tokens(inputs) 175 | 176 | mask = None 177 | if h.shape[1] > 1: 178 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 179 | mask = mask.astype(h.dtype) 180 | 181 | if cache is None: 182 | cache = [None] * len(self.layers) 183 | 184 | for e, layer in enumerate(self.layers): 185 | h, cache[e] = layer(h, mask, cache[e]) 186 | 187 | return self.norm(h), cache 188 | 189 | 190 | class Model(nn.Module): 191 | def __init__(self, args: ModelArgs): 192 | super().__init__() 193 | self.model = LlamaModel(args) 194 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 195 | 196 | def __call__( 197 | self, 198 | inputs: mx.array, 199 | cache=None, 200 | ): 201 | out, cache = self.model(inputs, cache) 202 | return self.lm_head(out), cache 203 | --------------------------------------------------------------------------------