├── src ├── __init__.py ├── data_processing │ ├── book_processing.py │ ├── data_creation_ideas.md │ ├── baseline_chunker.py │ ├── convert_to_markdown.py │ ├── bm25_func.py │ ├── README.md │ ├── semantic_chunker.py │ ├── jsonl_utils.py │ ├── prepare_training_data.py │ └── convert_hf_dataset_format.py ├── finetuning │ ├── lora_config.yaml │ ├── download_qwen3.py │ ├── finetune_qwen3.sh │ └── convert_qwen3.py ├── inference │ ├── generate_qwen_vlm_notebook.py │ ├── generate-qwen-vlm.py │ └── generate_qwen3.py └── evaluations │ └── run_evaluations.py ├── .python-version ├── main.py ├── mlx-quantization ├── eval_mlx-community_GLM-4.5-Air-5bit_0.4.9_mmlu_pro_computer_science ├── requirements.txt ├── LICENSE ├── CHANGELOG.md ├── CLAUDE.md ├── README.md ├── dwq_quantization.ipynb └── awq_quantization.ipynb ├── CONVERT ├── eval_Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7_0.4.9_mmlu_pro_computer_science ├── eval_Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8_0.4.9_mmlu_pro_computer_science ├── eval_Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8_0.4.9_mmlu_pro_computer_science ├── eval_many_4.sh ├── eval_many2.sh ├── eval_many3.sh ├── eval_many_2a.md ├── convert_qwen_coder3.sh ├── convert_many2.sh ├── eval_many.sh ├── convert_qwen_coder.sh ├── convert_qwen_coder2.sh ├── conversion_recipies.md ├── results.md └── convert_many.sh ├── .gitignore ├── project_setup.sh ├── pyproject.toml ├── .cursor └── rules │ └── repo-overview.mdc ├── old_pyproj.txt ├── examples_scratchpad.sh ├── README.md └── split_pdf_pages.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.13.9 2 | -------------------------------------------------------------------------------- /src/data_processing/book_processing.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | print("Hello from mlx-finetune-demo!") 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /mlx-quantization/eval_mlx-community_GLM-4.5-Air-5bit_0.4.9_mmlu_pro_computer_science: -------------------------------------------------------------------------------- 1 | { 2 | "mmlu_pro_computer_science": { 3 | "alias": "computer_science", 4 | "exact_match,custom-extract": 0.7634146341463415, 5 | "exact_match_stderr,custom-extract": 0.021014183737081388 6 | } 7 | } -------------------------------------------------------------------------------- /CONVERT/eval_Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7_0.4.9_mmlu_pro_computer_science: -------------------------------------------------------------------------------- 1 | { 2 | "mmlu_pro_computer_science": { 3 | "alias": "computer_science", 4 | "exact_match,custom-extract": 0.7926829268292683, 5 | "exact_match_stderr,custom-extract": 0.020044980247224457 6 | } 7 | } -------------------------------------------------------------------------------- /CONVERT/eval_Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8_0.4.9_mmlu_pro_computer_science: -------------------------------------------------------------------------------- 1 | { 2 | "mmlu_pro_computer_science": { 3 | "alias": "computer_science", 4 | "exact_match,custom-extract": 0.7878048780487805, 5 | "exact_match_stderr,custom-extract": 0.02021693788475414 6 | } 7 | } -------------------------------------------------------------------------------- /CONVERT/eval_Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8_0.4.9_mmlu_pro_computer_science: -------------------------------------------------------------------------------- 1 | { 2 | "mmlu_pro_computer_science": { 3 | "alias": "computer_science", 4 | "exact_match,custom-extract": 0.7926829268292683, 5 | "exact_match_stderr,custom-extract": 0.020044980247224453 6 | } 7 | } -------------------------------------------------------------------------------- /src/finetuning/lora_config.yaml: -------------------------------------------------------------------------------- 1 | lora_parameters: 2 | rank: 256 # LoRA rank (dimension of the adapter matrices) 3 | dropout: 0.05 # Dropout applied to the LoRA matrices 4 | scale: 12.0 # Scaling factor for the LoRA update (higher means more influence) - original 10.0 5 | learning_rate: 8e-6 # Overrides LEARNING_RATE in the bash script 6 | -------------------------------------------------------------------------------- /CONVERT/eval_many_4.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | echo "mlx-community/XBai-o4-8bit" 4 | mlx_lm.evaluate --model mlx-community/mlx-community/XBai-o4-8bit --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 5 | 6 | 7 | echo "mlx-community/XBai-o4-4bit-DWQ" 8 | mlx_lm.evaluate --model mlx-community/XBai-o4-4bit-DWQ --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | DATA 2 | *.json 3 | **/*.json 4 | **/runs 5 | **/MLX-Llama* 6 | mlx-pretrain/runs/* 7 | mlx-pretrain/MLX-Llama* 8 | **/*.jsonl 9 | *.jsonl 10 | mlx-pretrain/MLX-* 11 | **/__pycache__ 12 | __pycache__ 13 | ADAPTERS 14 | mlx_models 15 | **/sacredhunger.txt 16 | **/allthekingsmen.txt 17 | Qwen*DWQ* 18 | **/*.egg-info 19 | mlx-quantization/models 20 | uv.lock 21 | text_output*.md 22 | .DS_Store 23 | output*.md 24 | -------------------------------------------------------------------------------- /CONVERT/eval_many2.sh: -------------------------------------------------------------------------------- 1 | echo "mlx-community/Qwen3-Coder-30B-A3B-Instruct-bf16" 2 | mlx_lm.evaluate --model mlx-community/Qwen3-Coder-30B-A3B-Instruct-bf16 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 3 | 4 | echo "mlx-community/Qwen3-30B-A3B-Thinking-2507-bf16" 5 | mlx_lm.evaluate --model mlx-community/Qwen3-30B-A3B-Thinking-2507-bf16 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template -------------------------------------------------------------------------------- /CONVERT/eval_many3.sh: -------------------------------------------------------------------------------- 1 | echo "mlx-community/cogito-v2-preview-llama-70B-4Bit" 2 | mlx_lm.evaluate --model mlx-community/cogito-v2-preview-llama-70B-4Bit --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 3 | 4 | echo "mlx-community/GLM-4.5-Air-5bit" 5 | mlx_lm.evaluate --model mlx-community/GLM-4.5-Air-5bit --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 6 | 7 | https://github.com/cs2764/mlx-quantization 8 | # dynamic quantization 9 | 10 | -------------------------------------------------------------------------------- /project_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e # Exit immediately if a command exits with a non-zero status 3 | 4 | # Upgrade pip and install poetry 5 | pip install --upgrade pip 6 | pip install poetry 7 | 8 | # Update the lock file if necessary 9 | poetry lock 10 | 11 | # Install dependencies and the project 12 | poetry install 13 | 14 | # Create and install the IPython kernel for the project 15 | python -m ipykernel install --sys-prefix --name=mlx3129 --display-name "MLX 3.12.9" 16 | 17 | echo "Jupyter kernel 'mlx3129' has been installed." 18 | 19 | 20 | echo "Project setup complete!" -------------------------------------------------------------------------------- /CONVERT/eval_many_2a.md: -------------------------------------------------------------------------------- 1 | echo "Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-7" 2 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 3 | 4 | echo "Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr1e-8" 5 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr1e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 6 | 7 | echo "Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-9" 8 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-9 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mlx-finetune-demo" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.13" 7 | dependencies = [ 8 | #"mlx @ git+https://github.com/ml-explore/mlx", 9 | "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git", 10 | "mlx", 11 | "mlx-vlm", 12 | #"mlx-lm", 13 | "datasets", 14 | "transformers", 15 | "huggingface_hub", 16 | "ipykernel", 17 | "jupyter", 18 | "ipywidgets", 19 | "torch", 20 | "torchvision", 21 | "lm_eval", 22 | "datasets", 23 | "accelerate", 24 | "sentencepiece", 25 | "protobuf", 26 | "evaluate", 27 | "hf_transfer", 28 | "gradio", 29 | "pymupdf", 30 | "pdf2image", 31 | ] 32 | -------------------------------------------------------------------------------- /mlx-quantization/requirements.txt: -------------------------------------------------------------------------------- 1 | # Core MLX Framework 2 | mlx-lm>=0.12.0 3 | 4 | # Machine Learning Libraries 5 | torch>=2.0.0 6 | transformers>=4.40.0 7 | tokenizers>=0.15.0 8 | datasets>=2.16.0 9 | accelerate>=0.25.0 10 | 11 | # Hugging Face Integration 12 | huggingface_hub>=0.20.0 13 | 14 | # Text Processing 15 | sentencepiece>=0.1.99 16 | protobuf>=4.21.0 17 | 18 | # Quantization Support 19 | bitsandbytes>=0.41.0 20 | 21 | # Jupyter Environment 22 | jupyter>=1.0.0 23 | jupyterlab>=4.0.0 24 | ipywidgets>=8.0.0 25 | 26 | # Progress Bars and Utilities 27 | tqdm>=4.65.0 28 | numpy>=1.24.0 29 | scipy>=1.10.0 30 | 31 | # File Handling 32 | safetensors>=0.4.0 33 | 34 | # Optional Performance Libraries 35 | psutil>=5.9.0 36 | matplotlib>=3.7.0 37 | seaborn>=0.12.0 38 | 39 | # Development Tools (Optional) 40 | black>=23.0.0 41 | flake8>=6.0.0 42 | isort>=5.12.0 -------------------------------------------------------------------------------- /.cursor/rules/repo-overview.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: 3 | globs: 4 | alwaysApply: true 5 | --- 6 | Most of this project's unique code lives in the `src` folder and the subfolders within it. 7 | There are multiple reference projects included within the project root. These folders are prefaced with `mlx-`, for example `mlx-lm`, `mlx-vlm` as well as `synthetic-data-kit` etc.. Never modify existing files in those folders. Literally never. Instead, when you would like to modify anything there, always make a copy in the `src/copies` folder and make your changes there. 8 | 9 | Also note that `mlx-examples` may have some overlap with some of the other projects, and should in most cases be disregarded, unless the information is not available in the other project. So, for example, for LLM-related inference and post-training examples, the `mlx-lm` folder should be the preferred source, but there may occasionally be information that is only available in the `mlx-examples` folder. -------------------------------------------------------------------------------- /src/data_processing/data_creation_ideas.md: -------------------------------------------------------------------------------- 1 | ARXIV PAPERS 2 | 1. Get arxiv paper and convert to markdown 3 | 2. Get the abstract and the first section - ask LLM to summarize these. 4 | 3. Prompt is: Given the above summary, write the next section of the paper titled: . 5 | 4. Repeat, but keep re-doing the previous steps, so each summary will have include more and more of the paper. 6 | 7 | Note: Need to have a pretty good rubric fo instructing the LLM as to how it should create its summaries. 8 | 9 | BOOK SECTION CONTINUATION 10 | 1. Get the first section. 11 | 3. Prompt is: This is an excerpt from a novel. Write the next {1 paragraphs, 2 pargraphs, etc.} of the book. Use the same style as the excerpt. Make sure that while stylistically similar, the new section moves the story forward and/or develops the characters and/or adds new information or in some way continues on meaningfully from the previous section. 12 | 4. Repeat. 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /CONVERT/convert_qwen_coder3.sh: -------------------------------------------------------------------------------- 1 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr3e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 3e-7 --group-size 32 --bits 6 2 | touch Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr3e-7/README.md 3 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr3e-7 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr3e-7 4 | 5 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr9e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 9e-8 --group-size 32 --bits 6 6 | touch Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr9e-8/README.md 7 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr9e-8 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr9e-8 8 | 9 | mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr3e-7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 10 | mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 11 | -------------------------------------------------------------------------------- /mlx-quantization/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 MLX Quantization Toolkit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /old_pyproj.txt: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "mlx_finetune_demo" 7 | version = "0.1.0" 8 | description = "MLX Finetuning Demo" 9 | authors = ["Jeff Coggshall <thenextlocalminima@gmail.com>"] 10 | readme = "README.md" 11 | packages = [ 12 | { include = "src" } 13 | ] 14 | license = "MIT" 15 | 16 | [tool.poetry.dependencies] 17 | python = ">=3.12" 18 | pip = "*" 19 | accelerate = "*" 20 | mlx = "*" 21 | mlx-lm = "*" 22 | mlx_optimizers = "*" 23 | markitdown = "*" 24 | PyYAML = "*" 25 | tokenizers = "*" 26 | numpy = "*" 27 | pandas = "*" 28 | matplotlib = "*" 29 | datasets = "*" 30 | transformers = "*" 31 | huggingface_hub = "*" 32 | hf_transfer = "*" 33 | ipykernel = "*" 34 | sentencepiece = "*" 35 | torch = "*" 36 | torchao = "*" 37 | torchvision = "*" 38 | torchaudio = "*" 39 | fairscale = "*" 40 | fire = "*" 41 | jax = "*" 42 | flax = "*" 43 | optax = "*" 44 | einops = "*" 45 | diffusers = "*" 46 | tqdm = "*" 47 | rank-bm25 = {git = "https://github.com/dorianbrown/rank_bm25.git"} 48 | sentence-transformers = "*" 49 | 50 | [tool.poetry.scripts] 51 | # Add command line scripts here 52 | 53 | -------------------------------------------------------------------------------- /CONVERT/convert_many2.sh: -------------------------------------------------------------------------------- 1 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 8e-7 --group-size 32 --bits 6 2 | touch Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-7/README.md 3 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-7 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-7 4 | 5 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr1e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 1e-8 --group-size 32 --bits 6 6 | touch Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr1e-8/README.md 7 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr1e-8 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr1e-8 8 | 9 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-9 --max-seq-length 2048 --batch-size 4 --learning-rate 5e-9 --group-size 32 --bits 6 10 | touch Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-9/README.md 11 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-9 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-9 12 | 13 | #======================================== 14 | -------------------------------------------------------------------------------- /src/data_processing/baseline_chunker.py: -------------------------------------------------------------------------------- 1 | # baseline_chunker.py 2 | import re 3 | from pathlib import Path 4 | from typing import List, Iterable 5 | 6 | def load_paragraphs(path: str | Path) -> List[str]: 7 | raw = Path(path).read_text(encoding="utf‑8") 8 | # collapse Windows/Mac line endings, then split on 2+ newlines 9 | return [p.strip() for p in re.split(r"\n{2,}", raw) if p.strip()] 10 | 11 | def chunk_paragraphs(paragraphs: Iterable[str], 12 | target_words: int = 1000, 13 | overlap_paragraphs: int = 0) -> List[str]: 14 | chunks, current, cur_count = [], [], 0 15 | for p in paragraphs: 16 | p_words = len(p.split()) 17 | # if adding this paragraph would push us *over* the target, 18 | # flush what we’ve got (unless empty) and start anew 19 | if current and cur_count + p_words > target_words: 20 | chunks.append("\n\n".join(current)) 21 | # start next chunk with optional overlap from the *end* 22 | current = current[-overlap_paragraphs:] if overlap_paragraphs else [] 23 | cur_count = sum(len(x.split()) for x in current) 24 | current.append(p) 25 | cur_count += p_words 26 | if current: 27 | chunks.append("\n\n".join(current)) 28 | return chunks 29 | 30 | if __name__ == "__main__": 31 | import argparse, json 32 | ap = argparse.ArgumentParser() 33 | ap.add_argument("book_path") 34 | ap.add_argument("--size", type=int, default=1000, 35 | help="≈ words per chunk (default 1000)") 36 | ap.add_argument("--overlap", type=int, default=0, 37 | help="paragraphs to repeat between chunks") 38 | args = ap.parse_args() 39 | 40 | paras = load_paragraphs(args.book_path) 41 | chunks = chunk_paragraphs(paras, args.size, args.overlap) 42 | print(json.dumps({"chunks": chunks, "count": len(chunks)}, indent=2)) 43 | -------------------------------------------------------------------------------- /CONVERT/eval_many.sh: -------------------------------------------------------------------------------- 1 | # mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr1e-6 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 2 | 3 | # echo "Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7" 4 | # mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 5 | 6 | echo "Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7" 7 | mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 8 | 9 | echo "Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8" 10 | mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 11 | 12 | echo "Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8" 13 | mlx_lm.evaluate --model Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 14 | 15 | echo "Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr1e-7" 16 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr1e-7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 17 | 18 | echo "Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr2e-7" 19 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr2e-7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 20 | 21 | echo "Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr3e-8" 22 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr3e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 23 | 24 | echo "Qwen3-30B-A3B-Instruct-2507-bit-DWQ-lr5e-7" 25 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-bit-DWQ-lr5e-7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 26 | 27 | echo "Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr9e-8" 28 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr9e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 29 | -------------------------------------------------------------------------------- /CONVERT/convert_qwen_coder.sh: -------------------------------------------------------------------------------- 1 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7 --max-seq-length 2048 --batch-size 4 --learning-rate 2e-7 --group-size 32 --bits 8 --data-path voxmenthe/merged-sft-coding-mix2 2 | touch Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7/README.md 3 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7 4 | 5 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7 --max-seq-length 2048 --batch-size 4 --learning-rate 4e-7 --group-size 32 --bits 8 --data-path voxmenthe/merged-sft-coding-mix2 6 | touch Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7/README.md 7 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7 8 | 9 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8 --max-seq-length 2048 --batch-size 4 --learning-rate 9e-8 --group-size 32 --bits 8 --data-path voxmenthe/merged-sft-coding-mix2 10 | touch Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8/README.md 11 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8 12 | 13 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 5e-8 --group-size 32 --bits 8 --data-path voxmenthe/merged-sft-coding-mix2 14 | touch Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8/README.md 15 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8 16 | 17 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr1e-6 --max-seq-length 2048 --batch-size 4 --learning-rate 1e-6 --group-size 32 --bits 8 --data-path voxmenthe/merged-sft-coding-mix2 18 | touch Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr1e-6/README.md 19 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr1e-6 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr1e-6 -------------------------------------------------------------------------------- /CONVERT/convert_qwen_coder2.sh: -------------------------------------------------------------------------------- 1 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr2e7 --max-seq-length 2048 --batch-size 4 --learning-rate 2e-7 --group-size 32 --bits 5 --data-path voxmenthe/merged-sft-coding-mix2 2 | touch Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr2e7/README.md 3 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr2e7 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr2e7 4 | 5 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr4e7 --max-seq-length 2048 --batch-size 4 --learning-rate 4e-7 --group-size 32 --bits 5 --data-path voxmenthe/merged-sft-coding-mix2 6 | touch Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr4e7/README.md 7 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr4e7 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr4e7 8 | 9 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr9e8 --max-seq-length 2048 --batch-size 4 --learning-rate 9e-8 --group-size 32 --bits 5 --data-path voxmenthe/merged-sft-coding-mix2 10 | touch Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr9e8/README.md 11 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr9e8 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr9e8 12 | 13 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr5e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 5e-8 --group-size 32 --bits 5 --data-path voxmenthe/merged-sft-coding-mix2 14 | touch Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr5e-8/README.md 15 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr5e-8 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr5e-8 16 | 17 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr1e-6 --max-seq-length 2048 --batch-size 4 --learning-rate 1e-6 --group-size 32 --bits 5 --data-path voxmenthe/merged-sft-coding-mix2 18 | touch Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr1e-6/README.md 19 | mlx_lm.upload --path ./Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr1e-6 --upload-repo mlx-community/Qwen3-Coder-30B-A3B-Instruct-5bit-DWQ-lr1e-6 -------------------------------------------------------------------------------- /mlx-quantization/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to the MLX Quantization Toolkit will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [1.0.0] - 2025-01-30 9 | 10 | ### Added 11 | - Initial release of MLX Quantization Toolkit 12 | - Universal MLX converter for any Hugging Face model 13 | - DeepSeek-R1 AWQ to MLX conversion support 14 | - AWQ (Activation-aware Weight Quantization) implementation 15 | - DWQ (Distilled Weight Quantization) implementation 16 | - Dynamic mixed-precision quantization support 17 | - Comprehensive Jupyter notebook workflow 18 | - Automated model download and upload to Hugging Face 19 | - Performance benchmarking and validation tools 20 | - Apple Silicon optimization for M1/M2/M3/M4 devices 21 | - Robust error handling and fallback mechanisms 22 | - Complete documentation and usage examples 23 | 24 | ### Features 25 | - **5 Quantization Methods**: Universal, DeepSeek-R1, AWQ, DWQ, Dynamic 26 | - **Apple Silicon Optimized**: Native MLX framework integration 27 | - **Automated Workflows**: End-to-end conversion pipelines 28 | - **Performance Testing**: Built-in benchmarking tools 29 | - **Hugging Face Integration**: Seamless model upload/download 30 | - **Error Recovery**: Multiple fallback conversion methods 31 | 32 | ### Technical Specifications 33 | - **Python**: 3.8+ required 34 | - **Hardware**: Apple Silicon (M1/M2/M3/M4) required 35 | - **Storage**: 50GB+ recommended for large models 36 | - **Memory**: 16GB+ RAM recommended 37 | - **MLX Version**: 0.12.0+ supported 38 | 39 | ### Supported Models 40 | - Any Hugging Face transformer model 41 | - DeepSeek-R1 AWQ models (specialized support) 42 | - Large language models up to 70B+ parameters 43 | - Various architectures: Llama, Mistral, Qwen, etc. 44 | 45 | ### Performance Metrics 46 | - Model size reduction: 60-80% 47 | - Inference speed improvement: 2-4x on Apple Silicon 48 | - Quality retention: 95-99% of original performance 49 | - Memory usage reduction: 50-75% 50 | 51 | ### Documentation 52 | - Comprehensive README with setup instructions 53 | - Individual notebook documentation 54 | - Usage examples and best practices 55 | - Troubleshooting guide 56 | - Performance benchmarking results -------------------------------------------------------------------------------- /src/data_processing/convert_to_markdown.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | convert_to_markdown.py 4 | 5 | Usage 6 | ----- 7 | python convert_to_markdown_jsonl.py \ 8 | --input texts.txt # one snippet per line 9 | --output converted.jsonl 10 | """ 11 | 12 | import argparse 13 | import json 14 | import re 15 | import tempfile 16 | from pathlib import Path 17 | from typing import List 18 | 19 | from markitdown import MarkItDown # official API :contentReference[oaicite:1]{index=1} 20 | 21 | 22 | def guess_suffix(text: str) -> str: 23 | """Very small heuristic to give MarkItDown the right file extension.""" 24 | if re.search(r"<\s*html[^>]*>", text, re.I): 25 | return ".html" 26 | if text.lstrip().startswith("#"): 27 | return ".md" 28 | return ".txt" 29 | 30 | 31 | def convert_snippets(snippets: List[str]) -> List[dict]: 32 | """Return a list of {'raw', 'markdown'} dictionaries.""" 33 | md = MarkItDown(enable_plugins=False) # one converter for all calls 34 | records = [] 35 | 36 | for snippet in snippets: 37 | # Write the snippet to a NamedTemporaryFile so MarkItDown 38 | # can treat it like a real file 39 | suffix = guess_suffix(snippet) 40 | with tempfile.NamedTemporaryFile("w+b", suffix=suffix, delete=True) as tf: 41 | tf.write(snippet.encode("utf-8")) 42 | tf.flush() # ensure bytes are written 43 | result = md.convert(tf.name) # convert returns a DocumentResult 44 | records.append({"raw": snippet, "markdown": result.text_content}) 45 | return records 46 | 47 | 48 | def main() -> None: 49 | parser = argparse.ArgumentParser(description="Bulk convert text → Markdown (JSONL)") 50 | parser.add_argument("--input", required=True, 51 | help="Text file with one snippet per line") 52 | parser.add_argument("--output", required=True, 53 | help="Destination .jsonl file") 54 | args = parser.parse_args() 55 | 56 | # Load snippets (blank lines are ignored) 57 | with Path(args.input).expanduser().open(encoding="utf-8") as f: 58 | snippets = [line.rstrip("\n") for line in f if line.strip()] 59 | 60 | conversions = convert_snippets(snippets) 61 | 62 | # Write JSONL 63 | with Path(args.output).expanduser().open("w", encoding="utf-8") as out: 64 | for rec in conversions: 65 | out.write(json.dumps(rec, ensure_ascii=False) + "\n") 66 | 67 | print(f"✅ Wrote {len(conversions):,} records to {args.output}") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /CONVERT/conversion_recipies.md: -------------------------------------------------------------------------------- 1 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-4bit-DWQ --max-seq-length 2048 --batch-size 4 --learning-rate 1e-7 --group-size 32 --bits 4 2 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-4bit-DWQ --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-4bit-DWQ 3 | mlx_lm.generate --model mlx-community/Qwen3-30B-A3B-Instruct-2507-4bit-DWQ --max-tokens 4096 --temp 0.7 -p "Explain why the Soviet Union didn't collapse earlier than it did" 4 | 5 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-8bit-DWQ --max-seq-length 2048 --batch-size 4 --learning-rate 8e-8 --group-size 32 --bits 8 6 | Qwen3-30B-A3B-Instruct-2507-8bit-DWQ/README.md 7 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-8bit-DWQ --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-8bit-DWQ 8 | 9 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ --max-seq-length 2048 --batch-size 4 --learning-rate 1e-7 --group-size 32 --bits 6 10 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ/README.md 11 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ 12 | TODO: 13 | mlx_lm.generate --model mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ --max-tokens 4096 --temp 0.7 -p "Explain why the Soviet Union didn't collapse earlier than it did" 14 | 15 | 16 | "mlabonne/open-perfectblend" 17 | 18 | ValueError: Unsupported data format, check the supported formats here: 19 | https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/LORA.md#Data. 20 | --data-path 21 | "--data-path", 22 | type=str, 23 | default="allenai/tulu-3-sft-mixture", 24 | 'voxmenthe/merged-sft-coding-mix2' 25 | models: zai-org/GLM-4.5-Air 26 | 27 | mlx_lm.dwq --model zai-org/GLM-4.5-Air --mlx-path mlx-community/GLM-4.5-Air-8bit-DWQ --max-seq-length 2048 --batch-size 4 --learning-rate 8e-8 --group-size 32 --bits 8 28 | 29 | 30 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ --max-seq-length 2048 --batch-size 4 --learning-rate 1e-7 --group-size 32 --bits 6 31 | 32 | <<<<<<< HEAD 33 | ======== evals 34 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 35 | 36 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 37 | 38 | mlx_lm.evaluate --model Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8 --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 39 | ======= 40 | mlx_lm.dwq --model Qwen/Qwen3-Coder-30B-A3B-Instruct --mlx-path Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7 --max-seq-length 2048 --batch-size 4 --learning-rate 2e-7 --group-size 32 --bits 8 41 | >>>>>>> df719f3722138db30489c2f1a0920851392e99e9 42 | -------------------------------------------------------------------------------- /examples_scratchpad.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | mlx_lm.lora \ 4 | --model mlx_models/Qwen3-4B-mlx \ 5 | --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \ 6 | --data DATA/SACREDHUNGER \ 7 | --test 8 | 9 | 10 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 11 | --model-path mlx_models/Qwen3-4B-mlx \ 12 | --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \ 13 | --prompt "-" 14 | 15 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 16 | --model-path mlx_models/Qwen3-4B-mlx \ 17 | --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \ 18 | --prompt "-" \ 19 | --repetition-penalty 1.1 \ 20 | --temp 0.75 \ 21 | --top-p 0.95 22 | 23 | # WITHOUT ADAPTER 24 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 25 | --model-path mlx_models/Qwen3-4B-mlx \ 26 | --prompt "-" \ 27 | --repetition-penalty 1.1 \ 28 | --temp 0.75 \ 29 | --top-p 0.95 30 | 31 | python prepare_training_data.py \ 32 | --input_files semantic_chunks_480.json semantic_chunks_520.json semantic_chunks_680.json semantic_chunks_790.json \ 33 | --output_dir ../../DATA/SACREDHUNGER/ \ 34 | --train_ratio 0.9 \ 35 | --seed 123 36 | 37 | python prepare_training_data.py \ 38 | --input_files sacredhunger_350.json sacredhunger_480.json sacredhunger_520.json sacredhunger_570.json sacredhunger_680.json sacredhunger_730.json sacredhunger_790.json \ 39 | --output_dir ../../DATA/SACREDHUNGER/ \ 40 | --train_ratio 0.93 \ 41 | --seed 211 42 | 43 | python prepare_training_data.py \ 44 | --input_files allthekingsmen_480.json allthekingsmen_520.json allthekingsmen_680.json allthekingsmen_790.json \ 45 | --output_dir ../../DATA/ALLTHEKINGSMEN/ \ 46 | --train_ratio 0.9 \ 47 | --seed 123 48 | 49 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 50 | --model-path mlx_models/Qwen3-14B-mlx \ 51 | --adapter-path ADAPTERS/qwen3_14b_lora_sacredhunger_multi \ 52 | --prompt "-" \ 53 | --repetition-penalty 1.1 \ 54 | --temp 0.75 \ 55 | --top-p 0.95 56 | 57 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 58 | --model-path mlx_models/Qwen3-14B-mlx \ 59 | --adapter-path ADAPTERS/qwen3_14b_dora_novels_sh_atkm \ 60 | --prompt "-" \ 61 | --repetition-penalty 1.1 \ 62 | --temp 0.75 \ 63 | --top-p 0.95 64 | 65 | books: 66 | Fatherland_HarrisRobert.txt 67 | Great Gatsby, The - Francis Scott Fitzgerald.txt 68 | Imperium_ANovelofAncientRo_HarrisRobert.txt 69 | LOTR.txt 70 | One Hundred Years of Solitude - Gabriel Garcia Marquez.txt 71 | OldManandtheSeaThe_ErnestHemingway.txt 72 | Pride_and_Prejudice.txt 73 | PaperTowns_JohnGreen.txt 74 | Pachinko_MinJinLee.txt 75 | RedSister-MarkLawrence.txt 76 | ToKillAMockingbird_HarperLee.txt 77 | TheMartian.txt 78 | TheMagicians1.txt 79 | TheMagicians2.txt 80 | TheMagicians3.txt 81 | TheLiontheWitchandtheWar_LewisCS_.txt 82 | TheGodfather.txt 83 | TheGraveyardBook.txt 84 | TheDaVinciCode_BrownDan.txt 85 | AWrinkleinTime(PuffinModer_LengleMadeleine.txt 86 | AdventuresofTomSawyerThe_MarkTwain.txt 87 | Bartimaeus1.txt 88 | Bartimaeus2.txt 89 | Bartimaeus3.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 MLX Finetuning Demo Project 2 | 3 | ✨ A complete guide to setting up and running the MLX finetuning pipeline for your custom datasets 4 | 5 | ## 🛠️ Setup Instructions 6 | 7 | 1. **Create and activate virtual environment** 8 | ```bash 9 | python -m venv mlx_venv 10 | source mlx_venv/bin/activate # Linux/Mac 11 | # OR 12 | mlx_venv\Scripts\activate # Windows 13 | ``` 14 | 15 | 2. **Install dependencies** 16 | ```bash 17 | sh project_setup.sh 18 | ``` 19 | 20 | 3. **Prepare your dataset** 21 | - Place your source data file (long text format) in `/data/raw/` 22 | - Example dataset structure: 23 | ``` 24 | data/ 25 | raw/ 26 | my_dataset.txt 27 | ``` 28 | 29 | ## 🔄 Data Processing Pipeline 30 | 31 | 1. **Run semantic chunker** 32 | ```bash 33 | python src/data_processing/semantic_chunker.py \ 34 | --book_path data/raw/your_book.txt \ 35 | --target 480 \ 36 | --output_path data/processed/chunks.json 37 | ``` 38 | - `--target`: Target word count per chunk (default: 480) 39 | - Uses `lightonai/modernbert-embed-large` model by default 40 | 41 | 2. **Prepare training data** 42 | ```bash 43 | python src/data_processing/prepare_training_data.py \ 44 | --input_files data/processed/chunks.json \ 45 | --output_dir data/final \ 46 | --train_ratio 0.85 47 | ``` 48 | - Creates `train.jsonl` and `valid.jsonl` files 49 | - Each sample contains a prompt/continuation pair 50 | - Default 85/15 train/validation split 51 | 52 | ## ⚙️ Configuration 53 | 54 | Edit `lora_config.yaml` with your settings: 55 | ```yaml 56 | model_name: "bert-base-uncased" 57 | lora_rank: 8 58 | target_modules: ["query", "value"] 59 | learning_rate: 3e-4 60 | batch_size: 32 61 | num_epochs: 10 62 | ``` 63 | 64 | ## 🏋️ Training 65 | 66 | Start finetuning: 67 | ```bash 68 | sh src/finetuning/finetune_qwen3.sh \ 69 | --tune-type dora \ 70 | --config src/finetuning/lora_config.yaml 71 | ``` 72 | 73 | Key parameters (edit in script): 74 | - `MODEL_PATH`: Path to MLX model directory 75 | - `DATA_PATH`: Directory containing `train.jsonl` and `valid.jsonl` 76 | - `ADAPTER_PATH`: Where to save adapters 77 | - `ITERS`: Number of training iterations (default: 5600) 78 | - `BATCH_SIZE`: Batch size (default: 1) 79 | 80 | ## 📊 Evaluation 81 | 82 | Run evaluations: 83 | ```bash 84 | python run_evaluations.py \ 85 | --model-path mlx_models/Qwen3-14B-mlx \ 86 | --adapter-path ADAPTERS/qwen3_14b_dora_sacredhunger_multi \ 87 | --valid-jsonl-path data/final/valid.jsonl \ 88 | --output-dir eval_outputs \ 89 | --num-examples 50 90 | ``` 91 | 92 | Evaluation parameters: 93 | - `--temp`: Sampling temperature (default: 0.75) 94 | - `--top-p`: Top-p sampling (default: 0.95) 95 | - `--repetition-penalty`: Penalty for repeated tokens (default: 1.1) 96 | 97 | ## 📌 Tips 98 | 99 | - Monitor training with `tensorboard --logdir outputs/logs` 100 | - For large datasets, consider using `--num_workers` in data preparation 101 | - Adjust batch size based on your GPU memory 102 | 103 | 💡 For questions or issues, please open an issue in this repository! -------------------------------------------------------------------------------- /src/finetuning/download_qwen3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import sys 4 | from pathlib import Path 5 | 6 | # Add mlx-lm to the Python path 7 | # Assumes the script is run from the project root or src/finetuning 8 | project_root = Path(__file__).resolve().parents[2] 9 | mlx_lm_path = project_root / "mlx-lm" 10 | sys.path.insert(0, str(mlx_lm_path)) 11 | 12 | def download_and_convert(hf_repo_id: str, output_dir_base: Path): 13 | """ 14 | Downloads a model from Hugging Face and converts it to MLX format. 15 | 16 | Args: 17 | hf_repo_id: The Hugging Face repository ID (e.g., 'Qwen/Qwen3-14B'). 18 | output_dir_base: The base directory to save the converted MLX model within. 19 | A model-specific subdirectory will be created here. 20 | """ 21 | # Ensure the base directory exists 22 | output_dir_base.mkdir(parents=True, exist_ok=True) 23 | 24 | # Construct the final model-specific path 25 | model_name = hf_repo_id.split("/")[-1] + "-mlx" 26 | final_model_path = output_dir_base / model_name 27 | 28 | print(f"Converting {hf_repo_id} to MLX format...") 29 | print(f"Final output directory: {final_model_path}") 30 | 31 | # Check if the final path *already* exists before attempting conversion 32 | if final_model_path.exists(): 33 | print(f"Error: Target directory {final_model_path} already exists.") 34 | print("Please remove it or specify a different base directory if conversion is needed again.") 35 | sys.exit(1) 36 | 37 | try: 38 | # Use our custom conversion script 39 | command = [ 40 | sys.executable, # Use the current Python interpreter 41 | str(project_root / "src" / "finetuning" / "convert_qwen3_custom.py"), # Path to custom script 42 | "--hf-path", 43 | hf_repo_id, 44 | "--mlx-path", 45 | str(final_model_path), 46 | # Potentially add "--dtype" if needed, default is float16 47 | # "--dtype", "bfloat16" 48 | ] 49 | 50 | # Note: Ensure transformers and huggingface_hub are installed: 51 | # pip install transformers huggingface_hub sentencepiece tiktoken 52 | 53 | print(f"Running command: {' '.join(command)}") 54 | subprocess.run(command, check=True, capture_output=True, text=True) 55 | print(f"Successfully converted {hf_repo_id} and saved to {final_model_path}") 56 | 57 | except subprocess.CalledProcessError as e: 58 | print(f"Error during conversion:") 59 | print(f"Command: {' '.join(e.cmd)}") 60 | print(f"Return code: {e.returncode}") 61 | print(f"Stdout: {e.stdout}") 62 | print(f"Stderr: {e.stderr}") 63 | sys.exit(1) 64 | except Exception as e: 65 | print(f"An unexpected error occurred: {e}") 66 | sys.exit(1) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser( 71 | description="Download and convert a Hugging Face model to MLX format." 72 | ) 73 | parser.add_argument( 74 | "--hf-repo-id", 75 | type=str, 76 | default="Qwen/Qwen3-14B", 77 | help="Hugging Face repository ID of the model to download and convert.", 78 | ) 79 | parser.add_argument( 80 | "--output-dir", 81 | type=Path, 82 | default=Path("mlx_models"), # Base directory 83 | help="Base directory to save the converted MLX model (model-specific subfolder will be created).", 84 | ) 85 | args = parser.parse_args() 86 | 87 | # Ensure output base dir is relative to project root if not absolute 88 | if not args.output_dir.is_absolute(): 89 | args.output_dir = project_root / args.output_dir 90 | 91 | download_and_convert(args.hf_repo_id, args.output_dir) -------------------------------------------------------------------------------- /src/data_processing/bm25_func.py: -------------------------------------------------------------------------------- 1 | import re 2 | try: 3 | from rank_bm25 import BM25Okapi 4 | _HAS_BM25 = True 5 | except ImportError: 6 | _HAS_BM25 = False 7 | 8 | 9 | # --- utilities -------------------------------------------------------------- 10 | def _tokenise(text: str) -> list[str]: 11 | """Very light tokeniser → lowercase words A‑Z.""" 12 | return re.findall(r"[A-Za-z']+", text.lower()) 13 | 14 | # --------------------------------------------------------------------------- 15 | 16 | 17 | def bm25_gap_violation(boundary_chunks: tuple[str, str], 18 | entity: str, 19 | max_gap: int = 4, 20 | bm25_thresh: float = 0.0) -> bool: 21 | """ 22 | Return ``True`` when *entity* (e.g. “Erasmus”) vanishes for more than 23 | ``max_gap`` paragraph units **across** the boundary formed by the two 24 | chunks supplied. 25 | 26 | Parameters 27 | ---------- 28 | boundary_chunks : (prev_chunk, next_chunk) 29 | Tuple of the text immediately before and after the boundary. 30 | entity : str 31 | Name / term whose continuity we want to keep. 32 | max_gap : int, default 4 33 | Maximum allowed paragraph distance without seeing the entity. 34 | bm25_thresh : float, default 0.0 35 | Minimum BM25 score regarded as a “hit”. Leave at 0 to treat mere 36 | lexical presence as sufficient. 37 | 38 | Notes 39 | ----- 40 | * If the third‑party package ``rank_bm25`` is present we build a very 41 | small per‑boundary BM25 index so that inflected or approximate 42 | mentions (“Mr Kemp”, “Kemp’s”) still register continuity. 43 | * If the package is missing we fall back to a fast 44 | case‑insensitive regex exact match. 45 | """ 46 | 47 | prev_chunk, next_chunk = boundary_chunks 48 | prev_paras = re.split(r"\n{2,}", prev_chunk) 49 | next_paras = re.split(r"\n{2,}", next_chunk) 50 | 51 | # --------------------- helper to detect "entity present" --------------- 52 | entity_tokens = _tokenise(entity) 53 | if not entity_tokens: 54 | return False # nothing to look for 55 | 56 | if _HAS_BM25: 57 | # Build tiny BM25 index over paragraphs that straddle the boundary 58 | corpus_paras = prev_paras + next_paras 59 | corpus_tok = [_tokenise(p) for p in corpus_paras] 60 | bm25 = BM25Okapi(corpus_tok) 61 | scores = bm25.get_scores(entity_tokens) 62 | # Treat paragraph as containing the entity if BM25 > threshold 63 | contains = [s > bm25_thresh for s in scores] 64 | else: 65 | # Cheap lexical fallback 66 | pat = re.compile(rf"\b{re.escape(entity)}\b", flags=re.I) 67 | contains = [bool(pat.search(p)) for p in prev_paras + next_paras] 68 | 69 | # --------------------- measure the paragraph gap ----------------------- 70 | # Index of *last* hit in the previous chunk 71 | last_prev_idx = None 72 | for i in reversed(range(len(prev_paras))): 73 | if contains[i]: 74 | last_prev_idx = len(prev_paras) - 1 - i # distance back from end 75 | break 76 | 77 | # Index of *first* hit in the next chunk 78 | offset = len(prev_paras) # shift into global index 79 | first_next_idx = None 80 | for j in range(len(next_paras)): 81 | if contains[offset + j]: 82 | first_next_idx = j 83 | break 84 | 85 | # Compute paragraphs without the entity spanning the join 86 | if last_prev_idx is None: 87 | gap_left = len(prev_paras) # no mention in prev ⇒ full length 88 | else: 89 | gap_left = last_prev_idx 90 | 91 | if first_next_idx is None: 92 | gap_right = len(next_paras) # no mention in next ⇒ full length 93 | else: 94 | gap_right = first_next_idx 95 | 96 | total_gap = gap_left + gap_right + 1 # +1 for the boundary itself 97 | 98 | return total_gap > max_gap 99 | -------------------------------------------------------------------------------- /CONVERT/results.md: -------------------------------------------------------------------------------- 1 | All results are from the mmlu_pro_computer_science task. 2 | mlx_lm.evaluate --model <model> --tasks mmlu_pro_computer_science --max-tokens 5000 --no-apply-chat-template 3 | 4 | ============================================================ 5 | 6 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8 7 | { 8 | "alias": "computer_science", 9 | "exact_match,custom-extract": 0.7926829268292683, 10 | "exact_match_stderr,custom-extract": 0.020044980247224453 11 | } 12 | 13 | 14 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7Results: 15 | { 16 | "alias": "computer_science", 17 | "exact_match,custom-extract": 0.7926829268292683, 18 | "exact_match_stderr,custom-extract": 0.020044980247224457 19 | } 20 | 21 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8 22 | { 23 | "alias": "computer_science", 24 | "exact_match,custom-extract": 0.7878048780487805, 25 | "exact_match_stderr,custom-extract": 0.02021693788475414 26 | } 27 | 28 | Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr1e-6 29 | --data-path voxmenthe/merged-sft-coding-mix2 30 | { 31 | "alias": "computer_science", 32 | "exact_match,custom-extract": 0.6219512195121951, 33 | "exact_match_stderr,custom-extract": 0.023976756269796867 34 | } 35 | 36 | Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr2e7 37 | --data-path voxmenthe/merged-sft-coding-mix2 38 | Results: 39 | { 40 | "alias": "computer_science", 41 | "exact_match,custom-extract": 0.7292682926829268, 42 | "exact_match_stderr,custom-extract": 0.02197108846947813 43 | } 44 | 45 | Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr4e7 46 | --data-path voxmenthe/merged-sft-coding-mix2 47 | Results: 48 | { 49 | "alias": "computer_science", 50 | "exact_match,custom-extract": 0.697560975609756, 51 | "exact_match_stderr,custom-extract": 0.022711632302604486 52 | } 53 | Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr5e-8 54 | Results: 55 | { 56 | "alias": "computer_science", 57 | "exact_match,custom-extract": 0.7048780487804878, 58 | "exact_match_stderr,custom-extract": 0.022552572925167262 59 | } 60 | Qwen3-Coder-30B-A3B-Instruct-8bit-DWQ-lr9e8 61 | { 62 | "alias": "computer_science", 63 | "exact_match,custom-extract": 0.7292682926829268, 64 | "exact_match_stderr,custom-extract": 0.02197108846947813 65 | } 66 | mlx-community/GLM-4.5-Air-5bit 67 | { 68 | "alias": "computer_science",: 100%| 69 | "exact_match,custom-extract": 0.7634146341463415, 70 | "exact_match_stderr,custom-extract": 0.021014183737081388 71 | } 72 | mlx-community/Qwen3-Coder-30B-A3B-Instruct-bf16 73 | { 74 | "alias": "computer_science", 75 | "exact_match,custom-extract": 0.7268292682926829, 76 | "exact_match_stderr,custom-extract": 0.02203289844309934 77 | } 78 | Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr9e-8 79 | { 80 | "alias": "computer_science", 81 | "exact_match,custom-extract": 0.7170731707317073, 82 | "exact_match_stderr,custom-extract": 0.022271893903859002 83 | } 84 | Qwen3-Coder-30B-A3B-Instruct-6bit-DWQ-lr3e-7 85 | { 86 | "alias": "computer_science", 87 | "exact_match,custom-extract": 0.7414634146341463, 88 | "exact_match_stderr,custom-extract": 0.02164931770175753 89 | } 90 | mlx-community/Qwen3-30B-A3B-Thinking-2507-bf16 91 | { 92 | "alias": "computer_science", 93 | "exact_match,custom-extract": 0.7829268292682927, 94 | "exact_match_stderr,custom-extract": 0.020384591313839226 95 | } 96 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8 97 | { 98 | "alias": "computer_science", 99 | "exact_match,custom-extract": 0.7926829268292683, 100 | "exact_match_stderr,custom-extract": 0.020044980247224453 101 | } 102 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7Results: 103 | { 104 | "alias": "computer_science", 105 | "exact_match,custom-extract": 0.7926829268292683, 106 | "exact_match_stderr,custom-extract": 0.020044980247224457 107 | } 108 | 109 | Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8 110 | { 111 | "alias": "computer_science", 112 | "exact_match,custom-extract": 0.7878048780487805, 113 | "exact_match_stderr,custom-extract": 0.02021693788475414 114 | } -------------------------------------------------------------------------------- /src/data_processing/README.md: -------------------------------------------------------------------------------- 1 | ## 1 · Baseline, purely programmatic splitter 2 | 3 | **Core idea** 4 | 5 | 1. Read the text. 6 | 2. Isolate **paragraph units** (book already uses double line‑breaks as separators). 7 | 3. Greedily concatenate whole paragraphs until the running word‑count would exceed `target_size`; then start a new chunk. 8 | 4. Optionally create a *fixed paragraph overlap* for a little context bleed. 9 | 10 | *Configurable knobs* 11 | 12 | | parameter | purpose | sensible range | 13 | | ----------- | ------------------------------------ | -------------- | 14 | | `--size` | target words per chunk | 300‑2 000 | 15 | | `--overlap` | paragraphs repeated at each boundary | 0‑2 | 16 | 17 | --- 18 | 19 | ## 2 · Semantic‑aware splitter (embedding + BM25 refinement) 20 | 21 | ### Rationale 22 | 23 | Even with paragraph‑respecting boundaries you can accidentally: 24 | 25 | * cut **mid‑scene** (a named character vanishes between chunks), 26 | * split **dialogue exchanges**, harming retrieval‑QA accuracy. 27 | 28 | We therefore **post‑process** the baseline boundaries: 29 | 30 | 1. **Initial pass** – call the greedy algorithm above with `size ≈ target × 0.9` (gives headroom for later shuffling). 31 | 32 | 2. **Compute embeddings** for each paragraph (or sentence) using a suitable model (e.g., `lightonai/modernbert-embed-large`). 33 | *Note: This model requires specific prefixes. The script currently uses `search_document:` for the text segments being compared.* 34 | 35 | 3. For every tentative boundary `B` between chunk *i* and *i+1*: 36 | 37 | * Take the last `tail_len` sentences of chunk *i* (`S_tail`) and the first `head_len` sentences of chunk *i+1* (`S_head`). 38 | * `sim = cosine(get_emb(S_tail), get_emb(S_head))`. 39 | * If `sim < thresh_low`, **shift** `B` *forward* until similarity rises or the size budget is hit. 40 | * If `sim > thresh_high`, optionally create a *sentence‑level overlap* so the teaser sentence appears in both chunks. 41 | 42 | 4. **BM25 character check** – build a BM25 index over paragraphs. For each main character name (Erasmus, Paris, Thurso, etc.) ensure that it doesn't disappear for > `gap` paragraphs. If a gap occurs across a boundary, shift the boundary backward by one paragraph. 43 | 44 | 45 | ### Configurable aspects 46 | 47 | * `--target` – desired words / chunk 48 | * `tail_len`, `head_len` – size of *join windows* 49 | * `thresh_low`, `thresh_high` – similarity action thresholds 50 | * `char_names` + `max_bm25_gap` – continuity heuristics 51 | * `max_size` – hard cap after refinement 52 | 53 | --- 54 | 55 | **Measuring chunk boundaries** 56 | 57 | Measure the distance between chunks by counting paragraphs from the last occurrence of an entity in one chunk to the first in the next. I'll look for the entity in both chunks, calculate the gap, and check if it exceeds a defined max number of paragraphs. If it does, I'll return "True." This will be accomplished using a case-insensitive regex for the entity's location. My approach seems clear, just ensuring I handle both chunks with careful indexing. 58 | 59 | ```python 60 | # --- utilities -------------------------------------------------------------- 61 | def _tokenise(text: str) -> list[str]: 62 | """Very light tokeniser → lowercase words A‑Z.""" 63 | return re.findall(r"[A-Za-z']+", text.lower()) 64 | 65 | try: 66 | from rank_bm25 import BM25Okapi 67 | _HAS_BM25 = True 68 | except ImportError: 69 | _HAS_BM25 = False 70 | # --------------------------------------------------------------------------- 71 | 72 | ## `bm25_gap_violation` 73 | 74 | ### How it works 75 | 76 | 1. **Tokenisation** – a tiny regex picks out alphabetic tokens and lower‑cases them, enough for BM25. 77 | 2. **Dual mode** 78 | 79 | * If `rank_bm25` is available, we calculate paragraph‑level BM25 scores for the entity; any paragraph scoring above `bm25_thresh` (0 → "contains at least one query term") counts as a hit. 80 | * Without the library, we revert to a fast exact word‑boundary regex. 81 | 3. **Gap detection** – walk backward from the end of the *previous* chunk and forward from the start of the *next* chunk to locate the two nearest mentions. The sum of paragraphs between those two mentions (inclusive of the join) is the **gap**. If it exceeds `max_gap`, the function flags a violation so your boundary‑refinement logic can shift or duplicate paragraphs. 82 | 83 | You can now import the function directly in `semantic_chunker.py`, run the script, and the continuity‑checking step will operate deterministically—optionally strengthened by BM25 when the library is installed. 84 | 85 | -------------------------------------------------------------------------------- /mlx-quantization/CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | This repository contains Jupyter notebooks for converting and quantizing large language models using Apple's MLX framework. The project focuses on optimizing models for Apple Silicon devices through various quantization techniques. 8 | 9 | ## Repository Structure 10 | 11 | The repository is organized around Jupyter notebooks that handle different aspects of MLX model conversion and quantization: 12 | 13 | - **deepseek_r1_mlx_conversion.ipynb**: Converts DeepSeek-R1 AWQ models to MLX format 14 | - **universal_mlx_converter.ipynb**: Universal converter for any Hugging Face model to MLX format 15 | - **dwq_quantization.ipynb**: Distilled Weight Quantization implementation 16 | - **awq_quantization.ipynb**: Activation-aware Weight Quantization implementation 17 | - **dynamic_quantization.ipynb**: Dynamic quantization with mixed precision 18 | - **models/**: Directory for storing downloaded and converted models 19 | 20 | ## Core Dependencies 21 | 22 | All notebooks require these essential packages: 23 | - `mlx-lm`: Apple's MLX language model framework 24 | - `transformers`: Hugging Face transformers library 25 | - `torch`: PyTorch (for compatibility) 26 | - `huggingface_hub`: For model download/upload 27 | - `datasets`, `accelerate`, `sentencepiece`, `protobuf`: Supporting libraries 28 | 29 | ## Common Workflow Pattern 30 | 31 | Each notebook follows a consistent structure: 32 | 1. Environment setup and dependency installation 33 | 2. MLX import testing and verification 34 | 3. Model configuration and directory setup 35 | 4. Original model download from Hugging Face 36 | 5. Model conversion/quantization using MLX tools 37 | 6. Converted model testing and validation 38 | 7. Optional performance comparison 39 | 8. Optional Hugging Face upload 40 | 9. Cleanup and summary 41 | 42 | ## MLX Conversion Commands 43 | 44 | The project uses MLX command-line tools for conversions: 45 | 46 | ### Basic Conversion 47 | ```bash 48 | python -m mlx_lm.convert --hf-path <source_dir> --mlx-path <target_dir> 49 | ``` 50 | 51 | ### DWQ Quantization 52 | ```bash 53 | python -m mlx_lm.dwq --model <model_path> --mlx-path <output_path> --bits 4 --num-samples 1024 54 | ``` 55 | 56 | ### AWQ Quantization 57 | ```bash 58 | python -m mlx_lm.awq --model <model_path> --mlx-path <output_path> --bits 4 --num-samples 32 59 | ``` 60 | 61 | ### Dynamic Quantization 62 | ```bash 63 | python -m mlx_lm.dynamic_quant --model <model_path> --mlx-path <output_path> --target-bpw 4.0 64 | ``` 65 | 66 | ## Environment Setup Requirements 67 | 68 | **Critical**: This project requires macOS with Apple Silicon (M1/M2/M3/M4). The notebooks include specific handling for: 69 | - numpy/gfortran library conflicts in JupyterLab Desktop 70 | - MLX framework import verification 71 | - Automatic package installation with error handling 72 | - Kernel restart recommendations for import issues 73 | 74 | ## Model Storage Architecture 75 | 76 | The project uses a standardized directory structure: 77 | - `models/`: Root directory for all model storage 78 | - `models/<model_name_sanitized>/`: Original downloaded models 79 | - `models/<model_name>_<method>_<precision>/`: Quantized outputs 80 | - `sensitivities/`: Layer sensitivity analysis files (for dynamic quantization) 81 | 82 | ## Error Handling Patterns 83 | 84 | All notebooks implement robust error handling: 85 | - Multiple conversion method attempts with fallbacks 86 | - Comprehensive import testing before execution 87 | - File existence checks and cleanup procedures 88 | - Detailed error reporting with troubleshooting guidance 89 | 90 | ## Hugging Face Integration 91 | 92 | The notebooks include full Hugging Face workflow: 93 | - Secure token-based authentication 94 | - Model download with resume capability 95 | - Automatic model card generation 96 | - Repository creation and file upload 97 | - Upload verification and file listing 98 | 99 | ## Performance Testing 100 | 101 | Each quantization method includes: 102 | - Model loading and inference testing 103 | - Multi-prompt validation 104 | - Performance timing comparisons 105 | - Size reduction calculations 106 | - Quality evaluation options using standard datasets 107 | 108 | ## Important Notes 109 | 110 | - AWQ models require dequantization before MLX conversion (`--dequantize` flag) 111 | - Directory paths must be absolute, not relative 112 | - Large models require significant disk space (50GB+ for full-size models) 113 | - Model conversion can be time-intensive depending on model size 114 | - Always test converted models before deployment or upload -------------------------------------------------------------------------------- /CONVERT/convert_many.sh: -------------------------------------------------------------------------------- 1 | # mlx_lm.dwq --model zai-org/GLM-4.5-Air --mlx-path mlx-community/GLM-4.5-Air-8bit-DWQ-lr8e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 8e-8 --group-size 32 --bits 5 2 | # touch GLM-4.5-Air-8bit-DWQ-lr8e-8/README.md 3 | # mlx_lm.upload --path ./GLM-4.5-Air-8bit-DWQ-lr8e-8 --upload-repo mlx-community/GLM-4.5-Air-8bit-DWQ-lr8e-8 4 | 5 | # mlx_lm.dwq --model zai-org/GLM-4.5-Air --mlx-path mlx-community/GLM-4.5-Air-8bit-DWQ-lr7e-3 --max-seq-length 2048 --batch-size 4 --learning-rate 7e-3 --group-size 32 --bits 5 6 | # touch GLM-4.5-Air-8bit-DWQ-lr7e-3/README.md 7 | # mlx_lm.upload --path ./GLM-4.5-Air-8bit-DWQ-lr7e-3 --upload-repo mlx-community/GLM-4.5-Air-8bit-DWQ-lr7e-3 8 | 9 | # mlx_lm.dwq --model zai-org/GLM-4.5-Air --mlx-path mlx-community/GLM-4.5-Air-4bit-DWQ-lr8e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 8e-8 --group-size 32 --bits 4 10 | # touch GLM-4.5-Air-4bit-DWQ-lr8e-8/README.md 11 | # mlx_lm.upload --path ./GLM-4.5-Air-4bit-DWQ-lr8e-8 --upload-repo mlx-community/GLM-4.5-Air-4bit-DWQ-lr8e-8 12 | 13 | # mlx_lm.dwq --model zai-org/GLM-4.5-Air --mlx-path mlx-community/GLM-4.5-Air-4bit-DWQ-lr7e-3 --max-seq-length 2048 --batch-size 4 --learning-rate 7e-3 --group-size 32 --bits 4 14 | # touch GLM-4.5-Air-4bit-DWQ-lr7e-3/README.md 15 | # mlx_lm.upload --path ./GLM-4.5-Air-4bit-DWQ-lr7e-3 --upload-repo mlx-community/GLM-4.5-Air-4bit-DWQ-lr7e-3 16 | 17 | #======================================== 18 | 19 | # mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 3e-7 --group-size 32 --bits 6 20 | # touch Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7/README.md 21 | # mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr3e-7 22 | 23 | # mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 5e-8 --group-size 32 --bits 6 24 | # touch Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8/README.md 25 | # mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr5e-8 26 | 27 | # mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 8e-8 --group-size 32 --bits 6 28 | # touch Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8/README.md 29 | # mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-6bit-DWQ-lr8e-8 30 | 31 | #======================================== 32 | 33 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr5e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 5e-7 --group-size 32 --bits 5 34 | touch Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr5e-7/README.md 35 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr5e-7 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr5e-7 36 | 37 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr3e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 3e-8 --group-size 32 --bits 5 38 | touch Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr3e-8/README.md 39 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr3e-8 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr3e-8 40 | 41 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr2e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 2e-7 --group-size 32 --bits 5 42 | touch Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr2e-7/README.md 43 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr2e-7 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr2e-7 44 | 45 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr9e-8 --max-seq-length 2048 --batch-size 4 --learning-rate 9e-8 --group-size 32 --bits 5 46 | touch Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr9e-8/README.md 47 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr9e-8 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr9e-8 48 | 49 | mlx_lm.dwq --model Qwen/Qwen3-30B-A3B-Instruct-2507 --mlx-path Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr1e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 1e-7 --group-size 32 --bits 5 50 | touch Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr1e-7/README.md 51 | mlx_lm.upload --path ./Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr1e-7 --upload-repo mlx-community/Qwen3-30B-A3B-Instruct-2507-5bit-DWQ-lr1e-7 52 | 53 | 54 | mlx_lm.dwq --model zai-org/GLM-4.5-Air --mlx-path mlx-community/GLM-4.5-Air-5bit-DWQ-lr2e-7 --max-seq-length 2048 --batch-size 4 --learning-rate 2e-7 --group-size 32 --bits 5 -------------------------------------------------------------------------------- /mlx-quantization/README.md: -------------------------------------------------------------------------------- 1 | # MLX Model Quantization Toolkit 2 | 3 | A comprehensive collection of Jupyter notebooks for converting and quantizing large language models using Apple's MLX framework, optimized for Apple Silicon devices. 4 | 5 | ## 🚀 Features 6 | 7 | - **Universal Model Conversion**: Convert any Hugging Face model to MLX format 8 | - **Multiple Quantization Methods**: Support for AWQ, DWQ, and Dynamic Quantization 9 | - **Apple Silicon Optimized**: Built specifically for M1/M2/M3/M4 devices 10 | - **Automated Workflows**: Complete pipeline from download to deployment 11 | - **Performance Testing**: Built-in benchmarking and validation tools 12 | 13 | ## 📋 Requirements 14 | 15 | - **Hardware**: macOS with Apple Silicon (M1/M2/M3/M4) 16 | - **Python**: 3.8 or higher 17 | - **Storage**: 50GB+ free space for large models 18 | 19 | ## 🛠 Installation 20 | 21 | 1. Clone the repository: 22 | ```bash 23 | git clone https://github.com/cs2764/mlx-quantization.git 24 | cd mlx-quantization 25 | ``` 26 | 27 | 2. Install dependencies: 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | 3. Launch Jupyter: 33 | ```bash 34 | jupyter lab 35 | ``` 36 | 37 | ## 📚 Notebooks Overview 38 | 39 | ### Core Notebooks 40 | 41 | | Notebook | Description | Use Case | 42 | |----------|-------------|----------| 43 | | `universal_mlx_converter.ipynb` | Universal converter for any HF model | General model conversion | 44 | | `awq_quantization.ipynb` | Activation-aware Weight Quantization | High-quality 4-bit quantization | 45 | | `dwq_quantization.ipynb` | Distilled Weight Quantization | Fast quantization with good quality | 46 | | `dynamic_quantization.ipynb` | Dynamic mixed-precision quantization | Optimal size/quality balance | 47 | 48 | ### Quantization Methods Comparison 49 | 50 | | Method | Speed | Quality | Size Reduction | Best For | 51 | |--------|-------|---------|----------------|----------| 52 | | **AWQ** | Medium | High | ~75% | Production deployment | 53 | | **DWQ** | Fast | Good | ~70% | Quick prototyping | 54 | | **Dynamic** | Slow | Highest | Variable | Research/experimentation | 55 | 56 | ## 🔄 Common Workflow 57 | 58 | Each notebook follows this standardized pattern: 59 | 60 | 1. **Environment Setup** - Dependency installation and MLX verification 61 | 2. **Model Configuration** - Set up directories and parameters 62 | 3. **Model Download** - Fetch original model from Hugging Face 63 | 4. **Conversion/Quantization** - Apply selected quantization method 64 | 5. **Validation** - Test converted model functionality 65 | 6. **Performance Analysis** - Compare speed and quality metrics 66 | 7. **Optional Upload** - Push to Hugging Face Hub 67 | 8. **Cleanup** - Remove temporary files 68 | 69 | ## 📁 Directory Structure 70 | 71 | ``` 72 | mlx-quantization/ 73 | ├── models/ # Model storage 74 | │ ├── <model_name>/ # Original models 75 | │ └── <model_name>_<method>_<bits>/ # Quantized outputs 76 | ├── sensitivities/ # Layer analysis files 77 | ├── *.ipynb # Conversion notebooks 78 | ├── requirements.txt # Dependencies 79 | └── README.md # This file 80 | ``` 81 | 82 | ## 🚀 Quick Start 83 | 84 | 1. **Choose your quantization method** based on your requirements 85 | 2. **Open the corresponding notebook** in Jupyter Lab 86 | 3. **Follow the step-by-step instructions** in each cell 87 | 4. **Monitor the conversion process** and review results 88 | 5. **Test the quantized model** before deployment 89 | 90 | ## 📊 Performance Benchmarks 91 | 92 | Typical results on Apple M2 Pro: 93 | 94 | - **Model Size Reduction**: 60-80% smaller than original 95 | - **Inference Speed**: 2-4x faster on Apple Silicon 96 | - **Quality Retention**: 95-99% of original performance 97 | - **Memory Usage**: 50-75% reduction 98 | 99 | ## 🔧 MLX Commands Reference 100 | 101 | ### Basic Conversion 102 | ```bash 103 | python -m mlx_lm.convert --hf-path <source> --mlx-path <target> 104 | ``` 105 | 106 | ### AWQ Quantization 107 | ```bash 108 | python -m mlx_lm.awq --model <model> --mlx-path <output> --bits 4 109 | ``` 110 | 111 | ### DWQ Quantization 112 | ```bash 113 | python -m mlx_lm.dwq --model <model> --mlx-path <output> --bits 4 114 | ``` 115 | 116 | ### Dynamic Quantization 117 | ```bash 118 | python -m mlx_lm.dynamic_quant --model <model> --mlx-path <output> --target-bpw 4.0 119 | ``` 120 | 121 | ## ⚠️ Important Notes 122 | 123 | - **AWQ models require dequantization** before MLX conversion (`--dequantize` flag) 124 | - **Use absolute paths** - relative paths may cause issues 125 | - **Large models need significant storage** - ensure adequate disk space 126 | - **Test converted models** before production deployment 127 | - **Conversion time varies** based on model size and method 128 | 129 | ## 🤝 Contributing 130 | 131 | 1. Fork the repository 132 | 2. Create a feature branch 133 | 3. Make your changes 134 | 4. Test thoroughly on Apple Silicon 135 | 5. Submit a pull request 136 | 137 | ## 📄 License 138 | 139 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 140 | 141 | ## 🙏 Acknowledgments 142 | 143 | - Apple MLX Team for the excellent framework 144 | - Hugging Face for model hosting and tools 145 | - The open-source ML community 146 | 147 | ## 📞 Support 148 | 149 | - **Issues**: Report bugs and request features via GitHub Issues 150 | - **Discussions**: Join community discussions in GitHub Discussions 151 | - **Documentation**: Refer to individual notebook markdown cells 152 | 153 | --- 154 | 155 | **Version**: 1.0.0 156 | **Last Updated**: 2025-01-30 157 | **Compatibility**: Apple Silicon (M1/M2/M3/M4) + macOS -------------------------------------------------------------------------------- /src/finetuning/finetune_qwen3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to fine-tune Qwen3 models using mlx_lm.lora 4 | 5 | # --- Configuration --- 6 | 7 | # Default fine-tuning type 8 | FINE_TUNE_TYPE="dora" # "lora" # Default: "lora". Can be overridden with --tune-type flag ("lora", "dora", "full") 9 | CONFIG_PATH="src/finetuning/lora_config.yaml" # Optional path to a YAML config file for detailed LoRA/optimizer settings 10 | 11 | # --- Argument Parsing --- 12 | while [[ "$#" -gt 0 ]]; do 13 | case $1 in 14 | --tune-type) FINE_TUNE_TYPE="$2"; shift ;; 15 | --config) CONFIG_PATH="$2"; shift ;; 16 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 17 | esac 18 | shift 19 | done 20 | 21 | # Validate FINE_TUNE_TYPE 22 | if [[ "$FINE_TUNE_TYPE" != "lora" && "$FINE_TUNE_TYPE" != "dora" && "$FINE_TUNE_TYPE" != "full" ]]; then 23 | echo "Error: Invalid --tune-type specified: '$FINE_TUNE_TYPE'. Must be 'lora', 'dora', or 'full'." 24 | exit 1 25 | fi 26 | echo "Using Fine-Tuning Type: $FINE_TUNE_TYPE" 27 | 28 | # !!! IMPORTANT: Set this path to your *local* converted MLX model directory !!! 29 | # Example: ../../mlx_models/Qwen3-4B-mlx 30 | MODEL_PATH="mlx_models/Qwen3-14B-mlx" 31 | 32 | # !!! IMPORTANT: Path to the directory containing your training data files !!! 33 | # This directory MUST contain files named exactly 'train.jsonl' and 'valid.jsonl'. 34 | # Optionally, it can contain 'test.jsonl' if RUN_TEST=true. 35 | # This path is relative to where you run the script from. 36 | # See mlx-lm/mlx_lm/LORA.md#data for format details. 37 | DATA_PATH="DATA/SACREDHUNGER" 38 | 39 | # Directory to save the LoRA adapters (relative to where you run the script) 40 | ADAPTER_PATH="ADAPTERS/qwen3_14b_${FINE_TUNE_TYPE}_sacredhunger_multi" # _atkm_multi" # Example, adjust as needed 41 | 42 | # Training parameters (adjust as needed) 43 | ITERS=5600 # Number of training iterations 44 | BATCH_SIZE=1 # Batch size (reduce if hitting memory limits) 45 | LEARNING_RATE=1e-5 # Learning rate 46 | SAVE_EVERY=100 # Save adapter weights every N iterations 47 | NUM_LAYERS=-1 # 16 # Number of layers to apply LoRA to (-1 for all) 48 | MAX_SEQ_LENGTH=3827 # Max sequence length model can handle 49 | 50 | # Evaluation parameters (optional) 51 | RUN_TEST=false # Set to true to run evaluation on test.jsonl after training 52 | VAL_BATCHES=25 # Number of validation batches during training (-1 for full validation set) 53 | TEST_BATCHES=100 # Number of test batches if RUN_TEST=true (-1 for full test set) 54 | 55 | # --- Safety Checks --- 56 | if [ "$MODEL_PATH" == "../../mlx_models/<your_model_name>-mlx" ]; then 57 | echo "Error: Please set the MODEL_PATH variable in the script to your specific MLX model directory." 58 | exit 1 59 | fi 60 | 61 | # Basic check if DATA_PATH directory exists 62 | if [ ! -d "$DATA_PATH" ]; then 63 | echo "Error: Data directory '$DATA_PATH' not found." 64 | exit 1 65 | fi 66 | 67 | if [ ! -d "$MODEL_PATH" ]; then 68 | echo "Error: Model directory '$MODEL_PATH' not found." 69 | echo "Did you run a conversion script?" 70 | exit 1 71 | fi 72 | 73 | # Check for the required files within DATA_PATH 74 | if [ ! -f "$DATA_PATH/train.jsonl" ] || [ ! -f "$DATA_PATH/valid.jsonl" ]; then 75 | echo "Error: Could not find required 'train.jsonl' or 'valid.jsonl' in directory '$DATA_PATH'." 76 | echo "Please ensure your training files are named correctly and placed in this directory." 77 | exit 1 78 | fi 79 | 80 | # Add mlx-lm to Python path (adjust if your structure differs) 81 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 82 | PROJECT_ROOT=$( realpath "$SCRIPT_DIR/../.." ) 83 | MLX_LM_PATH="$PROJECT_ROOT/mlx-lm" 84 | export PYTHONPATH="$PYTHONPATH:$MLX_LM_PATH" 85 | 86 | echo "Using MLX Model: $MODEL_PATH" 87 | echo "Using Data Dir : $DATA_PATH" 88 | echo "Saving Adapters to: $ADAPTER_PATH" 89 | echo "PYTHONPATH set to: $PYTHONPATH" 90 | 91 | # --- Build Command --- 92 | # Uses the DATA_PATH directory as input for the --data argument 93 | CMD=( 94 | "python" 95 | "-m" 96 | "mlx_lm" 97 | "lora" 98 | "--model" "$MODEL_PATH" 99 | "--train" 100 | "--data" "$DATA_PATH" 101 | "--adapter-path" "$ADAPTER_PATH" 102 | "--iters" "$ITERS" 103 | "--batch-size" "$BATCH_SIZE" 104 | # "--learning-rate" "$LEARNING_RATE" # Let config file override this if provided 105 | "--save-every" "$SAVE_EVERY" 106 | "--num-layers" "$NUM_LAYERS" 107 | "--max-seq-length" "$MAX_SEQ_LENGTH" 108 | "--val-batches" "$VAL_BATCHES" 109 | "--fine-tune-type" "$FINE_TUNE_TYPE" 110 | # Add optional flags 111 | # "--grad-checkpoint" # Use gradient checkpointing to save memory 112 | # "--mask-prompt" # Ignore prompt tokens in loss calculation 113 | ) 114 | 115 | # Add config file if provided 116 | if [ -n "$CONFIG_PATH" ]; then 117 | if [ ! -f "$CONFIG_PATH" ]; then 118 | echo "Error: Config file specified but not found: '$CONFIG_PATH'" 119 | exit 1 120 | fi 121 | echo "Using config file: $CONFIG_PATH" 122 | CMD+=("--config" "$CONFIG_PATH") 123 | fi 124 | 125 | if [ "$RUN_TEST" = true ]; then 126 | if [ ! -f "$DATA_PATH/test.jsonl" ]; then # Check in DATA_PATH directory 127 | echo "Warning: RUN_TEST is true but 'test.jsonl' not found in '$DATA_PATH'. Skipping test evaluation." 128 | else 129 | CMD+=("--test" "--test-batches" "$TEST_BATCHES") 130 | fi 131 | fi 132 | 133 | # --- Run Training --- 134 | echo "Running command:" 135 | printf "%s " "${CMD[@]}" 136 | echo "\n" 137 | 138 | "${CMD[@]}" 139 | 140 | EXIT_CODE=$? 141 | if [ $EXIT_CODE -eq 0 ]; then 142 | echo "Fine-tuning completed successfully." 143 | echo "Adapters saved in: $ADAPTER_PATH" 144 | else 145 | echo "Fine-tuning failed with exit code $EXIT_CODE." 146 | fi 147 | 148 | exit $EXIT_CODE -------------------------------------------------------------------------------- /src/inference/generate_qwen_vlm_notebook.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from __future__ import annotations 3 | 4 | import logging 5 | import time 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import List, Sequence 9 | 10 | from mlx_vlm import generate, load 11 | from mlx_vlm.prompt_utils import apply_chat_template 12 | from mlx_vlm.utils import load_config 13 | 14 | 15 | # %% 16 | CONFIG = { 17 | "model_path": "mlx-community/Qwen3-VL-32B-Instruct-3bit", 18 | #"prompt_template": "Put all the text from this image into markdown format.", 19 | "prompt_template": "Put the table from this image into markdown format.", 20 | "max_tokens": 3_200, 21 | "supported_image_suffixes": {".jpg", ".jpeg", ".png", ".webp"}, 22 | } 23 | 24 | 25 | @dataclass(frozen=True) 26 | class LoadedArtifacts: 27 | model: object 28 | processor: object 29 | config: dict 30 | 31 | 32 | @dataclass(frozen=True) 33 | class SingleImageResult: 34 | image_path: Path 35 | text: str 36 | elapsed_seconds: float 37 | 38 | 39 | # %% 40 | def load_vlm_artifacts(model_path: str) -> LoadedArtifacts: 41 | logging.info("Loading model %s", model_path) 42 | model, processor = load(model_path) 43 | config = load_config(model_path) 44 | return LoadedArtifacts(model=model, processor=processor, config=config) 45 | 46 | 47 | def validate_image_path(image_path: Path, suffixes: Sequence[str]) -> Path: 48 | resolved = image_path.expanduser().resolve() 49 | if not resolved.exists(): 50 | raise FileNotFoundError(f"Image not found: {resolved}") 51 | suffix_set = {suffix.lower() for suffix in suffixes} 52 | if resolved.suffix.lower() not in suffix_set: 53 | raise ValueError(f"Unsupported image suffix: {resolved.suffix}. Expected one of {sorted(suffix_set)}") 54 | return resolved 55 | 56 | 57 | def generate_text_for_image( 58 | artifacts: LoadedArtifacts, 59 | *, 60 | image_path: Path, 61 | prompt_template: str, 62 | max_tokens: int, 63 | ) -> SingleImageResult: 64 | validated_path = validate_image_path(image_path, CONFIG["supported_image_suffixes"]) 65 | formatted_prompt = apply_chat_template( 66 | artifacts.processor, 67 | artifacts.config, 68 | prompt_template, 69 | num_images=1, 70 | ) 71 | 72 | image_inputs = [validated_path.as_posix()] 73 | start_time = time.perf_counter() 74 | output = generate( 75 | artifacts.model, 76 | artifacts.processor, 77 | formatted_prompt, 78 | image_inputs, 79 | verbose=False, 80 | max_tokens=max_tokens, 81 | ) 82 | elapsed = time.perf_counter() - start_time 83 | 84 | return SingleImageResult(image_path=validated_path, text=output.text.strip(), elapsed_seconds=elapsed) 85 | 86 | 87 | # %% 88 | def convert_pdf_to_images( 89 | pdf_path: Path, 90 | output_dir: Path | None = None, 91 | *, 92 | algorithm: str = "pymupdf", 93 | dpi: int = 200, 94 | jpeg_quality: int = 90, 95 | ) -> List[Path]: 96 | resolved_pdf = pdf_path.expanduser().resolve() 97 | if not resolved_pdf.exists(): 98 | raise FileNotFoundError(f"PDF not found: {resolved_pdf}") 99 | if output_dir is None: 100 | output_dir = resolved_pdf.with_name(f"{resolved_pdf.stem}_pages") 101 | output_dir.mkdir(parents=True, exist_ok=True) 102 | 103 | if algorithm == "pymupdf": 104 | try: 105 | import fitz # type: ignore 106 | except ImportError as exc: 107 | raise RuntimeError("Algorithm 'pymupdf' requires the 'pymupdf' package.") from exc 108 | 109 | scale = dpi / 72 110 | saved_paths: List[Path] = [] 111 | doc = fitz.open(resolved_pdf) 112 | try: 113 | for page_index in range(doc.page_count): 114 | page = doc.load_page(page_index) 115 | pix = page.get_pixmap(matrix=fitz.Matrix(scale, scale)) 116 | output_path = output_dir / f"{resolved_pdf.stem}_page_{page_index + 1:04d}.jpg" 117 | pix.save(output_path.as_posix(), jpg_quality=jpeg_quality) 118 | saved_paths.append(output_path) 119 | finally: 120 | doc.close() 121 | return saved_paths 122 | 123 | if algorithm == "pdf2image": 124 | try: 125 | from pdf2image import convert_from_path 126 | except ImportError as exc: 127 | raise RuntimeError("Algorithm 'pdf2image' requires the 'pdf2image' package and Poppler.") from exc 128 | 129 | images = convert_from_path(resolved_pdf.as_posix(), dpi=dpi) 130 | saved_paths = [] 131 | for index, image in enumerate(images, start=1): 132 | output_path = output_dir / f"{resolved_pdf.stem}_page_{index:04d}.jpg" 133 | image.save(output_path, format="JPEG", quality=jpeg_quality, optimize=True) 134 | saved_paths.append(output_path) 135 | return saved_paths 136 | 137 | raise ValueError(f"Unsupported algorithm: {algorithm}") 138 | 139 | 140 | # %% 141 | PDF_PATH = "../../DATA/PDFS/CVSPharmacyRateSchedule.pdf" 142 | PDF_PATH = "../../DATA/PDFS/Walgreens.pdf" 143 | PDF_OUTPUT_DIR = "../../DATA/PDFS/Walgreens_pages" 144 | PDF_ALGORITHM = "pymupdf" 145 | PDF_DPI = 200 146 | PDF_JPEG_QUALITY = 90 147 | 148 | converted_pages: List[Path] = [] 149 | if PDF_PATH is not None: 150 | converted_pages = convert_pdf_to_images( 151 | Path(PDF_PATH), 152 | Path(PDF_OUTPUT_DIR) if PDF_OUTPUT_DIR is not None else None, 153 | algorithm=PDF_ALGORITHM, 154 | dpi=PDF_DPI, 155 | jpeg_quality=PDF_JPEG_QUALITY, 156 | ) 157 | converted_pages 158 | 159 | 160 | # %% 161 | logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") 162 | artifacts = load_vlm_artifacts(CONFIG["model_path"]) 163 | 164 | 165 | # %% 166 | TARGET_IMAGE_PATH = "../../DATA/TABLE_IMAGES/prior-auth-example.png" 167 | PROMPT_TEMPLATE = CONFIG["prompt_template"] 168 | MAX_TOKENS = CONFIG["max_tokens"] 169 | 170 | 171 | # %% 172 | if TARGET_IMAGE_PATH is None: 173 | raise ValueError("Set TARGET_IMAGE_PATH to an actual image path before running this cell.") 174 | 175 | result = generate_text_for_image( 176 | artifacts, 177 | image_path=Path(TARGET_IMAGE_PATH), 178 | prompt_template=PROMPT_TEMPLATE, 179 | max_tokens=MAX_TOKENS, 180 | ) 181 | 182 | print(result.text) 183 | result.elapsed_seconds 184 | 185 | # %% 186 | -------------------------------------------------------------------------------- /src/evaluations/run_evaluations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import subprocess 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | # Add the project root to the Python path to allow importing from src 9 | project_root = Path(__file__).resolve().parents[1] 10 | sys.path.append(str(project_root)) 11 | 12 | def run_inference(model_path, adapter_path, prompt_text, temp, top_p, rep_penalty): 13 | """Runs the inference script with the given parameters.""" 14 | command = [ 15 | "python", 16 | str(project_root / "src/inference/generate_qwen3.py"), 17 | "--model-path", model_path, 18 | "--prompt", "-", # Indicate prompt comes from stdin 19 | "--temp", str(temp), 20 | "--top-p", str(top_p), 21 | "--repetition-penalty", str(rep_penalty), 22 | # Add other necessary args like max_tokens if needed 23 | # "--max-tokens", "512", 24 | ] 25 | if adapter_path: 26 | command.extend(["--adapter-path", adapter_path]) 27 | 28 | try: 29 | # Use subprocess.run to execute the command 30 | # Pass the prompt via stdin 31 | result = subprocess.run( 32 | command, 33 | input=prompt_text, 34 | text=True, 35 | capture_output=True, 36 | check=True, # Raise an exception if the command fails 37 | encoding='utf-8' 38 | ) 39 | return result.stdout.strip() 40 | except FileNotFoundError: 41 | print(f"Error: The script 'src/inference/generate_qwen3.py' was not found.", file=sys.stderr) 42 | sys.exit(1) 43 | except subprocess.CalledProcessError as e: 44 | print(f"Error during inference subprocess execution:", file=sys.stderr) 45 | print(f"Command: {' '.join(e.cmd)}", file=sys.stderr) 46 | print(f"Return code: {e.returncode}", file=sys.stderr) 47 | print(f"Stderr: {e.stderr}", file=sys.stderr) 48 | print(f"Stdout: {e.stdout}", file=sys.stderr) 49 | return None # Or re-raise the exception if preferred 50 | 51 | def main(): 52 | parser = argparse.ArgumentParser(description="Run batch evaluations using generate_qwen3.py") 53 | parser.add_argument("--model-path", required=True, help="Path to the MLX model directory.") 54 | parser.add_argument("--adapter-path", default=None, help="Path to the adapter directory (optional).") 55 | parser.add_argument("--valid-jsonl-path", required=True, help="Path to the validation JSONL file.") 56 | parser.add_argument("--output-dir", default="eval_outputs", help="Directory to save evaluation outputs.") 57 | parser.add_argument("--num-examples", type=int, default=10, help="Number of examples to run from the JSONL file.") 58 | parser.add_argument("--prompt-key", default="prompt", help="The key in the JSONL file containing the prompt text.") 59 | # Add generation parameters matching the inference script 60 | parser.add_argument("--temp", type=float, default=0.75, help="Sampling temperature.") 61 | parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling.") 62 | parser.add_argument("--repetition-penalty", type=float, default=1.1, help="Repetition penalty.") 63 | 64 | args = parser.parse_args() 65 | 66 | # Validate paths 67 | model_path_obj = Path(args.model_path) 68 | valid_jsonl_path_obj = Path(args.valid_jsonl_path) 69 | output_dir_obj = Path(args.output_dir) 70 | adapter_path_obj = Path(args.adapter_path) if args.adapter_path else None 71 | 72 | if not model_path_obj.is_dir(): 73 | print(f"Error: Model path '{args.model_path}' not found or not a directory.", file=sys.stderr) 74 | sys.exit(1) 75 | if args.adapter_path and not adapter_path_obj.is_dir(): 76 | print(f"Error: Adapter path '{args.adapter_path}' not found or not a directory.", file=sys.stderr) 77 | sys.exit(1) 78 | if not valid_jsonl_path_obj.is_file(): 79 | print(f"Error: Validation JSONL file '{args.valid_jsonl_path}' not found.", file=sys.stderr) 80 | sys.exit(1) 81 | 82 | # Create output directory 83 | output_dir_obj.mkdir(parents=True, exist_ok=True) 84 | output_file = output_dir_obj / "results.jsonl" 85 | 86 | print(f"Running evaluations for {args.num_examples} examples...") 87 | print(f"Model: {args.model_path}") 88 | if args.adapter_path: 89 | print(f"Adapter: {args.adapter_path}") 90 | print(f"Input: {args.valid_jsonl_path}") 91 | print(f"Output will be saved to: {output_file}") 92 | 93 | count = 0 94 | results = [] 95 | try: 96 | with open(args.valid_jsonl_path, 'r', encoding='utf-8') as infile, \ 97 | open(output_file, 'w', encoding='utf-8') as outfile: 98 | for line in infile: 99 | if count >= args.num_examples: 100 | break 101 | try: 102 | data = json.loads(line.strip()) 103 | except json.JSONDecodeError: 104 | print(f"Warning: Skipping invalid JSON line: {line.strip()}", file=sys.stderr) 105 | continue 106 | 107 | if args.prompt_key not in data: 108 | print(f"Warning: Prompt key '{args.prompt_key}' not found in JSON line: {line.strip()}. Skipping.", file=sys.stderr) 109 | continue 110 | 111 | prompt = data[args.prompt_key] 112 | print(f"\nRunning example {count + 1}/{args.num_examples}...") 113 | # print(f"Prompt: {prompt[:100]}...") # Optionally print start of prompt 114 | 115 | generation = run_inference( 116 | model_path=args.model_path, 117 | adapter_path=args.adapter_path, 118 | prompt_text=prompt, 119 | temp=args.temp, 120 | top_p=args.top_p, 121 | rep_penalty=args.repetition_penalty 122 | ) 123 | 124 | if generation is not None: 125 | print(f"Generation successful.") 126 | result_data = { 127 | "prompt": prompt, 128 | "generation": generation, 129 | "original_data": data # Keep original data for reference 130 | } 131 | # Write result to output file immediately 132 | outfile.write(json.dumps(result_data) + '\\n') 133 | outfile.flush() # Ensure it's written in case of interruption 134 | results.append(result_data) 135 | count += 1 136 | else: 137 | print(f"Generation failed for example {count + 1}.") 138 | # Decide if you want to stop or continue 139 | # break 140 | 141 | except FileNotFoundError: 142 | print(f"Error: Could not open validation file '{args.valid_jsonl_path}'", file=sys.stderr) 143 | sys.exit(1) 144 | except Exception as e: 145 | print(f"An unexpected error occurred: {e}", file=sys.stderr) 146 | sys.exit(1) 147 | 148 | print(f"\nEvaluation complete. {count} results saved to {output_file}") 149 | 150 | if __name__ == "__main__": 151 | main() -------------------------------------------------------------------------------- /src/data_processing/semantic_chunker.py: -------------------------------------------------------------------------------- 1 | # semantic_chunker.py 2 | import json, re 3 | import numpy as np 4 | from pathlib import Path 5 | from typing import List 6 | from tqdm import tqdm 7 | 8 | import numpy as np 9 | from baseline_chunker import load_paragraphs, chunk_paragraphs 10 | from src.data_processing.bm25_func import bm25_gap_violation 11 | 12 | from sentence_transformers import SentenceTransformer 13 | 14 | # ==== plug‑in hooks ========================================================= 15 | # Removed embed and similarities helper functions as logic is now inline 16 | # ============================================================================ 17 | 18 | def refine_boundaries(chunks: list[str], 19 | tail_len: int = 2, 20 | head_len: int = 2, 21 | thresh_low: float = 0.15, 22 | thresh_high: float = 0.65, 23 | char_names: list[str] = [ 24 | "Erasmus", "Paris", "Thurso", "William Kemp", 25 | "Sarah", "Blair", "Calley", "Kireku", "Delblanc", "Daniel", 26 | "Liverpool Merchant", "Loango", "Bonny", "Whydah", 27 | "Bight of Benin", "Barbados", "Liverpool", "Mersey" 28 | ], 29 | max_bm25_gap: int = 4, 30 | max_size: int = 1200, 31 | model: SentenceTransformer = None): 32 | new_chunks = [] 33 | modification_count = 0 # Initialize modification counter 34 | modified_current_boundary = False # Flag to track if current boundary was modified 35 | 36 | # Wrap chunks with tqdm for progress bar 37 | for i, chunk in enumerate(tqdm(chunks, desc="Refining chunk boundaries")): 38 | modified_current_boundary = False # Reset flag for each boundary check 39 | if i == 0: 40 | new_chunks.append(chunk) 41 | continue 42 | 43 | prev = new_chunks[-1] 44 | # candidate sentences near the join 45 | prev_sents = re.split(r'(?<=[.!?])\\s+', prev) 46 | next_sents = re.split(r'(?<=[.!?])\\s+', chunk) 47 | tail = " ".join(prev_sents[-tail_len:]) 48 | head = " ".join(next_sents[:head_len]) 49 | 50 | # Encode tail as document, head as query 51 | tail_embedding = model.encode(f"search_document: {tail}") 52 | head_embedding = model.encode(f"search_query: {head}") 53 | # Calculate similarity 54 | sim = model.similarity(tail_embedding, head_embedding)[0, 0] # Access the single similarity score 55 | 56 | if sim < thresh_low: 57 | # move first paragraph of current chunk back to previous 58 | para_split = re.split(r'\\n{2,}', chunk, maxsplit=1) 59 | if len(para_split) == 2: 60 | new_chunks[-1] += "\\n\\n" + para_split[0] 61 | chunk = para_split[1] 62 | modification_count += 1 63 | modified_current_boundary = True 64 | elif sim > thresh_high: 65 | # duplicate a connecting sentence for continuity 66 | new_chunks[-1] += " " + head 67 | modification_count += 1 68 | modified_current_boundary = True 69 | 70 | # BM25 continuity check (pseudo) - only check if not already modified by similarity 71 | if not modified_current_boundary: 72 | for character in char_names: 73 | if bm25_gap_violation((new_chunks[-1], chunk), character, max_bm25_gap): 74 | # pull one paragraph back if gap too wide 75 | para_split = re.split(r'\\n{2,}', chunk, maxsplit=1) 76 | if len(para_split) == 2: 77 | new_chunks[-1] += "\\n\\n" + para_split[0] 78 | chunk = para_split[1] 79 | modification_count += 1 80 | modified_current_boundary = True 81 | break # Stop checking characters for this boundary once modified 82 | 83 | # enforce hard upper size 84 | if len(chunk.split()) > max_size: 85 | # optional second pass of baseline chunking just on *this* oversize chunk 86 | mini_chunks = chunk_paragraphs( 87 | re.split(r'\\n{2,}', chunk), target_words=max_size) 88 | new_chunks.extend(mini_chunks) 89 | modification_count += 1 # Count the split as one modification event 90 | # Don't append the original oversized chunk 91 | else: 92 | # Only append if the chunk wasn't replaced by mini_chunks 93 | new_chunks.append(chunk) 94 | 95 | return new_chunks, modification_count # Return modification count 96 | 97 | if __name__ == "__main__": 98 | import argparse 99 | 100 | DEFAULT_TARGET_WORDS = 350 # 350 480 520 570 680 730 790 101 | DEFAULT_MODEL_NAME = "lightonai/modernbert-embed-large" 102 | DEFAULT_BOOK_PATH = "sacredhunger.txt" 103 | DEFAULT_OUTPUT_PATH = f"sacredhunger_{DEFAULT_TARGET_WORDS}.json" 104 | 105 | ap = argparse.ArgumentParser( 106 | description="Chunk a book into semantic segments, refining initial paragraph-based chunks." 107 | ) 108 | ap.add_argument( 109 | "--book_path", 110 | type=str, 111 | default=DEFAULT_BOOK_PATH, 112 | help=f"Path to the input text file (book). Defaults to '{DEFAULT_BOOK_PATH}'.", 113 | ) 114 | ap.add_argument( 115 | "--target", 116 | type=int, 117 | default=DEFAULT_TARGET_WORDS, 118 | help=f"Target number of words per chunk (default: {DEFAULT_TARGET_WORDS}). " 119 | f"Note: Baseline chunking aims for 90% of this target.", 120 | ) 121 | ap.add_argument( 122 | "--model_name", 123 | type=str, 124 | default=DEFAULT_MODEL_NAME, 125 | help=f"Name of the SentenceTransformer model to use for embeddings (default: {DEFAULT_MODEL_NAME})." 126 | ) 127 | ap.add_argument( 128 | "--output_path", 129 | type=str, 130 | default=DEFAULT_OUTPUT_PATH, 131 | help=f"Path to the output file (default: {DEFAULT_OUTPUT_PATH})." 132 | ) 133 | args = ap.parse_args() 134 | 135 | if args.target != DEFAULT_TARGET_WORDS: 136 | args.output_path = f"allthekingsmen_{args.target}.json" 137 | print(f"Using output path: {args.output_path}") 138 | 139 | print(f"Loading sentence transformer model: {args.model_name}") 140 | model = SentenceTransformer(args.model_name, trust_remote_code=True) 141 | 142 | print(f"Loading paragraphs from: {args.book_path}") 143 | paras = load_paragraphs(args.book_path) 144 | 145 | print(f"Creating baseline chunks (target ~{int(args.target * 0.9)} words)...") 146 | base = chunk_paragraphs(paras, int(args.target * 0.9)) 147 | 148 | print(f"Refining {len(base)} baseline chunks...") 149 | refined, mod_count = refine_boundaries(base, model=model) # Capture modification count 150 | print(f"Refinement process modified {mod_count} chunk boundaries.") # Print count 151 | 152 | Path(args.output_path).write_text( 153 | json.dumps({"chunks": refined, "count": len(refined)}, indent=2), 154 | encoding="utf‑8") 155 | print(f"Wrote {len(refined)} refined chunks to {args.output_path}") 156 | -------------------------------------------------------------------------------- /src/data_processing/jsonl_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from pathlib import Path 4 | from typing import List, Any, Dict 5 | 6 | def shuffle_jsonl(input_path: str | Path, output_path: str | Path, seed: int) -> None: 7 | """ 8 | Reads a JSONL file, shuffles its lines randomly using a seed, 9 | and writes the shuffled lines to a new JSONL file. 10 | 11 | Args: 12 | input_path: Path to the input JSONL file. 13 | output_path: Path to save the shuffled output JSONL file. 14 | seed: Random seed for shuffling. 15 | """ 16 | input_path = Path(input_path) 17 | output_path = Path(output_path) 18 | output_path.parent.mkdir(parents=True, exist_ok=True) 19 | 20 | lines: List[Dict[str, Any]] = [] 21 | with open(input_path, "r", encoding="utf-8") as infile: 22 | for line in infile: 23 | try: 24 | lines.append(json.loads(line.strip())) 25 | except json.JSONDecodeError: 26 | print(f"Warning: Skipping invalid JSON line in {input_path}: {line.strip()}") 27 | 28 | print(f"Read {len(lines)} lines from {input_path}.") 29 | 30 | random.seed(seed) 31 | random.shuffle(lines) 32 | print(f"Shuffled {len(lines)} lines using seed {seed}.") 33 | 34 | with open(output_path, "w", encoding="utf-8") as outfile: 35 | for item in lines: 36 | outfile.write(json.dumps(item, ensure_ascii=False) + "\n") 37 | 38 | print(f"Saved shuffled data to {output_path}.") 39 | 40 | 41 | def concatenate_and_shuffle_jsonl( 42 | input_paths: List[str | Path], output_path: str | Path, seed: int 43 | ) -> None: 44 | """ 45 | Loads multiple JSONL files, concatenates their contents, shuffles the combined 46 | lines randomly using a seed, and saves them as a single JSONL file. 47 | 48 | Args: 49 | input_paths: A list of paths to the input JSONL files. 50 | output_path: Path to save the concatenated and shuffled output JSONL file. 51 | seed: Random seed for shuffling. 52 | """ 53 | output_path = Path(output_path) 54 | output_path.parent.mkdir(parents=True, exist_ok=True) 55 | 56 | all_lines: List[Dict[str, Any]] = [] 57 | total_lines_read = 0 58 | for input_path_str in input_paths: 59 | input_path = Path(input_path_str) 60 | if not input_path.is_file(): 61 | print(f"Warning: Input file not found, skipping: {input_path}") 62 | continue 63 | 64 | current_file_lines: List[Dict[str, Any]] = [] 65 | with open(input_path, "r", encoding="utf-8") as infile: 66 | for line_num, line in enumerate(infile, 1): 67 | stripped_line = line.strip() 68 | if not stripped_line: # Skip empty lines 69 | continue 70 | try: 71 | obj = json.loads(stripped_line) 72 | # Assuming primary use case is JSON objects per line 73 | if isinstance(obj, dict): 74 | current_file_lines.append(obj) 75 | else: 76 | # Handle cases like top-level arrays, strings, numbers if needed, 77 | # but typically JSONL has objects. Warn if not an object. 78 | print(f"Warning: Skipping non-object JSON on line {line_num} in {input_path}: Type={type(obj)}") 79 | # Optionally append if other types are expected: current_file_lines.append(obj) 80 | 81 | except json.JSONDecodeError as e: 82 | # Provide more context about the error 83 | print(f"Warning: Skipping invalid JSON on line {line_num} in {input_path}.") 84 | print(f" Error: {e}") 85 | # Optionally print the problematic line (or part of it) 86 | max_len = 150 # Increased preview length 87 | line_preview = stripped_line[:max_len] + ('...' if len(stripped_line) > max_len else '') 88 | print(f" Line content (preview): {line_preview}") 89 | print(f"Read {len(current_file_lines)} valid JSON objects from {input_path}.") 90 | all_lines.extend(current_file_lines) 91 | total_lines_read += len(current_file_lines) # Count only successfully read objects 92 | 93 | # Adjust total lines read message to reflect valid objects 94 | print(f"Read a total of {total_lines_read} valid JSON objects from {len(input_paths)} files.") 95 | 96 | if not all_lines: 97 | print("Warning: No lines were read from any input file. Output file will be empty.") 98 | # Create an empty file 99 | with open(output_path, "w", encoding="utf-8") as outfile: 100 | pass 101 | print(f"Created empty output file at {output_path}.") 102 | return 103 | 104 | random.seed(seed) 105 | random.shuffle(all_lines) 106 | print(f"Shuffled {len(all_lines)} combined lines using seed {seed}.") 107 | 108 | # Revert to default newline handling by removing newline='' 109 | with open(output_path, "w", encoding="utf-8") as outfile: 110 | for i, item in enumerate(all_lines): 111 | # Add a check just in case non-dict items slipped through (shouldn't happen with current read logic) 112 | if not isinstance(item, dict): 113 | print(f"Error: Item at index {i} after shuffle is not a dict: {type(item)}. Skipping.") 114 | continue 115 | try: 116 | # Write the JSON object followed by a newline 117 | outfile.write(json.dumps(item, ensure_ascii=False) + "\n") 118 | except Exception as e: 119 | # Catch potential errors during dumping (e.g., complex objects not serializable) 120 | print(f"Error dumping item at index {i} to JSON. Skipping.") 121 | print(f" Item Preview: {str(item)[:100]}...") # Avoid printing huge items 122 | print(f" Error: {e}") 123 | 124 | print(f"Saved concatenated and shuffled data to {output_path}.") 125 | 126 | if __name__ == '__main__': 127 | # Example Usage (replace with your actual file paths and desired seed) 128 | 129 | # --- Example 1: Shuffle a single file --- 130 | # input_single = Path("path/to/your/input.jsonl") 131 | # output_single = Path("path/to/your/shuffled_output.jsonl") 132 | # random_seed = 42 133 | # if input_single.exists(): 134 | # shuffle_jsonl(input_single, output_single, random_seed) 135 | # else: 136 | # print(f"Example 1 input file not found: {input_single}") 137 | 138 | # --- Example 2: Concatenate and shuffle multiple files --- 139 | 140 | input_multiple = [ 141 | Path("../../DATA/ALLTHEKINGSMEN/train.jsonl"), 142 | Path("../../DATA/SACREDHUNGER/train.jsonl"), 143 | # Path("../../DATA/ALLTHEKINGSMEN/valid.jsonl"), 144 | # Path("../../DATA/SACREDHUNGER/valid.jsonl"), 145 | ] 146 | output_multiple = Path("../../DATA/NOVELS/train.jsonl") 147 | random_seed_concat = 123 148 | 149 | existing_input_multiple = [p for p in input_multiple if p.exists()] 150 | if not existing_input_multiple: 151 | print("No input files found.") 152 | elif len(existing_input_multiple) < len(input_multiple): 153 | print(f"Found {len(existing_input_multiple)} out of {len(input_multiple)} input files.") 154 | 155 | 156 | if existing_input_multiple: 157 | concatenate_and_shuffle_jsonl(existing_input_multiple, output_multiple, random_seed_concat) 158 | 159 | print("\nScript finished.") 160 | -------------------------------------------------------------------------------- /src/inference/generate-qwen-vlm.py: -------------------------------------------------------------------------------- 1 | """Convert a directory of page images into a single markdown document using Qwen VLM.""" 2 | 3 | from __future__ import annotations 4 | 5 | import argparse 6 | import logging 7 | import re 8 | import sys 9 | import time 10 | from dataclasses import dataclass 11 | from pathlib import Path 12 | from typing import List, Sequence 13 | 14 | from mlx_vlm import generate, load 15 | from mlx_vlm.prompt_utils import apply_chat_template 16 | from mlx_vlm.utils import load_config 17 | 18 | 19 | # --------------------------------------------------------------------------- 20 | # Configuration defaults live here to keep tuning data-driven and low blast. 21 | # --------------------------------------------------------------------------- 22 | CONFIG = { 23 | "model_path": "mlx-community/Qwen3-VL-32B-Instruct-8bit", 24 | "page_prompt_template": ( 25 | "Put all the text from this page into markdown format." 26 | ), 27 | "heading_template": "## Page {page_number}", 28 | "max_tokens": 3_200, 29 | "output_suffix": "_combined.md", 30 | "default_output_filename": "output.md", 31 | "supported_image_suffixes": {".jpg", ".jpeg", ".png", ".webp"}, 32 | } 33 | 34 | 35 | PAGE_NUMBER_PATTERN = re.compile(r"(\d+)$") 36 | 37 | 38 | @dataclass(frozen=True) 39 | class PageInferenceResult: 40 | page_number: int 41 | image_path: Path 42 | markdown: str 43 | elapsed_seconds: float 44 | 45 | 46 | def _parse_arguments() -> argparse.Namespace: 47 | parser = argparse.ArgumentParser( 48 | description="Generate a markdown document by processing page images with Qwen VLM.", 49 | ) 50 | parser.add_argument( 51 | "image_dir", 52 | type=Path, 53 | help="Directory containing per-page images (e.g. *_page_0001.jpg).", 54 | ) 55 | parser.add_argument( 56 | "--output", 57 | type=Path, 58 | default=None, 59 | help="Where to write the markdown output (defaults next to the image directory).", 60 | ) 61 | parser.add_argument( 62 | "--prompt", 63 | type=str, 64 | default=CONFIG["page_prompt_template"], 65 | help="Prompt template applied to each page (supports {page_number} and {image_name}).", 66 | ) 67 | parser.add_argument( 68 | "--heading-template", 69 | type=str, 70 | default=CONFIG["heading_template"], 71 | help="Markdown heading template per page (supports {page_number} and {image_name}).", 72 | ) 73 | parser.add_argument( 74 | "--max-tokens", 75 | type=int, 76 | default=CONFIG["max_tokens"], 77 | help="Maximum tokens to generate per page.", 78 | ) 79 | 80 | args = parser.parse_args() 81 | 82 | if args.max_tokens <= 0: 83 | parser.error("--max-tokens must be positive") 84 | 85 | return args 86 | 87 | 88 | def _resolve_output_path(image_dir: Path, explicit_output: Path | None) -> Path: 89 | if explicit_output is not None: 90 | explicit_output = explicit_output.resolve(strict=False) 91 | if explicit_output.exists() and explicit_output.is_dir(): 92 | return explicit_output / CONFIG["default_output_filename"] 93 | if explicit_output.suffix == "": 94 | return explicit_output / CONFIG["default_output_filename"] 95 | return explicit_output 96 | 97 | default_name = f"{image_dir.name}{CONFIG['output_suffix']}" 98 | return image_dir.parent / default_name 99 | 100 | 101 | def _collect_image_paths(image_dir: Path, suffixes: Sequence[str]) -> List[Path]: 102 | if not image_dir.exists(): 103 | raise FileNotFoundError(f"Image directory not found: {image_dir}") 104 | if not image_dir.is_dir(): 105 | raise NotADirectoryError(f"Expected a directory path: {image_dir}") 106 | 107 | suffix_set = {s.lower() for s in suffixes} 108 | 109 | candidates = [ 110 | path 111 | for path in image_dir.iterdir() 112 | if path.is_file() and path.suffix.lower() in suffix_set 113 | ] 114 | 115 | if not candidates: 116 | raise FileNotFoundError( 117 | f"No supported image files found in {image_dir}. Expected suffixes: {sorted(suffix_set)}" 118 | ) 119 | 120 | sortable = [] 121 | unsortable: List[Path] = [] 122 | for path in candidates: 123 | match = PAGE_NUMBER_PATTERN.search(path.stem) 124 | if match: 125 | sortable.append((int(match.group(1)), path)) 126 | else: 127 | unsortable.append(path) 128 | 129 | if unsortable: 130 | logging.warning( 131 | "The following files are missing a trailing page number and will appear last in lexicographic order: %s", 132 | ", ".join(sorted(p.name for p in unsortable)), 133 | ) 134 | 135 | sorted_paths = [path for _, path in sorted(sortable, key=lambda item: item[0])] 136 | sorted_paths.extend(sorted(unsortable)) 137 | return sorted_paths 138 | 139 | 140 | def _load_model(model_path: str): 141 | logging.info("Loading model %s", model_path) 142 | model, processor = load(model_path) 143 | config = load_config(model_path) 144 | return model, processor, config 145 | 146 | 147 | def _run_inference( 148 | model, 149 | processor, 150 | config, 151 | *, 152 | image_path: Path, 153 | prompt_template: str, 154 | heading_template: str, 155 | page_number: int, 156 | max_tokens: int, 157 | ) -> PageInferenceResult: 158 | prompt = prompt_template.format(page_number=page_number, image_name=image_path.name) 159 | formatted_prompt = apply_chat_template( 160 | processor, 161 | config, 162 | prompt, 163 | num_images=1, 164 | ) 165 | 166 | image_inputs = [image_path.as_posix()] 167 | start_time = time.perf_counter() 168 | output = generate( 169 | model, 170 | processor, 171 | formatted_prompt, 172 | image_inputs, 173 | verbose=False, 174 | max_tokens=max_tokens, 175 | ) 176 | elapsed = time.perf_counter() - start_time 177 | 178 | heading = heading_template.format(page_number=page_number, image_name=image_path.name) 179 | markdown_body = output.text.strip() 180 | markdown_block = f"{heading}\n\n{markdown_body}\n" 181 | 182 | logging.info( 183 | "Page %s processed in %.2f seconds (%s)", 184 | page_number, 185 | elapsed, 186 | image_path.name, 187 | ) 188 | 189 | return PageInferenceResult( 190 | page_number=page_number, 191 | image_path=image_path, 192 | markdown=markdown_block, 193 | elapsed_seconds=elapsed, 194 | ) 195 | 196 | 197 | def main() -> int: 198 | logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") 199 | 200 | args = _parse_arguments() 201 | image_dir = args.image_dir.resolve() 202 | output_path = _resolve_output_path(image_dir, args.output) 203 | 204 | image_paths = _collect_image_paths(image_dir, CONFIG["supported_image_suffixes"]) 205 | logging.info("Found %s page images in %s", len(image_paths), image_dir) 206 | 207 | model, processor, config = _load_model(CONFIG["model_path"]) 208 | 209 | output_path.parent.mkdir(parents=True, exist_ok=True) 210 | 211 | page_count = 0 212 | total_time = 0.0 213 | with output_path.open("w", encoding="utf-8") as output_file: 214 | for index, image_path in enumerate(image_paths, start=1): 215 | result = _run_inference( 216 | model, 217 | processor, 218 | config, 219 | image_path=image_path, 220 | prompt_template=args.prompt, 221 | heading_template=args.heading_template, 222 | page_number=index, 223 | max_tokens=args.max_tokens, 224 | ) 225 | output_file.write(result.markdown) 226 | output_file.flush() 227 | page_count += 1 228 | total_time += result.elapsed_seconds 229 | 230 | logging.info( 231 | "Wrote %s pages to %s in %.2f seconds total - avg %.2f seconds/page", 232 | page_count, 233 | output_path, 234 | total_time, 235 | total_time / page_count, 236 | ) 237 | 238 | return 0 239 | 240 | 241 | if __name__ == "__main__": 242 | sys.exit(main()) 243 | -------------------------------------------------------------------------------- /split_pdf_pages.py: -------------------------------------------------------------------------------- 1 | """Split a PDF into per-page JPEGs using selectable backends. 2 | 3 | This script provides multiple conversion algorithms so the user can pick the 4 | one that fits their environment. Each algorithm is declared in the CONFIG 5 | section to keep defaults visible and easy to change. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import argparse 11 | import logging 12 | import sys 13 | from dataclasses import dataclass 14 | from pathlib import Path 15 | from typing import Callable, Dict, Iterable, List, Protocol 16 | 17 | 18 | # --------------------------------------------------------------------------- 19 | # Configuration: adjust these values to tune defaults without touching logic. 20 | # --------------------------------------------------------------------------- 21 | CONFIG = { 22 | "default_quality": 90, # JPEG quality: 1 (worst) .. 100 (best) 23 | "default_algorithm": "pymupdf", 24 | "algorithm_settings": { 25 | # PyMuPDF offers fast native rendering. Requires "pymupdf" to be installed. 26 | "pymupdf": { 27 | "dpi": 200, 28 | }, 29 | # pdf2image leverages Poppler. Requires "pdf2image" and Poppler utilities. 30 | "pdf2image": { 31 | "dpi": 200, 32 | "thread_count": 2, 33 | }, 34 | }, 35 | } 36 | 37 | 38 | class ConversionError(RuntimeError): 39 | """Base class for conversion-related failures.""" 40 | 41 | 42 | class MissingDependencyError(ConversionError): 43 | """Raised when an algorithm cannot run because a dependency is absent.""" 44 | 45 | 46 | class ConversionAlgorithm(Protocol): 47 | """Callable signature all conversion backends must satisfy.""" 48 | 49 | def __call__( 50 | self, 51 | pdf_path: Path, 52 | output_dir: Path, 53 | *, 54 | jpeg_quality: int, 55 | settings: Dict[str, int | float | str | None], 56 | ) -> List[Path]: 57 | ... 58 | 59 | 60 | @dataclass(frozen=True) 61 | class AlgorithmDefinition: 62 | name: str 63 | executor: ConversionAlgorithm 64 | settings: Dict[str, int | float | str | None] 65 | 66 | 67 | def _convert_with_pymupdf( 68 | pdf_path: Path, 69 | output_dir: Path, 70 | *, 71 | jpeg_quality: int, 72 | settings: Dict[str, int | float | str | None], 73 | ) -> List[Path]: 74 | try: 75 | import fitz # PyMuPDF 76 | except ImportError as exc: 77 | raise MissingDependencyError( 78 | "Algorithm 'pymupdf' requires the 'pymupdf' package." 79 | ) from exc 80 | 81 | dpi = int(settings.get("dpi", 200) or 200) 82 | scale = dpi / 72 # PDF points per inch 83 | 84 | saved_paths: List[Path] = [] 85 | doc = fitz.open(pdf_path) 86 | try: 87 | for page_index in range(doc.page_count): 88 | page = doc.load_page(page_index) 89 | pix = page.get_pixmap(matrix=fitz.Matrix(scale, scale)) 90 | output_path = output_dir / f"{pdf_path.stem}_page_{page_index + 1:04d}.jpg" 91 | pix.save(output_path.as_posix(), jpg_quality=jpeg_quality) 92 | saved_paths.append(output_path) 93 | logging.debug( 94 | "Rendered page %s with PyMuPDF at %sdpi to %s", 95 | page_index + 1, 96 | dpi, 97 | output_path, 98 | ) 99 | finally: 100 | doc.close() 101 | 102 | return saved_paths 103 | 104 | 105 | def _convert_with_pdf2image( 106 | pdf_path: Path, 107 | output_dir: Path, 108 | *, 109 | jpeg_quality: int, 110 | settings: Dict[str, int | float | str | None], 111 | ) -> List[Path]: 112 | try: 113 | from pdf2image import convert_from_path 114 | except ImportError as exc: 115 | raise MissingDependencyError( 116 | "Algorithm 'pdf2image' requires the 'pdf2image' package and Poppler." 117 | ) from exc 118 | 119 | dpi = int(settings.get("dpi", 200) or 200) 120 | thread_count = settings.get("thread_count") 121 | 122 | images = convert_from_path( 123 | pdf_path.as_posix(), 124 | dpi=dpi, 125 | thread_count=int(thread_count) if thread_count else None, 126 | ) 127 | 128 | saved_paths: List[Path] = [] 129 | for index, image in enumerate(images, start=1): 130 | output_path = output_dir / f"{pdf_path.stem}_page_{index:04d}.jpg" 131 | image.save( 132 | output_path, 133 | format="JPEG", 134 | quality=jpeg_quality, 135 | optimize=True, 136 | ) 137 | saved_paths.append(output_path) 138 | logging.debug( 139 | "Rendered page %s with pdf2image at %sdpi to %s", 140 | index, 141 | dpi, 142 | output_path, 143 | ) 144 | 145 | return saved_paths 146 | 147 | 148 | ALGORITHM_REGISTRY: Dict[str, ConversionAlgorithm] = { 149 | "pymupdf": _convert_with_pymupdf, 150 | "pdf2image": _convert_with_pdf2image, 151 | } 152 | 153 | 154 | def _build_algorithm_definitions() -> Dict[str, AlgorithmDefinition]: 155 | definitions: Dict[str, AlgorithmDefinition] = {} 156 | for name, executor in ALGORITHM_REGISTRY.items(): 157 | settings = CONFIG["algorithm_settings"].get(name, {}) 158 | definitions[name] = AlgorithmDefinition(name=name, executor=executor, settings=settings) 159 | return definitions 160 | 161 | 162 | def _parse_arguments(available_algorithms: Iterable[str]) -> argparse.Namespace: 163 | parser = argparse.ArgumentParser( 164 | description="Split a PDF into JPEG pages using the configured conversion backend.", 165 | ) 166 | parser.add_argument("pdf_path", type=Path, help="Path to the PDF file to split") 167 | parser.add_argument( 168 | "--algorithm", 169 | choices=sorted(available_algorithms), 170 | default=CONFIG["default_algorithm"], 171 | help="Conversion backend to use (default: %(default)s)", 172 | ) 173 | parser.add_argument( 174 | "--quality", 175 | type=int, 176 | default=CONFIG["default_quality"], 177 | help="JPEG quality between 1-100 (default: %(default)s)", 178 | ) 179 | parser.add_argument( 180 | "--output-dir", 181 | type=Path, 182 | default=None, 183 | help="Optional output directory (defaults to <pdf_stem>_pages alongside the PDF)", 184 | ) 185 | 186 | args = parser.parse_args() 187 | 188 | if not 1 <= args.quality <= 100: 189 | parser.error("quality must be between 1 and 100") 190 | 191 | return args 192 | 193 | 194 | def _prepare_output_dir(pdf_path: Path, explicit_output_dir: Path | None) -> Path: 195 | if explicit_output_dir is not None: 196 | output_dir = explicit_output_dir 197 | else: 198 | output_dir = pdf_path.with_name(f"{pdf_path.stem}_pages") 199 | 200 | output_dir.mkdir(parents=True, exist_ok=True) 201 | return output_dir 202 | 203 | 204 | def _ensure_pdf_exists(pdf_path: Path) -> None: 205 | if not pdf_path.exists(): 206 | raise FileNotFoundError(f"PDF not found: {pdf_path}") 207 | if not pdf_path.is_file(): 208 | raise ValueError(f"Expected a file path, got: {pdf_path}") 209 | 210 | 211 | def main() -> int: 212 | logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") 213 | 214 | algorithm_definitions = _build_algorithm_definitions() 215 | args = _parse_arguments(algorithm_definitions.keys()) 216 | 217 | pdf_path: Path = args.pdf_path.resolve() 218 | _ensure_pdf_exists(pdf_path) 219 | 220 | output_dir = _prepare_output_dir(pdf_path, args.output_dir) 221 | logging.info("Saving pages to %s", output_dir) 222 | 223 | definition = algorithm_definitions[args.algorithm] 224 | 225 | try: 226 | saved_paths = definition.executor( 227 | pdf_path, 228 | output_dir, 229 | jpeg_quality=args.quality, 230 | settings=definition.settings, 231 | ) 232 | except MissingDependencyError as exc: 233 | logging.error("%s", exc) 234 | return 2 235 | except ConversionError as exc: 236 | logging.error("Conversion failed: %s", exc) 237 | return 3 238 | except Exception as exc: # pragma: no cover - unexpected failure surface 239 | logging.exception("Unexpected error during conversion") 240 | return 4 241 | 242 | logging.info("Wrote %s pages", len(saved_paths)) 243 | for path in saved_paths: 244 | logging.debug("Created %s", path) 245 | 246 | return 0 247 | 248 | 249 | if __name__ == "__main__": 250 | sys.exit(main()) 251 | 252 | -------------------------------------------------------------------------------- /src/finetuning/convert_qwen3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import glob 4 | import json 5 | import shutil 6 | import sys 7 | from pathlib import Path 8 | from typing import Tuple 9 | 10 | import mlx.core as mx 11 | import mlx.nn as nn 12 | from mlx.utils import tree_flatten 13 | 14 | # Add project root to path to find mlx_lm and src.copies 15 | project_root = Path(__file__).resolve().parents[2] 16 | sys.path.insert(0, str(project_root)) 17 | 18 | # Add mlx-lm path specifically for its imports 19 | mlx_lm_path = project_root / "mlx-lm" 20 | sys.path.insert(1, str(mlx_lm_path)) 21 | 22 | # Import necessary functions from original mlx-lm 23 | from mlx_lm.utils import ( 24 | fetch_from_hub, 25 | get_model_path, 26 | make_shards, 27 | save_config, 28 | save_weights, 29 | upload_to_hub, 30 | get_model_path 31 | ) 32 | from mlx_lm.tokenizer_utils import load_tokenizer 33 | 34 | # Import our *custom* Qwen3 model definition 35 | # Need to ensure src/copies/models has an __init__.py if running as script directly 36 | # Or rely on PYTHONPATH containing project root 37 | try: 38 | from src.copies.models.qwen3 import Model, ModelArgs 39 | except ModuleNotFoundError: 40 | print("Error: Could not import custom Qwen3 model from src.copies.models.qwen3") 41 | print("Ensure the project root is in PYTHONPATH or run this script as a module.") 42 | sys.exit(1) 43 | 44 | 45 | def get_custom_model_classes(config: dict) -> Tuple[nn.Module, ModelArgs]: 46 | """Returns our custom Qwen3 model classes if model_type matches.""" 47 | model_type = config.get("model_type") 48 | if model_type == "qwen3": 49 | return Model, ModelArgs 50 | else: 51 | # Fallback or error for other model types if needed 52 | raise ValueError(f"This custom converter only supports model_type 'qwen3', got {model_type}") 53 | 54 | 55 | def custom_load_model( 56 | model_path: Path, 57 | lazy: bool = False, 58 | strict: bool = True, 59 | ) -> nn.Module: 60 | """Loads model using the custom qwen3 class. 61 | 62 | Replicates logic from mlx_lm.utils.load_model but uses get_custom_model_classes. 63 | """ 64 | try: 65 | with open(model_path / "config.json", "r") as f: 66 | config = json.load(f) 67 | except FileNotFoundError: 68 | print(f"Config file not found in {model_path}") 69 | # Fallback for older HF format 70 | try: 71 | with open(model_path / "params.json", "r") as f: 72 | config = json.load(f) 73 | config["model_type"] = config["architectures"][0].lower() 74 | except FileNotFoundError: 75 | raise FileNotFoundError( 76 | f"Neither config.json nor params.json found in {model_path}" 77 | ) 78 | 79 | weight_files = glob.glob(str(model_path / "*.safetensors")) 80 | if not weight_files: 81 | # Fallback for older pytorch format 82 | weight_files = glob.glob(str(model_path / "*.bin")) 83 | if not weight_files: 84 | raise FileNotFoundError(f"No weights found in {model_path}") 85 | 86 | weights = {} 87 | for wf in weight_files: 88 | weights.update(mx.load(wf)) 89 | 90 | # --- Debug: Print initial dtypes --- 91 | print("\n--- Initial loaded weight dtypes (sample) ---") 92 | for k, v in list(weights.items())[:5]: # Print first 5 93 | print(f"{k}: {v.dtype}") 94 | if len(weights) > 5: 95 | for k, v in list(weights.items())[-5:]: # Print last 5 96 | print(f"{k}: {v.dtype}") 97 | print("---------------------------------------------") 98 | # ------------------------------------ 99 | 100 | # Use our custom class getter 101 | model_class, model_args_class = get_custom_model_classes(config=config) 102 | 103 | model_args = model_args_class.from_dict(config) 104 | model = model_class(model_args) 105 | 106 | # The sanitize method is part of our custom Model class 107 | if hasattr(model, "sanitize"): 108 | print("Sanitizing weights using custom model...") 109 | weights = model.sanitize(weights) 110 | 111 | # Quantization is not expected/handled for FP8 conversion here 112 | if (quantization := config.get("quantization", None)) is not None: 113 | print("Warning: Found quantization config, but this custom script doesn't apply it.") 114 | 115 | model.load_weights(list(weights.items()), strict=strict) 116 | 117 | if not lazy: 118 | mx.eval(model.parameters()) 119 | 120 | model.eval() 121 | return model, config 122 | 123 | 124 | def custom_convert( 125 | hf_path: str, 126 | mlx_path: str = "mlx_model", 127 | dtype: str = "float16", 128 | upload_repo: str = None, 129 | ): 130 | """Converts HF Qwen3 model using the custom sanitize logic.""" 131 | print(f"[INFO] Loading model from HF path: {hf_path}") 132 | model_path = get_model_path(hf_path) 133 | mlx_path = Path(mlx_path) 134 | mlx_path.mkdir(parents=True, exist_ok=True) 135 | 136 | # Load model using our custom loader which calls the custom sanitize 137 | model, config = custom_load_model(model_path, lazy=True) 138 | 139 | # Load tokenizer using standard mlx-lm function 140 | tokenizer = load_tokenizer(model_path) 141 | 142 | weights = dict(tree_flatten(model.parameters())) 143 | 144 | # --- Debug: Print dtypes before astype conversion --- 145 | print("\n--- Weight dtypes before astype() (sample) ---") 146 | param_list = list(weights.items()) 147 | for k, v in param_list[:5]: # Print first 5 148 | print(f"{k}: {v.dtype}") 149 | if len(param_list) > 5: 150 | for k, v in param_list[-5:]: # Print last 5 151 | print(f"{k}: {v.dtype}") 152 | print("---------------------------------------------") 153 | # -------------------------------------------------- 154 | 155 | dtype = getattr(mx, dtype) 156 | print(f"[INFO] Casting weights to target dtype: {dtype}") 157 | weights = {k: v.astype(dtype) for k, v in weights.items()} 158 | 159 | print("[INFO] Saving weights") 160 | save_weights(mlx_path, weights) 161 | 162 | # Save tokenizer 163 | shutil.copyfile( 164 | str(model_path / "tokenizer.json"), str(mlx_path / "tokenizer.json") 165 | ) 166 | if (model_path / "vocab.json").is_file(): # Qwen specific? 167 | shutil.copyfile( 168 | str(model_path / "vocab.json"), str(mlx_path / "vocab.json") 169 | ) 170 | if (model_path / "merges.txt").is_file(): # Qwen specific? 171 | shutil.copyfile( 172 | str(model_path / "merges.txt"), str(mlx_path / "merges.txt") 173 | ) 174 | if (model_path / "tokenizer_config.json").is_file(): 175 | shutil.copyfile( 176 | str(model_path / "tokenizer_config.json"), 177 | str(mlx_path / "tokenizer_config.json"), 178 | ) 179 | 180 | # Save config -- make sure dtype is updated if changed 181 | config["mlx_lm"] = {"dtype": str(dtype).split(".")[-1]} 182 | save_config(config, config_path=mlx_path / "config.json") 183 | 184 | if upload_repo: 185 | upload_to_hub(mlx_path, upload_repo, hf_path) 186 | 187 | print(f"[INFO] Conversion complete. Model saved to: {mlx_path}") 188 | 189 | 190 | if __name__ == "__main__": 191 | parser = argparse.ArgumentParser( 192 | description="Convert Qwen3 HF weights to MLX format using custom sanitize." 193 | ) 194 | parser.add_argument( 195 | "--hf-path", 196 | type=str, 197 | required=True, 198 | help="Path to the Hugging Face model directory or repo ID (e.g., Qwen/Qwen3-14B).", 199 | ) 200 | parser.add_argument( 201 | "--mlx-path", 202 | type=str, 203 | required=True, 204 | help="Path to save the MLX model (e.g., mlx_models/Qwen3-14B-mlx). This path should NOT exist.", 205 | ) 206 | parser.add_argument( 207 | "--dtype", 208 | help="Type to save the parameters, ignored for quantized models. Options: float16, bfloat16, float32.", 209 | type=str, 210 | default="float16", # Keep original precision if possible, but HF FP8 might load as float16/32 initially 211 | ) 212 | parser.add_argument( 213 | "--upload-repo", 214 | help="The Hugging Face repo to upload the model to.", 215 | type=str, 216 | default=None, 217 | ) 218 | 219 | args = parser.parse_args() 220 | 221 | # Basic check for mlx_path existence (convert expects it not to exist) 222 | if Path(args.mlx_path).exists(): 223 | print(f"Error: Output path {args.mlx_path} already exists. Please remove it first.") 224 | sys.exit(1) 225 | 226 | custom_convert(**vars(args)) -------------------------------------------------------------------------------- /src/data_processing/prepare_training_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from pathlib import Path 4 | import random 5 | 6 | # Define the prompt template directly 7 | PROMPT_TEMPLATE = '''This is an excerpt from a novel. Write the next excerpt of similar length. Use the same style as the excerpt. Make sure that while stylistically similar, the new section moves the story forward and/or develops the characters and/or adds new information or in some way continues on meaningfully from the previous section. 8 | 9 | EXCERPT: 10 | {}''' 11 | 12 | def save_to_jsonl(data: list, file_path: Path): 13 | """Saves a list of strings to a JSONL file, with each string under the 'text' key.""" 14 | with file_path.open("w", encoding="utf-8") as f: 15 | for item in data: 16 | # Ensure the output is a valid JSON line 17 | f.write(json.dumps({"text": item}) + '\n') 18 | 19 | def load_chunks_from_file(file_path: Path) -> list: 20 | """Loads chunks from a single JSON file.""" 21 | if not file_path.is_file(): 22 | print(f"Warning: Input file not found at {file_path}, skipping.") 23 | return [] 24 | try: 25 | with file_path.open("r", encoding="utf-8") as f: 26 | data = json.load(f) 27 | if "chunks" not in data or not isinstance(data["chunks"], list): 28 | print(f"Warning: Input file {file_path} does not contain a list under the 'chunks' key, skipping.") 29 | return [] 30 | return data["chunks"] 31 | except json.JSONDecodeError: 32 | print(f"Warning: Could not decode JSON from {file_path}, skipping.") 33 | return [] 34 | except Exception as e: 35 | print(f"Warning: An error occurred while reading {file_path}: {e}, skipping.") 36 | return [] 37 | 38 | 39 | def create_pairs_from_chunks(chunks: list) -> list: 40 | """Creates prompt/completion pairs from a list of chunks.""" 41 | paired_texts = [] 42 | if len(chunks) < 2: 43 | return [] # Not enough chunks to form pairs 44 | # Iterate up to the second-to-last chunk to form pairs 45 | for i in range(len(chunks) - 1): 46 | prompt_part = PROMPT_TEMPLATE.format(chunks[i]) 47 | completion_part = chunks[i+1] 48 | combined_text = f"{prompt_part}\nNEXT EXCERPT:\n{completion_part}" 49 | paired_texts.append(combined_text) 50 | return paired_texts 51 | 52 | if __name__ == "__main__": 53 | DEFAULT_INPUT_FILES = ["semantic_chunks.json"] # Default is now a list 54 | DEFAULT_OUTPUT_DIR = "." 55 | DEFAULT_TRAIN_RATIO = 0.85 56 | DEFAULT_SEED = 42 57 | 58 | parser = argparse.ArgumentParser( 59 | description="Prepare training and validation data from semantic chunks in one or more input files." 60 | ) 61 | parser.add_argument( 62 | "--input_files", 63 | type=str, 64 | nargs='+', # Accept one or more arguments 65 | default=DEFAULT_INPUT_FILES, 66 | help=f"Path(s) to the input JSON file(s) containing semantic chunks (default: {DEFAULT_INPUT_FILES[0]})." 67 | ) 68 | parser.add_argument( 69 | "--output_dir", 70 | type=str, 71 | default=DEFAULT_OUTPUT_DIR, 72 | help=f"Directory to save the train.jsonl and valid.jsonl files (default: {DEFAULT_OUTPUT_DIR})." 73 | ) 74 | parser.add_argument( 75 | "--train_ratio", 76 | type=float, 77 | default=DEFAULT_TRAIN_RATIO, 78 | help=f"Proportion of data to use for the training set (default: {DEFAULT_TRAIN_RATIO})." 79 | ) 80 | parser.add_argument( 81 | "--seed", 82 | type=int, 83 | default=DEFAULT_SEED, 84 | help=f"Random seed for shuffling data (default: {DEFAULT_SEED})." 85 | ) 86 | args = parser.parse_args() 87 | 88 | output_dir = Path(args.output_dir) 89 | output_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists 90 | 91 | all_chunks_data = [] 92 | max_chunks_len = 0 93 | file_with_max_chunks = None 94 | 95 | print("Loading chunks from input files...") 96 | for file_path_str in args.input_files: 97 | file_path = Path(file_path_str) 98 | print(f" Processing: {file_path}") 99 | chunks = load_chunks_from_file(file_path) 100 | if chunks: 101 | all_chunks_data.append({"path": file_path, "chunks": chunks}) 102 | if len(chunks) > max_chunks_len: 103 | max_chunks_len = len(chunks) 104 | file_with_max_chunks = file_path 105 | else: 106 | print(f" No valid chunks loaded from {file_path}.") 107 | 108 | 109 | if not all_chunks_data: 110 | print("Error: No valid chunks loaded from any input file.") 111 | exit(1) 112 | 113 | if max_chunks_len < 2: 114 | print("Error: The file with the most chunks has less than 2 chunks. Cannot create pairs.") 115 | exit(1) 116 | 117 | print(f"\nFile determining shuffle/split order (max {max_chunks_len} chunks): {file_with_max_chunks}") 118 | max_pairs_len = max_chunks_len - 1 # Number of pairs is one less than chunks 119 | 120 | # Create shuffled indices based on the file with the maximum number of pairs 121 | print(f"Creating shuffle order based on {max_pairs_len} potential pairs...") 122 | indices = list(range(max_pairs_len)) 123 | random.seed(args.seed) 124 | random.shuffle(indices) 125 | 126 | # Determine split point based on the max number of pairs 127 | split_index = int(max_pairs_len * args.train_ratio) 128 | train_indices_full = set(indices[:split_index]) 129 | valid_indices_full = set(indices[split_index:]) 130 | 131 | print(f"Determined split: {len(train_indices_full)} train indices, {len(valid_indices_full)} valid indices (based on max pairs).") 132 | 133 | combined_train_data = [] 134 | combined_valid_data = [] 135 | 136 | print("\nProcessing pairs and splitting for each file...") 137 | for file_data in all_chunks_data: 138 | file_path = file_data["path"] 139 | chunks = file_data["chunks"] 140 | print(f" Processing pairs from: {file_path} ({len(chunks)} chunks)") 141 | 142 | paired_texts = create_pairs_from_chunks(chunks) 143 | num_pairs_in_file = len(paired_texts) 144 | 145 | if num_pairs_in_file == 0: 146 | print(f" No pairs created for {file_path}, skipping.") 147 | continue 148 | 149 | print(f" Created {num_pairs_in_file} pairs.") 150 | 151 | file_train_data = [] 152 | file_valid_data = [] 153 | 154 | # Use the pre-calculated shuffled indices, but only up to the number of pairs available in *this* file 155 | for i in range(num_pairs_in_file): 156 | original_index = i # This corresponds to the index in paired_texts for this file 157 | # Check if this index falls into the train or valid set based on the *overall* shuffled indices 158 | # We map the position `i` in the *current* file's pairs back to the global shuffled index list `indices`. 159 | # Find where `i` appears in the shuffled list `indices`. 160 | # However, a simpler approach is to iterate through the global train/valid indices and check if they are valid for the current file 161 | pass # Refactoring logic below 162 | 163 | current_train_count = 0 164 | current_valid_count = 0 165 | # Iterate through the globally determined shuffled indices 166 | for idx_in_shuffled_list, original_pair_index in enumerate(indices): 167 | # Check if this original_pair_index is valid for the current file's pair list 168 | if original_pair_index < num_pairs_in_file: 169 | pair = paired_texts[original_pair_index] 170 | # Determine if this index belongs to the train or validation set based on the split point of the *shuffled* list 171 | if idx_in_shuffled_list < split_index: # Check position in the shuffled list 172 | file_train_data.append(pair) 173 | current_train_count += 1 174 | else: 175 | file_valid_data.append(pair) 176 | current_valid_count += 1 177 | 178 | 179 | print(f" Added {current_train_count} pairs to train set, {current_valid_count} pairs to valid set.") 180 | combined_train_data.extend(file_train_data) 181 | combined_valid_data.extend(file_valid_data) 182 | 183 | 184 | print(f"\nTotal training samples: {len(combined_train_data)}") 185 | print(f"Total validation samples: {len(combined_valid_data)}") 186 | 187 | # Shuffle the combined sets again for good measure? Optional, but can help ensure randomness if order matters downstream. 188 | # random.shuffle(combined_train_data) 189 | # random.shuffle(combined_valid_data) 190 | 191 | train_output_path = output_dir / "train.jsonl" 192 | valid_output_path = output_dir / "valid.jsonl" 193 | 194 | print(f"\nSaving combined training data to: {train_output_path}") 195 | save_to_jsonl(combined_train_data, train_output_path) 196 | 197 | print(f"Saving combined validation data to: {valid_output_path}") 198 | save_to_jsonl(combined_valid_data, valid_output_path) 199 | 200 | print("\nDone.") 201 | 202 | """ 203 | Example Usage (Multiple Files): 204 | python src/data_processing/prepare_training_data.py \ 205 | --input_files semantic_chunks_part1.json semantic_chunks_part2.json \ 206 | --output_dir DATA/processed_novel_combined \ 207 | --train_ratio 0.9 \ 208 | --seed 123 209 | """ -------------------------------------------------------------------------------- /src/inference/generate_qwen3.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Usage: 4 | 5 | # WITH ADAPTER 6 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 7 | --model-path mlx_models/Qwen3-4B-mlx \ 8 | --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \ 9 | --prompt "-" \ 10 | --repetition-penalty 1.1 \ 11 | --temp 0.75 \ 12 | --top-p 0.95 13 | 14 | # WITHOUT ADAPTER 15 | cat temp_prompt.txt | python src/inference/generate_qwen3.py \ 16 | --model-path mlx_models/Qwen3-4B-mlx \ 17 | --prompt "-" \ 18 | --repetition-penalty 1.1 \ 19 | --temp 0.75 \ 20 | """ 21 | 22 | import argparse 23 | import sys 24 | import time 25 | from pathlib import Path 26 | 27 | import mlx.core as mx 28 | 29 | # Add mlx-lm to the Python path 30 | # Assumes the script is run from the project root or src/inference 31 | project_root = Path(__file__).resolve().parents[2] 32 | mlx_lm_path = project_root / "mlx-lm" 33 | 34 | if str(mlx_lm_path) not in sys.path: 35 | sys.path.insert(0, str(mlx_lm_path)) 36 | 37 | try: 38 | from mlx_lm.utils import load 39 | from mlx_lm import generate 40 | from mlx_lm.sample_utils import make_sampler, make_logits_processors 41 | except ImportError as e: 42 | print("Error importing mlx_lm. Please ensure mlx-lm is installed and discoverable.") 43 | print(f"Attempted path: {mlx_lm_path}") 44 | print(f"PYTHONPATH: {sys.path}") 45 | print(f"ImportError: {e}") 46 | sys.exit(1) 47 | 48 | 49 | def main(args): 50 | mx.random.seed(args.seed) 51 | 52 | model_path = Path(args.model_path) 53 | if not model_path.exists(): 54 | print(f"Error: Model path does not exist: {model_path}") 55 | print("Make sure you have downloaded and converted the model using a script like") 56 | print("src/finetuning/download_qwen3.py") 57 | sys.exit(1) 58 | 59 | adapter_path = args.adapter_path 60 | if adapter_path: 61 | adapter_path = Path(adapter_path) 62 | if not adapter_path.exists(): 63 | print(f"Error: Adapter path does not exist: {adapter_path}") 64 | sys.exit(1) 65 | print(f"Loading model from {model_path} with adapter from {adapter_path}...") 66 | else: 67 | print(f"Loading model from {model_path}...") 68 | 69 | try: 70 | model, tokenizer = load(args.model_path, adapter_path=str(adapter_path) if adapter_path else None) 71 | except Exception as e: 72 | print(f"Error loading the model or adapter: {e}") 73 | sys.exit(1) 74 | 75 | print("Model loaded.") 76 | 77 | # Handle prompt input (stdin or argument) 78 | if args.prompt == "-": 79 | print("Reading prompt from stdin...") 80 | try: 81 | prompt_input = sys.stdin.read() 82 | except EOFError: 83 | print("Error: Reached end of input while reading from stdin.") 84 | sys.exit(1) 85 | if not prompt_input: 86 | print("Error: Received empty prompt from stdin.") 87 | sys.exit(1) 88 | else: 89 | # Replace escaped newlines/tabs if coming from command line 90 | prompt_input = args.prompt.replace("\\n", "\n").replace("\\t", "\t") 91 | 92 | # Prepare the prompt (apply chat template if requested and available) 93 | if args.use_chat_template and hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None: 94 | print("Applying chat template...") 95 | messages = [] 96 | if args.system_prompt: 97 | messages.append({"role": "system", "content": args.system_prompt}) 98 | messages.append({"role": "user", "content": prompt_input}) 99 | 100 | try: 101 | prompt_str = tokenizer.apply_chat_template( 102 | messages, 103 | tokenize=False, 104 | add_generation_prompt=True 105 | ) 106 | # Encode the templated prompt, template likely includes BOS/EOS handling 107 | encoded_prompt = tokenizer.encode(prompt_str, add_special_tokens=False) 108 | print(f"Using templated prompt: {prompt_str[:200]}...") # Print start of templated prompt 109 | except Exception as e: 110 | print(f"Error applying chat template: {e}") 111 | print("Falling back to raw prompt encoding.") 112 | # Fallback if template application fails 113 | encoded_prompt = tokenizer.encode(prompt_input, add_special_tokens=True) # Use original input 114 | 115 | else: 116 | if not args.use_chat_template: 117 | print("Chat template disabled by user.") 118 | elif not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None: 119 | print("No chat template found in tokenizer.") 120 | print("Encoding raw prompt...") 121 | # Encode raw prompt, assuming default BOS/EOS handling is desired 122 | encoded_prompt = tokenizer.encode(prompt_input, add_special_tokens=True) # Use original input 123 | 124 | encoded_prompt = mx.array(encoded_prompt) 125 | 126 | print("Generating response...") 127 | start_time = time.time() 128 | 129 | # Create the sampler instance 130 | sampler = make_sampler(temp=args.temp, top_p=args.top_p) 131 | 132 | # Create the logits processors list 133 | logits_processors = make_logits_processors( 134 | repetition_penalty=args.repetition_penalty, 135 | repetition_context_size=args.repetition_context_size 136 | ) 137 | 138 | # Pass sampler and processors instead of individual args 139 | response = generate( 140 | model, 141 | tokenizer, 142 | prompt=encoded_prompt, # Use the processed prompt 143 | max_tokens=args.max_tokens, 144 | sampler=sampler, 145 | logits_processors=logits_processors, 146 | verbose=args.verbose 147 | ) 148 | 149 | # If verbose=False, generate returns the full string, otherwise it prints token by token and returns None 150 | if not args.verbose: 151 | print(response) 152 | 153 | end_time = time.time() 154 | elapsed_time = end_time - start_time 155 | 156 | print(f"\n\nGeneration complete in {elapsed_time:.2f} seconds.") 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = argparse.ArgumentParser(description="Run inference with a converted MLX model, optionally using LoRA adapters.") 161 | 162 | # --- Model and Prompt Arguments --- 163 | parser.add_argument( 164 | "--model-path", 165 | type=str, 166 | required=True, 167 | help="Path to the directory containing the converted MLX model files (e.g., <project_root>/mlx_models/Qwen3-14B-mlx)." 168 | ) 169 | parser.add_argument( 170 | "--adapter-path", 171 | type=str, 172 | default=None, 173 | help="Optional path to the directory containing the trained LoRA adapter weights (`adapters.safetensors`) and config (`adapter_config.json`)." 174 | ) 175 | parser.add_argument( 176 | "--prompt", 177 | type=str, 178 | required=True, 179 | help="The input prompt for the model, or '-' to read from stdin." 180 | ) 181 | parser.add_argument( 182 | "--use-chat-template", 183 | action="store_true", 184 | default=True, 185 | help="Use the tokenizer's chat template if available. Set --no-use-chat-template to disable." 186 | ) 187 | parser.add_argument( 188 | "--no-use-chat-template", 189 | action="store_false", 190 | dest="use_chat_template", 191 | help="Disable the use of the tokenizer's chat template." 192 | ) 193 | parser.add_argument( 194 | "--system-prompt", 195 | type=str, 196 | default=None, 197 | help="Optional system prompt to prepend when using the chat template." 198 | ) 199 | 200 | # --- Generation Control Arguments --- 201 | parser.add_argument( 202 | "--max-tokens", 203 | type=int, 204 | default=3200, #32768, 205 | help="Maximum number of tokens to generate. [Default: 3200]" 206 | ) 207 | parser.add_argument( 208 | "--temp", 209 | type=float, 210 | default=0.6, 211 | help=( 212 | "Sampling temperature. Controls randomness. Lower values (e.g., 0.1) make the output " 213 | "more deterministic, higher values (e.g., 1.0) make it more random. [Default: 0.6]" 214 | ) 215 | ) 216 | parser.add_argument( 217 | "--top-p", 218 | type=float, 219 | default=1.0, 220 | help=( 221 | "Top-p (nucleus) sampling probability. Selects tokens from the smallest set whose cumulative " 222 | "probability exceeds top_p. A value of 1.0 considers all tokens. Lower values (e.g., 0.9) " 223 | "restrict sampling to more likely tokens. [Default: 1.0]" 224 | ) 225 | ) 226 | parser.add_argument( 227 | "--repetition-penalty", 228 | type=float, 229 | default=None, 230 | help=( 231 | "Penalty applied to repeated tokens. Values > 1.0 discourage repetition. " 232 | "A value of 1.0 means no penalty. [Default: None (no penalty)]" 233 | ) 234 | ) 235 | parser.add_argument( 236 | "--repetition-context-size", 237 | type=int, 238 | default=20, 239 | help=( 240 | "The number of previous tokens to consider for the repetition penalty. " 241 | "[Default: 20]" 242 | ) 243 | ) 244 | parser.add_argument( 245 | "--seed", 246 | type=int, 247 | default=0, 248 | help="Seed for the random number generator. [Default: 0]" 249 | ) 250 | parser.add_argument( 251 | "--verbose", 252 | action="store_true", 253 | help="Stream the generated text token by token instead of printing the full output at the end." 254 | ) 255 | 256 | args = parser.parse_args() 257 | main(args) 258 | 259 | # --- Example Test Cases --- 260 | # (Replace paths as needed) 261 | # 262 | # 1. Basic generation (no adapter): 263 | # python src/inference/generate_qwen3.py \ 264 | # --model-path mlx_models/Qwen3-4B-mlx \ 265 | # --prompt "Tell me a short story about a brave knight." 266 | # 267 | # 2. Generation with an adapter: 268 | # python src/inference/generate_qwen3.py \ 269 | # --model-path mlx_models/Qwen3-4B-mlx \ 270 | # --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \ 271 | # --prompt "Write a paragraph in the style of the finetuning data." 272 | # 273 | # 2b. Generation with adapter, reading long prompt from stdin: 274 | # cat prompt.txt | python src/inference/generate_qwen3.py \\ 275 | # --model-path mlx_models/Qwen3-4B-mlx \\ 276 | # --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \\ 277 | # --prompt "-" 278 | # (Where prompt.txt contains your long prompt) 279 | # 280 | # 2c. Generation with adapter, using a 'here document' for prompt: 281 | # python src/inference/generate_qwen3.py \\ 282 | # --model-path mlx_models/Qwen3-4B-mlx \\ 283 | # --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \\ 284 | # --prompt "-" <<EOF 285 | # This is a very long prompt 286 | # that spans multiple lines. 287 | # The shell will pass this text to the script's stdin. 288 | # EOF 289 | # 290 | # 3. More creative output (higher temperature): 291 | # python src/inference/generate_qwen3.py \ 292 | # --model-path mlx_models/Qwen3-4B-mlx \ 293 | # --prompt "What if the moon was made of cheese?" \ 294 | # --temp 1.0 \ 295 | # --max-tokens 150 296 | # 297 | # 4. More focused output (lower temperature, top-p): 298 | # python src/inference/generate_qwen3.py \ 299 | # --model-path mlx_models/Qwen3-4B-mlx \ 300 | # --prompt "Explain the concept of photosynthesis in simple terms." \ 301 | # --temp 0.3 \ 302 | # --top-p 0.9 \ 303 | # --max-tokens 200 304 | # 305 | # 5. Discourage repetition: 306 | # python src/inference/generate_qwen3.py \ 307 | # --model-path mlx_models/Qwen3-4B-mlx \ 308 | # --prompt "List the planets in our solar system, starting from the sun." \ 309 | # --repetition-penalty 1.2 \ 310 | # --max-tokens 50 311 | # 312 | # 6. Stream output token by token (with adapter): 313 | # python src/inference/generate_qwen3.py \ 314 | # --model-path mlx_models/Qwen3-4B-mlx \ 315 | # --adapter-path ADAPTERS/qwen3_4b_lora_sacredhunger \ 316 | # --prompt "Write a haiku about a cat." \ 317 | # --verbose -------------------------------------------------------------------------------- /src/data_processing/convert_hf_dataset_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | convert_hf_dataset_format.py 3 | 4 | Converts Hugging Face datasets from the "conversations" format to the standard "messages" format. 5 | This script handles the conversion between different chat dataset formats, following Hugging Face 6 | best practices for dataset transformation. 7 | 8 | Source format: 9 | - "conversations" field with "from" and "value" keys 10 | - Roles: "human", "gpt", "system" 11 | 12 | Target format: 13 | - "messages" field with "role" and "content" keys 14 | - Roles: "user", "assistant", "system" 15 | - Additional fields: "id", "source" 16 | """ 17 | 18 | import json 19 | import argparse 20 | from pathlib import Path 21 | from typing import Dict, List, Any, Optional, Union 22 | from datasets import Dataset, load_dataset, DatasetDict, IterableDataset, IterableDatasetDict 23 | import logging 24 | 25 | # Configure logging 26 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 27 | logger = logging.getLogger(__name__) 28 | 29 | # Role mapping from source to target format 30 | ROLE_MAPPING = { 31 | "human": "user", 32 | "gpt": "assistant", 33 | "system": "system" 34 | } 35 | 36 | def convert_conversation_format(example: Dict[str, Any], sample_number: int, dataset_name: str, source_name: str) -> Dict[str, Any]: 37 | """ 38 | Converts a single example from conversations format to messages format. 39 | 40 | Args: 41 | example: Dictionary containing the conversation data 42 | sample_number: The row number for generating the ID 43 | dataset_name: Name to use for ID generation 44 | source_name: Source name for the dataset 45 | 46 | Returns: 47 | Dictionary in the target format with messages, id, and source fields 48 | """ 49 | if "conversations" not in example: 50 | logger.warning(f"Sample {sample_number}: No 'conversations' field found") 51 | return None 52 | 53 | conversations = example["conversations"] 54 | if not isinstance(conversations, list): 55 | logger.warning(f"Sample {sample_number}: 'conversations' is not a list") 56 | return None 57 | 58 | messages = [] 59 | for turn in conversations: 60 | if not isinstance(turn, dict) or "from" not in turn or "value" not in turn: 61 | logger.warning(f"Sample {sample_number}: Invalid conversation turn format") 62 | continue 63 | 64 | role = turn["from"] 65 | content = turn["value"] 66 | 67 | # Skip system messages as specified in requirements 68 | if role == "system": 69 | continue 70 | 71 | # Map roles according to the specification 72 | if role in ROLE_MAPPING: 73 | mapped_role = ROLE_MAPPING[role] 74 | messages.append({ 75 | "role": mapped_role, 76 | "content": content 77 | }) 78 | else: 79 | logger.warning(f"Sample {sample_number}: Unknown role '{role}', skipping turn") 80 | 81 | # Only return valid examples with at least one message 82 | if not messages: 83 | logger.warning(f"Sample {sample_number}: No valid messages after conversion") 84 | return None 85 | 86 | return { 87 | "messages": messages, 88 | "id": f"{dataset_name}_{sample_number}", 89 | "source": source_name 90 | } 91 | 92 | def convert_dataset_batch(batch: Dict[str, List[Any]], start_idx: int, dataset_name: str, source_name: str) -> Dict[str, List[Any]]: 93 | """ 94 | Converts a batch of examples for efficient processing. 95 | 96 | Args: 97 | batch: Dictionary containing batched data 98 | start_idx: Starting index for ID generation 99 | dataset_name: Name to use for ID generation 100 | source_name: Source name for the dataset 101 | 102 | Returns: 103 | Dictionary containing converted batch data 104 | """ 105 | converted_messages = [] 106 | converted_ids = [] 107 | converted_sources = [] 108 | 109 | batch_size = len(batch["conversations"]) 110 | 111 | for i in range(batch_size): 112 | sample_number = start_idx + i 113 | example = {key: batch[key][i] for key in batch.keys()} 114 | 115 | converted = convert_conversation_format(example, sample_number, dataset_name, source_name) 116 | 117 | if converted is not None: 118 | converted_messages.append(converted["messages"]) 119 | converted_ids.append(converted["id"]) 120 | converted_sources.append(converted["source"]) 121 | 122 | return { 123 | "messages": converted_messages, 124 | "id": converted_ids, 125 | "source": converted_sources 126 | } 127 | 128 | def convert_single_example(example: Dict[str, Any], idx: int, dataset_name: str, source_name: str) -> Dict[str, Any]: 129 | """ 130 | Converts a single example (for streaming datasets). 131 | 132 | Args: 133 | example: Dictionary containing the conversation data 134 | idx: The sample index for generating the ID 135 | dataset_name: Name to use for ID generation 136 | source_name: Source name for the dataset 137 | 138 | Returns: 139 | Dictionary in the target format with messages, id, and source fields 140 | """ 141 | return convert_conversation_format(example, idx, dataset_name, source_name) 142 | 143 | def extract_dataset_names(dataset_path: str) -> tuple[str, str]: 144 | """ 145 | Extracts dataset name and source name from the dataset path. 146 | 147 | Args: 148 | dataset_path: Path to dataset (local path or HF Hub identifier) 149 | 150 | Returns: 151 | Tuple of (dataset_name, source_name) 152 | """ 153 | # Convert to Path object to handle both local and remote paths 154 | path = Path(dataset_path) 155 | 156 | # Check if it looks like a HuggingFace Hub identifier (contains '/') 157 | if '/' in dataset_path and not dataset_path.startswith('./') and not dataset_path.startswith('/'): 158 | # HuggingFace Hub format: "org/repo" or "org/repo/subfolder" 159 | parts = dataset_path.split('/') 160 | if len(parts) >= 2: 161 | # Use the full path as source, just the repo name for ID 162 | source_name = dataset_path 163 | dataset_name = parts[1] # Use just the repo name part 164 | else: 165 | # Fallback 166 | source_name = dataset_path 167 | dataset_name = dataset_path 168 | else: 169 | # Local path - use the filename without extension as both 170 | stem = path.stem if path.suffix else path.name 171 | source_name = stem 172 | dataset_name = stem 173 | 174 | return dataset_name, source_name 175 | 176 | def load_and_convert_dataset( 177 | dataset_path: str, 178 | split: Optional[str] = None, 179 | streaming: bool = False, 180 | max_samples: Optional[int] = None 181 | ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: 182 | """ 183 | Loads and converts a dataset from conversations format to messages format. 184 | 185 | Args: 186 | dataset_path: Path to the dataset (local or HF Hub) 187 | split: Which split to load (if None, loads all splits) 188 | streaming: Whether to use streaming mode 189 | max_samples: Maximum number of samples to process (for testing) 190 | 191 | Returns: 192 | Converted Dataset object 193 | """ 194 | logger.info(f"Loading dataset from: {dataset_path}") 195 | 196 | # Extract dataset names for dynamic ID and source generation 197 | dataset_name, source_name = extract_dataset_names(dataset_path) 198 | logger.info(f"Using dataset_name='{dataset_name}', source_name='{source_name}'") 199 | 200 | try: 201 | # Load the dataset 202 | if split: 203 | dataset = load_dataset(dataset_path, split=split, streaming=streaming) 204 | else: 205 | dataset = load_dataset(dataset_path, streaming=streaming) 206 | 207 | if streaming: 208 | # Handle streaming datasets 209 | if isinstance(dataset, IterableDatasetDict): 210 | logger.info(f"Found streaming splits: {list(dataset.keys())}") 211 | converted_datasets = {} 212 | 213 | for split_name, split_dataset in dataset.items(): 214 | logger.info(f"Converting streaming split: {split_name}") 215 | converted_datasets[split_name] = _convert_streaming_dataset( 216 | split_dataset, max_samples, dataset_name, source_name 217 | ) 218 | 219 | return IterableDatasetDict(converted_datasets) 220 | else: 221 | # Single streaming dataset 222 | return _convert_streaming_dataset(dataset, max_samples, dataset_name, source_name) 223 | else: 224 | # Handle regular datasets 225 | if isinstance(dataset, DatasetDict): 226 | logger.info(f"Found splits: {list(dataset.keys())}") 227 | converted_datasets = {} 228 | 229 | for split_name, split_dataset in dataset.items(): 230 | logger.info(f"Converting split: {split_name}") 231 | converted_datasets[split_name] = _convert_regular_dataset( 232 | split_dataset, max_samples, dataset_name, source_name 233 | ) 234 | 235 | return DatasetDict(converted_datasets) 236 | else: 237 | # Single dataset 238 | return _convert_regular_dataset(dataset, max_samples, dataset_name, source_name) 239 | 240 | except Exception as e: 241 | logger.error(f"Error loading dataset: {e}") 242 | raise 243 | 244 | def _convert_regular_dataset(dataset: Dataset, max_samples: Optional[int] = None, dataset_name: str = "", source_name: str = "") -> Dataset: 245 | """Helper function to convert a regular (non-streaming) dataset.""" 246 | 247 | # Limit samples if specified (useful for testing) 248 | if max_samples: 249 | logger.info(f"Limiting to {max_samples} samples for testing") 250 | dataset = dataset.select(range(min(max_samples, len(dataset)))) 251 | 252 | logger.info(f"Converting {len(dataset)} samples...") 253 | 254 | # Convert using map with batching for efficiency 255 | converted_dataset = dataset.map( 256 | lambda batch, idx: convert_dataset_batch(batch, idx, dataset_name, source_name), 257 | batched=True, 258 | batch_size=1000, # Process in batches of 1000 259 | with_indices=True, 260 | remove_columns=dataset.column_names, # Remove original columns 261 | desc="Converting format" 262 | ) 263 | 264 | # Filter out None values (invalid conversions) 265 | original_size = len(converted_dataset) 266 | converted_dataset = converted_dataset.filter( 267 | lambda x: x["messages"] is not None and len(x["messages"]) > 0 268 | ) 269 | final_size = len(converted_dataset) 270 | 271 | logger.info(f"Conversion complete: {final_size}/{original_size} samples retained") 272 | 273 | return converted_dataset 274 | 275 | def _convert_streaming_dataset(dataset: IterableDataset, max_samples: Optional[int] = None, dataset_name: str = "", source_name: str = "") -> IterableDataset: 276 | """Helper function to convert a streaming dataset.""" 277 | 278 | logger.info("Converting streaming dataset...") 279 | 280 | # For streaming datasets, we need to use a different approach 281 | def convert_example_with_index(example, idx): 282 | converted = convert_single_example(example, idx, dataset_name, source_name) 283 | return converted 284 | 285 | # Convert using map (no batching for streaming datasets) 286 | converted_dataset = dataset.map( 287 | convert_example_with_index, 288 | with_indices=True, 289 | remove_columns=dataset.column_names if hasattr(dataset, 'column_names') else None 290 | ) 291 | 292 | # Filter out None values 293 | converted_dataset = converted_dataset.filter( 294 | lambda x: x is not None and "messages" in x and x["messages"] is not None and len(x["messages"]) > 0 295 | ) 296 | 297 | # Limit samples if specified 298 | if max_samples: 299 | logger.info(f"Limiting to {max_samples} samples for testing") 300 | converted_dataset = converted_dataset.take(max_samples) 301 | 302 | logger.info("Streaming dataset conversion setup complete") 303 | 304 | return converted_dataset 305 | 306 | def save_dataset( 307 | dataset: Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict], 308 | output_path: str, 309 | save_format: str = "json" 310 | ) -> None: 311 | """ 312 | Saves the converted dataset to the specified format. 313 | 314 | Args: 315 | dataset: The converted dataset 316 | output_path: Path where to save the dataset 317 | save_format: Format to save in ('json', 'jsonl', 'parquet', 'hf_hub') 318 | """ 319 | output_path = Path(output_path) 320 | 321 | if save_format == "hf_hub": 322 | # Push to Hugging Face Hub 323 | logger.info(f"Pushing dataset to Hub: {output_path}") 324 | 325 | # For streaming datasets, we need to collect the data first 326 | if isinstance(dataset, (IterableDataset, IterableDatasetDict)): 327 | logger.info("Converting streaming dataset to regular dataset for upload...") 328 | if isinstance(dataset, IterableDatasetDict): 329 | regular_datasets = {} 330 | for split_name, streaming_split in dataset.items(): 331 | logger.info(f"Collecting data for split: {split_name}") 332 | data = list(streaming_split) 333 | logger.info(f"Collected {len(data)} samples for split: {split_name}") 334 | regular_datasets[split_name] = Dataset.from_list(data) 335 | dataset = DatasetDict(regular_datasets) 336 | else: 337 | logger.info("Collecting streaming data...") 338 | data = list(dataset) 339 | logger.info(f"Collected {len(data)} samples") 340 | dataset = Dataset.from_list(data) 341 | 342 | dataset.push_to_hub(str(output_path)) 343 | else: 344 | # Save locally 345 | output_path.parent.mkdir(parents=True, exist_ok=True) 346 | 347 | # For streaming datasets, collect data first 348 | if isinstance(dataset, (IterableDataset, IterableDatasetDict)): 349 | logger.info("Converting streaming dataset for local save...") 350 | if isinstance(dataset, IterableDatasetDict): 351 | for split_name, streaming_split in dataset.items(): 352 | split_output = output_path.parent / f"{output_path.name}_{split_name}" 353 | data = list(streaming_split) 354 | split_dataset = Dataset.from_list(data) 355 | _save_single_dataset(split_dataset, split_output, save_format) 356 | return 357 | else: 358 | data = list(dataset) 359 | dataset = Dataset.from_list(data) 360 | 361 | if isinstance(dataset, DatasetDict): 362 | for split_name, split_dataset in dataset.items(): 363 | split_output = output_path.parent / f"{output_path.name}_{split_name}" 364 | _save_single_dataset(split_dataset, split_output, save_format) 365 | else: 366 | _save_single_dataset(dataset, output_path, save_format) 367 | 368 | def _save_single_dataset(dataset: Dataset, output_path: Path, save_format: str): 369 | """Helper function to save a single dataset.""" 370 | if save_format == "json": 371 | output_file = output_path.with_suffix('.json') 372 | logger.info(f"Saving dataset to: {output_file}") 373 | dataset.to_json(str(output_file), orient="records", force_ascii=False) 374 | 375 | elif save_format == "jsonl": 376 | output_file = output_path.with_suffix('.jsonl') 377 | logger.info(f"Saving dataset to: {output_file}") 378 | dataset.to_json(str(output_file), orient="records", lines=True, force_ascii=False) 379 | 380 | elif save_format == "parquet": 381 | output_file = output_path.with_suffix('.parquet') 382 | logger.info(f"Saving dataset to: {output_file}") 383 | dataset.to_parquet(str(output_file)) 384 | 385 | else: 386 | raise ValueError(f"Unsupported save format: {save_format}") 387 | 388 | def validate_converted_sample(sample: Dict[str, Any]) -> bool: 389 | """ 390 | Validates that a converted sample has the correct format. 391 | 392 | Args: 393 | sample: The converted sample to validate 394 | 395 | Returns: 396 | True if valid, False otherwise 397 | """ 398 | required_fields = ["messages", "id", "source"] 399 | 400 | # Check required fields 401 | for field in required_fields: 402 | if field not in sample: 403 | return False 404 | 405 | # Validate messages structure 406 | messages = sample["messages"] 407 | if not isinstance(messages, list) or len(messages) == 0: 408 | return False 409 | 410 | # Validate each message 411 | for message in messages: 412 | if not isinstance(message, dict): 413 | return False 414 | if "role" not in message or "content" not in message: 415 | return False 416 | if message["role"] not in ["user", "assistant", "system"]: 417 | return False 418 | 419 | return True 420 | 421 | def main(): 422 | parser = argparse.ArgumentParser( 423 | description="Convert Hugging Face datasets from conversations to messages format" 424 | ) 425 | parser.add_argument( 426 | "dataset_path", 427 | help="Path to dataset (local path or HF Hub identifier)" 428 | ) 429 | parser.add_argument( 430 | "--output-path", 431 | required=True, 432 | help="Output path for converted dataset" 433 | ) 434 | parser.add_argument( 435 | "--split", 436 | help="Specific split to convert (default: all splits)" 437 | ) 438 | parser.add_argument( 439 | "--format", 440 | choices=["json", "jsonl", "parquet", "hf_hub"], 441 | default="jsonl", 442 | help="Output format (default: jsonl)" 443 | ) 444 | parser.add_argument( 445 | "--max-samples", 446 | type=int, 447 | help="Maximum number of samples to process (for testing)" 448 | ) 449 | parser.add_argument( 450 | "--streaming", 451 | action="store_true", 452 | help="Use streaming mode for large datasets" 453 | ) 454 | parser.add_argument( 455 | "--validate", 456 | action="store_true", 457 | help="Validate a sample of converted data" 458 | ) 459 | 460 | args = parser.parse_args() 461 | 462 | try: 463 | # Load and convert the dataset 464 | converted_dataset = load_and_convert_dataset( 465 | args.dataset_path, 466 | split=args.split, 467 | streaming=args.streaming, 468 | max_samples=args.max_samples 469 | ) 470 | 471 | # Validate if requested 472 | if args.validate: 473 | logger.info("Validating converted data...") 474 | 475 | # Get a sample for validation 476 | if isinstance(converted_dataset, (DatasetDict, IterableDatasetDict)): 477 | first_split = list(converted_dataset.keys())[0] 478 | sample_dataset = converted_dataset[first_split] 479 | else: 480 | sample_dataset = converted_dataset 481 | 482 | # Get first sample 483 | if isinstance(sample_dataset, IterableDataset): 484 | sample = next(iter(sample_dataset)) 485 | else: 486 | sample = sample_dataset[0] 487 | 488 | if validate_converted_sample(sample): 489 | logger.info("Validation passed ✓") 490 | logger.info(f"Sample: {json.dumps(sample, indent=2, ensure_ascii=False)}") 491 | else: 492 | logger.error("Validation failed ✗") 493 | return 1 494 | 495 | # Save the dataset 496 | save_dataset(converted_dataset, args.output_path, args.format) 497 | 498 | logger.info("Conversion completed successfully!") 499 | return 0 500 | 501 | except Exception as e: 502 | logger.error(f"Conversion failed: {e}") 503 | return 1 504 | 505 | if __name__ == "__main__": 506 | exit(main()) -------------------------------------------------------------------------------- /mlx-quantization/dwq_quantization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DWQ (Distilled Weight Quantization) with MLX-LM\n", 8 | "\n", 9 | "This notebook demonstrates how to use DWQ (Distilled Weight Quantization) with MLX-LM to reduce model quality loss during quantization.\n", 10 | "\n", 11 | "## What is DWQ?\n", 12 | "DWQ is designed to minimize quality loss when quantizing models to lower bit precision. It works best for 2-4 bit models and uses calibration samples to maintain model performance.\n", 13 | "\n", 14 | "## Requirements\n", 15 | "- macOS with Apple Silicon (M1/M2/M3/M4)\n", 16 | "- Python 3.9+\n", 17 | "- MLX framework\n", 18 | "- Sufficient disk space for model storage" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Step 1: Environment Setup and Dependencies" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os\n", 35 | "import sys\n", 36 | "import subprocess\n", 37 | "from pathlib import Path\n", 38 | "\n", 39 | "# Environment setup\n", 40 | "print(\"Setting up environment for DWQ quantization...\")\n", 41 | "\n", 42 | "# Create project directories\n", 43 | "project_dir = Path.cwd()\n", 44 | "models_dir = project_dir / \"models\"\n", 45 | "models_dir.mkdir(exist_ok=True)\n", 46 | "\n", 47 | "print(f\"Project directory: {project_dir}\")\n", 48 | "print(f\"Models directory: {models_dir}\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Step 2: Install MLX and Dependencies" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Install required packages\n", 65 | "print(\"Installing MLX and dependencies...\")\n", 66 | "\n", 67 | "packages = [\n", 68 | " \"mlx-lm\",\n", 69 | " \"transformers\",\n", 70 | " \"torch\", \n", 71 | " \"huggingface_hub\",\n", 72 | " \"datasets\",\n", 73 | " \"accelerate\",\n", 74 | " \"sentencepiece\",\n", 75 | " \"protobuf\"\n", 76 | "]\n", 77 | "\n", 78 | "for package in packages:\n", 79 | " try:\n", 80 | " print(f\"Installing {package}...\")\n", 81 | " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", package], \n", 82 | " check=True, capture_output=True, text=True)\n", 83 | " print(f\"✅ {package} installed successfully\")\n", 84 | " except subprocess.CalledProcessError as e:\n", 85 | " print(f\"⚠️ Warning installing {package}: {e}\")\n", 86 | "\n", 87 | "print(\"\\n📦 All packages installation completed!\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Step 3: Test MLX Imports" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Test imports\n", 104 | "print(\"Testing MLX imports...\")\n", 105 | "\n", 106 | "try:\n", 107 | " import mlx.core as mx\n", 108 | " from mlx_lm import load, generate\n", 109 | " from huggingface_hub import login, snapshot_download\n", 110 | " print(\"✅ All imports successful!\")\n", 111 | " \n", 112 | " # Test MLX functionality\n", 113 | " test_array = mx.array([1, 2, 3])\n", 114 | " print(f\"✅ MLX test array: {test_array}\")\n", 115 | " \n", 116 | "except ImportError as e:\n", 117 | " print(f\"❌ Import failed: {e}\")\n", 118 | " print(\"Please restart kernel and try again.\")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Step 4: Configuration" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# DWQ Configuration\n", 135 | "print(\"=== DWQ Configuration ===\\n\")\n", 136 | "\n", 137 | "# Model to quantize (you can change this)\n", 138 | "MODEL_NAME = \"Qwen/Qwen2.5-0.5B\" # Small model for demonstration\n", 139 | "\n", 140 | "# DWQ Parameters\n", 141 | "DWQ_CONFIG = {\n", 142 | " \"bits\": 4, # Quantization precision (2-4 bits work best)\n", 143 | " \"num_samples\": 1024, # Calibration samples (default: 1024)\n", 144 | " \"batch_size\": 8, # Batch size to reduce memory footprint\n", 145 | " \"group_size\": 64, # Group size (smaller can improve results)\n", 146 | " \"learning_rate\": 0.01, # Learning rate (adjust based on precision)\n", 147 | "}\n", 148 | "\n", 149 | "print(f\"Model: {MODEL_NAME}\")\n", 150 | "print(f\"Target bits: {DWQ_CONFIG['bits']}\")\n", 151 | "print(f\"Calibration samples: {DWQ_CONFIG['num_samples']}\")\n", 152 | "print(f\"Batch size: {DWQ_CONFIG['batch_size']}\")\n", 153 | "print(f\"Group size: {DWQ_CONFIG['group_size']}\")\n", 154 | "\n", 155 | "# Set up directories\n", 156 | "original_model_dir = models_dir / MODEL_NAME.replace(\"/\", \"_\")\n", 157 | "dwq_model_dir = models_dir / f\"{MODEL_NAME.replace('/', '_')}_DWQ_{DWQ_CONFIG['bits']}bit\"\n", 158 | "\n", 159 | "print(f\"\\nOriginal model dir: {original_model_dir}\")\n", 160 | "print(f\"DWQ model dir: {dwq_model_dir}\")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Step 5: Download Original Model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "from datetime import datetime\n", 177 | "\n", 178 | "print(f\"Downloading {MODEL_NAME}...\")\n", 179 | "print(\"This may take a while depending on model size and internet connection.\")\n", 180 | "\n", 181 | "# Create directories\n", 182 | "original_model_dir.mkdir(parents=True, exist_ok=True)\n", 183 | "\n", 184 | "# Check if model already exists\n", 185 | "if list(original_model_dir.glob(\"*\")):\n", 186 | " print(f\"Model files found in {original_model_dir}\")\n", 187 | " use_existing = input(\"Use existing model files? (y/n): \").strip().lower()\n", 188 | " if use_existing != 'y':\n", 189 | " import shutil\n", 190 | " shutil.rmtree(original_model_dir)\n", 191 | " original_model_dir.mkdir(parents=True, exist_ok=True)\n", 192 | "\n", 193 | "if not list(original_model_dir.glob(\"*\")):\n", 194 | " try:\n", 195 | " start_time = datetime.now()\n", 196 | " \n", 197 | " downloaded_path = snapshot_download(\n", 198 | " repo_id=MODEL_NAME,\n", 199 | " local_dir=str(original_model_dir),\n", 200 | " local_dir_use_symlinks=False\n", 201 | " )\n", 202 | " \n", 203 | " end_time = datetime.now()\n", 204 | " duration = end_time - start_time\n", 205 | " \n", 206 | " print(f\"✅ Model downloaded successfully in {duration}\")\n", 207 | " \n", 208 | " except Exception as e:\n", 209 | " print(f\"❌ Download failed: {e}\")\n", 210 | " print(\"Please check the model name and internet connection.\")\n", 211 | "\n", 212 | "# List downloaded files\n", 213 | "print(\"\\nModel files:\")\n", 214 | "total_size = 0\n", 215 | "for file in original_model_dir.glob(\"*\"):\n", 216 | " if file.is_file():\n", 217 | " size_mb = file.stat().st_size / 1024 / 1024\n", 218 | " total_size += size_mb\n", 219 | " print(f\" {file.name} ({size_mb:.2f} MB)\")\n", 220 | "\n", 221 | "print(f\"\\nTotal model size: {total_size:.2f} MB\")" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "## Step 6: DWQ Quantization" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "import subprocess\n", 238 | "import shutil\n", 239 | "from datetime import datetime\n", 240 | "\n", 241 | "print(\"Starting DWQ quantization...\")\n", 242 | "print(f\"Source: {original_model_dir}\")\n", 243 | "print(f\"Target: {dwq_model_dir}\")\n", 244 | "print(f\"Configuration: {DWQ_CONFIG}\")\n", 245 | "\n", 246 | "# Clean up existing DWQ directory\n", 247 | "if dwq_model_dir.exists():\n", 248 | " print(f\"Removing existing DWQ directory: {dwq_model_dir}\")\n", 249 | " shutil.rmtree(dwq_model_dir)\n", 250 | "\n", 251 | "dwq_model_dir.mkdir(parents=True, exist_ok=True)\n", 252 | "\n", 253 | "# Build DWQ command\n", 254 | "dwq_cmd = [\n", 255 | " \"python\", \"-m\", \"mlx_lm.dwq\",\n", 256 | " \"--model\", str(original_model_dir),\n", 257 | " \"--mlx-path\", str(dwq_model_dir),\n", 258 | " \"--bits\", str(DWQ_CONFIG[\"bits\"]),\n", 259 | " \"--num-samples\", str(DWQ_CONFIG[\"num_samples\"]),\n", 260 | " \"--batch-size\", str(DWQ_CONFIG[\"batch_size\"])\n", 261 | "]\n", 262 | "\n", 263 | "print(f\"\\nRunning command: {' '.join(dwq_cmd)}\")\n", 264 | "\n", 265 | "try:\n", 266 | " start_time = datetime.now()\n", 267 | " \n", 268 | " # Run DWQ quantization\n", 269 | " result = subprocess.run(\n", 270 | " dwq_cmd,\n", 271 | " capture_output=True,\n", 272 | " text=True,\n", 273 | " cwd=str(project_dir)\n", 274 | " )\n", 275 | " \n", 276 | " end_time = datetime.now()\n", 277 | " duration = end_time - start_time\n", 278 | " \n", 279 | " if result.returncode == 0:\n", 280 | " print(f\"\\n✅ DWQ quantization completed successfully in {duration}!\")\n", 281 | " print(\"STDOUT:\", result.stdout)\n", 282 | " else:\n", 283 | " print(f\"\\n❌ DWQ quantization failed!\")\n", 284 | " print(\"STDERR:\", result.stderr)\n", 285 | " print(\"STDOUT:\", result.stdout)\n", 286 | " \n", 287 | "except Exception as e:\n", 288 | " print(f\"❌ Error running DWQ: {e}\")\n", 289 | "\n", 290 | "# Check results\n", 291 | "if dwq_model_dir.exists() and list(dwq_model_dir.glob(\"*\")):\n", 292 | " print(\"\\nDWQ quantized files:\")\n", 293 | " total_size = 0\n", 294 | " for file in dwq_model_dir.glob(\"*\"):\n", 295 | " if file.is_file():\n", 296 | " size_mb = file.stat().st_size / 1024 / 1024\n", 297 | " total_size += size_mb\n", 298 | " print(f\" {file.name} ({size_mb:.2f} MB)\")\n", 299 | " \n", 300 | " print(f\"\\nTotal DWQ model size: {total_size:.2f} MB\")\n", 301 | " print(f\"Size reduction: {((total_size/total_size if 'total_size' in locals() else 0) - 1) * 100:.1f}%\")" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "## Step 7: Test DWQ Model" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "# Test the DWQ quantized model\n", 318 | "if dwq_model_dir.exists() and list(dwq_model_dir.glob(\"*\")):\n", 319 | " print(\"Testing DWQ quantized model...\")\n", 320 | " \n", 321 | " try:\n", 322 | " # Load the DWQ model\n", 323 | " model, tokenizer = load(str(dwq_model_dir))\n", 324 | " print(\"✅ DWQ model loaded successfully!\")\n", 325 | " \n", 326 | " # Test generation\n", 327 | " test_prompts = [\n", 328 | " \"Hello, how are you?\",\n", 329 | " \"The weather today is\",\n", 330 | " \"Artificial intelligence is\"\n", 331 | " ]\n", 332 | " \n", 333 | " print(\"\\n=== DWQ Model Test Results ===\")\n", 334 | " for prompt in test_prompts:\n", 335 | " print(f\"\\nPrompt: '{prompt}'\")\n", 336 | " \n", 337 | " response = generate(\n", 338 | " model, \n", 339 | " tokenizer, \n", 340 | " prompt=prompt, \n", 341 | " max_tokens=50,\n", 342 | " temp=0.7\n", 343 | " )\n", 344 | " \n", 345 | " print(f\"Response: {response}\")\n", 346 | " \n", 347 | " print(\"\\n✅ DWQ model is working correctly!\")\n", 348 | " \n", 349 | " except Exception as e:\n", 350 | " print(f\"❌ Error testing DWQ model: {e}\")\nelse:\n", 351 | " print(\"❌ DWQ model not found. Quantization may have failed.\")" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## Step 8: Evaluate Model Quality" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "# Optional: Evaluate the quantized model\n", 368 | "# This requires a dataset for evaluation\n", 369 | "\n", 370 | "print(\"=== Model Quality Evaluation ===\\n\")\n", 371 | "\n", 372 | "evaluate_model = input(\"Do you want to evaluate model quality? (y/n): \").strip().lower()\n", 373 | "\n", 374 | "if evaluate_model == 'y':\n", 375 | " # You can use mlx_lm.evaluate for this\n", 376 | " eval_cmd = [\n", 377 | " \"python\", \"-m\", \"mlx_lm.evaluate\",\n", 378 | " \"--model\", str(dwq_model_dir),\n", 379 | " \"--dataset\", \"wikitext\", # or your preferred dataset\n", 380 | " \"--few-shot\", \"5\"\n", 381 | " ]\n", 382 | " \n", 383 | " print(f\"Running evaluation: {' '.join(eval_cmd)}\")\n", 384 | " \n", 385 | " try:\n", 386 | " result = subprocess.run(eval_cmd, capture_output=True, text=True)\n", 387 | " \n", 388 | " if result.returncode == 0:\n", 389 | " print(\"\\n✅ Evaluation completed!\")\n", 390 | " print(result.stdout)\n", 391 | " else:\n", 392 | " print(\"\\n❌ Evaluation failed!\")\n", 393 | " print(result.stderr)\n", 394 | " \n", 395 | " except Exception as e:\n", 396 | " print(f\"❌ Error running evaluation: {e}\")\nelse:\n", 397 | " print(\"Skipping evaluation.\")" 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "metadata": {}, 403 | "source": [ 404 | "## Step 9: Upload to Hugging Face (Optional)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [ 413 | "from huggingface_hub import HfApi, upload_folder\n", 414 | "import getpass\n", 415 | "\n", 416 | "upload_to_hf = input(\"Do you want to upload the DWQ model to Hugging Face? (y/n): \").strip().lower()\n", 417 | "\n", 418 | "if upload_to_hf == 'y':\n", 419 | " # Get Hugging Face credentials\n", 420 | " print(\"Please enter your Hugging Face token:\")\n", 421 | " hf_token = getpass.getpass(\"HF Token: \")\n", 422 | " \n", 423 | " try:\n", 424 | " login(token=hf_token)\n", 425 | " print(\"✅ Successfully logged in to Hugging Face!\")\n", 426 | " \n", 427 | " # Get repository name\n", 428 | " repo_name = input(\"Enter repository name (e.g., 'username/model-name-dwq'): \").strip()\n", 429 | " \n", 430 | " # Create repository\n", 431 | " api = HfApi()\n", 432 | " api.create_repo(repo_id=repo_name, repo_type=\"model\", exist_ok=True)\n", 433 | " print(f\"✅ Repository {repo_name} created!\")\n", 434 | " \n", 435 | " # Create model card\n", 436 | " model_card = f\"\"\"---\n", 437 | "license: apache-2.0\n", 438 | "base_model: {MODEL_NAME}\n", 439 | "tags:\n", 440 | "- mlx\n", 441 | "- dwq\n", 442 | "- quantized\n", 443 | "- {DWQ_CONFIG['bits']}-bit\n", 444 | "---\n", 445 | "\n", 446 | "# {MODEL_NAME.split('/')[-1]} - DWQ {DWQ_CONFIG['bits']}-bit\n", 447 | "\n", 448 | "This is a DWQ (Distilled Weight Quantization) {DWQ_CONFIG['bits']}-bit quantized version of [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}).\n", 449 | "\n", 450 | "## Quantization Details\n", 451 | "- Method: DWQ (Distilled Weight Quantization)\n", 452 | "- Precision: {DWQ_CONFIG['bits']}-bit\n", 453 | "- Calibration samples: {DWQ_CONFIG['num_samples']}\n", 454 | "- Group size: {DWQ_CONFIG['group_size']}\n", 455 | "\n", 456 | "## Usage\n", 457 | "```python\n", 458 | "from mlx_lm import load, generate\n", 459 | "\n", 460 | "model, tokenizer = load(\"{repo_name}\")\n", 461 | "response = generate(model, tokenizer, prompt=\"Hello\", max_tokens=100)\n", 462 | "```\n", 463 | "\"\"\"\n", 464 | " \n", 465 | " # Save model card\n", 466 | " with open(dwq_model_dir / \"README.md\", \"w\") as f:\n", 467 | " f.write(model_card)\n", 468 | " \n", 469 | " # Upload\n", 470 | " print(f\"Uploading to {repo_name}...\")\n", 471 | " upload_folder(\n", 472 | " folder_path=str(dwq_model_dir),\n", 473 | " repo_id=repo_name,\n", 474 | " repo_type=\"model\",\n", 475 | " commit_message=f\"Add DWQ {DWQ_CONFIG['bits']}-bit quantized model\"\n", 476 | " )\n", 477 | " \n", 478 | " print(f\"✅ Model uploaded successfully!\")\n", 479 | " print(f\"🔗 https://huggingface.co/{repo_name}\")\n", 480 | " \n", 481 | " except Exception as e:\n", 482 | " print(f\"❌ Upload failed: {e}\")\nelse:\n", 483 | " print(\"Skipping upload.\")" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "## Step 10: Summary" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "# Final summary\n", 500 | "print(\"\\n\" + \"=\"*60)\n", 501 | "print(\"🎉 DWQ QUANTIZATION SUMMARY\")\n", 502 | "print(\"=\"*60)\n", 503 | "\n", 504 | "print(f\"\\n📋 Configuration:\")\n", 505 | "print(f\" Base Model: {MODEL_NAME}\")\n", 506 | "print(f\" Target Bits: {DWQ_CONFIG['bits']}\")\n", 507 | "print(f\" Calibration Samples: {DWQ_CONFIG['num_samples']}\")\n", 508 | "print(f\" Group Size: {DWQ_CONFIG['group_size']}\")\n", 509 | "\n", 510 | "print(f\"\\n📁 Directories:\")\n", 511 | "print(f\" Original: {original_model_dir}\")\n", 512 | "print(f\" DWQ Model: {dwq_model_dir}\")\n", 513 | "\n", 514 | "# Check if quantization was successful\n", 515 | "if dwq_model_dir.exists() and list(dwq_model_dir.glob(\"*\")):\n", 516 | " print(f\"\\n✅ Status: DWQ quantization completed successfully!\")\n", 517 | " \n", 518 | " # Calculate size reduction if possible\n", 519 | " original_size = sum(f.stat().st_size for f in original_model_dir.glob(\"*\") if f.is_file()) / 1024 / 1024\n", 520 | " dwq_size = sum(f.stat().st_size for f in dwq_model_dir.glob(\"*\") if f.is_file()) / 1024 / 1024\n", 521 | " \n", 522 | " print(f\" Original size: {original_size:.2f} MB\")\n", 523 | " print(f\" DWQ size: {dwq_size:.2f} MB\")\n", 524 | " print(f\" Size reduction: {((original_size - dwq_size) / original_size * 100):.1f}%\")\nelse:\n", 525 | " print(f\"\\n❌ Status: DWQ quantization failed or incomplete\")\n", 526 | "\n", 527 | "print(f\"\\n💡 Tips for DWQ:\")\n", 528 | "print(f\" • Works best for 2-4 bit quantization\")\n", 529 | "print(f\" • Decreasing group size can improve results\")\n", 530 | "print(f\" • Adjust learning rate based on precision\")\n", 531 | "print(f\" • More calibration samples = better quality\")\n", 532 | "\n", 533 | "print(\"\\n\" + \"=\"*60)\n", 534 | "print(\"Thank you for using DWQ quantization!\")\n", 535 | "print(\"=\"*60)" 536 | ] 537 | } 538 | ], 539 | "metadata": { 540 | "kernelspec": { 541 | "display_name": "Python 3", 542 | "language": "python", 543 | "name": "python3" 544 | }, 545 | "language_info": { 546 | "codemirror_mode": { 547 | "name": "ipython", 548 | "version": 3 549 | }, 550 | "file_extension": ".py", 551 | "name": "python", 552 | "nbconvert_exporter": "python", 553 | "pygments_lexer": "ipython3", 554 | "version": "3.9.0" 555 | } 556 | }, 557 | "nbformat": 4, 558 | "nbformat_minor": 4 559 | } -------------------------------------------------------------------------------- /mlx-quantization/awq_quantization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# AWQ (Activation-aware Weight Quantization) with MLX-LM\n", 8 | "\n", 9 | "This notebook demonstrates how to use AWQ (Activation-aware Weight Quantization) with MLX-LM to scale and clip weights before quantization.\n", 10 | "\n", 11 | "## What is AWQ?\n", 12 | "AWQ is a quantization method that scales and clips weights before quantization to preserve model quality. It uses calibration samples to determine optimal scaling factors for different weights.\n", 13 | "\n", 14 | "## Requirements\n", 15 | "- macOS with Apple Silicon (M1/M2/M3/M4)\n", 16 | "- Python 3.9+\n", 17 | "- MLX framework\n", 18 | "- Sufficient disk space for model storage" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Step 1: Environment Setup and Dependencies" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os\n", 35 | "import sys\n", 36 | "import subprocess\n", 37 | "from pathlib import Path\n", 38 | "\n", 39 | "# Environment setup\n", 40 | "print(\"Setting up environment for AWQ quantization...\")\n", 41 | "\n", 42 | "# Create project directories\n", 43 | "project_dir = Path.cwd()\n", 44 | "models_dir = project_dir / \"models\"\n", 45 | "models_dir.mkdir(exist_ok=True)\n", 46 | "\n", 47 | "print(f\"Project directory: {project_dir}\")\n", 48 | "print(f\"Models directory: {models_dir}\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Step 2: Install MLX and Dependencies" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Install required packages\n", 65 | "print(\"Installing MLX and dependencies...\")\n", 66 | "\n", 67 | "packages = [\n", 68 | " \"mlx-lm\",\n", 69 | " \"transformers\",\n", 70 | " \"torch\", \n", 71 | " \"huggingface_hub\",\n", 72 | " \"datasets\",\n", 73 | " \"accelerate\",\n", 74 | " \"sentencepiece\",\n", 75 | " \"protobuf\"\n", 76 | "]\n", 77 | "\n", 78 | "for package in packages:\n", 79 | " try:\n", 80 | " print(f\"Installing {package}...\")\n", 81 | " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", package], \n", 82 | " check=True, capture_output=True, text=True)\n", 83 | " print(f\"✅ {package} installed successfully\")\n", 84 | " except subprocess.CalledProcessError as e:\n", 85 | " print(f\"⚠️ Warning installing {package}: {e}\")\n", 86 | "\n", 87 | "print(\"\\n📦 All packages installation completed!\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Step 3: Test MLX Imports" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Test imports\n", 104 | "print(\"Testing MLX imports...\")\n", 105 | "\n", 106 | "try:\n", 107 | " import mlx.core as mx\n", 108 | " from mlx_lm import load, generate\n", 109 | " from huggingface_hub import login, snapshot_download\n", 110 | " print(\"✅ All imports successful!\")\n", 111 | " \n", 112 | " # Test MLX functionality\n", 113 | " test_array = mx.array([1, 2, 3])\n", 114 | " print(f\"✅ MLX test array: {test_array}\")\n", 115 | " \n", 116 | "except ImportError as e:\n", 117 | " print(f\"❌ Import failed: {e}\")\n", 118 | " print(\"Please restart kernel and try again.\")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Step 4: Configuration" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# AWQ Configuration\n", 135 | "print(\"=== AWQ Configuration ===\\n\")\n", 136 | "\n", 137 | "# Model to quantize (you can change this)\n", 138 | "MODEL_NAME = \"Qwen/Qwen2.5-0.5B\" # Small model for demonstration\n", 139 | "\n", 140 | "# AWQ Parameters\n", 141 | "AWQ_CONFIG = {\n", 142 | " \"bits\": 4, # Quantization precision (typically 4 bits)\n", 143 | " \"num_samples\": 32, # Calibration samples (default: 32)\n", 144 | " \"n_grid\": 10, # Search granularity (default: 10)\n", 145 | " \"group_size\": 128, # Group size for quantization\n", 146 | "}\n", 147 | "\n", 148 | "print(f\"Model: {MODEL_NAME}\")\n", 149 | "print(f\"Target bits: {AWQ_CONFIG['bits']}\")\n", 150 | "print(f\"Calibration samples: {AWQ_CONFIG['num_samples']}\")\n", 151 | "print(f\"Search grid: {AWQ_CONFIG['n_grid']}\")\n", 152 | "print(f\"Group size: {AWQ_CONFIG['group_size']}\")\n", 153 | "\n", 154 | "# Set up directories\n", 155 | "original_model_dir = models_dir / MODEL_NAME.replace(\"/\", \"_\")\n", 156 | "awq_model_dir = models_dir / f\"{MODEL_NAME.replace('/', '_')}_AWQ_{AWQ_CONFIG['bits']}bit\"\n", 157 | "\n", 158 | "print(f\"\\nOriginal model dir: {original_model_dir}\")\n", 159 | "print(f\"AWQ model dir: {awq_model_dir}\")" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## Step 5: Download Original Model" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "from datetime import datetime\n", 176 | "\n", 177 | "print(f\"Downloading {MODEL_NAME}...\")\n", 178 | "print(\"This may take a while depending on model size and internet connection.\")\n", 179 | "\n", 180 | "# Create directories\n", 181 | "original_model_dir.mkdir(parents=True, exist_ok=True)\n", 182 | "\n", 183 | "# Check if model already exists\n", 184 | "if list(original_model_dir.glob(\"*\")):\n", 185 | " print(f\"Model files found in {original_model_dir}\")\n", 186 | " use_existing = input(\"Use existing model files? (y/n): \").strip().lower()\n", 187 | " if use_existing != 'y':\n", 188 | " import shutil\n", 189 | " shutil.rmtree(original_model_dir)\n", 190 | " original_model_dir.mkdir(parents=True, exist_ok=True)\n", 191 | "\n", 192 | "if not list(original_model_dir.glob(\"*\")):\n", 193 | " try:\n", 194 | " start_time = datetime.now()\n", 195 | " \n", 196 | " downloaded_path = snapshot_download(\n", 197 | " repo_id=MODEL_NAME,\n", 198 | " local_dir=str(original_model_dir),\n", 199 | " local_dir_use_symlinks=False\n", 200 | " )\n", 201 | " \n", 202 | " end_time = datetime.now()\n", 203 | " duration = end_time - start_time\n", 204 | " \n", 205 | " print(f\"✅ Model downloaded successfully in {duration}\")\n", 206 | " \n", 207 | " except Exception as e:\n", 208 | " print(f\"❌ Download failed: {e}\")\n", 209 | " print(\"Please check the model name and internet connection.\")\n", 210 | "\n", 211 | "# List downloaded files\n", 212 | "print(\"\\nModel files:\")\n", 213 | "total_size = 0\n", 214 | "for file in original_model_dir.glob(\"*\"):\n", 215 | " if file.is_file():\n", 216 | " size_mb = file.stat().st_size / 1024 / 1024\n", 217 | " total_size += size_mb\n", 218 | " print(f\" {file.name} ({size_mb:.2f} MB)\")\n", 219 | "\n", 220 | "print(f\"\\nTotal model size: {total_size:.2f} MB\")" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "## Step 6: AWQ Quantization" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "import subprocess\n", 237 | "import shutil\n", 238 | "from datetime import datetime\n", 239 | "\n", 240 | "print(\"Starting AWQ quantization...\")\n", 241 | "print(f\"Source: {original_model_dir}\")\n", 242 | "print(f\"Target: {awq_model_dir}\")\n", 243 | "print(f\"Configuration: {AWQ_CONFIG}\")\n", 244 | "\n", 245 | "# Clean up existing AWQ directory\n", 246 | "if awq_model_dir.exists():\n", 247 | " print(f\"Removing existing AWQ directory: {awq_model_dir}\")\n", 248 | " shutil.rmtree(awq_model_dir)\n", 249 | "\n", 250 | "awq_model_dir.mkdir(parents=True, exist_ok=True)\n", 251 | "\n", 252 | "# Build AWQ command\n", 253 | "awq_cmd = [\n", 254 | " \"python\", \"-m\", \"mlx_lm.awq\",\n", 255 | " \"--model\", str(original_model_dir),\n", 256 | " \"--mlx-path\", str(awq_model_dir),\n", 257 | " \"--bits\", str(AWQ_CONFIG[\"bits\"]),\n", 258 | " \"--num-samples\", str(AWQ_CONFIG[\"num_samples\"]),\n", 259 | " \"--n-grid\", str(AWQ_CONFIG[\"n_grid\"])\n", 260 | "]\n", 261 | "\n", 262 | "print(f\"\\nRunning command: {' '.join(awq_cmd)}\")\n", 263 | "\n", 264 | "try:\n", 265 | " start_time = datetime.now()\n", 266 | " \n", 267 | " # Run AWQ quantization\n", 268 | " result = subprocess.run(\n", 269 | " awq_cmd,\n", 270 | " capture_output=True,\n", 271 | " text=True,\n", 272 | " cwd=str(project_dir)\n", 273 | " )\n", 274 | " \n", 275 | " end_time = datetime.now()\n", 276 | " duration = end_time - start_time\n", 277 | " \n", 278 | " if result.returncode == 0:\n", 279 | " print(f\"\\n✅ AWQ quantization completed successfully in {duration}!\")\n", 280 | " print(\"STDOUT:\", result.stdout)\n", 281 | " else:\n", 282 | " print(f\"\\n❌ AWQ quantization failed!\")\n", 283 | " print(\"STDERR:\", result.stderr)\n", 284 | " print(\"STDOUT:\", result.stdout)\n", 285 | " \n", 286 | "except Exception as e:\n", 287 | " print(f\"❌ Error running AWQ: {e}\")\n", 288 | "\n", 289 | "# Check results\n", 290 | "if awq_model_dir.exists() and list(awq_model_dir.glob(\"*\")):\n", 291 | " print(\"\\nAWQ quantized files:\")\n", 292 | " total_awq_size = 0\n", 293 | " for file in awq_model_dir.glob(\"*\"):\n", 294 | " if file.is_file():\n", 295 | " size_mb = file.stat().st_size / 1024 / 1024\n", 296 | " total_awq_size += size_mb\n", 297 | " print(f\" {file.name} ({size_mb:.2f} MB)\")\n", 298 | " \n", 299 | " print(f\"\\nTotal AWQ model size: {total_awq_size:.2f} MB\")\n", 300 | " if total_size > 0:\n", 301 | " print(f\"Size reduction: {((total_size - total_awq_size) / total_size * 100):.1f}%\")" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "## Step 7: Test AWQ Model" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "# Test the AWQ quantized model\n", 318 | "if awq_model_dir.exists() and list(awq_model_dir.glob(\"*\")):\n", 319 | " print(\"Testing AWQ quantized model...\")\n", 320 | " \n", 321 | " try:\n", 322 | " # Load the AWQ model\n", 323 | " model, tokenizer = load(str(awq_model_dir))\n", 324 | " print(\"✅ AWQ model loaded successfully!\")\n", 325 | " \n", 326 | " # Test generation\n", 327 | " test_prompts = [\n", 328 | " \"Hello, how are you?\",\n", 329 | " \"The weather today is\",\n", 330 | " \"Artificial intelligence is\",\n", 331 | " \"Machine learning can be used for\"\n", 332 | " ]\n", 333 | " \n", 334 | " print(\"\\n=== AWQ Model Test Results ===\")\n", 335 | " for prompt in test_prompts:\n", 336 | " print(f\"\\nPrompt: '{prompt}'\")\n", 337 | " \n", 338 | " response = generate(\n", 339 | " model, \n", 340 | " tokenizer, \n", 341 | " prompt=prompt, \n", 342 | " max_tokens=50,\n", 343 | " temp=0.7\n", 344 | " )\n", 345 | " \n", 346 | " print(f\"Response: {response}\")\n", 347 | " \n", 348 | " print(\"\\n✅ AWQ model is working correctly!\")\n", 349 | " \n", 350 | " except Exception as e:\n", 351 | " print(f\"❌ Error testing AWQ model: {e}\")\nelse:\n", 352 | " print(\"❌ AWQ model not found. Quantization may have failed.\")" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "## Step 8: Compare Original vs AWQ Performance" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "# Optional: Compare original vs AWQ model performance\n", 369 | "import time\n", 370 | "\n", 371 | "compare_models = input(\"Do you want to compare original vs AWQ model performance? (y/n): \").strip().lower()\n", 372 | "\n", 373 | "if compare_models == 'y':\n", 374 | " print(\"\\n=== Model Performance Comparison ===\")\n", 375 | " \n", 376 | " test_prompt = \"The future of artificial intelligence\"\n", 377 | " max_tokens = 100\n", 378 | " \n", 379 | " try:\n", 380 | " # Test original model\n", 381 | " print(\"\\n🔄 Testing original model...\")\n", 382 | " original_model, original_tokenizer = load(str(original_model_dir))\n", 383 | " \n", 384 | " start_time = time.time()\n", 385 | " original_response = generate(\n", 386 | " original_model, \n", 387 | " original_tokenizer, \n", 388 | " prompt=test_prompt, \n", 389 | " max_tokens=max_tokens,\n", 390 | " temp=0.7\n", 391 | " )\n", 392 | " original_time = time.time() - start_time\n", 393 | " \n", 394 | " print(f\"Original response: {original_response}\")\n", 395 | " print(f\"Original generation time: {original_time:.2f}s\")\n", 396 | " \n", 397 | " except Exception as e:\n", 398 | " print(f\"❌ Error testing original model: {e}\")\n", 399 | " original_response = None\n", 400 | " original_time = None\n", 401 | " \n", 402 | " try:\n", 403 | " # Test AWQ model (already loaded above)\n", 404 | " print(\"\\n🔄 Testing AWQ model...\")\n", 405 | " \n", 406 | " start_time = time.time()\n", 407 | " awq_response = generate(\n", 408 | " model, \n", 409 | " tokenizer, \n", 410 | " prompt=test_prompt, \n", 411 | " max_tokens=max_tokens,\n", 412 | " temp=0.7\n", 413 | " )\n", 414 | " awq_time = time.time() - start_time\n", 415 | " \n", 416 | " print(f\"AWQ response: {awq_response}\")\n", 417 | " print(f\"AWQ generation time: {awq_time:.2f}s\")\n", 418 | " \n", 419 | " # Compare performance\n", 420 | " if original_time and awq_time:\n", 421 | " speedup = original_time / awq_time\n", 422 | " print(f\"\\n📊 Performance comparison:\")\n", 423 | " print(f\" Speedup: {speedup:.2f}x\")\n", 424 | " print(f\" Time saved: {original_time - awq_time:.2f}s\")\n", 425 | " \n", 426 | " except Exception as e:\n", 427 | " print(f\"❌ Error testing AWQ model: {e}\")\nelse:\n", 428 | " print(\"Skipping performance comparison.\")" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": {}, 434 | "source": [ 435 | "## Step 9: Evaluate Model Quality" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "# Optional: Evaluate the quantized model\n", 445 | "print(\"=== Model Quality Evaluation ===\\n\")\n", 446 | "\n", 447 | "evaluate_model = input(\"Do you want to evaluate model quality? (y/n): \").strip().lower()\n", 448 | "\n", 449 | "if evaluate_model == 'y':\n", 450 | " # You can use mlx_lm.evaluate for this\n", 451 | " eval_cmd = [\n", 452 | " \"python\", \"-m\", \"mlx_lm.evaluate\",\n", 453 | " \"--model\", str(awq_model_dir),\n", 454 | " \"--dataset\", \"wikitext\", # or your preferred dataset\n", 455 | " \"--few-shot\", \"5\"\n", 456 | " ]\n", 457 | " \n", 458 | " print(f\"Running evaluation: {' '.join(eval_cmd)}\")\n", 459 | " \n", 460 | " try:\n", 461 | " result = subprocess.run(eval_cmd, capture_output=True, text=True)\n", 462 | " \n", 463 | " if result.returncode == 0:\n", 464 | " print(\"\\n✅ Evaluation completed!\")\n", 465 | " print(result.stdout)\n", 466 | " else:\n", 467 | " print(\"\\n❌ Evaluation failed!\")\n", 468 | " print(result.stderr)\n", 469 | " \n", 470 | " except Exception as e:\n", 471 | " print(f\"❌ Error running evaluation: {e}\")\nelse:\n", 472 | " print(\"Skipping evaluation.\")" 473 | ] 474 | }, 475 | { 476 | "cell_type": "markdown", 477 | "metadata": {}, 478 | "source": [ 479 | "## Step 10: Upload to Hugging Face (Optional)" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "from huggingface_hub import HfApi, upload_folder\n", 489 | "import getpass\n", 490 | "\n", 491 | "upload_to_hf = input(\"Do you want to upload the AWQ model to Hugging Face? (y/n): \").strip().lower()\n", 492 | "\n", 493 | "if upload_to_hf == 'y':\n", 494 | " # Get Hugging Face credentials\n", 495 | " print(\"Please enter your Hugging Face token:\")\n", 496 | " hf_token = getpass.getpass(\"HF Token: \")\n", 497 | " \n", 498 | " try:\n", 499 | " login(token=hf_token)\n", 500 | " print(\"✅ Successfully logged in to Hugging Face!\")\n", 501 | " \n", 502 | " # Get repository name\n", 503 | " repo_name = input(\"Enter repository name (e.g., 'username/model-name-awq'): \").strip()\n", 504 | " \n", 505 | " # Create repository\n", 506 | " api = HfApi()\n", 507 | " api.create_repo(repo_id=repo_name, repo_type=\"model\", exist_ok=True)\n", 508 | " print(f\"✅ Repository {repo_name} created!\")\n", 509 | " \n", 510 | " # Create model card\n", 511 | " model_card = f\"\"\"---\n", 512 | "license: apache-2.0\n", 513 | "base_model: {MODEL_NAME}\n", 514 | "tags:\n", 515 | "- mlx\n", 516 | "- awq\n", 517 | "- quantized\n", 518 | "- {AWQ_CONFIG['bits']}-bit\n", 519 | "---\n", 520 | "\n", 521 | "# {MODEL_NAME.split('/')[-1]} - AWQ {AWQ_CONFIG['bits']}-bit\n", 522 | "\n", 523 | "This is an AWQ (Activation-aware Weight Quantization) {AWQ_CONFIG['bits']}-bit quantized version of [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}).\n", 524 | "\n", 525 | "## Quantization Details\n", 526 | "- Method: AWQ (Activation-aware Weight Quantization)\n", 527 | "- Precision: {AWQ_CONFIG['bits']}-bit\n", 528 | "- Calibration samples: {AWQ_CONFIG['num_samples']}\n", 529 | "- Search grid: {AWQ_CONFIG['n_grid']}\n", 530 | "- Group size: {AWQ_CONFIG['group_size']}\n", 531 | "\n", 532 | "## Features\n", 533 | "- Scales and clips weights before quantization\n", 534 | "- Optimized for Apple Silicon devices\n", 535 | "- Maintains model quality through activation-aware scaling\n", 536 | "\n", 537 | "## Usage\n", 538 | "```python\n", 539 | "from mlx_lm import load, generate\n", 540 | "\n", 541 | "model, tokenizer = load(\"{repo_name}\")\n", 542 | "response = generate(model, tokenizer, prompt=\"Hello\", max_tokens=100)\n", 543 | "```\n", 544 | "\"\"\"\n", 545 | " \n", 546 | " # Save model card\n", 547 | " with open(awq_model_dir / \"README.md\", \"w\") as f:\n", 548 | " f.write(model_card)\n", 549 | " \n", 550 | " # Upload\n", 551 | " print(f\"Uploading to {repo_name}...\")\n", 552 | " upload_folder(\n", 553 | " folder_path=str(awq_model_dir),\n", 554 | " repo_id=repo_name,\n", 555 | " repo_type=\"model\",\n", 556 | " commit_message=f\"Add AWQ {AWQ_CONFIG['bits']}-bit quantized model\"\n", 557 | " )\n", 558 | " \n", 559 | " print(f\"✅ Model uploaded successfully!\")\n", 560 | " print(f\"🔗 https://huggingface.co/{repo_name}\")\n", 561 | " \n", 562 | " except Exception as e:\n", 563 | " print(f\"❌ Upload failed: {e}\")\nelse:\n", 564 | " print(\"Skipping upload.\")" 565 | ] 566 | }, 567 | { 568 | "cell_type": "markdown", 569 | "metadata": {}, 570 | "source": [ 571 | "## Step 11: Summary" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "# Final summary\n", 581 | "print(\"\\n\" + \"=\"*60)\n", 582 | "print(\"🎉 AWQ QUANTIZATION SUMMARY\")\n", 583 | "print(\"=\"*60)\n", 584 | "\n", 585 | "print(f\"\\n📋 Configuration:\")\n", 586 | "print(f\" Base Model: {MODEL_NAME}\")\n", 587 | "print(f\" Target Bits: {AWQ_CONFIG['bits']}\")\n", 588 | "print(f\" Calibration Samples: {AWQ_CONFIG['num_samples']}\")\n", 589 | "print(f\" Search Grid: {AWQ_CONFIG['n_grid']}\")\n", 590 | "print(f\" Group Size: {AWQ_CONFIG['group_size']}\")\n", 591 | "\n", 592 | "print(f\"\\n📁 Directories:\")\n", 593 | "print(f\" Original: {original_model_dir}\")\n", 594 | "print(f\" AWQ Model: {awq_model_dir}\")\n", 595 | "\n", 596 | "# Check if quantization was successful\n", 597 | "if awq_model_dir.exists() and list(awq_model_dir.glob(\"*\")):\n", 598 | " print(f\"\\n✅ Status: AWQ quantization completed successfully!\")\n", 599 | " \n", 600 | " # Calculate size reduction if possible\n", 601 | " if 'total_size' in locals() and 'total_awq_size' in locals():\n", 602 | " print(f\" Original size: {total_size:.2f} MB\")\n", 603 | " print(f\" AWQ size: {total_awq_size:.2f} MB\")\n", 604 | " print(f\" Size reduction: {((total_size - total_awq_size) / total_size * 100):.1f}%\")\nelse:\n", 605 | " print(f\"\\n❌ Status: AWQ quantization failed or incomplete\")\n", 606 | "\n", 607 | "print(f\"\\n💡 AWQ Advantages:\")\n", 608 | "print(f\" • Scales and clips weights before quantization\")\n", 609 | "print(f\" • Preserves model quality through activation awareness\")\n", 610 | "print(f\" • Efficient search for optimal scaling factors\")\n", 611 | "print(f\" • Good balance between size and performance\")\n", 612 | "\n", 613 | "print(f\"\\n🔧 Tuning Tips:\")\n", 614 | "print(f\" • Increase num_samples for better quality\")\n", 615 | "print(f\" • Increase n_grid for more thorough search\")\n", 616 | "print(f\" • Adjust group_size based on model architecture\")\n", 617 | "\n", 618 | "print(\"\\n\" + \"=\"*60)\n", 619 | "print(\"Thank you for using AWQ quantization!\")\n", 620 | "print(\"=\"*60)" 621 | ] 622 | } 623 | ], 624 | "metadata": { 625 | "kernelspec": { 626 | "display_name": "Python 3", 627 | "language": "python", 628 | "name": "python3" 629 | }, 630 | "language_info": { 631 | "codemirror_mode": { 632 | "name": "ipython", 633 | "version": 3 634 | }, 635 | "file_extension": ".py", 636 | "name": "python", 637 | "nbconvert_exporter": "python", 638 | "pygments_lexer": "ipython3", 639 | "version": "3.9.0" 640 | } 641 | }, 642 | "nbformat": 4, 643 | "nbformat_minor": 4 644 | } --------------------------------------------------------------------------------