├── models ├── __init__.py ├── base.py ├── lora.py ├── phi2.py └── llama.py ├── .gitignore ├── requirements.txt ├── data ├── test.jsonl ├── valid.jsonl └── train.jsonl ├── README.md ├── fuse.py ├── convert.py ├── whatsapp.py ├── utils.py ├── lora.py └── models.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | chat.txt 2 | __pycache__ 3 | mlx_model 4 | *.npz 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.0.9 2 | sentencepiece 3 | torch 4 | numpy 5 | transformers -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class BaseModelArgs: 7 | @classmethod 8 | def from_dict(cls, params): 9 | return cls( 10 | **{ 11 | k: v 12 | for k, v in params.items() 13 | if k in inspect.signature(cls).parameters 14 | } 15 | ) 16 | -------------------------------------------------------------------------------- /data/test.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "Timon: Pumbaa, what do you think stars are?\nPumbaa: They're giant fireflies, Timon!"} 2 | {"text": "Woody: Buzz, we need to find Bo Peep!\nBuzz Lightyear: To infinity and beyond, Woody!"} 3 | {"text": "Alice: Cheshire Cat, which way should I go?\nCheshire Cat: That depends on where you want to end up."} 4 | {"text": "Mowgli: Baloo, what's the best part of the jungle?\nBaloo: The bare necessities, Mowgli!"} 5 | -------------------------------------------------------------------------------- /data/valid.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "Belle: Gaston, I will never marry you.\nGaston: You'll change your mind, Belle!"} 2 | {"text": "Jasmine: Aladdin, where will we go next on the carpet?\nAladdin: Anywhere you want, Jasmine!"} 3 | {"text": "Aurora: I keep dreaming of someone I've never met.\nFairy Godmother: Dreams can be real, Aurora."} 4 | {"text": "Lilo: Stitch, what's more important than 'ohana?\nStitch: Nothing, Lilo! 'Ohana means family."} 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mlx-whatsapp 2 | 3 | This is an experimental project to convert your chat backups from Whatsapp to finetune mistral using mlx. The `lora.py`, `models.py`, `models` directory and `convert.py` are from [https://github.com/ml-explore/mlx-examples](https://github.com/ml-explore/mlx-examples) 4 | 5 | ## How to backup your chats 6 | 7 | Go to whatsapp -> Settings -> Export Chat -> Select group conversation -> Without Media 8 | 9 | 10 | ## Download Mistral and convert to quantized version 11 | 12 | Install the dependencies: 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | Next, download and convert the model. The following command will download mistral from huggingface and convert it to quantized version 19 | 20 | ``` 21 | python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q 22 | ``` 23 | 24 | 25 | ## Converting the files 26 | 27 | Save your file exported from whatsapp as `chat.txt`. Then create the training files below 28 | 29 | ```bash 30 | python whatsapp.py --input_file chat.txt --output_file chat.jsonl --test_file data/test.jsonl --train_file data/train.jsonl --valid_file data/valid.jsonl 31 | ``` 32 | 33 | By default the test and validation files take 30 samples. You can adjust them. 34 | 35 | 36 | ## Training 37 | 38 | ```bash 39 | python lora.py --model mlx_model --train --iters 600 --data ./data --batch-size 2 --adapter-file whatsapp.npz 40 | ``` 41 | 42 | ## Inference 43 | 44 | ```bash 45 | python lora.py --model ./mlx_model \ 46 | --adapter-file ./whatsapp.npz \ 47 | --max-tokens 500 \ 48 | --prompt \ 49 | "Mickey Mouse: Hey Minnie, are we going to the fair? 50 | Minnie: " 51 | ``` 52 | 53 | ## Combine your adapter and model together 54 | 55 | ```bash 56 | python fuse.py --model mlx_model --adapter-file whatsapp.npz --save-path fused 57 | ``` 58 | 59 | Now the folder fused contains `safetensors` that can be used directly with transformers. 60 | 61 | ## Warning 62 | 63 | A word of caution - Dont upload your fused models to public sites such a huggingface as your model can leak personal data that you trained it on. -------------------------------------------------------------------------------- /data/train.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "Mickey Mouse: Hey Minnie, are we still on for the picnic today?\nMinnie Mouse: Yes, I'm looking forward to it, Mickey!\nDonald Duck: Can Daisy and I join you guys?"} 2 | {"text": "Goofy: Hiya, Pluto! Wanna go to the park?\nPluto: Woof woof! (Yes, let's go!)\nGoofy: Awesome, I'll bring the frisbee."} 3 | {"text": "Cinderella: Have you seen my glass slipper?\nFairy Godmother: Let me use my magic wand. Bibbidi-Bobbidi-Boo!\nCinderella: Oh, there it is! Thank you!"} 4 | {"text": "Aladdin: Ready for a magic carpet ride, Jasmine?\nJasmine: Absolutely, Aladdin. Let's explore the world!"} 5 | {"text": "Elsa: Do you want to build a snowman, Anna?\nAnna: Of course, Elsa! Let's make the biggest snowman in Arendelle."} 6 | {"text": "Buzz Lightyear: Woody, are you ready for a new adventure?\nWoody: Always ready, Buzz! To infinity and beyond!"} 7 | {"text": "Ariel: I wonder what this thingamabob is for?\nSebastian: Ariel, you really should stay away from human stuff!"} 8 | {"text": "Simba: Nala, let's go to the Elephant Graveyard.\nNala: That sounds dangerous, Simba. Are you sure?"} 9 | {"text": "Belle: I've never seen so many books!\nBeast: You like it? It's yours!"} 10 | {"text": "Mulan: Mushu, do you think I can pass as a soldier?\nMushu: Of course, Mulan! You're as tough as the best of them."} 11 | {"text": "Pocahontas: The wind is speaking to us, Meeko.\nMeeko: (Chitters excitedly)"} 12 | {"text": "Tiana: I'm dreaming of opening my own restaurant someday.\nLouis: And I'll play the best jazz there!"} 13 | {"text": "Rapunzel: I just can't believe I'm finally seeing the floating lights!\nFlynn Rider: It's all for you, Rapunzel."} 14 | {"text": "Merida: I want to change my fate!\nQueen Elinor: Merida, we must be careful with such wishes."} 15 | {"text": "Hercules: Meg, have you ever heard of the Hydra?\nMegara: Oh, Hercules, isn't that a dangerous monster?"} 16 | {"text": "Snow White: Oh, what a cute little cottage!\nGrumpy: Well, it's not much, but it's home."} 17 | {"text": "Peter Pan: Let's fly to Neverland!\nWendy: Oh, Peter, that sounds like an adventure!"} 18 | {"text": "Alice: This place is curiouser and curiouser.\nThe Mad Hatter: Welcome to Wonderland, Alice!"} 19 | {"text": "Moana: I'm destined to cross the sea.\nMaui: You? Cross the sea? Ha! It's not that easy."} 20 | {"text": "Aurora: I've been dreaming of a true love's kiss.\nPrince Phillip: Maybe dreams do come true."} 21 | -------------------------------------------------------------------------------- /fuse.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | import mlx.core as mx 7 | import utils 8 | from mlx.utils import tree_flatten, tree_unflatten 9 | from models.lora import LoRALinear 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") 13 | parser.add_argument( 14 | "--model", 15 | default="mlx_model", 16 | help="The path to the local model directory or Hugging Face repo.", 17 | ) 18 | parser.add_argument( 19 | "--save-path", 20 | default="lora_fused_model", 21 | help="The path to save the fused model.", 22 | ) 23 | parser.add_argument( 24 | "--adapter-file", 25 | type=str, 26 | default="adapters.npz", 27 | help="Path to the trained adapter weights (npz or safetensors).", 28 | ) 29 | parser.add_argument( 30 | "--hf-path", 31 | help=( 32 | "Path to the original Hugging Face model. This is " 33 | "required for upload if --model is a local directory." 34 | ), 35 | type=str, 36 | default=None, 37 | ) 38 | parser.add_argument( 39 | "--upload-name", 40 | help="The name of model to upload to Hugging Face MLX Community", 41 | type=str, 42 | default=None, 43 | ) 44 | 45 | print("Loading pretrained model") 46 | args = parser.parse_args() 47 | 48 | model, tokenizer, config = utils.load(args.model) 49 | 50 | # Load adapters and get number of LoRA layers 51 | adapters = list(mx.load(args.adapter_file).items()) 52 | lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) 53 | 54 | # Freeze all layers other than LORA linears 55 | model.freeze() 56 | for l in model.model.layers[-lora_layers:]: 57 | l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) 58 | l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) 59 | 60 | model.update(tree_unflatten(adapters)) 61 | fused_linears = [ 62 | (n, m.to_linear()) 63 | for n, m in model.named_modules() 64 | if isinstance(m, LoRALinear) 65 | ] 66 | 67 | model.update_modules(tree_unflatten(fused_linears)) 68 | weights = dict(tree_flatten(model.parameters())) 69 | utils.save_model(args.save_path, weights, tokenizer, config) 70 | 71 | if args.upload_name is not None: 72 | hf_path = args.hf_path 73 | if not Path(args.model).exists(): 74 | # If the model path doesn't exist, assume it's an HF repo 75 | hf_path = args.model 76 | elif hf_path is None: 77 | raise ValueError( 78 | "Must provide original Hugging Face repo to upload local model." 79 | ) 80 | utils.upload_to_hub(args.save_path, args.upload_name, hf_path) 81 | -------------------------------------------------------------------------------- /models/lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | 6 | 7 | class LoRALinear(nn.Module): 8 | @staticmethod 9 | def from_linear(linear: nn.Linear, rank: int = 8): 10 | # TODO remove when input_dims and output_dims are attributes 11 | # on linear and quantized linear 12 | output_dims, input_dims = linear.weight.shape 13 | if isinstance(linear, nn.QuantizedLinear): 14 | input_dims *= 32 // linear.bits 15 | lora_lin = LoRALinear(input_dims, output_dims, rank) 16 | lora_lin.linear = linear 17 | return lora_lin 18 | 19 | def to_linear(self): 20 | linear = self.linear 21 | bias = "bias" in linear 22 | weight = linear.weight 23 | is_quantized = isinstance(linear, nn.QuantizedLinear) 24 | 25 | # Use the same type as the linear weight if not quantized 26 | dtype = weight.dtype 27 | 28 | if is_quantized: 29 | dtype = mx.float16 30 | weight = mx.dequantize( 31 | weight, 32 | linear.scales, 33 | linear.biases, 34 | linear.group_size, 35 | linear.bits, 36 | ) 37 | output_dims, input_dims = weight.shape 38 | fused_linear = nn.Linear(input_dims, output_dims, bias=bias) 39 | 40 | lora_b = (self.scale * self.lora_b.T).astype(dtype) 41 | lora_a = self.lora_a.T.astype(dtype) 42 | fused_linear.weight = weight + lora_b @ lora_a 43 | if bias: 44 | fused_linear.bias = linear.bias 45 | 46 | if is_quantized: 47 | fused_linear = nn.QuantizedLinear.from_linear( 48 | fused_linear, 49 | linear.group_size, 50 | linear.bits, 51 | ) 52 | 53 | return fused_linear 54 | 55 | def __init__( 56 | self, 57 | input_dims: int, 58 | output_dims: int, 59 | lora_rank: int = 8, 60 | bias: bool = False, 61 | scale: float = 20.0, 62 | ): 63 | super().__init__() 64 | 65 | # Regular linear layer weights 66 | self.linear = nn.Linear(input_dims, output_dims, bias=bias) 67 | 68 | # Scale for low-rank update 69 | self.scale = scale 70 | 71 | # Low rank lora weights 72 | scale = 1 / math.sqrt(input_dims) 73 | self.lora_a = mx.random.uniform( 74 | low=-scale, 75 | high=scale, 76 | shape=(input_dims, lora_rank), 77 | ) 78 | self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) 79 | 80 | def __call__(self, x): 81 | dtype = self.linear.weight.dtype 82 | if isinstance(self.linear, nn.QuantizedLinear): 83 | dtype = self.linear.scales.dtype 84 | y = self.linear(x.astype(dtype)) 85 | z = (x @ self.lora_a) @ self.lora_b 86 | return y + self.scale * z 87 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import copy 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import utils 9 | from mlx.utils import tree_flatten 10 | 11 | 12 | def quantize(weights, config, args): 13 | quantized_config = copy.deepcopy(config) 14 | 15 | # Get model classes 16 | model_class, model_args_class = utils._get_classes(config=config) 17 | 18 | # Load the model: 19 | model = model_class(model_args_class.from_dict(config)) 20 | model.load_weights(list(weights.items())) 21 | 22 | # Quantize the model: 23 | nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) 24 | 25 | # Update the config: 26 | quantized_config["quantization"] = { 27 | "group_size": args.q_group_size, 28 | "bits": args.q_bits, 29 | } 30 | quantized_weights = dict(tree_flatten(model.parameters())) 31 | 32 | return quantized_weights, quantized_config 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser( 37 | description="Convert Hugging Face model to MLX format" 38 | ) 39 | parser.add_argument( 40 | "--hf-path", 41 | type=str, 42 | help="Path to the Hugging Face model.", 43 | ) 44 | parser.add_argument( 45 | "--mlx-path", 46 | type=str, 47 | default="mlx_model", 48 | help="Path to save the MLX model.", 49 | ) 50 | parser.add_argument( 51 | "-q", 52 | "--quantize", 53 | help="Generate a quantized model.", 54 | action="store_true", 55 | ) 56 | parser.add_argument( 57 | "--q-group-size", 58 | help="Group size for quantization.", 59 | type=int, 60 | default=64, 61 | ) 62 | parser.add_argument( 63 | "--q-bits", 64 | help="Bits per weight for quantization.", 65 | type=int, 66 | default=4, 67 | ) 68 | parser.add_argument( 69 | "--dtype", 70 | help="Type to save the parameters, ignored if -q is given.", 71 | type=str, 72 | choices=["float16", "bfloat16", "float32"], 73 | default="float16", 74 | ) 75 | parser.add_argument( 76 | "--upload-name", 77 | help="The name of model to upload to Hugging Face MLX Community", 78 | type=str, 79 | default=None, 80 | ) 81 | 82 | args = parser.parse_args() 83 | 84 | print("[INFO] Loading") 85 | weights, config, tokenizer = utils.fetch_from_hub(args.hf_path) 86 | 87 | dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) 88 | weights = {k: v.astype(dtype) for k, v in weights.items()} 89 | if args.quantize: 90 | print("[INFO] Quantizing") 91 | weights, config = quantize(weights, config, args) 92 | 93 | utils.save_model(args.mlx_path, weights, tokenizer, config) 94 | if args.upload_name is not None: 95 | utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path) 96 | -------------------------------------------------------------------------------- /whatsapp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | 5 | def clean_and_format(input_file, output_file, max_split_len,): 6 | # Function to check if the line is the start of a new conversation 7 | def is_start_of_conversation(line): 8 | return bool(re.match(r'\[\d{1,2}/\d{1,2}/\d{2}, \d{1,2}:\d{2}:\d{2}\s[APM]{2}\]', line)) 9 | 10 | # Function to remove the timestamp and replace newlines with \n 11 | def remove_timestamp_and_convert_newlines(line): 12 | if ']' in line: 13 | line = line.split(']', 1)[1] 14 | return line 15 | 16 | # Read the input file 17 | with open(input_file, 'r', encoding='utf-8') as file: 18 | lines = file.readlines() 19 | 20 | # Process lines 21 | chunks = [] 22 | current_chunk = '' 23 | for line in lines: 24 | line = remove_timestamp_and_convert_newlines(line) 25 | if is_start_of_conversation(line) and current_chunk: 26 | chunks.append(current_chunk.strip()) 27 | current_chunk = line 28 | elif len(current_chunk) + len(line) < max_split_len: 29 | current_chunk += line + ' ' 30 | else: 31 | chunks.append(current_chunk.strip()) 32 | current_chunk = line 33 | 34 | if current_chunk: 35 | chunks.append(current_chunk.strip()) 36 | 37 | # Write to a JSONL file 38 | with open(output_file, 'w', encoding='utf-8') as out_file: 39 | for chunk in chunks: 40 | json_record = json.dumps({"text": chunk}) 41 | out_file.write(json_record + '\n') 42 | 43 | def split_jsonl(input_file, test_file, validate_file, train_file, test_size=30, validate_size=30): 44 | with open(input_file, 'r') as infile: 45 | lines = infile.readlines() 46 | 47 | # Ensure there are enough lines to split as requested 48 | if len(lines) < test_size + validate_size: 49 | raise ValueError("Not enough data to split as requested.") 50 | 51 | # Split data 52 | test_data = lines[:test_size] 53 | validate_data = lines[test_size:test_size + validate_size] 54 | train_data = lines[test_size + validate_size:] 55 | 56 | # Write to files 57 | for data, file in zip([test_data, validate_data, train_data], [test_file, validate_file, train_file]): 58 | with open(file, 'w') as outfile: 59 | for line in data: 60 | outfile.write(line) 61 | 62 | # Usage 63 | 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser(description="Clean and format chat data into JSONL format.") 68 | parser.add_argument("--input_file", required=True, help="The input text file containing chat data.") 69 | parser.add_argument("--output_file", default="output_file.jsonl", help="The output JSONL file where the cleaned data will be stored. Default: output_file.jsonl") 70 | parser.add_argument("--max_split_len", type=int, default=2000, help="Depending on the size of your export, adjust this number to merge multiple messages required for training data") 71 | parser.add_argument("--test_file", default="test.jsonl", help="The output JSONL file for test data. Default: test.jsonl") 72 | parser.add_argument("--valid_file", default="valid.jsonl", help="The output JSONL file for validation data. Default: valid.jsonl") 73 | parser.add_argument("--train_file", default="train.jsonl", help="The output JSONL file for training data. Default: train.jsonl") 74 | 75 | args = parser.parse_args() 76 | 77 | clean_and_format(args.input_file, args.output_file, args.max_split_len) 78 | split_jsonl(args.output_file, args.test_file, args.valid_file, args.train_file) -------------------------------------------------------------------------------- /models/phi2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from .base import BaseModelArgs 8 | 9 | 10 | @dataclass 11 | class ModelArgs(BaseModelArgs): 12 | n_positions: int = 2048 13 | vocab_size: int = 51200 14 | n_embd: int = 2560 15 | n_head: int = 32 16 | n_layer: int = 32 17 | rotary_dim: int = 32 18 | 19 | 20 | class LayerNorm(nn.LayerNorm): 21 | def __call__(self, x: mx.array) -> mx.array: 22 | return super().__call__(x.astype(mx.float32)).astype(x.dtype) 23 | 24 | 25 | class RoPEAttention(nn.Module): 26 | def __init__(self, dims: int, n_head: int, rotary_dim: int): 27 | super().__init__() 28 | 29 | self.n_head = n_head 30 | 31 | self.q_proj = nn.Linear(dims, dims) 32 | self.k_proj = nn.Linear(dims, dims) 33 | self.v_proj = nn.Linear(dims, dims) 34 | self.dense = nn.Linear(dims, dims) 35 | 36 | self.rope = nn.RoPE(rotary_dim, traditional=False) 37 | 38 | def __call__(self, x, mask=None, cache=None): 39 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 40 | 41 | # Extract some shapes 42 | n_head = self.n_head 43 | B, L, D = queries.shape 44 | 45 | # Prepare the queries, keys and values for the attention computation 46 | queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) 47 | keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) 48 | values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) 49 | 50 | # Add RoPE to the queries and keys and combine them with the cache 51 | if cache is not None: 52 | key_cache, value_cache = cache 53 | queries = self.rope(queries, offset=key_cache.shape[2]) 54 | keys = self.rope(keys, offset=key_cache.shape[2]) 55 | keys = mx.concatenate([key_cache, keys], axis=2) 56 | values = mx.concatenate([value_cache, values], axis=2) 57 | else: 58 | queries = self.rope(queries) 59 | keys = self.rope(keys) 60 | 61 | queries = queries.astype(mx.float32) 62 | keys = keys.astype(mx.float32) 63 | 64 | # Finally perform the attention computation 65 | scale = math.sqrt(1 / queries.shape[-1]) 66 | scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) 67 | if mask is not None: 68 | scores = scores + mask 69 | 70 | scores = mx.softmax(scores, axis=-1).astype(values.dtype) 71 | values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 72 | 73 | return self.dense(values_hat), (keys, values) 74 | 75 | 76 | class MLP(nn.Module): 77 | def __init__(self, dim, hidden_dim): 78 | super().__init__() 79 | self.fc1 = nn.Linear(dim, hidden_dim) 80 | self.fc2 = nn.Linear(hidden_dim, dim) 81 | self.act = nn.GELU(approx="precise") 82 | 83 | def __call__(self, x) -> mx.array: 84 | return self.fc2(self.act(self.fc1(x))) 85 | 86 | 87 | class ParallelBlock(nn.Module): 88 | def __init__(self, config: ModelArgs): 89 | super().__init__() 90 | dims = config.n_embd 91 | mlp_dims = dims * 4 92 | self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) 93 | self.input_layernorm = LayerNorm(dims) 94 | self.mlp = MLP(dims, mlp_dims) 95 | 96 | def __call__(self, x, mask, cache): 97 | h = self.input_layernorm(x) 98 | attn_h, cache = self.self_attn(h, mask, cache) 99 | ff_h = self.mlp(h) 100 | return attn_h + ff_h + x, cache 101 | 102 | 103 | class Transformer(nn.Module): 104 | def __init__(self, config: ModelArgs): 105 | super().__init__() 106 | self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) 107 | self.layers = [ParallelBlock(config) for i in range(config.n_layer)] 108 | self.final_layernorm = LayerNorm(config.n_embd) 109 | 110 | def __call__(self, x, mask, cache): 111 | x = self.embed_tokens(x) 112 | if cache is None: 113 | cache = [None] * len(self.layers) 114 | 115 | for e, layer in enumerate(self.layers): 116 | x, cache[e] = layer(x, mask, cache[e]) 117 | return self.final_layernorm(x), cache 118 | 119 | 120 | class Model(nn.Module): 121 | def __init__(self, config: ModelArgs): 122 | super().__init__() 123 | self.model = Transformer(config) 124 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size) 125 | 126 | def __call__( 127 | self, 128 | x: mx.array, 129 | mask: mx.array = None, 130 | cache: mx.array = None, 131 | ) -> tuple[mx.array, mx.array]: 132 | mask = None 133 | if x.shape[1] > 1: 134 | mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) 135 | mask = mask.astype(x.dtype) 136 | 137 | y, cache = self.model(x, mask, cache) 138 | return self.lm_head(y), cache 139 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import glob 4 | import json 5 | import logging 6 | from pathlib import Path 7 | from typing import Generator 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import models.llama as llama 12 | import models.phi2 as phi2 13 | import transformers 14 | from huggingface_hub import snapshot_download 15 | 16 | # Constants 17 | MODEL_MAPPING = { 18 | "llama": llama, 19 | "mistral": llama, # mistral is compatible with llama 20 | "phi": phi2, 21 | } 22 | 23 | 24 | def _get_classes(config: dict): 25 | """ 26 | Retrieve the model and model args classes based on the configuration. 27 | 28 | Args: 29 | config (dict): The model configuration. 30 | 31 | Returns: 32 | A tuple containing the Model class and the ModelArgs class. 33 | """ 34 | model_type = config["model_type"] 35 | if model_type not in MODEL_MAPPING: 36 | msg = f"Model type {model_type} not supported." 37 | logging.error(msg) 38 | raise ValueError(msg) 39 | 40 | arch = MODEL_MAPPING[model_type] 41 | return arch.Model, arch.ModelArgs 42 | 43 | 44 | def fetch_from_hub(hf_path: str): 45 | model_path = snapshot_download( 46 | repo_id=hf_path, 47 | allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], 48 | ) 49 | weight_files = glob.glob(f"{model_path}/*.safetensors") 50 | if len(weight_files) == 0: 51 | raise FileNotFoundError("No safetensors found in {}".format(model_path)) 52 | 53 | weights = {} 54 | for wf in weight_files: 55 | weights.update(mx.load(wf).items()) 56 | 57 | config = transformers.AutoConfig.from_pretrained(hf_path) 58 | tokenizer = transformers.AutoTokenizer.from_pretrained( 59 | hf_path, 60 | ) 61 | return weights, config.to_dict(), tokenizer 62 | 63 | 64 | def upload_to_hub(path: str, name: str, hf_path: str): 65 | import os 66 | 67 | from huggingface_hub import HfApi, ModelCard, logging 68 | 69 | repo_id = f"mlx-community/{name}" 70 | 71 | card = ModelCard.load(hf_path) 72 | card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] 73 | card.text = f""" 74 | # {name} 75 | This model was converted to MLX format from [`{hf_path}`](). 76 | Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. 77 | ## Use with mlx 78 | ```bash 79 | pip install mlx 80 | git clone https://github.com/ml-explore/mlx-examples.git 81 | cd mlx-examples/llms/hf_llm 82 | python generate.py --model {repo_id} --prompt "My name is" 83 | ``` 84 | """ 85 | card.save(os.path.join(path, "README.md")) 86 | 87 | logging.set_verbosity_info() 88 | 89 | api = HfApi() 90 | api.create_repo(repo_id=repo_id, exist_ok=True) 91 | api.upload_folder( 92 | folder_path=path, 93 | repo_id=repo_id, 94 | repo_type="model", 95 | ) 96 | 97 | 98 | def make_shards(weights: dict, max_file_size_gibibyte: int = 15): 99 | max_file_size_bytes = max_file_size_gibibyte << 30 100 | shards = [] 101 | shard, shard_size = {}, 0 102 | for k, v in weights.items(): 103 | estimated_size = v.size * v.dtype.size 104 | if shard_size + estimated_size > max_file_size_bytes: 105 | shards.append(shard) 106 | shard, shard_size = {}, 0 107 | shard[k] = v 108 | shard_size += estimated_size 109 | shards.append(shard) 110 | return shards 111 | 112 | 113 | def save_model(save_dir: str, weights, tokenizer, config): 114 | save_dir = Path(save_dir) 115 | save_dir.mkdir(parents=True, exist_ok=True) 116 | shards = make_shards(weights) 117 | for i, shard in enumerate(shards): 118 | # TODO use HF file name scheme for simplicity 119 | mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard) 120 | tokenizer.save_pretrained(save_dir) 121 | with open(save_dir / "config.json", "w") as fid: 122 | json.dump(config, fid, indent=4) 123 | 124 | 125 | def load(path_or_hf_repo: str): 126 | # If the path exists, it will try to load model form it 127 | # otherwise download and cache from the hf_repo and cache 128 | model_path = Path(path_or_hf_repo) 129 | if not model_path.exists(): 130 | model_path = Path( 131 | snapshot_download( 132 | repo_id=path_or_hf_repo, 133 | allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], 134 | ) 135 | ) 136 | 137 | with open(model_path / "config.json", "r") as f: 138 | config = json.loads(f.read()) 139 | quantization = config.get("quantization", None) 140 | 141 | weight_files = glob.glob(str(model_path / "*.safetensors")) 142 | if len(weight_files) == 0: 143 | raise FileNotFoundError("No safetensors found in {}".format(model_path)) 144 | 145 | weights = {} 146 | for wf in weight_files: 147 | weights.update(mx.load(wf).items()) 148 | 149 | model_class, model_args_class = _get_classes(config=config) 150 | model_args = model_args_class.from_dict(config) 151 | model = model_class(model_args) 152 | if quantization is not None: 153 | nn.QuantizedLinear.quantize_module(model, **quantization) 154 | 155 | model.load_weights(list(weights.items())) 156 | 157 | mx.eval(model.parameters()) 158 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) 159 | return model, tokenizer, config 160 | 161 | 162 | def generate( 163 | prompt: mx.array, model: nn.Module, temp: float = 0.0 164 | ) -> Generator[mx.array, None, None]: 165 | """ 166 | Generate text based on the given prompt and model. 167 | 168 | Args: 169 | prompt (mx.array): The input prompt. 170 | model (nn.Module): The model to use for generation. 171 | temp (float): The temperature for sampling. If temp is 0, use max sampling. 172 | 173 | Yields: 174 | mx.array: The generated text. 175 | """ 176 | 177 | def sample(logits: mx.array) -> mx.array: 178 | return ( 179 | mx.argmax(logits, axis=-1) 180 | if temp == 0 181 | else mx.random.categorical(logits * (1 / temp)) 182 | ) 183 | 184 | y = prompt 185 | cache = None 186 | while True: 187 | logits, cache = model(y[None], cache=cache) 188 | logits = logits[:, -1, :] 189 | y = sample(logits) 190 | yield y 191 | -------------------------------------------------------------------------------- /models/llama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from .base import BaseModelArgs 8 | 9 | 10 | @dataclass 11 | class ModelArgs(BaseModelArgs): 12 | hidden_size: int 13 | num_hidden_layers: int 14 | intermediate_size: int 15 | num_attention_heads: int 16 | rms_norm_eps: float 17 | vocab_size: int 18 | num_key_value_heads: int = None 19 | rope_theta: float = 10000 20 | rope_traditional: bool = False 21 | model_type: str = None 22 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 23 | 24 | def __post_init__(self): 25 | if self.num_key_value_heads is None: 26 | self.num_key_value_heads = self.num_attention_heads 27 | 28 | if self.rope_scaling: 29 | required_keys = {"factor", "type"} 30 | if not all(key in self.rope_scaling for key in required_keys): 31 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 32 | 33 | if self.rope_scaling["type"] != "linear": 34 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 35 | 36 | 37 | class RMSNorm(nn.Module): 38 | def __init__(self, dims: int, eps: float = 1e-5): 39 | super().__init__() 40 | self.weight = mx.ones((dims,)) 41 | self.eps = eps 42 | 43 | def _norm(self, x): 44 | return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) 45 | 46 | def __call__(self, x): 47 | output = self._norm(x.astype(mx.float32)).astype(x.dtype) 48 | return self.weight * output 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, args: ModelArgs): 53 | super().__init__() 54 | 55 | dim = args.hidden_size 56 | self.n_heads = n_heads = args.num_attention_heads 57 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 58 | 59 | self.repeats = n_heads // n_kv_heads 60 | 61 | head_dim = args.hidden_size // n_heads 62 | self.scale = head_dim**-0.5 63 | 64 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 65 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 66 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 67 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 68 | 69 | rope_scale = ( 70 | 1 / args.rope_scaling["factor"] 71 | if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" 72 | else 1 73 | ) 74 | self.rope = nn.RoPE( 75 | head_dim, 76 | traditional=args.rope_traditional, 77 | base=args.rope_theta, 78 | scale=rope_scale, 79 | ) 80 | 81 | def __call__( 82 | self, 83 | x: mx.array, 84 | mask: Optional[mx.array] = None, 85 | cache: Optional[Tuple[mx.array, mx.array]] = None, 86 | ) -> mx.array: 87 | B, L, D = x.shape 88 | 89 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 90 | 91 | # Prepare the queries, keys and values for the attention computation 92 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 93 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 94 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 95 | 96 | def repeat(a): 97 | a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) 98 | return a.reshape([B, self.n_heads, L, -1]) 99 | 100 | if self.repeats > 1: 101 | keys, values = map(repeat, (keys, values)) 102 | 103 | if cache is not None: 104 | key_cache, value_cache = cache 105 | queries = self.rope(queries, offset=key_cache.shape[2]) 106 | keys = self.rope(keys, offset=key_cache.shape[2]) 107 | keys = mx.concatenate([key_cache, keys], axis=2) 108 | values = mx.concatenate([value_cache, values], axis=2) 109 | else: 110 | queries = self.rope(queries) 111 | keys = self.rope(keys) 112 | 113 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) 114 | if mask is not None: 115 | scores += mask 116 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 117 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 118 | return self.o_proj(output), (keys, values) 119 | 120 | 121 | class MLP(nn.Module): 122 | def __init__(self, dim, hidden_dim): 123 | super().__init__() 124 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 125 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 126 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 127 | 128 | def __call__(self, x) -> mx.array: 129 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 130 | 131 | 132 | class TransformerBlock(nn.Module): 133 | def __init__(self, args: ModelArgs): 134 | super().__init__() 135 | self.num_attention_heads = args.num_attention_heads 136 | self.hidden_size = args.hidden_size 137 | self.self_attn = Attention(args) 138 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 139 | self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 140 | self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 141 | self.args = args 142 | 143 | def __call__( 144 | self, 145 | x: mx.array, 146 | mask: Optional[mx.array] = None, 147 | cache: Optional[Tuple[mx.array, mx.array]] = None, 148 | ) -> mx.array: 149 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 150 | h = x + r 151 | r = self.mlp(self.post_attention_layernorm(h)) 152 | out = h + r 153 | return out, cache 154 | 155 | 156 | class LlamaModel(nn.Module): 157 | def __init__(self, args: ModelArgs): 158 | super().__init__() 159 | self.args = args 160 | self.vocab_size = args.vocab_size 161 | self.num_hidden_layers = args.num_hidden_layers 162 | assert self.vocab_size > 0 163 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 164 | self.layers = [ 165 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 166 | ] 167 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 168 | 169 | def __call__( 170 | self, 171 | inputs: mx.array, 172 | cache=None, 173 | ): 174 | h = self.embed_tokens(inputs) 175 | 176 | mask = None 177 | if h.shape[1] > 1: 178 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 179 | mask = mask.astype(h.dtype) 180 | 181 | if cache is None: 182 | cache = [None] * len(self.layers) 183 | 184 | for e, layer in enumerate(self.layers): 185 | h, cache[e] = layer(h, mask, cache[e]) 186 | 187 | return self.norm(h), cache 188 | 189 | 190 | class Model(nn.Module): 191 | def __init__(self, args: ModelArgs): 192 | super().__init__() 193 | self.model = LlamaModel(args) 194 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 195 | 196 | def __call__( 197 | self, 198 | inputs: mx.array, 199 | cache=None, 200 | ): 201 | out, cache = self.model(inputs, cache) 202 | return self.lm_head(out), cache 203 | -------------------------------------------------------------------------------- /lora.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import json 5 | import math 6 | import time 7 | from pathlib import Path 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import mlx.optimizers as optim 12 | import numpy as np 13 | import utils as lora_utils 14 | from mlx.utils import tree_flatten, tree_unflatten 15 | from models.lora import LoRALinear 16 | 17 | 18 | def build_parser(): 19 | parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") 20 | parser.add_argument( 21 | "--model", 22 | default="mlx_model", 23 | help="The path to the local model directory or Hugging Face repo.", 24 | ) 25 | # Generation args 26 | parser.add_argument( 27 | "--max-tokens", 28 | "-m", 29 | type=int, 30 | default=100, 31 | help="The maximum number of tokens to generate", 32 | ) 33 | parser.add_argument( 34 | "--temp", type=float, default=0.8, help="The sampling temperature" 35 | ) 36 | parser.add_argument( 37 | "--prompt", 38 | "-p", 39 | type=str, 40 | help="The prompt for generation", 41 | default=None, 42 | ) 43 | 44 | # Training args 45 | parser.add_argument( 46 | "--train", 47 | action="store_true", 48 | help="Do training", 49 | ) 50 | parser.add_argument( 51 | "--data", 52 | type=str, 53 | default="data/", 54 | help="Directory with {train, valid, test}.jsonl files", 55 | ) 56 | parser.add_argument( 57 | "--lora-layers", 58 | type=int, 59 | default=16, 60 | help="Number of layers to fine-tune", 61 | ) 62 | parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.") 63 | parser.add_argument( 64 | "--iters", type=int, default=1000, help="Iterations to train for." 65 | ) 66 | parser.add_argument( 67 | "--val-batches", 68 | type=int, 69 | default=25, 70 | help="Number of validation batches, -1 uses the entire validation set.", 71 | ) 72 | parser.add_argument( 73 | "--learning-rate", type=float, default=1e-5, help="Adam learning rate." 74 | ) 75 | parser.add_argument( 76 | "--steps-per-report", 77 | type=int, 78 | default=10, 79 | help="Number of training steps between loss reporting.", 80 | ) 81 | parser.add_argument( 82 | "--steps-per-eval", 83 | type=int, 84 | default=200, 85 | help="Number of training steps between validations.", 86 | ) 87 | parser.add_argument( 88 | "--resume-adapter-file", 89 | type=str, 90 | default=None, 91 | help="Load path to resume training with the given adapter weights.", 92 | ) 93 | parser.add_argument( 94 | "--adapter-file", 95 | type=str, 96 | default="adapters.npz", 97 | help="Save/load path for the trained adapter weights.", 98 | ) 99 | parser.add_argument( 100 | "--test", 101 | action="store_true", 102 | help="Evaluate on the test set after training", 103 | ) 104 | parser.add_argument( 105 | "--test-batches", 106 | type=int, 107 | default=500, 108 | help="Number of test set batches, -1 uses the entire test set.", 109 | ) 110 | parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") 111 | return parser 112 | 113 | 114 | class Dataset: 115 | """ 116 | Light-weight wrapper to hold lines from a jsonl file 117 | """ 118 | 119 | def __init__(self, path: Path, key: str = "text"): 120 | if not path.exists(): 121 | self._data = None 122 | else: 123 | with open(path, "r") as fid: 124 | self._data = [json.loads(l) for l in fid] 125 | self._key = key 126 | 127 | def __getitem__(self, idx: int): 128 | return self._data[idx][self._key] 129 | 130 | def __len__(self): 131 | return len(self._data) 132 | 133 | 134 | def load(args): 135 | names = ("train", "valid", "test") 136 | train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names) 137 | if args.train and len(train) == 0: 138 | raise ValueError( 139 | "Training set not found or empty. Must provide training set for fine-tuning." 140 | ) 141 | if args.train and len(valid) == 0: 142 | raise ValueError( 143 | "Validation set not found or empty. Must provide validation set for fine-tuning." 144 | ) 145 | if args.test and len(test) == 0: 146 | raise ValueError( 147 | "Test set not found or empty. Must provide test set for evaluation." 148 | ) 149 | return train, valid, test 150 | 151 | 152 | def loss(model, inputs, targets, lengths): 153 | # Run model on inputs 154 | logits, _ = model(inputs) 155 | logits = logits.astype(mx.float32) 156 | 157 | # Mask padding tokens 158 | length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] 159 | 160 | # Calculate the loss 161 | ce = nn.losses.cross_entropy(logits, targets) * length_mask 162 | ntoks = length_mask.sum() 163 | ce = ce.sum() / ntoks 164 | return ce, ntoks 165 | 166 | 167 | def iterate_batches(dset, tokenizer, batch_size, train=False): 168 | # Shuffle indices 169 | while True: 170 | indices = np.arange(len(dset)) 171 | if train: 172 | indices = np.random.permutation(indices) 173 | 174 | # Collect batches from dataset 175 | for i in range(0, len(indices) - batch_size + 1, batch_size): 176 | # Encode batch 177 | batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)] 178 | lengths = [len(x) for x in batch] 179 | 180 | # Check if any sequence is longer than 2048 tokens 181 | if max(lengths) > 2048: 182 | print( 183 | "[WARNING] Some sequences are longer than 2048 tokens. " 184 | "Consider pre-splitting your data to save memory." 185 | ) 186 | 187 | # Pad to the max length 188 | batch_arr = np.zeros((batch_size, max(lengths)), np.int32) 189 | 190 | for j in range(batch_size): 191 | batch_arr[j, : lengths[j]] = batch[j] 192 | batch = mx.array(batch_arr) 193 | yield batch[:, :-1], batch[:, 1:], mx.array(lengths) 194 | 195 | if not train: 196 | break 197 | 198 | 199 | def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): 200 | all_losses = [] 201 | ntokens = 0 202 | for it, batch in zip( 203 | range(num_batches), 204 | iterate_batches(dataset, tokenizer, batch_size), 205 | ): 206 | losses, toks = loss(model, *batch) 207 | all_losses.append((losses * toks).item()) 208 | ntokens += toks.item() 209 | 210 | return np.sum(all_losses) / ntokens 211 | 212 | 213 | def train(model, train_set, val_set, optimizer, loss, tokenizer, args): 214 | # Create value and grad function for loss 215 | loss_value_and_grad = nn.value_and_grad(model, loss) 216 | 217 | losses = [] 218 | n_tokens = 0 219 | 220 | # Main training loop 221 | start = time.perf_counter() 222 | for it, batch in zip( 223 | range(args.iters), 224 | iterate_batches(train_set, tokenizer, args.batch_size, train=True), 225 | ): 226 | # Forward and backward pass 227 | (lvalue, toks), grad = loss_value_and_grad(model, *batch) 228 | 229 | # Model update 230 | optimizer.update(model, grad) 231 | mx.eval(model.parameters(), optimizer.state, lvalue) 232 | 233 | # Record loss 234 | losses.append(lvalue.item()) 235 | n_tokens += toks.item() 236 | 237 | # Report training loss if needed 238 | if (it + 1) % args.steps_per_report == 0: 239 | train_loss = np.mean(losses) 240 | 241 | stop = time.perf_counter() 242 | print( 243 | f"Iter {it + 1}: Train loss {train_loss:.3f}, " 244 | f"It/sec {args.steps_per_report / (stop - start):.3f}, " 245 | f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" 246 | ) 247 | losses = [] 248 | n_tokens = 0 249 | start = time.perf_counter() 250 | 251 | # Report validation loss if needed 252 | if it == 0 or (it + 1) % args.steps_per_eval == 0: 253 | stop = time.perf_counter() 254 | val_loss = evaluate( 255 | model, val_set, loss, tokenizer, args.batch_size, args.val_batches 256 | ) 257 | print( 258 | f"Iter {it + 1}: " 259 | f"Val loss {val_loss:.3f}, " 260 | f"Val took {(time.perf_counter() - stop):.3f}s" 261 | ) 262 | 263 | start = time.perf_counter() 264 | 265 | 266 | def generate(model, prompt, tokenizer, args): 267 | print(prompt, end="", flush=True) 268 | 269 | prompt = mx.array(tokenizer.encode(prompt)) 270 | 271 | tokens = [] 272 | skip = 0 273 | for token, n in zip( 274 | lora_utils.generate(prompt, model, args.temp), 275 | range(args.max_tokens), 276 | ): 277 | if token == tokenizer.eos_token_id: 278 | break 279 | 280 | tokens.append(token.item()) 281 | s = tokenizer.decode(tokens) 282 | print(s[skip:], end="", flush=True) 283 | skip = len(s) 284 | print(tokenizer.decode(tokens)[skip:], flush=True) 285 | print("=" * 10) 286 | if len(tokens) == 0: 287 | print("No tokens generated for this prompt") 288 | return 289 | 290 | 291 | if __name__ == "__main__": 292 | parser = build_parser() 293 | args = parser.parse_args() 294 | 295 | np.random.seed(args.seed) 296 | 297 | print("Loading pretrained model") 298 | model, tokenizer, _ = lora_utils.load(args.model) 299 | 300 | # Freeze all layers other than LORA linears 301 | model.freeze() 302 | for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: 303 | l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) 304 | l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) 305 | 306 | p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 307 | print(f"Total parameters {p:.3f}M") 308 | p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 309 | print(f"Trainable parameters {p:.3f}M") 310 | 311 | print("Loading datasets") 312 | train_set, valid_set, test_set = load(args) 313 | 314 | # Resume training the given adapters. 315 | if args.resume_adapter_file is not None: 316 | print(f"Loading pretrained adapters from {args.resume_adapter_file}") 317 | model.load_weights(args.resume_adapter_file, strict=False) 318 | 319 | if args.train: 320 | print("Training") 321 | opt = optim.Adam(learning_rate=args.learning_rate) 322 | 323 | # Train model 324 | train(model, train_set, valid_set, opt, loss, tokenizer, args) 325 | 326 | # Save adapter weights 327 | mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))) 328 | 329 | # Load the LoRA adapter weights which we assume should exist by this point 330 | if not Path(args.adapter_file).is_file(): 331 | raise ValueError( 332 | f"Adapter file {args.adapter_file} missing. " 333 | "Use --train to learn and save the adapters.npz." 334 | ) 335 | model.load_weights(args.adapter_file, strict=False) 336 | 337 | if args.test: 338 | print("Testing") 339 | 340 | test_loss = evaluate( 341 | model, 342 | test_set, 343 | loss, 344 | tokenizer, 345 | args.batch_size, 346 | num_batches=args.test_batches, 347 | ) 348 | test_ppl = math.exp(test_loss) 349 | 350 | print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") 351 | 352 | if args.prompt is not None: 353 | print("Generating") 354 | generate(model, args.prompt, tokenizer, args) 355 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import glob 4 | import inspect 5 | import json 6 | import math 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | from typing import Dict, List, Optional, Tuple, Union 10 | 11 | import mlx.core as mx 12 | import mlx.nn as nn 13 | import numpy as np 14 | from huggingface_hub import snapshot_download 15 | from transformers import AutoTokenizer 16 | 17 | 18 | @dataclass 19 | class ModelArgs: 20 | hidden_size: int 21 | num_hidden_layers: int 22 | intermediate_size: int 23 | num_attention_heads: int 24 | rms_norm_eps: float 25 | vocab_size: int 26 | num_key_value_heads: int = None 27 | rope_theta: float = 10000 28 | rope_traditional: bool = False 29 | model_type: str = None 30 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 31 | 32 | def __post_init__(self): 33 | if self.num_key_value_heads is None: 34 | self.num_key_value_heads = self.num_attention_heads 35 | 36 | if self.rope_scaling: 37 | required_keys = {"factor", "type"} 38 | if not all(key in self.rope_scaling for key in required_keys): 39 | raise ValueError(f"rope_scaling must contain keys {required_keys}") 40 | 41 | if self.rope_scaling["type"] != "linear": 42 | raise ValueError("rope_scaling 'type' currently only supports 'linear'") 43 | 44 | @classmethod 45 | def from_dict(cls, params): 46 | return cls( 47 | **{ 48 | k: v 49 | for k, v in params.items() 50 | if k in inspect.signature(cls).parameters 51 | } 52 | ) 53 | 54 | 55 | class LoRALinear(nn.Module): 56 | @staticmethod 57 | def from_linear(linear: nn.Linear, rank: int = 8): 58 | # TODO remove when input_dims and output_dims are attributes 59 | # on linear and quantized linear 60 | output_dims, input_dims = linear.weight.shape 61 | if isinstance(linear, nn.QuantizedLinear): 62 | input_dims *= 32 // linear.bits 63 | lora_lin = LoRALinear(input_dims, output_dims, rank) 64 | lora_lin.linear = linear 65 | return lora_lin 66 | 67 | def to_linear(self): 68 | linear = self.linear 69 | bias = "bias" in linear 70 | weight = linear.weight 71 | is_quantized = isinstance(linear, nn.QuantizedLinear) 72 | 73 | # Use the same type as the linear weight if not quantized 74 | dtype = weight.dtype 75 | 76 | if is_quantized: 77 | dtype = mx.float16 78 | weight = mx.dequantize( 79 | weight, 80 | linear.scales, 81 | linear.biases, 82 | linear.group_size, 83 | linear.bits, 84 | ) 85 | output_dims, input_dims = weight.shape 86 | fused_linear = nn.Linear(input_dims, output_dims, bias=bias) 87 | 88 | lora_b = (self.scale * self.lora_b.T).astype(dtype) 89 | lora_a = self.lora_a.T.astype(dtype) 90 | fused_linear.weight = weight + lora_b @ lora_a 91 | if bias: 92 | fused_linear.bias = linear.bias 93 | 94 | if is_quantized: 95 | fused_linear = nn.QuantizedLinear.from_linear( 96 | fused_linear, 97 | linear.group_size, 98 | linear.bits, 99 | ) 100 | 101 | return fused_linear 102 | 103 | def __init__( 104 | self, 105 | input_dims: int, 106 | output_dims: int, 107 | lora_rank: int = 8, 108 | bias: bool = False, 109 | scale: float = 20.0, 110 | ): 111 | super().__init__() 112 | 113 | # Regular linear layer weights 114 | self.linear = nn.Linear(input_dims, output_dims, bias=bias) 115 | 116 | # Scale for low-rank update 117 | self.scale = scale 118 | 119 | # Low rank lora weights 120 | scale = 1 / math.sqrt(input_dims) 121 | self.lora_a = mx.random.uniform( 122 | low=-scale, 123 | high=scale, 124 | shape=(input_dims, lora_rank), 125 | ) 126 | self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) 127 | 128 | def __call__(self, x): 129 | dtype = self.linear.weight.dtype 130 | if isinstance(self.linear, nn.QuantizedLinear): 131 | dtype = self.linear.scales.dtype 132 | y = self.linear(x.astype(dtype)) 133 | z = (x @ self.lora_a) @ self.lora_b 134 | return y + self.scale * z 135 | 136 | 137 | class RMSNorm(nn.Module): 138 | def __init__(self, dims: int, eps: float = 1e-5): 139 | super().__init__() 140 | self.weight = mx.ones((dims,)) 141 | self.eps = eps 142 | 143 | def _norm(self, x): 144 | return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) 145 | 146 | def __call__(self, x): 147 | output = self._norm(x.astype(mx.float32)).astype(x.dtype) 148 | return self.weight * output 149 | 150 | 151 | class Attention(nn.Module): 152 | def __init__(self, args: ModelArgs): 153 | super().__init__() 154 | 155 | dim = args.hidden_size 156 | self.n_heads = n_heads = args.num_attention_heads 157 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 158 | 159 | self.repeats = n_heads // n_kv_heads 160 | 161 | head_dim = args.hidden_size // n_heads 162 | self.scale = head_dim**-0.5 163 | 164 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) 165 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 166 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) 167 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) 168 | rope_scale = ( 169 | 1 / args.rope_scaling["factor"] 170 | if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" 171 | else 1 172 | ) 173 | self.rope = nn.RoPE( 174 | head_dim, 175 | traditional=args.rope_traditional, 176 | base=args.rope_theta, 177 | scale=rope_scale, 178 | ) 179 | 180 | def __call__( 181 | self, 182 | x: mx.array, 183 | mask: Optional[mx.array] = None, 184 | cache: Optional[Tuple[mx.array, mx.array]] = None, 185 | ) -> mx.array: 186 | B, L, D = x.shape 187 | 188 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 189 | 190 | # Prepare the queries, keys and values for the attention computation 191 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 192 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 193 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 194 | 195 | def repeat(a): 196 | a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) 197 | return a.reshape([B, self.n_heads, L, -1]) 198 | 199 | if self.repeats > 1: 200 | keys, values = map(repeat, (keys, values)) 201 | 202 | if cache is not None: 203 | key_cache, value_cache = cache 204 | queries = self.rope(queries, offset=key_cache.shape[2]) 205 | keys = self.rope(keys, offset=key_cache.shape[2]) 206 | keys = mx.concatenate([key_cache, keys], axis=2) 207 | values = mx.concatenate([value_cache, values], axis=2) 208 | else: 209 | queries = self.rope(queries) 210 | keys = self.rope(keys) 211 | 212 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) 213 | if mask is not None: 214 | scores += mask 215 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 216 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 217 | return self.o_proj(output), (keys, values) 218 | 219 | 220 | class MLP(nn.Module): 221 | def __init__(self, dim, hidden_dim): 222 | super().__init__() 223 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) 224 | self.down_proj = nn.Linear(hidden_dim, dim, bias=False) 225 | self.up_proj = nn.Linear(dim, hidden_dim, bias=False) 226 | 227 | def __call__(self, x) -> mx.array: 228 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 229 | 230 | 231 | class TransformerBlock(nn.Module): 232 | def __init__(self, args: ModelArgs): 233 | super().__init__() 234 | self.num_attention_heads = args.num_attention_heads 235 | self.hidden_size = args.hidden_size 236 | self.self_attn = Attention(args) 237 | self.mlp = MLP(args.hidden_size, args.intermediate_size) 238 | self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 239 | self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 240 | self.args = args 241 | 242 | def __call__( 243 | self, 244 | x: mx.array, 245 | mask: Optional[mx.array] = None, 246 | cache: Optional[Tuple[mx.array, mx.array]] = None, 247 | ) -> mx.array: 248 | r, cache = self.self_attn(self.input_layernorm(x), mask, cache) 249 | h = x + r 250 | r = self.mlp(self.post_attention_layernorm(h)) 251 | out = h + r 252 | return out, cache 253 | 254 | 255 | class LlamaModel(nn.Module): 256 | def __init__(self, args: ModelArgs): 257 | super().__init__() 258 | self.args = args 259 | self.vocab_size = args.vocab_size 260 | self.num_hidden_layers = args.num_hidden_layers 261 | assert self.vocab_size > 0 262 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 263 | self.layers = [ 264 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 265 | ] 266 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 267 | 268 | def __call__( 269 | self, 270 | inputs: mx.array, 271 | cache=None, 272 | ): 273 | h = self.embed_tokens(inputs) 274 | 275 | mask = None 276 | if h.shape[1] > 1: 277 | mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) 278 | mask = mask.astype(h.dtype) 279 | 280 | if cache is None: 281 | cache = [None] * len(self.layers) 282 | 283 | for e, layer in enumerate(self.layers): 284 | h, cache[e] = layer(h, mask, cache[e]) 285 | 286 | return self.norm(h), cache 287 | 288 | 289 | class Model(nn.Module): 290 | def __init__(self, args: ModelArgs): 291 | super().__init__() 292 | self.model = LlamaModel(args) 293 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 294 | 295 | def __call__( 296 | self, 297 | inputs: mx.array, 298 | cache=None, 299 | ): 300 | out, cache = self.model(inputs, cache) 301 | return self.lm_head(out), cache 302 | 303 | 304 | def load(path_or_hf_repo: str): 305 | # If the path exists, it will try to load model form it 306 | # otherwise download and cache from the hf_repo and cache 307 | model_path = Path(path_or_hf_repo) 308 | if not model_path.exists(): 309 | model_path = Path( 310 | snapshot_download( 311 | repo_id=path_or_hf_repo, 312 | allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], 313 | ) 314 | ) 315 | 316 | with open(model_path / "config.json", "r") as f: 317 | config = json.loads(f.read()) 318 | quantization = config.get("quantization", None) 319 | model_args = ModelArgs.from_dict(config) 320 | 321 | weight_files = glob.glob(str(model_path / "*.safetensors")) 322 | if len(weight_files) == 0: 323 | raise FileNotFoundError("No safetensors found in {}".format(model_path)) 324 | 325 | weights = {} 326 | for wf in weight_files: 327 | weights.update(mx.load(wf).items()) 328 | 329 | model = Model(model_args) 330 | if quantization is not None: 331 | nn.QuantizedLinear.quantize_module(model, **quantization) 332 | 333 | model.load_weights(list(weights.items())) 334 | 335 | mx.eval(model.parameters()) 336 | tokenizer = AutoTokenizer.from_pretrained(model_path) 337 | return model, tokenizer, config 338 | 339 | 340 | def generate(prompt: mx.array, model: Model, temp: float = 0.0): 341 | def sample(logits): 342 | if temp == 0: 343 | return mx.argmax(logits, axis=-1) 344 | else: 345 | return mx.random.categorical(logits * (1 / temp)) 346 | 347 | y = prompt 348 | cache = None 349 | while True: 350 | logits, cache = model(y[None], cache=cache) 351 | logits = logits[:, -1, :] 352 | y = sample(logits) 353 | yield y 354 | --------------------------------------------------------------------------------