├── README-assets └── example-loss-llama-2m.png ├── requirements.txt ├── tokenizer-config-sample.yaml ├── .gitignore ├── model-config-sample.yaml ├── examine.py ├── generate.py ├── README.md ├── train-tokenizer.py ├── convert-to-mlx-lm.py ├── LICENSE.txt ├── plot-logs.py ├── generate_lite.py └── train.py /README-assets/example-loss-llama-2m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/N8python/mlx-pretrain/HEAD/README-assets/example-loss-llama-2m.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.10.1 2 | mlx==0.25.2 3 | mlx_lm==0.23.2 4 | mlx_optimizers==0.4.1 5 | numpy==2.2.5 6 | PyYAML==6.0.2 7 | PyYAML==6.0.2 8 | tokenizers==0.21.1 9 | tqdm==4.67.1 10 | -------------------------------------------------------------------------------- /tokenizer-config-sample.yaml: -------------------------------------------------------------------------------- 1 | name: "BPE Tokenizer Training" 2 | data: 3 | input_file: "train.jsonl" 4 | max_texts_to_train_on: 32768 5 | tokenizer: 6 | special_tokens: 7 | pad: "" 8 | bos: "" 9 | eos: "" 10 | 11 | tokenizer: 12 | vocab_size: 8192 13 | output_dir: "tokenizer" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | env/ 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Jupyter Notebook 24 | .ipynb_checkpoints 25 | 26 | # Model checkpoints and runs (keeping final checkpoints and metadata) 27 | runs/*/checkpoints/*.safetensors 28 | !runs/*/checkpoints/step_final.safetensors 29 | !runs/*/.gitkeep 30 | !runs/*/config.yaml 31 | !runs/*/metadata.json 32 | !runs/*/log.txt 33 | 34 | # Node.js 35 | node_modules/ 36 | npm-debug.log 37 | yarn-debug.log 38 | yarn-error.log 39 | 40 | # Environment variables and secrets 41 | .env 42 | .env.local 43 | .env.*.local 44 | *.env 45 | *.pem 46 | API-KEY 47 | 48 | # OS-specific 49 | .DS_Store 50 | .DS_Store? 51 | ._* 52 | .Spotlight-V100 53 | .Trashes 54 | ehthumbs.db 55 | Thumbs.db 56 | 57 | # Data files 58 | *.parquet 59 | *.jsonl 60 | !train.jsonl 61 | !val.jsonl 62 | !validation_test_set.jsonl 63 | CLAUDE.md 64 | *.jsonl 65 | # Logs 66 | *.log 67 | logs/ 68 | 69 | # MLX cache 70 | .mlx_cache/ 71 | 72 | # Editor specific 73 | .vscode/ 74 | .idea/ 75 | *.swp 76 | *.swo 77 | 78 | MLX-Llama-*/* 79 | evals/* 80 | runs/* 81 | tokenizer/* 82 | MLX-Llama-*/ 83 | evals/ 84 | runs/ 85 | tokenizer/ -------------------------------------------------------------------------------- /model-config-sample.yaml: -------------------------------------------------------------------------------- 1 | name: "Llama (2M)" 2 | overwrite: true 3 | data: 4 | input_file: "train.jsonl" 5 | # Optional validation file 6 | validation_file: "val.jsonl" 7 | # Optional external tokenizer path 8 | tokenizer_path: "tokenizer" # Path to a directory containing tokenizer.json 9 | preprocessing: 10 | max_context_size: 1024 11 | chunk_overlap: 0 # If > 0, chunks will overlap by this many tokens 12 | 13 | tokenizer: 14 | normal_vocab_size: 256 15 | special_tokens: 16 | pad: "" 17 | bos: "" 18 | eos: "" 19 | # Add custom tokens if needed: 20 | # ctrl1: "" 21 | # ctrl2: "" 22 | 23 | model: 24 | architecture: "llama" 25 | dimensions: 26 | hidden_size: 128 27 | intermediate_size: 256 # 384 * 4 28 | num_layers: 4 29 | attention: 30 | num_heads: 8 31 | num_kv_heads: null # If null, equals num_heads 32 | head_dim: null # If null, computed from hidden_size/num_heads 33 | max_position_embeddings: null 34 | normalization: 35 | rms_norm_eps: 1.0e-5 36 | rope: 37 | theta: 10000 38 | traditional: false 39 | scaling: null 40 | misc: 41 | attention_bias: false 42 | mlp_bias: false 43 | tie_word_embeddings: true 44 | 45 | training: 46 | # Number of epochs to train for (optional) 47 | epochs: 1 48 | hyperparameters: 49 | batch_size: 16 50 | learning_rate: 2.0e-2 51 | weight_decay: 0.01 52 | # iters: 10000 # If epochs is provided, this is ignored 53 | 54 | scheduler: 55 | type: "cosine" # Options: linear, cosine, cosine_with_warmup 56 | min_lr_ratio: 0.01 # Minimum LR as a ratio of initial LR 57 | 58 | optimization: 59 | optimizer: "muon" # Options: adam, adamw, muon, sgd 60 | 61 | logging: 62 | log_dir: "logs" 63 | checkpoint_dir: "checkpoints" 64 | steps: 65 | logging_interval: 1 66 | checkpoint_interval: 10000 67 | validation_interval: 1000 # Run validation every 1000 steps (0 to disable) 68 | metrics: 69 | log_loss: true 70 | log_perplexity: true 71 | log_tokens_per_second: true 72 | log_learning_rate: true 73 | log_tokens_processed: true 74 | 75 | system: 76 | seed: 42 77 | device: "gpu" # Options: cpu, gpu 78 | -------------------------------------------------------------------------------- /examine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Inspection utility for MLX-Pretrain data files. 4 | 5 | Usage: 6 | python examine.py --count-tokens [data_path] [tokenizer_path] 7 | 8 | Examples: 9 | python examine.py --count-tokens train.jsonl tokenizer/tokenizer.json 10 | python examine.py --count-tokens val.jsonl tokenizer 11 | """ 12 | 13 | import argparse 14 | import json 15 | import os 16 | from typing import List, Dict, Any 17 | from tokenizers import Tokenizer 18 | from tqdm import tqdm 19 | 20 | def load_jsonl(data_path: str) -> List[Dict[str, Any]]: 21 | """Load data from a JSONL file.""" 22 | data = [] 23 | line_count = 0 24 | with open(data_path, "r", encoding="utf-8") as f: 25 | for _ in f: 26 | line_count += 1 27 | 28 | with open(data_path, "r", encoding="utf-8") as f: 29 | for line in tqdm(f, total=line_count, desc="Loading data"): 30 | if line.strip(): # Skip empty lines 31 | data.append(json.loads(line)) 32 | return data 33 | 34 | def count_tokens(data_path: str, tokenizer_path: str): 35 | """Count the total number of tokens in a JSONL file.""" 36 | print(f"Loading data from {data_path}") 37 | data = load_jsonl(data_path) 38 | 39 | print(f"Loading tokenizer from {tokenizer_path}") 40 | if os.path.isdir(tokenizer_path): 41 | tokenizer_file = os.path.join(tokenizer_path, "tokenizer.json") 42 | else: 43 | tokenizer_file = tokenizer_path 44 | tokenizer = Tokenizer.from_file(tokenizer_file) 45 | 46 | total_tokens = 0 47 | for item in tqdm(data, desc="Counting tokens"): 48 | if "text" in item: 49 | encoding = tokenizer.encode(item["text"]) 50 | total_tokens += len(encoding.ids) 51 | 52 | print(f"Total tokens in {data_path}: {total_tokens}") 53 | return total_tokens 54 | 55 | def main(): 56 | parser = argparse.ArgumentParser(description="Inspection utilities for MLX-Pretrain data") 57 | 58 | # This approach allows for adding other inspection commands in the future 59 | parser.add_argument("--count-tokens", action="store_true", 60 | help="Count tokens in a JSONL file") 61 | parser.add_argument("data_path", nargs="?", help="Path to the JSONL data file") 62 | parser.add_argument("tokenizer_path", nargs="?", help="Path to the tokenizer file") 63 | 64 | args = parser.parse_args() 65 | 66 | if args.count_tokens: 67 | if not args.data_path or not args.tokenizer_path: 68 | parser.error("--count-tokens requires data_path and tokenizer_path") 69 | count_tokens(args.data_path, args.tokenizer_path) 70 | else: 71 | parser.print_help() 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import mlx.core as mx 4 | from train import Trainer 5 | from mlx_lm.sample_utils import make_sampler, make_logits_processors 6 | import mlx.nn as nn 7 | import time 8 | from generate_lite import generate_lite, beam_search 9 | mx.set_default_device(mx.gpu) 10 | def main(): 11 | parser = argparse.ArgumentParser(description='Generate text using a trained model') 12 | parser.add_argument('--run', type=str, required=True, 13 | help='Name of the training run to use') 14 | parser.add_argument('--prompt', type=str, required=True, 15 | help='Text prompt to start generation') 16 | parser.add_argument('--max-tokens', type=int, default=256, 17 | help='Maximum number of tokens to generate') 18 | parser.add_argument('--temperature', type=float, default=1.0, 19 | help='Sampling temperature') 20 | parser.add_argument('--min-p', type=float, default=0.05, 21 | help='Minimum probability for nucleus sampling') 22 | parser.add_argument('--repetition-penalty', type=float, default=1.1, 23 | help='Repetition penalty factor') 24 | parser.add_argument('--repetition-context-size', type=int, default=20, 25 | help='Context size for repetition penalty') 26 | args = parser.parse_args() 27 | 28 | # Load run configuration and initialize trainer 29 | config_path = Path('runs') / args.run / 'config.yaml' 30 | if not config_path.exists(): 31 | raise ValueError(f"Config not found for run: {args.run}") 32 | 33 | trainer = Trainer(str(config_path), for_training=False) 34 | 35 | # Load the final checkpoint 36 | checkpoint_path = Path('runs') / args.run / 'checkpoints' / 'step_final_model.safetensors' 37 | if not checkpoint_path.exists(): 38 | checkpoint_path = Path('runs') / args.run / 'checkpoints' / 'step_final.safetensors' 39 | if not checkpoint_path.exists(): 40 | raise ValueError(f"Final checkpoint not found for run: {args.run}") 41 | checkpoint_path = str(checkpoint_path) 42 | 43 | trainer.model.load_weights(checkpoint_path) 44 | 45 | # Prepare input 46 | tokens = [trainer.tokenizer.BOS_TOKEN] + trainer.tokenizer.tokenize(args.prompt) 47 | 48 | # Setup generation parameters 49 | sampler = make_sampler(temp=args.temperature, min_p=args.min_p) 50 | logits_processors = make_logits_processors( 51 | repetition_penalty=args.repetition_penalty, 52 | repetition_context_size=args.repetition_context_size 53 | ) 54 | 55 | # Generate 56 | """output = beam_search( 57 | trainer.model, 58 | mx.array(tokens), 59 | max_tokens=args.max_tokens, # Limit the max tokens to generate 60 | verbose=True, 61 | n_beams=4, # Use beam search for generation 62 | stop_tokens=[trainer.tokenizer.EOS_TOKEN] 63 | )""" 64 | mx.random.seed(int(time.time() * 1000)) 65 | greedy_output, greedy_score = generate_lite( 66 | trainer.model, 67 | mx.array(tokens), 68 | max_tokens=args.max_tokens, 69 | sampler=sampler, 70 | verbose=False, 71 | stop_tokens=[trainer.tokenizer.EOS_TOKEN], 72 | logits_processors=logits_processors 73 | ) 74 | print(f"Greedy Output: {trainer.tokenizer.detokenize(greedy_output)}") 75 | #print(f"Greedy Output (Score: {greedy_score:.3f}): {trainer.tokenizer.detokenize(greedy_output)}") 76 | 77 | # Print result 78 | #print(f"Greedy (Score: {score:.3f}): {trainer.tokenizer.detokenize(output)}") 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX-Pretrain 2 | 3 | `mlx-pretrain` is a library that allows easy pretraining of large language models (LLMs) using MLX on Apple Silicon. Instructions below: 4 | 5 | ## Installation 6 | 7 | 1. Clone the repository: 8 | ```bash 9 | git clone https://github.com/N8python/mlx-pretrain.git 10 | cd mlx-pretrain 11 | ``` 12 | 2. Create a virtual environment through any means you prefer and: 13 | ```bash 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | 18 | Make sure to use python 3.10 or 3.11 - 3.13 causes issues with `sentencepiece`. 19 | 20 | ## Training a Toy Model 21 | 22 | Download the toy dataset - 200M tokens of Fineweb-Edu (and a validation set): 23 | 24 | ```bash 25 | wget https://huggingface.co/datasets/N8Programs/mlx-pretrain-ex/resolve/main/train.jsonl 26 | wget https://huggingface.co/datasets/N8Programs/mlx-pretrain-ex/resolve/main/val.jsonl 27 | ``` 28 | 29 | Make sure these are in the same directory as the `train.py` script. You can adjust the exact path in the config if you want to keep them somewhere else. 30 | 31 | Now, we will first train a tokenizer on the dataset. This is a simple BPE tokenizer, and it will be saved to `tokenizer/tokenizer.json`: 32 | 33 | ```bash 34 | python train-tokenizer.py --config tokenizer-config-sample.yaml 35 | ``` 36 | 37 | This will create a `tokenizer` directory with a `tokenizer.json` file inside (This should take 5-15 minutes). 38 | 39 | Now, we can train the toy model, simply run: 40 | 41 | ```bash 42 | python train.py --config model-config-sample.yaml 43 | ``` 44 | 45 | This will train a 2M parameter Llama Model on 200M tokens of Fineweb-Edu. This will take around 2 hours on an M3 Max. If you wish to shorten the training time, modify (in the config file): 46 | 47 | ```yaml 48 | training: 49 | # Number of epochs to train for (optional) 50 | # epochs: 1 (Remove epochs: 1) 51 | hyperparameters: 52 | batch_size: 16 53 | learning_rate: 2.0e-2 54 | weight_decay: 0.01 55 | iters: 10000 # Uncomment "iters" - 10000 should complete is ~20 minutes 56 | ``` 57 | 58 | Once the model is done training, it will be saved in the `runs` directory under the folder `Llama (2M)`. 59 | 60 | You view the loss curve by running: 61 | 62 | ```bash 63 | python plot-logs.py "Llama (2M)" 64 | ``` 65 | 66 | You should see an image like this: 67 | 68 | ![Loss Curve](README-assets/example-loss-llama-2m.png) 69 | 70 | You can now generate text with the model. To do this, run: 71 | 72 | ```bash 73 | python generate.py --run "Llama (2M)" --prompt "It is recommended to eat three apples a day, because if you don't, then " 74 | ``` 75 | 76 | This will generate text using the model (by default, at temperature 1.0). Example output: 77 | 78 | ``` 79 | It is recommended to eat three apples a day, because if you don't, then 80 | -> 81 | you will need to have any different benefits and for you. 82 | What are the steps in the work? 83 | Typically, if you have to talk about the workplace when you are receiving by doing that. When you do this, it would probably be an open water source... 84 | ``` 85 | 86 | Now, we can convert the model to MLX-LM format to use it with `mlx-lm` more generally - this is dead simple, run: 87 | 88 | ```bash 89 | python convert-to-mlx-lm.py --run "Llama (2M)" --out-path "MLX-Llama-2M" 90 | ``` 91 | 92 | The resulting model can be used with any MLX-LM script. For example, you can evaluate it on ARC-Easy (if you `pip install lm-eval`), via: 93 | 94 | ```bash 95 | python -m mlx_lm evaluate --model MLX-Llama-2M --tasks arc_easy 96 | ``` 97 | 98 | You should see: 99 | 100 | ``` 101 | { 102 | "alias": "arc_easy", 103 | "acc,none": 0.31607744107744107, 104 | "acc_stderr,none": 0.009540440071928285, 105 | "acc_norm,none": 0.30934343434343436, 106 | "acc_norm_stderr,none": 0.009484615220606835 107 | } 108 | ``` 109 | 110 | Which shows the model get ~31% accuracy on ARC-Easy - which surpasses the random baseline of 25% and shows our model did actually learn something. 111 | 112 | Now that you have the MLX-LM model, you can proceed as you wish - upload it to HuggingFace, use it locally for evaluation purposes, etc. 113 | 114 | # Related Projects 115 | @arthurcolle's [MLX + Cuda Pretraining](https://github.com/arthurcolle/mlx-cuda-distributed-pretraining/tree/muon) -------------------------------------------------------------------------------- /train-tokenizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import yaml 4 | import os 5 | from tokenizers import Tokenizer 6 | from tokenizers.models import BPE 7 | from tokenizers.trainers import BpeTrainer 8 | from tokenizers.pre_tokenizers import ByteLevel 9 | from tokenizers.normalizers import NFKC 10 | from tokenizers.decoders import ByteLevel as ByteLevelDecoder 11 | 12 | 13 | def load_config(config_path): 14 | """Load the YAML configuration file.""" 15 | with open(config_path, 'r') as file: 16 | return yaml.safe_load(file) 17 | 18 | 19 | def load_jsonl_texts(file_path): 20 | """Load and extract text from a JSONL file.""" 21 | texts = [] 22 | with open(file_path, 'r', encoding='utf-8') as f: 23 | for line in f: 24 | try: 25 | item = json.loads(line.strip()) 26 | if 'text' in item: 27 | texts.append(item['text']) 28 | except json.JSONDecodeError: 29 | print(f"Warning: Could not parse line: {line}") 30 | return texts 31 | 32 | 33 | def batch_iterator(texts, batch_size=1000): 34 | """Creates batches of texts for tokenizer training.""" 35 | for i in range(0, len(texts), batch_size): 36 | yield texts[i:i+batch_size] 37 | 38 | 39 | def train_tokenizer(config): 40 | """Train a byte-level BPE tokenizer based on the provided configuration.""" 41 | 42 | # Initialize the tokenizer with a BPE model 43 | tokenizer = Tokenizer(BPE()) 44 | 45 | # Configure pre-tokenizer for superword BPE (no word boundaries) 46 | tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False, use_regex=False) 47 | 48 | # Set up the normalizer 49 | tokenizer.normalizer = NFKC() 50 | 51 | # Set up the decoder 52 | tokenizer.decoder = ByteLevelDecoder() 53 | 54 | # Get special tokens from config 55 | special_tokens = [] 56 | if 'data' in config and 'tokenizer' in config['data'] and 'special_tokens' in config['data']['tokenizer']: 57 | special_tokens = list(config['data']['tokenizer']['special_tokens'].values()) 58 | 59 | # Get vocab size from config 60 | vocab_size = 32000 # Default 61 | if 'tokenizer' in config and 'vocab_size' in config['tokenizer']: 62 | vocab_size = config['tokenizer']['vocab_size'] 63 | 64 | # Set up the trainer 65 | trainer = BpeTrainer( 66 | vocab_size=vocab_size, 67 | min_frequency=2, 68 | special_tokens=special_tokens, 69 | show_progress=True 70 | ) 71 | 72 | # Load training data 73 | input_file = config['data']['input_file'] if 'data' in config and 'input_file' in config['data'] else 'train.jsonl' 74 | texts = load_jsonl_texts(input_file) 75 | 76 | if 'data' in config and 'max_texts_to_train_on' in config['data']: 77 | texts = texts[:config['data']['max_texts_to_train_on']] 78 | print(f"Training tokenizer on {len(texts)} texts with vocab size {vocab_size}") 79 | 80 | # Train the tokenizer 81 | tokenizer.train_from_iterator(batch_iterator(texts), trainer=trainer) 82 | 83 | # Create output directory if it doesn't exist 84 | output_dir = config['tokenizer']['output_dir'] if 'tokenizer' in config and 'output_dir' in config['tokenizer'] else 'tokenizer' 85 | os.makedirs(output_dir, exist_ok=True) 86 | 87 | # Save the tokenizer 88 | output_path = os.path.join(output_dir, "tokenizer.json") 89 | tokenizer.save(output_path) 90 | print(f"Tokenizer saved to {output_path}") 91 | 92 | # Test the tokenizer 93 | if texts: 94 | test_text = texts[0][:100] # Take first 100 chars of first text for testing 95 | encoded = tokenizer.encode(test_text) 96 | print("\nTest encoding:") 97 | print(f"Text: {test_text}") 98 | print(f"Tokens: {encoded.tokens}") 99 | print(f"IDs: {encoded.ids}") 100 | 101 | return tokenizer 102 | 103 | 104 | def main(): 105 | parser = argparse.ArgumentParser(description="Train a BPE tokenizer using a YAML configuration") 106 | parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") 107 | args = parser.parse_args() 108 | 109 | config = load_config(args.config) 110 | train_tokenizer(config) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() -------------------------------------------------------------------------------- /convert-to-mlx-lm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import mlx.core as mx 4 | from train import Trainer 5 | from mlx_lm.sample_utils import make_sampler, make_logits_processors 6 | import mlx.nn as nn 7 | from mlx_lm import load, generate 8 | import time 9 | from generate_lite import generate_lite, beam_search 10 | mx.set_default_device(mx.gpu) 11 | import os 12 | import json 13 | def main(): 14 | parser = argparse.ArgumentParser(description='Convert a model to MLX format') 15 | parser.add_argument('--run', type=str, required=True, 16 | help='Name of the training run to use') 17 | parser.add_argument('--out-path', type=str, default='output', 18 | help='Path for MLX-LM Model output directory') 19 | args = parser.parse_args() 20 | 21 | # Load run configuration and initialize trainer 22 | config_path = Path('runs') / args.run / 'config.yaml' 23 | if not config_path.exists(): 24 | raise ValueError(f"Config not found for run: {args.run}") 25 | 26 | trainer = Trainer(str(config_path), for_training=False) 27 | 28 | # Load the config 29 | 30 | 31 | # Load the final checkpoint 32 | checkpoint_path = Path('runs') / args.run / 'checkpoints' / 'step_final_model.safetensors' 33 | if not checkpoint_path.exists(): 34 | checkpoint_path = Path('runs') / args.run / 'checkpoints' / 'step_final.safetensors' 35 | if not checkpoint_path.exists(): 36 | raise ValueError(f"Final checkpoint not found for run: {args.run}") 37 | checkpoint_path = str(checkpoint_path) 38 | 39 | trainer.model.load_weights(checkpoint_path) 40 | 41 | # Create output directory 42 | out_dir = Path(args.out_path) 43 | os.makedirs(out_dir, exist_ok=True) 44 | 45 | # Set the output path for the model file 46 | out_path_model = out_dir / 'model.safetensors' 47 | 48 | # Copy the model file 49 | import shutil 50 | print(f"Copying model from {checkpoint_path} to {out_path_model}") 51 | shutil.copy2(checkpoint_path, out_path_model) 52 | 53 | # Copy the tokenizer 54 | 55 | tokenizer_path = Path('runs') / args.run / 'tokenizer' / 'tokenizer.json' 56 | 57 | shutil.copy2(tokenizer_path, out_dir / 'tokenizer.json') 58 | 59 | config = { 60 | "architectures": [ 61 | "LlamaForCausalLM" 62 | ], 63 | "attention_bias": False, 64 | "attention_dropout": 0.0, 65 | } 66 | config["attention_bias"] = trainer.config.model.misc['attention_bias'] 67 | config["bos_token_id"] = trainer.tokenizer.tokenize(trainer.config.data.tokenizer['special_tokens']['bos'])[0] 68 | config["eos_token_id"] = [trainer.tokenizer.tokenize(trainer.config.data.tokenizer['special_tokens']['eos'])[0]] 69 | #print(trainer.config.model) 70 | config["hidden_act"] = "silu" 71 | config["hidden_size"] = trainer.config.model.dimensions["hidden_size"] 72 | config["intermediate_size"] = trainer.config.model.dimensions["intermediate_size"] 73 | config["max_position_embeddings"] = trainer.config.data.preprocessing["max_context_size"] 74 | config["mlp_bias"] = trainer.config.model.misc['mlp_bias'] 75 | config["model_type"] = trainer.config.model.architecture 76 | config["num_attention_heads"] = trainer.config.model.attention["num_heads"] 77 | config["num_hidden_layers"] = trainer.config.model.dimensions["num_layers"] 78 | config["rms_norm_eps"] = trainer.config.model.normalization['rms_norm_eps'] 79 | config["rope_scaling"] = trainer.config.model.rope['scaling'] 80 | config["rope_theta"] = trainer.config.model.rope['theta'] 81 | config["tie_word_embeddings"] = trainer.config.model.misc['tie_word_embeddings'] 82 | config["torch_dtype"] = "float32" # Only support float32 for now 83 | config["use_cache"] = True 84 | config["vocab_size"] = trainer.tokenizer.VOCAB_SIZE 85 | 86 | # Save the config 87 | config_path = out_dir / 'config.json' 88 | with open(config_path, 'w') as f: 89 | json.dump(config, f, indent=4) 90 | 91 | tokenizer_config = { 92 | "bos_token": trainer.config.data.tokenizer['special_tokens']['bos'], 93 | "eos_token": trainer.config.data.tokenizer['special_tokens']['eos'], 94 | "model_input_names": [ 95 | "input_ids", 96 | "attention_mask" 97 | ], 98 | "model_max_length": trainer.config.data.preprocessing["max_context_size"], 99 | "tokenizer_class": "PreTrainedTokenizerFast", 100 | 101 | 102 | } 103 | 104 | # Save the tokenizer config 105 | tokenizer_config_path = out_dir / 'tokenizer_config.json' 106 | with open(tokenizer_config_path, 'w') as f: 107 | json.dump(tokenizer_config, f, indent=4) 108 | 109 | # Modify the tokenizer to start with BOS using a post-processor 110 | 111 | tokenizer_path = out_dir / 'tokenizer.json' 112 | bos_token = tokenizer_config["bos_token"] 113 | bos_id = trainer.tokenizer.tokenize(bos_token)[0] 114 | with open(tokenizer_path, 'r') as f: 115 | tokenizer_data = json.load(f) 116 | tokenizer_data['post_processor'] = { 117 | "type": "Sequence", 118 | "processors": [ 119 | { 120 | "type": "TemplateProcessing", 121 | "single": [ 122 | { 123 | "SpecialToken": { 124 | "id": bos_token, 125 | "type_id": 0 126 | } 127 | }, 128 | { 129 | "Sequence": { 130 | "id": "A", 131 | "type_id": 0 132 | } 133 | } 134 | ], 135 | "pair": [ 136 | { 137 | "SpecialToken": { 138 | "id": bos_token, 139 | "type_id": 0 140 | } 141 | }, 142 | { 143 | "Sequence": { 144 | "id": "A", 145 | "type_id": 0 146 | } 147 | }, 148 | { 149 | "SpecialToken": { 150 | "id": bos_token, 151 | "type_id": 1 152 | } 153 | }, 154 | { 155 | "Sequence": { 156 | "id": "B", 157 | "type_id": 1 158 | } 159 | } 160 | ], 161 | "special_tokens": { 162 | bos_token: { 163 | "id": bos_token, 164 | "ids": [ 165 | bos_id 166 | ], 167 | "tokens": [ 168 | bos_token, 169 | ] 170 | } 171 | } 172 | } 173 | ] 174 | } 175 | # Save 176 | with open(tokenizer_path, 'w') as f: 177 | json.dump(tokenizer_data, f, indent=4) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /plot-logs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import argparse 4 | import json 5 | from pathlib import Path 6 | 7 | def process_log(log_file: Path) -> tuple[list, list, list, list, list]: 8 | """Process a single log file and return tokens, training losses, and validation data.""" 9 | with open(log_file, 'r') as f: 10 | lines = f.readlines() 11 | 12 | # Parse training losses from regular log entries 13 | #train_losses = [] 14 | #tokens = [0] 15 | train_steps = [] 16 | 17 | # Parse validation losses 18 | val_steps = [] 19 | val_losses = [] 20 | 21 | for line in lines: 22 | if line.startswith("Step") and "validation:" not in line: 23 | step = int(line.split()[1][:-1]) 24 | # Regular training log 25 | parts = line.split("|") 26 | # First part contains loss 27 | loss_part = next((p for p in parts if "loss=" in p), None) 28 | loss = float(loss_part.split("=")[1].strip()) 29 | 30 | toks_part = next((p for p in parts if "toks=" in p), None) 31 | toks = float(toks_part.split("=")[1].strip()) 32 | train_steps.append((step, loss, toks)) 33 | """"if loss_part: 34 | loss = float(loss_part.split("=")[1].strip()) 35 | train_losses.append(loss) 36 | 37 | # Find tokens processed 38 | toks_part = next((p for p in parts if "toks=" in p), None) 39 | if toks_part: 40 | toks = float(toks_part.split("=")[1].strip()) 41 | tokens.append(toks + tokens[-1])""" 42 | 43 | elif "validation:" in line: 44 | # Validation log 45 | step = int(line.split()[1]) 46 | val_loss = float(line.split("val_loss=")[1].split()[0]) 47 | val_steps.append(step) 48 | val_losses.append(val_loss) 49 | # Sort train_steps 50 | train_steps.sort(key=lambda x: x[0]) 51 | deduped_train_steps = [] 52 | for step, loss, toks in train_steps: 53 | if len(deduped_train_steps) == 0 or deduped_train_steps[-1][0] != step: 54 | deduped_train_steps.append((step, loss, toks)) 55 | train_losses = [] 56 | tokens = [0] 57 | for step, loss, toks in deduped_train_steps: 58 | train_losses.append(loss) 59 | # Append tokens processed 60 | tokens.append(toks + tokens[-1]) 61 | # Deduplicate steps 62 | # Ensure tokens list has same length as losses 63 | if len(tokens) > len(train_losses) + 1: 64 | tokens = tokens[:len(train_losses) + 1] 65 | tokens = tokens[1:] 66 | # Validation data might also be in metadata 67 | metadata_path = log_file.parent / "metadata.json" 68 | if metadata_path.exists(): 69 | try: 70 | with open(metadata_path, 'r') as f: 71 | metadata = json.load(f) 72 | 73 | if 'validation' in metadata and len(metadata['validation']['steps']) > 0: 74 | # Use metadata for validation data as it's more reliable 75 | val_steps = metadata['validation']['steps'] 76 | val_losses = metadata['validation']['losses'] 77 | except: 78 | # Fallback to log-parsed validation data 79 | pass 80 | 81 | # EMA smoothing for training loss 82 | ema = 0.9 83 | smoothed_train_losses = [train_losses[0]] 84 | for loss in train_losses[1:]: 85 | smoothed_train_losses.append(ema * smoothed_train_losses[-1] + (1 - ema) * loss) 86 | 87 | # EMA smoothing for validation loss 88 | ema_val = 0.0 89 | smoothed_val_losses = [] 90 | if val_losses: 91 | smoothed_val_losses = [val_losses[0]] 92 | for loss in val_losses[1:]: 93 | smoothed_val_losses.append(ema_val * smoothed_val_losses[-1] + (1 - ema_val) * loss) 94 | ema_val = ema ** (1000/16) 95 | 96 | return tokens, smoothed_train_losses, val_steps, val_losses, smoothed_val_losses 97 | 98 | def main(): 99 | parser = argparse.ArgumentParser(description='Plot training logs for multiple runs') 100 | parser.add_argument('run_names', type=str, nargs='+', help='Names of the training runs to plot') 101 | parser.add_argument('--no-val', action='store_true', help='Ignore validation data when plotting') 102 | args = parser.parse_args() 103 | 104 | # Create a figure with 1 row, 2 columns 105 | plt.figure(figsize=(16, 8)) 106 | 107 | # Full range training and validation loss plot 108 | plt.subplot(1, 2, 1) 109 | has_validation_data = False 110 | 111 | for run_name in args.run_names: 112 | log_file = Path("runs") / run_name / "log.txt" 113 | if not log_file.exists(): 114 | print(f"Error: Log file not found at {log_file}") 115 | continue 116 | 117 | tokens, train_losses, val_steps, val_losses, smoothed_val_losses = process_log(log_file) 118 | 119 | # Plot training losses 120 | plt.plot(tokens, train_losses, label=f"{run_name} (train EMA)") 121 | 122 | # Plot validation losses if available and not disabled 123 | if not args.no_val and val_steps and val_losses: 124 | has_validation_data = True 125 | val_tokens = [] 126 | for step in val_steps: 127 | if step < len(tokens): 128 | val_tokens.append(tokens[step]) 129 | else: 130 | # Estimate based on last available token count 131 | val_tokens.append(tokens[-1] * step / len(tokens)) 132 | 133 | #plt.plot(val_tokens, val_losses, 'o', alpha=0.5, label=f"{run_name} (val)") 134 | plt.plot(val_tokens, smoothed_val_losses, '-', label=f"{run_name} (val EMA)") 135 | 136 | plt.xlabel("Total tokens processed") 137 | plt.ylabel("Loss") 138 | title = "Training Loss (Full Range)" if args.no_val else "Training and Validation Loss (Full Range)" 139 | plt.title(title) 140 | plt.legend() 141 | plt.grid(True, alpha=0.3) 142 | 143 | # Last 80% training and validation loss plot 144 | plt.subplot(1, 2, 2) 145 | 146 | for run_name in args.run_names: 147 | log_file = Path("runs") / run_name / "log.txt" 148 | if not log_file.exists(): 149 | continue 150 | 151 | tokens, train_losses, val_steps, val_losses, smoothed_val_losses = process_log(log_file) 152 | 153 | # Calculate 20% cutoff point 154 | cutoff = int(0.2 * len(tokens)) 155 | tokens_last_80 = tokens[cutoff:] 156 | train_losses_last_80 = train_losses[cutoff:] 157 | 158 | # Plot training losses for last 80% 159 | plt.plot(tokens_last_80, train_losses_last_80, label=f"{run_name} (train EMA)") 160 | 161 | # Plot validation losses for last 80% if available and not disabled 162 | if not args.no_val and val_steps and val_losses: 163 | val_tokens = [] 164 | for step in val_steps: 165 | if step < len(tokens): 166 | val_tokens.append(tokens[step]) 167 | else: 168 | # Estimate based on last available token count 169 | val_tokens.append(tokens[-1] * step / len(tokens)) 170 | 171 | # Filter validation points to only include those in the last 80% 172 | last_80_points = [(t, l, s) for t, l, s in zip(val_tokens, val_losses, smoothed_val_losses) 173 | if t >= tokens_last_80[0]] 174 | 175 | if last_80_points: 176 | last_tokens, last_losses, last_smoothed = zip(*last_80_points) 177 | #plt.plot(last_tokens, last_losses, 'o', alpha=0.5, label=f"{run_name} (val)") 178 | plt.plot(last_tokens, last_smoothed, '-', label=f"{run_name} (val EMA)") 179 | 180 | plt.xlabel("Total tokens processed") 181 | plt.ylabel("Loss") 182 | title = "Training Loss (Last 80%)" if args.no_val else "Training and Validation Loss (Last 80%)" 183 | plt.title(title) 184 | plt.legend() 185 | plt.grid(True, alpha=0.3) 186 | 187 | plt.tight_layout() 188 | plt.show() 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /generate_lite.py: -------------------------------------------------------------------------------- 1 | # generate_lite.py 2 | # Copyright © 2023-2024 Apple Inc. 3 | 4 | import time 5 | import logging 6 | from contextlib import contextmanager 7 | from typing import Any, Callable, Generator, List, Optional, Tuple 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | from mlx_lm.models.cache import make_prompt_cache 12 | 13 | ############################################################################## 14 | # Minimal Utilities (wired_limit, generation_stream, maybe_quantize_kv_cache) 15 | ############################################################################## 16 | 17 | # A stream on the default device just for generation 18 | generation_stream = mx.new_stream(mx.default_device()) 19 | 20 | class ModelNotFoundError(Exception): 21 | """Exception for missing model files.""" 22 | def __init__(self, message: str): 23 | self.message = message 24 | super().__init__(self.message) 25 | 26 | @contextmanager 27 | def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): 28 | """ 29 | A context manager to temporarily change the wired limit, synchronizing 30 | streams on exit to prevent overlapping changes in asynchronous contexts. 31 | """ 32 | model_bytes = 0 33 | # Recursively sum up all array nbytes in the model 34 | def _tree_reduce(m, acc=0): 35 | if isinstance(m, mx.array): 36 | return acc + m.nbytes 37 | if isinstance(m, nn.Module): 38 | for child in m.children(): 39 | acc = _tree_reduce(child, acc) 40 | return acc 41 | model_bytes = _tree_reduce(model) 42 | 43 | max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] 44 | if model_bytes > 0.9 * max_rec_size: 45 | model_mb = model_bytes // 2**20 46 | max_rec_mb = max_rec_size // 2**20 47 | print( 48 | f"[WARNING] Generating with a model that requires {model_mb} MB, " 49 | f"close to the max recommended {max_rec_mb} MB. This can be slow." 50 | ) 51 | 52 | old_limit = mx.set_wired_limit(max_rec_size) 53 | try: 54 | yield 55 | finally: 56 | if streams is not None: 57 | for s in streams: 58 | mx.synchronize(s) 59 | else: 60 | mx.synchronize() 61 | mx.set_wired_limit(old_limit) 62 | 63 | def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): 64 | """ 65 | If we've passed 'quantized_kv_start', convert the KV cache to a quantized 66 | variant (if not already), using the specified group size and bits. 67 | """ 68 | if ( 69 | kv_bits is not None 70 | and not isinstance(prompt_cache[0], nn.cache.QuantizedKVCache) 71 | and prompt_cache[0].offset > quantized_kv_start 72 | ): 73 | for i in range(len(prompt_cache)): 74 | if isinstance(prompt_cache[i], nn.cache.KVCache): 75 | prompt_cache[i] = prompt_cache[i].to_quantized( 76 | group_size=kv_group_size, bits=kv_bits 77 | ) 78 | 79 | 80 | ############################################################################## 81 | # The Core Generator (generate_step) 82 | ############################################################################## 83 | 84 | def generate_step( 85 | prompt: mx.array, 86 | model: nn.Module, 87 | *, 88 | max_tokens: int = 256, 89 | sampler: Optional[Callable[[mx.array], mx.array]] = None, 90 | logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, 91 | max_kv_size: Optional[int] = None, 92 | prompt_cache: Optional[Any] = None, 93 | prefill_step_size: int = 512, 94 | kv_bits: Optional[int] = None, 95 | kv_group_size: int = 64, 96 | quantized_kv_start: int = 0, 97 | prompt_progress_callback: Optional[Callable[[int, int], None]] = None, 98 | ) -> Generator[Tuple[mx.array, mx.array], None, None]: 99 | """ 100 | A low-level generator producing token ids from 'prompt', using 'model'. 101 | Yields (token, logprobs) as we generate one token at a time. 102 | """ 103 | y = prompt 104 | tokens = None # for the optional logits processors 105 | 106 | # Create (or reuse) the key-value cache for generation 107 | if prompt_cache is None: 108 | prompt_cache = make_prompt_cache(model, max_kv_size=max_kv_size) 109 | elif len(prompt_cache) != len(model.layers): 110 | raise ValueError("Wrong number of layers in the prompt cache.") 111 | 112 | sampler = sampler or (lambda logprobs: mx.argmax(logprobs, axis=-1)) 113 | logits_processors = logits_processors or [] 114 | prompt_progress_callback = prompt_progress_callback or (lambda *_: None) 115 | 116 | def _step(y_tok: mx.array): 117 | """One forward pass step: produce next-token logits, apply processors.""" 118 | logits = model(y_tok[None], cache=prompt_cache) 119 | # logits shape: [1, seq_len=1, vocab_size] 120 | logits = logits[:, -1, :] # take the last token 121 | if logits_processors: 122 | nonlocal tokens 123 | tokens = mx.concat([tokens, y_tok]) if tokens is not None else y_tok 124 | for processor in logits_processors: 125 | logits = processor(tokens, logits) 126 | 127 | maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) 128 | logprobs = logits - mx.logsumexp(logits, keepdims=True) 129 | next_token = sampler(logprobs) # shape [1] or [batch_size=1] 130 | return next_token, logprobs.squeeze(0) 131 | 132 | # Prefill stage: feed large chunks of the prompt to fill the cache 133 | total_prompt_tokens = y.size 134 | prompt_processed_tokens = 0 135 | while y.size > prefill_step_size: 136 | model(y[:prefill_step_size][None], cache=prompt_cache) 137 | maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) 138 | mx.eval([c.state for c in prompt_cache]) 139 | prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) 140 | prompt_processed_tokens += prefill_step_size 141 | y = y[prefill_step_size:] 142 | mx.clear_cache() 143 | 144 | # Process the remainder of the prompt in a single step 145 | y, logprobs = _step(y) 146 | mx.async_eval(y, logprobs) 147 | 148 | # Generate tokens up to max_tokens 149 | n = 0 150 | while True: 151 | if n != max_tokens: 152 | next_y, next_logprobs = _step(y) 153 | mx.async_eval(next_y, next_logprobs) 154 | if n == 0: 155 | mx.eval(y) 156 | prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) 157 | if n == max_tokens: 158 | break 159 | # Output the current token & logprobs 160 | yield y.item(), logprobs 161 | if n % 256 == 0: 162 | mx.clear_cache() 163 | y, logprobs = next_y, next_logprobs 164 | n += 1 165 | 166 | 167 | ############################################################################## 168 | # The "generate_lite" Function 169 | ############################################################################## 170 | 171 | def generate_lite( 172 | model: nn.Module, 173 | prompt: mx.array, 174 | *, 175 | max_tokens: int = 256, 176 | sampler: Optional[Callable[[mx.array], mx.array]] = None, 177 | logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, 178 | max_kv_size: Optional[int] = None, 179 | prompt_cache: Optional[Any] = None, 180 | prefill_step_size: int = 512, 181 | kv_bits: Optional[int] = None, 182 | kv_group_size: int = 64, 183 | quantized_kv_start: int = 0, 184 | prompt_progress_callback: Optional[Callable[[int, int], None]] = None, 185 | stop_tokens: Optional[List[int]] = None, 186 | verbose: bool = False, 187 | ): 188 | """ 189 | A compact function that generates tokens from an mx.array prompt, 190 | without requiring any tokenizer. It supports: 191 | - caching (prompt_cache) 192 | - samplers 193 | - logits processors 194 | - custom stopping tokens 195 | - kv cache quantization 196 | - prefill steps 197 | - optional verbose logging 198 | 199 | Args: 200 | model (nn.Module): The model to use for generation. 201 | prompt (mx.array): The prompt tokens. 202 | max_tokens (int): Maximum new tokens to generate. 203 | sampler (Callable[[mx.array], mx.array], optional): Sampler for picking the next token 204 | from logprobs. Defaults to argmax if not provided. 205 | logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): 206 | Functions to transform the logits at each step, e.g. repetition penalty. 207 | max_kv_size (int, optional): Maximum capacity of the KV cache. 208 | prompt_cache (Any, optional): Existing cache to reuse; updated in-place. 209 | prefill_step_size (int): How many tokens to feed at once for the prompt. 210 | kv_bits (int, optional): Bits for KV cache quantization. 211 | kv_group_size (int): Group size for KV quantization. 212 | quantized_kv_start (int): Step index at which to begin quantizing the KV cache. 213 | prompt_progress_callback (Callable[[int, int], None], optional): 214 | Callback that receives (#tokens_processed, total_prompt_tokens). 215 | stop_tokens (List[int], optional): If a generated token is in this list, generation stops. 216 | verbose (bool): Print basic debug info (timing, memory usage, etc.). 217 | 218 | Returns: 219 | mx.array: The concatenated tokens (original prompt + newly generated tokens). 220 | """ 221 | if stop_tokens is None: 222 | stop_tokens = [] 223 | lps = [] 224 | with wired_limit(model, [generation_stream]): 225 | start_time = time.perf_counter() 226 | generated_tokens = [] 227 | 228 | # Loop over generate_step 229 | for i, (token, logprobs) in enumerate( 230 | generate_step( 231 | prompt, 232 | model, 233 | max_tokens=max_tokens, 234 | sampler=sampler, 235 | logits_processors=logits_processors, 236 | max_kv_size=max_kv_size, 237 | prompt_cache=prompt_cache, 238 | prefill_step_size=prefill_step_size, 239 | kv_bits=kv_bits, 240 | kv_group_size=kv_group_size, 241 | quantized_kv_start=quantized_kv_start, 242 | prompt_progress_callback=prompt_progress_callback, 243 | ) 244 | ): 245 | # On the first iteration, measure how long the prompt took 246 | if i == 0: 247 | prompt_time = time.perf_counter() - start_time 248 | prompt_tps = (prompt.size / prompt_time) if prompt_time > 0 else 0.0 249 | # Reset timer for generation 250 | start_time = time.perf_counter() 251 | generated_tokens.append(token) 252 | lps.append(logprobs[token]) 253 | 254 | # Stop if we hit any user-defined stop token 255 | if token in stop_tokens: 256 | break 257 | 258 | # Final stats 259 | generation_time = time.perf_counter() - start_time 260 | generation_tps = (len(generated_tokens) / generation_time) if generation_time > 0 else 0.0 261 | 262 | # Print debug info if requested 263 | if verbose: 264 | print("=" * 10) 265 | if len(generated_tokens) == 0: 266 | print("No tokens generated for this prompt.") 267 | else: 268 | print(f"Prompt: {prompt.size} tokens, {prompt_tps:.3f} tokens/sec") 269 | print( 270 | f"Generation: {len(generated_tokens)} tokens, " 271 | f"{generation_tps:.3f} tokens/sec" 272 | ) 273 | used_mem_gb = mx.get_peak_memory() / 1e9 274 | print(f"Peak memory: {used_mem_gb:.3f} GB") 275 | lps_avg = sum(lps) 276 | # Return the combined sequence: original prompt + newly generated tokens 277 | if generated_tokens: 278 | return mx.array(generated_tokens, dtype=prompt.dtype), lps_avg 279 | else: 280 | return mx.array([], dtype=prompt.dtype), lps_avg 281 | 282 | def beam_search(model, input_tokens, max_tokens=512, verbose=False, n_beams=4, stop_tokens=None): 283 | """ 284 | Perform beam search to generate text from the model. 285 | """ 286 | # Repeat the input for each beam and initialize beam scores. 287 | beams = mx.repeat(mx.array([input_tokens]), n_beams, axis=0).tolist() 288 | beam_scores = [0] * n_beams 289 | finished_beams = [] 290 | l_prefix = len(input_tokens) # To later remove the input prefix from the output. 291 | 292 | for step in range(max_tokens): 293 | # Use the current number of beams instead of the constant n_beams. 294 | current_beam_count = len(beams) 295 | logits = model(mx.array(beams))[:, -1, :] # Get logits for the last token in each beam. 296 | logprobs = nn.log_softmax(logits, axis=-1) # Convert logits to log probabilities. 297 | # For each beam, pick the top n_beams candidate tokens. 298 | top_indices = mx.argsort(-logprobs, axis=-1)[:, :n_beams] 299 | #top_logprobs = logprobs[mx.arange(current_beam_count), top_indices] 300 | top_logprobs = mx.take_along_axis(logprobs, top_indices, axis=-1) 301 | top_indices = top_indices.tolist() 302 | top_logprobs = top_logprobs.tolist() 303 | 304 | # Build candidate extensions for each beam. 305 | beam_possibilities = [] 306 | for beam_idx in range(current_beam_count): 307 | token_and_score = [] 308 | for k in range(n_beams): 309 | token_and_score.append((top_indices[beam_idx][k], top_logprobs[beam_idx][k])) 310 | beam_possibilities.append(token_and_score) 311 | 312 | # Extend each current beam with every candidate token. 313 | new_beams = [] 314 | for beam_idx in range(current_beam_count): 315 | base_beam = beams[beam_idx] 316 | base_beam_score = beam_scores[beam_idx] 317 | for token, logprob in beam_possibilities[beam_idx]: 318 | new_beam = base_beam + [token] 319 | mix = 1/(len(new_beam) - l_prefix) 320 | #new_score = base_beam_score * (1-mix) + logprob * (mix) 321 | new_score = base_beam_score + logprob 322 | new_beams.append((new_beam, new_score)) 323 | 324 | # Sort and de-duplicate the candidate beams. 325 | seen_beams = set() 326 | dedup_new_beams = [] 327 | for beam, score in new_beams: 328 | hash_beam = tuple(beam) 329 | if hash_beam not in seen_beams: 330 | seen_beams.add(hash_beam) 331 | dedup_new_beams.append((beam, score)) 332 | new_beams = dedup_new_beams 333 | new_beams.sort(key=lambda x: x[1], reverse=True) 334 | # Select the top candidates while checking for stop tokens. 335 | chosen_beams = [] 336 | while len(chosen_beams) < n_beams and new_beams: 337 | possible_beam, possible_score = new_beams.pop(0) 338 | if stop_tokens is not None and possible_beam[-1] in stop_tokens: 339 | if len(possible_beam) - l_prefix == 1: # its just an EOS token 340 | possible_score = -float('inf') # Penalize EOS to avoid it being chosen unless it's the only option. 341 | finished_beams.append((possible_beam[:-1], possible_score)) 342 | n_beams -= 1 # Reduce the beam count since we finished one. 343 | else: 344 | chosen_beams.append((possible_beam, possible_score)) 345 | 346 | # Update the beams and scores for the next iteration. 347 | beams = [beam for beam, score in chosen_beams] 348 | beam_scores = [score for beam, score in chosen_beams] 349 | 350 | # Exit early if no beams are left to extend. 351 | if len(beams) == 0: 352 | if verbose: 353 | print("All beams finished.") 354 | break 355 | 356 | # If no beams finished with a stop token, use the current beams. 357 | if not finished_beams: 358 | finished_beams = list(zip(beams, beam_scores)) 359 | else: 360 | finished_beams.extend( 361 | [(beam, score) for beam, score in zip(beams, beam_scores) if len(beam) > l_prefix] 362 | ) 363 | finished_beams.sort(key=lambda x: x[1], reverse=True) 364 | # Remove the input prefix from the output beams. 365 | finished_beams = [(beam[l_prefix:], score) for beam, score in finished_beams] 366 | return finished_beams 367 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Dict, Any, Optional, List, Tuple 6 | import yaml 7 | import mlx.optimizers as optim 8 | import mlx_optimizers as optim_x 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import numpy as np 12 | from tqdm import tqdm 13 | import time 14 | from datetime import datetime 15 | import os 16 | #from mlx_lm.models.llama import Model, ModelArgs 17 | import importlib 18 | from mlx.utils import tree_flatten, tree_map, tree_unflatten 19 | import inspect 20 | 21 | def filter_valid_args(cls, arg_dict): 22 | valid_params = inspect.signature(cls).parameters 23 | return {k: v for k, v in arg_dict.items() if k in valid_params} 24 | 25 | 26 | @dataclass 27 | class DataConfig: 28 | input_file: str 29 | preprocessing: Dict[str, int] 30 | tokenizer: Dict[str, Any] 31 | tokenizer_path: Optional[str] = None # Path to a directory containing a tokenizer.json file 32 | validation_file: Optional[str] = None 33 | weight_path: Optional[str] = None 34 | 35 | @dataclass 36 | class ModelConfig: 37 | architecture: str 38 | dimensions: Dict[str, int] 39 | attention: Dict[str, Any] 40 | normalization: Dict[str, float] 41 | rope: Dict[str, Any] 42 | misc: Dict[str, bool] 43 | 44 | @dataclass 45 | class TrainingConfig: 46 | hyperparameters: Dict[str, Any] 47 | scheduler: Dict[str, Any] 48 | optimization: Dict[str, Any] 49 | epochs: Optional[int] = None 50 | 51 | @dataclass 52 | class LoggingConfig: 53 | log_dir: str 54 | checkpoint_dir: str 55 | steps: Dict[str, int] 56 | metrics: Dict[str, bool] 57 | # Default to 0 (no validation) if not specified 58 | 59 | @dataclass 60 | class SystemConfig: 61 | seed: int 62 | device: str 63 | 64 | @dataclass 65 | class ResumeConfig: 66 | checkpoint: str # Path to checkpoint base name 67 | reset_optimizer: bool = False # Optional flag to reset optimizer state 68 | 69 | @dataclass 70 | class Config: 71 | name: str # New field for run name 72 | data: DataConfig 73 | model: ModelConfig 74 | training: TrainingConfig 75 | logging: LoggingConfig 76 | system: SystemConfig 77 | resume: Optional[ResumeConfig] = None 78 | overwrite: bool = False 79 | 80 | @classmethod 81 | def from_yaml(cls, yaml_path: str) -> 'Config': 82 | with open(yaml_path, 'r') as f: 83 | config_dict = yaml.safe_load(f) 84 | 85 | # Validate that name is present 86 | if 'name' not in config_dict: 87 | raise ValueError("Config must specify a 'name' field at the top level") 88 | 89 | # Extract epochs if it exists at the top level of training config 90 | training_config = config_dict['training'].copy() 91 | epochs = training_config.pop('epochs', None) 92 | 93 | # Extract resume config if present 94 | resume_config = None 95 | if 'resume' in config_dict: 96 | resume_config = ResumeConfig(**config_dict['resume']) 97 | 98 | return cls( 99 | name=config_dict['name'], 100 | overwrite=config_dict.get('overwrite', False), 101 | data=DataConfig(**config_dict['data']), 102 | model=ModelConfig(**config_dict['model']), 103 | training=TrainingConfig(**training_config, epochs=epochs), 104 | logging=LoggingConfig(**config_dict['logging']), 105 | system=SystemConfig(**config_dict['system']), 106 | resume=resume_config 107 | ) 108 | 109 | class CheckpointManager: 110 | @staticmethod 111 | def validate_unique_name(name: str) -> None: 112 | """Validates that the run directory doesn't already exist""" 113 | run_path = Path('runs') / name 114 | if run_path.exists(): 115 | raise ValueError(f"Run directory already exists for name '{name}'") 116 | 117 | @staticmethod 118 | def setup_run_directory(name: str) -> tuple[Path, Path, Path]: 119 | """Creates and returns paths for run directory structure""" 120 | run_dir = Path('runs') / name 121 | checkpoint_dir = run_dir / 'checkpoints' 122 | 123 | # Create directory structure 124 | run_dir.mkdir(parents=True, exist_ok=True) 125 | checkpoint_dir.mkdir(exist_ok=True) 126 | 127 | return run_dir, run_dir / 'log.txt', checkpoint_dir 128 | 129 | @staticmethod 130 | def get_checkpoint_paths(checkpoint_path: str) -> tuple[str, str, str]: 131 | """Returns the paths for model, optimizer, and state files""" 132 | model_path = f"{checkpoint_path}_model.safetensors" 133 | optimizer_path = f"{checkpoint_path}_optimizer.safetensors" 134 | state_path = f"{checkpoint_path}_state.json" 135 | return model_path, optimizer_path, state_path 136 | 137 | class TokenizerManager: 138 | def __init__(self, config: DataConfig, run_dir: Optional[Path] = None): 139 | self.config = config 140 | self.external_tokenizer = None 141 | 142 | # Check if an external tokenizer path is provided 143 | if config.tokenizer_path is not None: 144 | self.use_external_tokenizer(config.tokenizer_path) 145 | 146 | # If we have a run directory, copy the tokenizer to it 147 | if run_dir is not None: 148 | self.copy_tokenizer_to_run_dir(config.tokenizer_path, run_dir) 149 | else: 150 | # Fall back to byte-level tokenization 151 | self.setup_vocabulary() 152 | 153 | def use_external_tokenizer(self, tokenizer_path: str): 154 | """Load and use an external tokenizer from the specified path.""" 155 | from tokenizers import Tokenizer 156 | import os 157 | tokenizer_file = os.path.join(tokenizer_path, "tokenizer.json") 158 | 159 | if not os.path.exists(tokenizer_file): 160 | raise ValueError(f"Tokenizer file not found at {tokenizer_file}") 161 | 162 | print(f"Loading external tokenizer from {tokenizer_file}") 163 | self.external_tokenizer = Tokenizer.from_file(tokenizer_file) 164 | 165 | # Extract special token IDs 166 | vocab = self.external_tokenizer.get_vocab() 167 | special_tokens = self.config.tokenizer['special_tokens'] 168 | 169 | # Map special tokens to their IDs 170 | self.PAD_TOKEN = vocab.get(special_tokens['pad']) 171 | self.BOS_TOKEN = vocab.get(special_tokens['bos']) 172 | self.EOS_TOKEN = vocab.get(special_tokens['eos']) 173 | self.VOCAB_SIZE = len(vocab) 174 | 175 | if self.PAD_TOKEN is None or self.BOS_TOKEN is None or self.EOS_TOKEN is None: 176 | raise ValueError(f"One or more special tokens not found in the external tokenizer vocabulary") 177 | 178 | def copy_tokenizer_to_run_dir(self, tokenizer_path: str, run_dir: Path): 179 | """Copy the tokenizer files to the run directory.""" 180 | import shutil 181 | import os 182 | 183 | # Create tokenizer directory in run_dir 184 | run_tokenizer_dir = run_dir / 'tokenizer' 185 | os.makedirs(run_tokenizer_dir, exist_ok=True) 186 | 187 | # Copy tokenizer.json 188 | tokenizer_file = os.path.join(tokenizer_path, "tokenizer.json") 189 | shutil.copy2(tokenizer_file, run_tokenizer_dir / "tokenizer.json") 190 | 191 | print(f"Copied tokenizer to {run_tokenizer_dir}") 192 | 193 | def setup_vocabulary(self): 194 | """Set up the byte-level tokenization vocabulary.""" 195 | normal_vocab_size = self.config.tokenizer['normal_vocab_size'] 196 | special_tokens = self.config.tokenizer['special_tokens'] 197 | 198 | # Create vocabulary mapping 199 | self.special_token_map = { 200 | token: normal_vocab_size + idx 201 | for idx, token in enumerate(special_tokens.values()) 202 | } 203 | 204 | # Store common tokens 205 | self.PAD_TOKEN = self.special_token_map[special_tokens['pad']] 206 | self.BOS_TOKEN = self.special_token_map[special_tokens['bos']] 207 | self.EOS_TOKEN = self.special_token_map[special_tokens['eos']] 208 | self.VOCAB_SIZE = normal_vocab_size + len(self.special_token_map) 209 | 210 | def tokenize(self, text: str) -> list: 211 | if self.external_tokenizer is not None: 212 | # Use external tokenizer 213 | encoded = self.external_tokenizer.encode(text) 214 | return encoded.ids 215 | else: 216 | # Use byte-level tokenization 217 | return list(text.encode('utf-8')) 218 | 219 | def detokenize(self, tokens: list) -> str: 220 | if self.external_tokenizer is not None: 221 | # Use external tokenizer 222 | return self.external_tokenizer.decode(tokens.tolist()) 223 | else: 224 | # Use byte-level detokenization 225 | return bytes(tokens).decode('utf-8', errors='ignore') 226 | 227 | def tokenize_doc(self, doc: str) -> list: 228 | """Tokenize a document, ensuring it doesn't exceed the max context size. 229 | 230 | Args: 231 | doc: The text to tokenize 232 | 233 | Returns: 234 | A list of token IDs, including BOS and EOS tokens 235 | """ 236 | max_length = self.config.preprocessing['max_context_size'] 237 | 238 | if self.external_tokenizer is not None: 239 | # Use external tokenizer 240 | encoded = self.external_tokenizer.encode(doc) 241 | tokens = encoded.ids[:max_length] 242 | return [self.BOS_TOKEN] + tokens + [self.EOS_TOKEN] 243 | else: 244 | # Use byte-level tokenization 245 | return [self.BOS_TOKEN] + self.tokenize(doc)[:max_length] + [self.EOS_TOKEN] 246 | 247 | class DataManager: 248 | def __init__(self, config: DataConfig, tokenizer: TokenizerManager, batch_size: int = 1): 249 | self.config = config 250 | self.tokenizer = tokenizer 251 | self.train_docs = [] 252 | self.val_docs = [] 253 | self.train_idx = None 254 | self.val_idx = None 255 | self.batch_size = batch_size 256 | self.load_data() 257 | 258 | def load_data(self): 259 | # Load training data 260 | self._load_file(self.config.input_file, self.train_docs) 261 | 262 | # Set up training batches 263 | self.train_idx = sorted(range(len(self.train_docs)), key=lambda idx: len(self.train_docs[idx])) 264 | random.shuffle(self.train_idx) 265 | self.train_batch_idx = [ 266 | self.train_idx[i : i + self.batch_size : 1] 267 | for i in range(0, len(self.train_idx) - self.batch_size + 1, self.batch_size) 268 | ] 269 | self.train_indices = np.random.permutation(len(self.train_batch_idx)) 270 | 271 | # Load validation data if specified 272 | if self.config.validation_file: 273 | self._load_file(self.config.validation_file, self.val_docs) 274 | 275 | # Set up validation batches 276 | self.val_idx = sorted(range(len(self.val_docs)), key=lambda idx: len(self.val_docs[idx])) 277 | self.val_batch_idx = [ 278 | self.val_idx[i : i + self.batch_size : 1] 279 | for i in range(0, len(self.val_idx) - self.batch_size + 1, self.batch_size) 280 | ] 281 | self.val_indices = np.random.permutation(len(self.val_batch_idx)) 282 | self.val_ptr = 0 # Pointer for validation batches 283 | 284 | def _load_file(self, file_path: str, docs_list: list): 285 | """Helper method to load documents from a file.""" 286 | with open(file_path, 'r') as f: 287 | for line in f: 288 | d = json.loads(line) 289 | text = d["text"] 290 | """chunk_size = self.config.preprocessing['max_context_size'] 291 | overlap = self.config.preprocessing.get('chunk_overlap', 0) 292 | 293 | # Handle overlapping chunks if specified 294 | stride = chunk_size - overlap 295 | for i in range(0, len(text), stride): 296 | chunk_text = text[i : i + chunk_size] 297 | docs_list.append(chunk_text)""" 298 | docs_list.append(text) 299 | 300 | def generate_batch(self, step: int) -> mx.array: 301 | """Generate a training batch.""" 302 | indices = self.train_batch_idx[self.train_indices[step % len(self.train_indices)]] 303 | return self._create_batch([self.train_docs[i] for i in indices]) 304 | 305 | def generate_validation_batch(self, batch_idx: int) -> mx.array: 306 | """Generate a validation batch.""" 307 | if not self.config.validation_file or batch_idx >= len(self.val_batch_idx): 308 | raise ValueError("No validation data available or batch index out of range") 309 | 310 | indices = self.val_batch_idx[self.val_indices[self.val_ptr % len(self.val_indices)]] 311 | self.val_ptr += 1 312 | return self._create_batch([self.val_docs[i] for i in indices]) 313 | 314 | def _create_batch(self, docs: list) -> mx.array: 315 | """Helper method to create and pad a batch from documents.""" 316 | batch = [self.tokenizer.tokenize_doc(doc) for doc in docs] 317 | max_len = max(len(x) for x in batch) 318 | 319 | # Pad sequences 320 | for i in range(len(batch)): 321 | batch[i] += [self.tokenizer.PAD_TOKEN] * (max_len - len(batch[i])) 322 | 323 | return mx.array(batch) 324 | 325 | @property 326 | def has_validation_data(self) -> bool: 327 | """Check if validation data is available.""" 328 | return self.config.validation_file is not None and len(self.val_docs) > 0 329 | 330 | @property 331 | def num_validation_batches(self) -> int: 332 | """Get the number of validation batches.""" 333 | return len(self.val_batch_idx) if self.has_validation_data else 0 334 | 335 | class OptimizationManager: 336 | def __init__(self, config: TrainingConfig, num_training_steps: int): 337 | self.config = config 338 | self.num_training_steps = num_training_steps 339 | 340 | def create_scheduler(self) -> Any: 341 | cfg = self.config.scheduler 342 | initial_lr = self.config.hyperparameters['learning_rate'] 343 | 344 | if cfg['type'] == 'cosine_with_warmup': 345 | warmup = optim.linear_schedule(0, initial_lr, steps=cfg['warmup_steps']) 346 | cosine = optim.cosine_decay(initial_lr, self.num_training_steps, initial_lr * cfg['min_lr_ratio']) 347 | return optim.join_schedules([warmup, cosine], [cfg['warmup_steps']]) 348 | elif cfg['type'] == 'cosine': 349 | return optim.cosine_decay(initial_lr, self.num_training_steps, initial_lr * cfg['min_lr_ratio']) 350 | elif cfg['type'] == 'linear': 351 | return optim.linear_schedule(initial_lr, 0, steps=self.num_training_steps) 352 | else: 353 | raise ValueError(f"Unsupported scheduler type: {cfg['type']}") 354 | 355 | def create_optimizer(self, schedule: Any) -> optim.Optimizer: 356 | cfg = self.config.optimization 357 | kwargs = { 358 | 'learning_rate': schedule, 359 | } 360 | if 'betas' in cfg: 361 | kwargs['betas'] = tuple(cfg['betas']) 362 | if 'eps' in cfg: 363 | kwargs['eps'] = cfg['eps'] 364 | if 'weight_decay' in cfg: 365 | kwargs['weight_decay'] = self.config.hyperparameters['weight_decay'] 366 | if cfg['optimizer'] == 'adamw': 367 | return optim.AdamW(**kwargs) 368 | elif cfg['optimizer'] == 'adam': 369 | return optim.Adam(**kwargs) 370 | elif cfg['optimizer'] == 'muon': 371 | return optim_x.Muon(**kwargs, alternate_optimizer=optim.AdamW(**kwargs)) 372 | elif cfg['optimizer'] == 'sgd': 373 | return optim.SGD(**kwargs) 374 | else: 375 | raise ValueError(f"Unsupported optimizer: {cfg['optimizer']}") 376 | 377 | class Trainer: 378 | def __init__(self, config_path: str, for_training=True): 379 | self.config = Config.from_yaml(config_path) 380 | self.config_path = config_path 381 | 382 | # Initialize tracking variables 383 | self.total_tokens = 0 384 | self.start_step = 0 385 | 386 | # Validate unique run name before proceeding 387 | if for_training and not self.config.overwrite and not (self.config.resume and self.config.resume.checkpoint): 388 | CheckpointManager.validate_unique_name(self.config.name) 389 | 390 | self.setup_system() 391 | 392 | # Create run directory early so we can copy tokenizer to it 393 | if for_training: 394 | self.run_dir, self.log_file, self.checkpoint_dir = CheckpointManager.setup_run_directory(self.config.name) 395 | else: 396 | self.run_dir = None 397 | 398 | # Initialize tokenizer with run directory if available 399 | self.tokenizer = TokenizerManager(self.config.data, self.run_dir) 400 | 401 | self.setup_model() 402 | if for_training: 403 | self.data_manager = DataManager(self.config.data, self.tokenizer, batch_size=self.config.training.hyperparameters['batch_size']) 404 | self.setup_training() 405 | self.setup_logging() 406 | 407 | # Initialize validation metrics tracking 408 | self.validation_steps = self.config.logging.steps.get('validation_interval', 0) 409 | self.validation_losses = [] 410 | 411 | def setup_system(self): 412 | # Set random seeds 413 | random.seed(self.config.system.seed) 414 | np.random.seed(self.config.system.seed) 415 | mx.random.seed(self.config.system.seed) 416 | 417 | def setup_model(self): 418 | model_cfg = self.config.model 419 | arch_file = f"arch.{model_cfg.architecture}" 420 | mlx_lm_file = f"mlx_lm.models.{model_cfg.architecture}" 421 | Model = None 422 | ModelArgs = None 423 | try: 424 | module = importlib.import_module(arch_file) 425 | Model = getattr(module, 'Model') 426 | ModelArgs = getattr(module, 'ModelArgs') 427 | except ImportError: 428 | try: 429 | module = importlib.import_module(mlx_lm_file) 430 | Model = getattr(module, 'Model') 431 | ModelArgs = getattr(module, 'ModelArgs') 432 | except ImportError: 433 | raise ImportError(f"Model architecture '{model_cfg.architecture}' not found in both {arch_file} and {mlx_lm_file}") 434 | 435 | all_args = { 436 | 'model_type': model_cfg.architecture, 437 | 'hidden_size': model_cfg.dimensions['hidden_size'], 438 | 'num_hidden_layers': model_cfg.dimensions.get('num_layers', 8), 439 | 'intermediate_size': model_cfg.dimensions['intermediate_size'], 440 | 'num_attention_heads': model_cfg.attention['num_heads'], 441 | 'rms_norm_eps': model_cfg.normalization['rms_norm_eps'], 442 | 'vocab_size': self.tokenizer.VOCAB_SIZE, 443 | 'head_dim': model_cfg.attention['head_dim'], 444 | 'max_position_embeddings': model_cfg.attention['max_position_embeddings'], 445 | 'num_key_value_heads': model_cfg.attention['num_kv_heads'], 446 | 'attention_bias': model_cfg.misc['attention_bias'], 447 | 'mlp_bias': model_cfg.misc['mlp_bias'], 448 | 'rope_theta': model_cfg.rope['theta'], 449 | 'rope_traditional': model_cfg.rope['traditional'], 450 | 'rope_scaling': model_cfg.rope['scaling'], 451 | 'tie_word_embeddings': model_cfg.misc['tie_word_embeddings'], 452 | 'logit_scale': model_cfg.misc.get('logit_scale', None), 453 | 'num_local_experts': model_cfg.misc.get('num_local_experts', 0), 454 | 'num_experts_per_tok': model_cfg.misc.get('num_experts_per_tok', 0), 455 | } 456 | valid_args = filter_valid_args(ModelArgs, all_args) 457 | args = ModelArgs(**valid_args) 458 | 459 | self.model = Model(args) 460 | 461 | if self.config.data.weight_path is not None: 462 | print(f"Loading weights from {self.config.data.weight_path}") 463 | self.model.load_weights(self.config.data.weight_path, strict=False) 464 | # Log model size 465 | p = sum(v.size for _, v in tree_flatten(self.model.trainable_parameters())) / 10**6 466 | print(f"Model has {p:.2f}M parameters") 467 | 468 | def setup_training(self): 469 | # Calculate number of training steps 470 | num_samples = len(self.data_manager.train_docs) 471 | batch_size = self.config.training.hyperparameters['batch_size'] 472 | steps_per_epoch = num_samples // batch_size 473 | 474 | if self.config.training.epochs is not None: 475 | # If epochs is set, calculate total steps based on epochs 476 | self.total_steps = steps_per_epoch * self.config.training.epochs 477 | else: 478 | # Otherwise use specified iters or default to one epoch 479 | self.total_steps = self.config.training.hyperparameters.get('iters', steps_per_epoch) 480 | 481 | # Store steps_per_epoch for logging 482 | self.steps_per_epoch = steps_per_epoch 483 | 484 | # Setup optimization 485 | opt_manager = OptimizationManager(self.config.training, self.total_steps) 486 | self.lr_schedule = opt_manager.create_scheduler() 487 | self.optimizer = opt_manager.create_optimizer(self.lr_schedule) 488 | 489 | def setup_logging(self): 490 | # Run directory structure should already be set up in __init__ 491 | 492 | # Create initial metadata file 493 | metadata = { 494 | 'name': self.config.name, 495 | 'created_at': datetime.now().isoformat(), 496 | 'config': { 497 | 'model': self.config.model.__dict__, 498 | 'training': self.config.training.__dict__, 499 | 'system': self.config.system.__dict__ 500 | }, 501 | 'training_info': { 502 | 'steps_per_epoch': self.steps_per_epoch, 503 | 'total_steps': self.total_steps, 504 | 'epochs': self.config.training.epochs 505 | } 506 | } 507 | 508 | # Add tokenizer information to metadata 509 | if self.config.data.tokenizer_path: 510 | metadata['tokenizer'] = { 511 | 'type': 'external', 512 | 'path': self.config.data.tokenizer_path, 513 | 'vocab_size': self.tokenizer.VOCAB_SIZE 514 | } 515 | else: 516 | metadata['tokenizer'] = { 517 | 'type': 'byte-level', 518 | 'vocab_size': self.tokenizer.VOCAB_SIZE 519 | } 520 | 521 | with open(self.run_dir / 'metadata.json', 'w') as f: 522 | json.dump(metadata, f, indent=2) 523 | 524 | # Save the config used to the run directory 525 | with open(self.run_dir / 'config.yaml', 'w') as f: 526 | with open(self.config_path, 'r') as config_file: 527 | f.write(config_file.read()) 528 | 529 | def compute_loss(self, model, inputs: mx.array, targets: mx.array) -> Tuple[mx.array, int]: 530 | logits = model(inputs) 531 | logits = logits.astype(mx.float32) 532 | loss = nn.losses.cross_entropy(logits, targets) 533 | # Mask padding tokens 534 | pad_mask = (targets != self.tokenizer.PAD_TOKEN) 535 | loss = loss * pad_mask 536 | ntoks = pad_mask.sum() 537 | 538 | return loss.sum() / ntoks, ntoks 539 | 540 | def validate(self) -> float: 541 | """Run validation on the validation dataset. 542 | 543 | Returns: 544 | float: Average validation loss 545 | """ 546 | if not self.data_manager.has_validation_data: 547 | return None 548 | 549 | # Ensure we're in evaluation mode (no need for gradients) 550 | total_loss = 0.0 551 | total_tokens = 0 552 | 553 | # Process all validation batches 554 | num_batches = min(self.data_manager.num_validation_batches, 50) # Cap at 50 batches to avoid too long validation 555 | 556 | for batch_idx in range(num_batches): 557 | batch = self.data_manager.generate_validation_batch(batch_idx) 558 | 559 | # Forward pass only 560 | loss, tokens = self.compute_loss(self.model, batch[:, :-1], batch[:, 1:]) 561 | 562 | # Accumulate metrics 563 | total_loss += float(loss) 564 | total_tokens += tokens 565 | 566 | # Clear GPU cache if needed 567 | if self.config.system.device == "gpu": 568 | mx.clear_cache() 569 | 570 | # Calculate average loss 571 | avg_loss = total_loss / num_batches 572 | 573 | return avg_loss 574 | 575 | def save_checkpoint(self, step: int | str, val_loss: float = None): 576 | # Save model weights 577 | weights = dict(tree_flatten(self.model.parameters())) 578 | model_path = self.checkpoint_dir / f'step_{step}_model.safetensors' 579 | mx.save_safetensors(str(model_path), weights) 580 | 581 | # Save optimizer state 582 | optimizer_state = dict(tree_flatten(self.optimizer.state)) 583 | optimizer_path = self.checkpoint_dir / f'step_{step}_optimizer.safetensors' 584 | mx.save_safetensors(str(optimizer_path), optimizer_state) 585 | 586 | # Save training state 587 | training_state = { 588 | 'step': step if isinstance(step, int) else self.total_steps, 589 | 'val_ptr': self.data_manager.val_ptr, 590 | 'total_tokens': self.total_tokens.item(), 591 | 'validation_losses': self.validation_losses, 592 | } 593 | state_path = self.checkpoint_dir / f'step_{step}_state.json' 594 | with open(state_path, 'w') as f: 595 | json.dump(training_state, f) 596 | 597 | # Update metadata with checkpoint info 598 | metadata_path = self.run_dir / 'metadata.json' 599 | with open(metadata_path, 'r') as f: 600 | metadata = json.load(f) 601 | 602 | if 'checkpoints' not in metadata: 603 | metadata['checkpoints'] = [] 604 | 605 | checkpoint_info = { 606 | 'step': step, 607 | 'timestamp': datetime.now().isoformat(), 608 | 'paths': { 609 | 'model': f'checkpoints/step_{step}_model.safetensors', 610 | 'optimizer': f'checkpoints/step_{step}_optimizer.safetensors', 611 | 'state': f'checkpoints/step_{step}_state.json' 612 | } 613 | } 614 | 615 | # Include validation loss if available 616 | if val_loss is not None: 617 | checkpoint_info['validation_loss'] = val_loss 618 | 619 | metadata['checkpoints'].append(checkpoint_info) 620 | 621 | with open(metadata_path, 'w') as f: 622 | json.dump(metadata, f, indent=2) 623 | 624 | def log_metrics(self, step: int, loss: float, tokens: int, 625 | total_tokens: int, start_time: float, val_loss: float = None) -> str: 626 | metrics = [] 627 | 628 | # Add epoch information if epochs are configured 629 | if self.config.training.epochs is not None: 630 | current_epoch = step // self.steps_per_epoch + 1 631 | epoch_step = step % self.steps_per_epoch + 1 632 | metrics.append(f"epoch={current_epoch}/{self.config.training.epochs} ({epoch_step}/{self.steps_per_epoch})") 633 | 634 | if self.config.logging.metrics['log_loss']: 635 | metrics.append(f"loss={loss:.3e}") 636 | 637 | # Add validation loss if available 638 | if val_loss is not None: 639 | metrics.append(f"val_loss={val_loss:.3e}") 640 | 641 | if self.config.logging.metrics['log_perplexity']: 642 | metrics.append(f"ppl={np.exp(loss):.2f}") 643 | 644 | # Add validation perplexity if available 645 | if val_loss is not None: 646 | metrics.append(f"val_ppl={np.exp(val_loss):.2f}") 647 | 648 | if self.config.logging.metrics['log_tokens_per_second']: 649 | tokens_per_sec = total_tokens / (1000 * (time.time() - start_time)) 650 | metrics.append(f"tok/s={tokens_per_sec:.2f}K") 651 | 652 | if self.config.logging.metrics['log_tokens_processed']: 653 | metrics.append(f"toks={tokens}") 654 | 655 | if self.config.logging.metrics['log_learning_rate']: 656 | metrics.append(f"lr={self.lr_schedule(step):.3e}") 657 | 658 | return " | ".join(metrics) 659 | 660 | def load_checkpoint(self, checkpoint_path: str, reset_optimizer: bool = False): 661 | """Load a checkpoint and restore model, optimizer, and training state""" 662 | # Extract step from checkpoint path 663 | step_str = checkpoint_path.split('step_')[-1] 664 | 665 | # Get checkpoint file paths 666 | model_path, optimizer_path, state_path = CheckpointManager.get_checkpoint_paths(checkpoint_path) 667 | 668 | # Load model weights 669 | print(f"Loading model weights from {model_path}") 670 | #weights = mx.load(model_path) 671 | self.model.load_weights(model_path) 672 | # Load optimizer state if not resetting 673 | if not reset_optimizer: 674 | print(f"Loading optimizer state from {optimizer_path}") 675 | state_dict = mx.load(optimizer_path) 676 | state = tree_unflatten(list(state_dict.items())) 677 | self.optimizer.state = state 678 | 679 | # Load training state 680 | print(f"Loading training state from {state_path}") 681 | with open(state_path, 'r') as f: 682 | training_state = json.load(f) 683 | 684 | # Restore training state 685 | self.start_step = training_state['step'] if isinstance(training_state['step'], int) else 0 686 | self.data_manager.val_ptr = training_state['val_ptr'] 687 | self.total_tokens = training_state['total_tokens'] 688 | self.validation_losses = training_state['validation_losses'] 689 | 690 | print(f"Resumed training from checkpoint {checkpoint_path} at step {self.start_step}") 691 | 692 | return self.start_step 693 | 694 | def train(self): 695 | # Initialize variables 696 | total_tokens = self.total_tokens 697 | start_step = 0 698 | 699 | # Check if resuming from checkpoint 700 | if self.config.resume and self.config.resume.checkpoint: 701 | checkpoint_path = self.config.resume.checkpoint 702 | reset_optimizer = self.config.resume.reset_optimizer 703 | start_step = self.load_checkpoint(checkpoint_path, reset_optimizer) 704 | 705 | # If we're resuming, we should skip the initial validation 706 | skip_initial_validation = True 707 | else: 708 | skip_initial_validation = False 709 | 710 | loss_value_and_grad = nn.value_and_grad(self.model, self.compute_loss) 711 | start_time = time.time() 712 | # Create progress bar with adjusted range for resuming 713 | progress_bar = tqdm(range(self.total_steps), desc="Training", initial=start_step) 714 | 715 | 716 | # Initialize logging 717 | with open(self.log_file, 'a' if start_step > 0 else 'w') as log_file: 718 | if start_step == 0: 719 | log_file.write(f"Training started at {datetime.now()}\n") 720 | log_file.write(f"Total steps: {self.total_steps}\n") 721 | if self.config.training.epochs is not None: 722 | log_file.write(f"Training for {self.config.training.epochs} epochs with {self.steps_per_epoch} steps per epoch\n") 723 | if self.data_manager.has_validation_data: 724 | log_file.write(f"Validation data: {self.config.data.validation_file}\n") 725 | log_file.write(f"Validation batches: {self.data_manager.num_validation_batches}\n") 726 | log_file.write("=" * 50 + "\n\n") 727 | else: 728 | log_file.write(f"\nResuming training at step {start_step} at {datetime.now()}\n") 729 | log_file.write(f"Remaining steps: {self.total_steps - start_step}\n") 730 | log_file.write("=" * 50 + "\n\n") 731 | 732 | # Log initial validation loss if validation data is available and not resuming 733 | val_loss = None 734 | if self.validation_steps > 0 and self.data_manager.has_validation_data and not skip_initial_validation: 735 | val_loss = self.validate() 736 | log_file.write(f"Initial validation loss: {val_loss:.4e} (ppl={np.exp(val_loss):.2f})\n\n") 737 | # Add to validation loss history 738 | self.validation_losses.append((0, val_loss)) 739 | 740 | for step in progress_bar: 741 | step += start_step 742 | if step >= self.total_steps: 743 | break 744 | # Generate batch 745 | batch = self.data_manager.generate_batch(step) 746 | 747 | # Forward and backward pass 748 | (loss, tokens), grad = loss_value_and_grad( 749 | self.model, batch[:, :-1], batch[:, 1:] 750 | ) 751 | 752 | # Gradient clipping if configured 753 | if 'gradient_clip' in self.config.training.hyperparameters: 754 | clip_value = self.config.training.hyperparameters['gradient_clip'] 755 | grad = tree_map(lambda x: mx.clip(x, -clip_value, clip_value), grad) 756 | 757 | # Update model 758 | total_tokens += tokens 759 | self.optimizer.update(self.model, grad) 760 | mx.eval(loss) 761 | 762 | if self.config.system.device == "gpu": 763 | mx.clear_cache() 764 | 765 | # Run validation 766 | if self.validation_steps > 0 and self.data_manager.has_validation_data and (step + 1) % self.validation_steps == 0: 767 | val_loss = self.validate() 768 | # Add to validation loss history 769 | self.validation_losses.append((step + 1, val_loss)) 770 | 771 | # Log validation separately for clear visibility 772 | val_metrics = f"val_loss={val_loss:.3e} | val_ppl={np.exp(val_loss):.2f}" 773 | log_file.write(f"Step {step + 1} validation: {val_metrics}\n") 774 | log_file.flush() 775 | 776 | # Logging 777 | if step % self.config.logging.steps['logging_interval'] == 0: 778 | # Only include val_loss if it was just calculated 779 | current_val_loss = val_loss if self.validation_steps > 0 and (step + 1) % self.validation_steps == 0 else None 780 | metrics = self.log_metrics(step, loss, tokens, total_tokens, start_time, current_val_loss) 781 | 782 | # Update progress bar 783 | progress_bar.set_description(metrics) 784 | 785 | # Write to log file 786 | log_message = f"Step {step}: {metrics}\n" 787 | log_file.write(log_message) 788 | log_file.flush() 789 | 790 | # Save checkpoint 791 | if (1 + step) % self.config.logging.steps['checkpoint_interval'] == 0: 792 | # Find the most recent validation loss if available 793 | last_val_loss = val_loss if val_loss is not None else None 794 | # Update total_tokens in the trainer instance for checkpoint saving 795 | self.total_tokens = total_tokens 796 | self.save_checkpoint(step + 1, last_val_loss) 797 | 798 | # Final validation 799 | final_val_loss = None 800 | if self.validation_steps > 0 and self.data_manager.has_validation_data: 801 | final_val_loss = self.validate() 802 | self.validation_losses.append((self.total_steps, final_val_loss)) 803 | 804 | # Save final checkpoint with validation loss 805 | self.total_tokens = total_tokens 806 | self.save_checkpoint("final", final_val_loss) 807 | 808 | # Save validation losses to metadata 809 | if self.validation_losses: 810 | metadata_path = self.run_dir / 'metadata.json' 811 | with open(metadata_path, 'r') as f: 812 | metadata = json.load(f) 813 | 814 | metadata['validation'] = { 815 | 'steps': [step for step, _ in self.validation_losses], 816 | 'losses': [float(loss) for _, loss in self.validation_losses] 817 | } 818 | 819 | with open(metadata_path, 'w') as f: 820 | json.dump(metadata, f, indent=2) 821 | 822 | # Write final summary 823 | with open(self.log_file, 'a') as log_file: 824 | log_file.write("\n" + "=" * 50 + "\n") 825 | log_file.write(f"Training completed at {datetime.now()}\n") 826 | log_file.write(f"Final training metrics: {metrics}\n") 827 | if final_val_loss is not None: 828 | log_file.write(f"Final validation loss: {final_val_loss:.4e} (ppl={np.exp(final_val_loss):.2f})\n") 829 | log_file.write(f"Total tokens processed: {total_tokens/1000:.2f}K\n") 830 | 831 | def main(): 832 | import argparse 833 | parser = argparse.ArgumentParser(description='Train a language model with MLX') 834 | parser.add_argument('--config', type=str, required=True, 835 | help='Path to YAML configuration file') 836 | args = parser.parse_args() 837 | # Make 'runs' directory if it doesn't exist 838 | os.makedirs('runs', exist_ok=True) 839 | trainer = Trainer(args.config) 840 | trainer.train() 841 | 842 | if __name__ == "__main__": 843 | main() 844 | --------------------------------------------------------------------------------