├── .gitignore ├── LICENSE ├── README.md ├── book ├── .gitignore ├── book.toml └── src │ ├── SUMMARY.md │ ├── glossary.md │ ├── preface.md │ ├── week1-overview.md │ ├── week2-overview.md │ └── week3-overview.md ├── main.py ├── poetry.lock ├── pyproject.toml ├── src └── tiny_llm │ ├── __init__.py │ ├── funcs.py │ └── layers.py └── tests ├── test_attention.py ├── test_funcs.py ├── test_rope.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tiny-llm 2 | 3 | Still WIP and in very early stage. A tutorial on LLM serving using MLX for system engineers. The codebase 4 | is solely (almost!) based on MLX array/matrix APIs without any high-level neural network APIs, so that we 5 | can build the model serving infrastructure from scratch and dig into the optimizations. 6 | 7 | We test the implementation against PyTorch's CPU implementation to ensure correctness. The main codebase uses MLX 8 | instead of PyTorch because nowadays it's easier to get an Apple Silicon MacBook than an NVIDIA GPU. In theory you can 9 | implement everything using PyTorch tensor APIs, but we didn't have the test infra to support that. 10 | 11 | (TODO: maybe we should test against MLX? PyTorch APIs sometimes don't align with MLX; but I also want to ensure the computation 12 | precision is enough to load any model directly from PyTorch tensors without converting to MLX format.) 13 | 14 | The goal is to learn the techniques behind efficiently serving an LLM model (i.e., Qwen2 models). 15 | 16 | * Week 1: serve Qwen2 with purely Python APIs. No fancy optimizations, just Python. 17 | * Week 2: optimizations, implement C++/Metal custom kernels to make the model run faster. 18 | * Week 3: more optimizations, batch the requests to serve the model with high throughput. 19 | 20 | TBD: implement a leaderboard service? 21 | 22 | ## Usage 23 | 24 | ```bash 25 | poetry install 26 | poetry run pytest 27 | poetry run python main.py 28 | ``` 29 | 30 | ## Week 1: LLM from Scratch 31 | 32 | Instead of me explaining everything in this tutorial, it works in a different way: I collect all the materials, blog posts, and source 33 | code I've read when implementing the components. There's API specification and test infra to compare the result with PyTorch/MLX in the 34 | repo. You may read the materials and implement the components :) 35 | 36 | ### Day 1: Attention is All You Need 37 | 38 | Implement `scaled_dot_product_attention`. The function takes key, value, and query of the same dimensions. 39 | 40 | ``` 41 | K: N.. x H x L x E 42 | V: N.. x H x L x E 43 | Q: N.. x H x L x E 44 | ``` 45 | 46 | Where `N..` is zero or some number of dimensions for batches. Within each of the batch, `H` is the number of heads, 47 | `L` is the sequence length, and `E` is the embedding/hidden size. 48 | 49 | You may use `softmax` provided by mlx and implement it later in week 2. 50 | 51 | **References** 52 | 53 | * Annotated Transformer https://nlp.seas.harvard.edu/annotated-transformer/ 54 | * PyTorch API (the case where enable_gqa=False) https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 55 | * MLX API https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html 56 | * https://arxiv.org/abs/1706.03762 57 | 58 | Implement `MultiHeadAttention`. The layer takes a batch of vectors `x`, maps it through the K,V,Q weight matrixes, and 59 | use the attention function we implemented in day 1 to compute the result. The output needs to be mapped using the O 60 | weight matrix. You will also need to implement the `linear` function. 61 | 62 | ``` 63 | x: N x L x D 64 | D = num_heads x head_dim 65 | ``` 66 | 67 | **References** 68 | 69 | * Annotated Transformer https://nlp.seas.harvard.edu/annotated-transformer/ 70 | * PyTorch API https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html 71 | * MLX API https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html 72 | 73 | ### Day 2: RoPE Embedding 74 | 75 | Note there are traditional and non-traditional ropes. 76 | 77 | **References** 78 | 79 | * https://pytorch.org/torchtune/stable/generated/torchtune.modules.RotaryPositionalEmbeddings.html 80 | * https://github.com/pytorch/torchtune/blob/main/torchtune/modules/position_embeddings.py 81 | * https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py 82 | * https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.RoPE.html 83 | * https://arxiv.org/abs/2104.09864 84 | 85 | ### Day 3: Grouped Query Attention 86 | 87 | The Qwen2 models use Grouped Query Attention (GQA). GQA allows different dimensions for query and key/value. 88 | 89 | **References** 90 | 91 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py 92 | * PyTorch API (the case where enable_gqa=True) https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 93 | * torchtune.modules.MultiHeadAttention https://pytorch.org/torchtune/0.3/generated/torchtune.modules.MultiHeadAttention.html 94 | * https://arxiv.org/abs/2305.13245v1 95 | 96 | ### Day 4: RMSNorm and MLP 97 | 98 | RMSNorm needs to be accumulated over float32 99 | 100 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py 101 | * SiLU https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html 102 | * RMSNorm (note that it needs to accumulate at float32) 103 | 104 | ### Day 5: Transformer Block 105 | 106 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py 107 | 108 | ### Day 6: Load the Model 109 | 110 | We will use mlx-lm's loader to load the model. We will _steal_ the loaded parameters from the mlx model and 111 | plug it into our own operators. 112 | 113 | ### Day 7: Generate Responses 114 | 115 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py 116 | 117 | Run `python main.py` and it should give you a reasonable response. 118 | 119 | On my M4 Pro Mac Mini, my implementation gives 17 tokens per sec on Metal, versus 50 tokens per sec from the mlx-lm 120 | Qwen2 implementation. Sadly, it also takes 4x memory than using the mlx-lm components as it does not support computation 121 | over quantized parameters. 122 | 123 | ## Week 2 124 | 125 | Quantization, implement softmax/linear/silu kernels, implement attention kernels, key-value cache and compression, attention masks, prompt cache. 126 | 127 | ## Week 3 128 | 129 | Continuous batching, OpenAPI HTTP endpoint, integrate with other services. 130 | -------------------------------------------------------------------------------- /book/.gitignore: -------------------------------------------------------------------------------- 1 | book 2 | -------------------------------------------------------------------------------- /book/book.toml: -------------------------------------------------------------------------------- 1 | [book] 2 | authors = ["Alex Chi"] 3 | language = "en" 4 | multilingual = false 5 | src = "src" 6 | title = "Tiny LLM - LLM Serving in a Week" 7 | 8 | [preprocessor.toc] 9 | command = "mdbook-toc" 10 | renderer = ["html"] 11 | 12 | [output.html] 13 | git-repository-url = "https://github.com/skyzh/tiny-llm" 14 | -------------------------------------------------------------------------------- /book/src/SUMMARY.md: -------------------------------------------------------------------------------- 1 | # LLM Serving in a Week 2 | 3 | [Preface](./preface.md) 4 | [Setting Up the Environment]() 5 | 6 | --- 7 | 8 | - [Week 1: From Matmul to Text]() 9 | - [Attention and Multi-Head Attention, Linear]() 10 | - [Positional Embeddings and RoPE]() 11 | - [Grouped/Multi Query Attention, Embedding, Silu]() 12 | - [Multilayer Perceptron Layer and Transformer, RMSNorm]() 13 | - [Wiring and Loading the Model, Dequantize]() 14 | - [Tokenize and Generating Response]() 15 | 16 | - [Week 2: Optimizing]() 17 | 18 | - [Week 3: Serving]() 19 | 20 | --- 21 | 22 | [Glossary Index](./glossary.md) 23 | -------------------------------------------------------------------------------- /book/src/glossary.md: -------------------------------------------------------------------------------- 1 | # Glossary Index 2 | 3 | The functionality is covered in which days? 4 | -------------------------------------------------------------------------------- /book/src/preface.md: -------------------------------------------------------------------------------- 1 | # Preface 2 | 3 | LLM serving in a week for systems engineers! 4 | -------------------------------------------------------------------------------- /book/src/week1-overview.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/book/src/week1-overview.md -------------------------------------------------------------------------------- /book/src/week2-overview.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/book/src/week2-overview.md -------------------------------------------------------------------------------- /book/src/week3-overview.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/book/src/week3-overview.md -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from mlx_lm import load, generate 2 | from mlx_lm.sample_utils import make_sampler, make_logits_processors 3 | from tiny_llm.layers import Qwen2Model 4 | import mlx.core as mx 5 | 6 | with mx.stream(mx.gpu): 7 | mlx_model, tokenizer = load( 8 | "Qwen/Qwen2-7B-Instruct-MLX", 9 | tokenizer_config={"eos_token": "<|im_end|>"}, 10 | model_config={"tie_word_embeddings": False, "rope_traditional": True}, 11 | ) 12 | tiny_llm_model = Qwen2Model(mlx_model) 13 | 14 | prompt = "Give me a short introduction to large language model." 15 | messages = [ 16 | {"role": "system", "content": "You are a helpful assistant."}, 17 | {"role": "user", "content": prompt}, 18 | ] 19 | prompt = tokenizer.apply_chat_template( 20 | messages, tokenize=False, add_generation_prompt=True 21 | ) 22 | def _step(model, y, offset): 23 | logits = model(y[None], offset) 24 | logits = logits[:, -1, :] 25 | logprobs = logits - mx.logsumexp(logits, keepdims=True) 26 | sampler = lambda x: mx.argmax(x, axis=-1) 27 | y = sampler(logprobs) 28 | return y, logprobs.squeeze(0) 29 | # prefill with the prompt 30 | tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False)) 31 | offset = tokens.size 32 | detokenizer = tokenizer.detokenizer 33 | detokenizer.reset() 34 | # generate 35 | while True: 36 | token, _ = _step(tiny_llm_model, tokens, offset) 37 | tokens = mx.concat([tokens, token]) 38 | if token.item() == tokenizer.eos_token_id: 39 | break 40 | detokenizer.add_token(token.item()) 41 | print(detokenizer.last_segment, end="", flush=True) 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tiny-llm" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Alex Chi Z "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.12" 10 | mlx = "^0.25.0" 11 | torch = "^2.6.0" 12 | mlx-lm = "^0.23.0" 13 | torchtune = "^0.6.1" 14 | torchao = "^0.10.0" 15 | 16 | 17 | [tool.poetry.group.dev.dependencies] 18 | pytest = "^8.3.5" 19 | numpy = "^2.2.4" 20 | ruff = "^0.11.6" 21 | 22 | [build-system] 23 | requires = ["poetry-core"] 24 | build-backend = "poetry.core.masonry.api" 25 | 26 | [project] 27 | name = "tiny-llm" 28 | version = "0.1.0" 29 | 30 | [tool.pytest.ini_options] 31 | addopts = [ 32 | "--import-mode=importlib", 33 | ] 34 | pythonpath = "src" 35 | -------------------------------------------------------------------------------- /src/tiny_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/src/tiny_llm/__init__.py -------------------------------------------------------------------------------- /src/tiny_llm/funcs.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import math 3 | 4 | 5 | def softmax(x: mx.array, axis: int) -> mx.array: 6 | # TODO: manual implementation 7 | return mx.softmax(x, axis=axis) 8 | 9 | 10 | def scaled_dot_product_attention( 11 | query: mx.array, 12 | key: mx.array, 13 | value: mx.array, 14 | scale: float | None = None, 15 | mask: mx.array | None = None, 16 | stream: mx.Stream | mx.Device | None = None, 17 | ) -> mx.array: 18 | """ 19 | Compute scaled dot-product attention. 20 | 21 | query: batch_size x 22 | """ 23 | factor = mx.rsqrt(query.shape[-1]) if scale is None else scale 24 | scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor 25 | if mask is not None: 26 | scores = scores + mask 27 | return mx.matmul(softmax(scores, axis=-1), value) 28 | 29 | 30 | def scaled_dot_product_attention_grouped( 31 | query: mx.array, 32 | key: mx.array, 33 | value: mx.array, 34 | scale: float | None = None, 35 | mask: mx.array | None = None, 36 | ) -> mx.array: 37 | """ 38 | Compute scaled dot-product attention. 39 | 40 | query: batch_size x 41 | """ 42 | factor = mx.rsqrt(query.shape[-1]) if scale is None else scale 43 | expected_shape = query.shape 44 | query = query.reshape(-1, query.shape[-3], query.shape[-2], query.shape[-1]) 45 | key = key.reshape(-1, key.shape[-3], key.shape[-2], key.shape[-1]) 46 | value = value.reshape(-1, value.shape[-3], value.shape[-2], value.shape[-1]) 47 | B, H_q, L, E = query.shape 48 | _, H, S, _ = key.shape 49 | assert H_q % H == 0 50 | n_repeats = H_q // H 51 | query = query.reshape((B, H, n_repeats, L, E)) 52 | key = key.reshape((B, H, 1, S, E)) 53 | value = value.reshape((B, H, 1, S, E)) 54 | scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor 55 | if mask is not None: 56 | mask = mask.reshape(-1, H, n_repeats, mask.shape[-2], mask.shape[-1]) 57 | scores = scores + mask 58 | result = mx.matmul( 59 | softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype), value 60 | ) 61 | return result.reshape(expected_shape) 62 | 63 | 64 | def linear( 65 | x: mx.array, 66 | w: mx.array, 67 | bias: mx.array | None = None, 68 | ) -> mx.array: 69 | if bias is not None: 70 | return mx.matmul(x, w.T) + bias 71 | else: 72 | return mx.matmul(x, w.T) 73 | 74 | 75 | def silu(x: mx.array) -> mx.array: 76 | return x / (1 + mx.exp(-x)) 77 | -------------------------------------------------------------------------------- /src/tiny_llm/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Optional 3 | 4 | import mlx.core as mx 5 | from mlx_lm.models.cache import KVCache 6 | from .funcs import ( 7 | linear, 8 | scaled_dot_product_attention, 9 | scaled_dot_product_attention_grouped, 10 | silu, 11 | ) 12 | 13 | # TODO: add license for those heavily based on mlx-lm/PyTorch 14 | 15 | 16 | class MultiHeadAttention: 17 | def __init__( 18 | self, 19 | hidden_size: int, 20 | num_heads: int, 21 | wq: mx.array, 22 | wk: mx.array, 23 | wv: mx.array, 24 | wo: mx.array, 25 | ): 26 | self.hidden_size = hidden_size 27 | self.num_heads = num_heads 28 | assert hidden_size % num_heads == 0 29 | self.head_dim = hidden_size // num_heads 30 | self.scale = mx.rsqrt(self.head_dim) 31 | assert wq.shape == (hidden_size, num_heads * self.head_dim) 32 | assert wk.shape == (hidden_size, num_heads * self.head_dim) 33 | assert wv.shape == (hidden_size, num_heads * self.head_dim) 34 | assert wo.shape == (num_heads * self.head_dim, hidden_size) 35 | self.wq = wq 36 | self.wk = wk 37 | self.wv = wv 38 | self.wo = wo 39 | 40 | def __call__( 41 | self, 42 | query: mx.array, 43 | key: mx.array, 44 | value: mx.array, 45 | mask: mx.array | None = None, 46 | ) -> mx.array: 47 | n_batches = query.shape[0] 48 | batch_size = query.shape[1] 49 | projection_q = ( 50 | linear(query, self.wq) 51 | .reshape(n_batches, self.num_heads * batch_size, self.head_dim) 52 | .transpose(1, 0, 2) 53 | ) 54 | projection_k = ( 55 | linear(key, self.wk) 56 | .reshape(n_batches, self.num_heads * batch_size, self.head_dim) 57 | .transpose(1, 0, 2) 58 | ) 59 | projection_v = ( 60 | linear(value, self.wv) 61 | .reshape(n_batches, self.num_heads * batch_size, self.head_dim) 62 | .transpose(1, 0, 2) 63 | ) 64 | x = scaled_dot_product_attention( 65 | projection_q, 66 | projection_k, 67 | projection_v, 68 | scale=self.scale, 69 | mask=mask, 70 | ) 71 | x = x.transpose(1, 0, 2).reshape(n_batches, batch_size, self.hidden_size) 72 | return linear(x, self.wo) 73 | 74 | 75 | class Qwen2MultiHeadAttention: 76 | def __init__( 77 | self, 78 | hidden_size: int, 79 | num_heads: int, 80 | num_kv_heads: int, 81 | wq: mx.array, 82 | wk: mx.array, 83 | wv: mx.array, 84 | wo: mx.array, 85 | bq: mx.array, 86 | bk: mx.array, 87 | bv: mx.array, 88 | max_seq_len: int = 32768, 89 | theta: int = 1000000, 90 | ): 91 | self.hidden_size = hidden_size 92 | self.num_heads = num_heads 93 | self.num_kv_heads = num_kv_heads 94 | assert hidden_size % num_heads == 0, ( 95 | f"hidden_size {hidden_size} must be divisible by num_heads {num_heads}" 96 | ) 97 | assert num_heads % num_kv_heads == 0, ( 98 | f"num_heads {num_heads} must be divisible by num_kv_heads {num_kv_heads}" 99 | ) 100 | self.head_dim = hidden_size // num_heads 101 | self.scale = mx.rsqrt(self.head_dim) 102 | self.wq = wq 103 | self.wk = wk 104 | self.wv = wv 105 | self.wo = wo 106 | self.bq = bq 107 | self.bk = bk 108 | self.bv = bv 109 | self.rope = RoPE(self.head_dim, max_seq_len, theta) 110 | 111 | def __call__( 112 | self, 113 | x: mx.array, 114 | offset: int, 115 | mask: mx.array | None = None, 116 | cache: KVCache | None = None, 117 | ) -> mx.array: 118 | B, L, _ = x.shape 119 | orig_dtype = x.dtype 120 | projection_q = ( 121 | linear(x, self.wq, bias=self.bq) 122 | .reshape(B, L, self.num_heads, self.head_dim) 123 | .astype(mx.float32) 124 | ) 125 | projection_k = ( 126 | linear(x, self.wk, bias=self.bk) 127 | .reshape(B, L, self.num_kv_heads, self.head_dim) 128 | .astype(mx.float32) 129 | ) 130 | projection_v = ( 131 | linear(x, self.wv, bias=self.bv) 132 | .reshape(B, L, self.num_kv_heads, self.head_dim) 133 | .astype(mx.float32) 134 | ) 135 | # offset = cache.offset 136 | projection_q = self.rope(projection_q, offset=slice(offset, offset + L)) 137 | projection_k = self.rope(projection_k, offset=slice(offset, offset + L)) 138 | projection_q = projection_q.transpose(0, 2, 1, 3) 139 | projection_k = projection_k.transpose(0, 2, 1, 3) 140 | projection_v = projection_v.transpose(0, 2, 1, 3) 141 | # TODO: it is possible to get a sensible result without using a kv-cache? Otherwise we have to include kv-cache in week 1. 142 | # mlx-lm's KvCache seems to do more than just caching, we could extract something out of it. 143 | # projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v) 144 | assert ( 145 | projection_k.dtype == mx.float32 146 | ) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly. 147 | assert projection_v.dtype == mx.float32 148 | x = scaled_dot_product_attention_grouped( 149 | projection_q, 150 | projection_k, 151 | projection_v, 152 | scale=self.scale, 153 | mask=mask, 154 | ).astype(orig_dtype) 155 | x = x.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size) 156 | return linear(x, self.wo) 157 | 158 | 159 | class RoPE: 160 | def __init__( 161 | self, 162 | dims: int, 163 | seq_len: int, 164 | base: int = 10000, 165 | traditional: bool = False, 166 | ): 167 | self.dims = dims 168 | self.seq_len = seq_len 169 | half_dims = dims // 2 170 | inner = mx.arange(0, half_dims, dtype=mx.float32) / half_dims 171 | freqs = mx.power(base, -inner) 172 | t = mx.arange(seq_len) 173 | freqs = mx.outer(t, freqs) 174 | self.cos_freqs = mx.cos(freqs) 175 | self.sin_freqs = mx.sin(freqs) 176 | self.base = base 177 | self.half_dims = half_dims 178 | self.traditional = traditional 179 | def __call__( 180 | self, x: mx.array, offset: slice | None = None 181 | ) -> tuple[mx.array, mx.array]: 182 | # input x: (b, s, n_heads, head_dim) 183 | *N, S, H, D = x.shape 184 | # if offset is not None: 185 | # assert len(offset) == S, f"offset {len(offset)} must be of length {s}" 186 | cos_basis = ( 187 | self.cos_freqs[:S, :] if offset is None else self.cos_freqs[offset, :] 188 | ) 189 | sin_basis = ( 190 | self.sin_freqs[:S, :] if offset is None else self.sin_freqs[offset, :] 191 | ) 192 | # reshape x: (b, s, n_heads, head_dim // 2, 2) 193 | if self.traditional: 194 | x = x.reshape(*N, S, H, self.half_dims, 2) 195 | x1 = x[..., 0] 196 | x2 = x[..., 1] 197 | else: 198 | x1 = x[..., 0:self.half_dims] 199 | x2 = x[..., self.half_dims:self.dims] 200 | # reshape basis: (1, s, 1, dims // 2, 2) 201 | cos_basis = cos_basis.reshape(S, 1, self.half_dims) 202 | sin_basis = sin_basis.reshape(S, 1, self.half_dims) 203 | # manually doing complex number multiplication.. 204 | real = mx.multiply(x1, cos_basis) - mx.multiply(x2, sin_basis) 205 | imag = mx.multiply(x2, cos_basis) + mx.multiply(x1, sin_basis) 206 | if self.traditional: 207 | y = mx.stack([real, imag], axis=-1) 208 | y = y.reshape(*N, S, H, D) 209 | else: 210 | y = mx.concat([real, imag], axis=-1) 211 | y = y.reshape(*N, S, H, D) 212 | return y 213 | 214 | 215 | class Qwen2MLP: 216 | def __init__( 217 | self, 218 | dim: int, 219 | hidden_dim: int, 220 | w_gate: mx.array, 221 | w_up: mx.array, 222 | w_down: mx.array, 223 | ): 224 | self.dim = dim 225 | self.hidden_dim = hidden_dim 226 | self.w_gate = w_gate 227 | self.w_up = w_up 228 | self.w_down = w_down 229 | 230 | def __call__(self, x: mx.array) -> mx.array: 231 | return linear(silu(linear(x, self.w_gate)) * linear(x, self.w_up), self.w_down) 232 | 233 | 234 | class RMSNorm: 235 | def __init__(self, dim: int, weight: mx.array, eps: float = 1e-5): 236 | self.dim = dim 237 | self.eps = eps 238 | self.weight = weight.astype(mx.float32) 239 | 240 | def __call__(self, x: mx.array) -> mx.array: 241 | # TODO: tests to ensure the precision of this function 242 | orig_dtype = x.dtype 243 | x = x.astype(mx.float32) 244 | return ( 245 | self.weight 246 | * x 247 | * mx.rsqrt(mx.mean(mx.square(x), axis=-1, keepdims=True) + self.eps) 248 | ).astype(orig_dtype) 249 | 250 | 251 | class Qwen2TransformerBlock: 252 | def __init__( 253 | self, 254 | num_attention_heads: int, 255 | num_kv_heads: int, 256 | hidden_size: int, 257 | intermediate_size: int, 258 | rms_norm_eps: float, 259 | wq: mx.array, 260 | wk: mx.array, 261 | wv: mx.array, 262 | wo: mx.array, 263 | bq: mx.array, 264 | bk: mx.array, 265 | bv: mx.array, 266 | w_gate: mx.array, 267 | w_up: mx.array, 268 | w_down: mx.array, 269 | w_input_layernorm: mx.array, 270 | w_post_attention_layernorm: mx.array, 271 | max_seq_len: int = 32768, 272 | theta: int = 1000000, 273 | ): 274 | self.num_attention_heads = num_attention_heads 275 | self.hidden_size = hidden_size 276 | self.mlp = Qwen2MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) 277 | self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) 278 | self.post_attention_layernorm = RMSNorm( 279 | hidden_size, w_post_attention_layernorm, eps=rms_norm_eps 280 | ) 281 | self.self_attn = Qwen2MultiHeadAttention( 282 | num_heads=num_attention_heads, 283 | hidden_size=hidden_size, 284 | num_kv_heads=num_kv_heads, 285 | wq=wq, 286 | wk=wk, 287 | wv=wv, 288 | wo=wo, 289 | bq=bq, 290 | bk=bk, 291 | bv=bv, 292 | max_seq_len=max_seq_len, 293 | theta=theta, 294 | ) 295 | 296 | def __call__( 297 | self, 298 | x: mx.array, 299 | offset: int, 300 | mask: mx.array | None = None, 301 | cache: KVCache | None = None, 302 | ) -> mx.array: 303 | r = self.self_attn(self.input_layernorm(x), offset, mask, cache) 304 | h = x + r 305 | r = self.mlp(self.post_attention_layernorm(h)) 306 | out = h + r 307 | return out 308 | 309 | 310 | def dequantize_linear(mx_layer: Any) -> mx.array: 311 | w = mx.dequantize( 312 | mx_layer.weight, 313 | mx_layer.scales, 314 | mx_layer.biases, 315 | mx_layer.group_size, 316 | mx_layer.bits, 317 | ) 318 | return w 319 | 320 | class Embedding: 321 | def __init__(self, vocab_size: int, embedding_dim: int, weight: mx.array): 322 | self.vocab_size = vocab_size 323 | self.embedding_dim = embedding_dim 324 | self.weight = weight 325 | 326 | def __call__(self, x: mx.array) -> mx.array: 327 | return self.weight[x, :] 328 | 329 | class Qwen2Model: 330 | def __init__( 331 | self, 332 | mlx_model: Any, 333 | ): 334 | self.num_hidden_layers = mlx_model.args.num_hidden_layers 335 | self.hidden_size = mlx_model.args.hidden_size 336 | self.vocab_size = mlx_model.args.vocab_size 337 | precision = mx.float16 338 | self.precision = precision 339 | 340 | self.embedding = Embedding( 341 | vocab_size=self.vocab_size, 342 | embedding_dim=self.hidden_size, 343 | weight=dequantize_linear(mlx_model.model.embed_tokens).astype(precision), 344 | ) 345 | self.layers_inner = [] 346 | 347 | 348 | for i in range(mlx_model.args.num_hidden_layers): 349 | wq = dequantize_linear(mlx_model.model.layers[i].self_attn.q_proj) 350 | wk = dequantize_linear(mlx_model.model.layers[i].self_attn.k_proj) 351 | wv = dequantize_linear(mlx_model.model.layers[i].self_attn.v_proj) 352 | wo = dequantize_linear(mlx_model.model.layers[i].self_attn.o_proj) 353 | w_gate = dequantize_linear(mlx_model.model.layers[i].mlp.gate_proj) 354 | w_up = dequantize_linear(mlx_model.model.layers[i].mlp.up_proj) 355 | w_down = dequantize_linear(mlx_model.model.layers[i].mlp.down_proj) 356 | 357 | layer = Qwen2TransformerBlock( 358 | num_attention_heads=mlx_model.args.num_attention_heads, 359 | num_kv_heads=mlx_model.args.num_key_value_heads, 360 | hidden_size=mlx_model.args.hidden_size, 361 | intermediate_size=mlx_model.args.intermediate_size, 362 | rms_norm_eps=mlx_model.args.rms_norm_eps, 363 | wq=wq.astype(precision), 364 | wk=wk.astype(precision), 365 | wv=wv.astype(precision), 366 | wo=wo.astype(precision), 367 | bq=mlx_model.model.layers[i].self_attn.q_proj.bias.astype(precision), 368 | bk=mlx_model.model.layers[i].self_attn.k_proj.bias.astype(precision), 369 | bv=mlx_model.model.layers[i].self_attn.v_proj.bias.astype(precision), 370 | w_gate=w_gate.astype(precision), 371 | w_up=w_up.astype(precision), 372 | w_down=w_down.astype(precision), 373 | w_input_layernorm=mlx_model.model.layers[ 374 | i 375 | ].input_layernorm.weight.astype(precision), 376 | w_post_attention_layernorm=mlx_model.model.layers[ 377 | i 378 | ].post_attention_layernorm.weight.astype(precision), 379 | max_seq_len=mlx_model.args.max_position_embeddings, 380 | theta=mlx_model.args.rope_theta, 381 | ) 382 | self.layers_inner.append(layer) 383 | self.norm = RMSNorm( 384 | mlx_model.args.hidden_size, 385 | weight=mlx_model.model.norm.weight.astype(precision), 386 | eps=mlx_model.args.rms_norm_eps, 387 | ) 388 | self.w_lm_head = dequantize_linear(mlx_model.lm_head) 389 | self.mlx_model = mlx_model 390 | 391 | def __call__( 392 | self, 393 | inputs: mx.array, 394 | offset: int, 395 | mask: mx.array | None = None, 396 | cache: KVCache | None = None, 397 | ) -> mx.array: 398 | h = self.embedding(inputs) 399 | for layer in range(self.num_hidden_layers): 400 | h = self.layers_inner[layer](h, offset, None, cache[layer] if cache else None) 401 | h = self.norm(h) 402 | return linear(h, self.w_lm_head) 403 | 404 | def sanitize(self, weights: dict): 405 | assert False, "not implemented" 406 | 407 | @property 408 | def layers(self): 409 | return self.layers_inner 410 | -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import mlx.core as mx 3 | import torch 4 | from tiny_llm.funcs import * 5 | from tiny_llm.layers import * 6 | import numpy as np 7 | from .utils import * 8 | 9 | 10 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 11 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 12 | def test_softmax(stream: mx.Stream, precision: np.dtype): 13 | with mx.stream(stream): 14 | BATCH_SIZE = 10 15 | DIM = 10 16 | for _ in range(100): 17 | x = np.random.rand(BATCH_SIZE, DIM).astype(precision) 18 | user_output = softmax(mx.array(x), axis=-1) 19 | reference_output = torch.nn.functional.softmax( 20 | torch.tensor(x, device=TORCH_DEVICE), dim=-1 21 | ) 22 | assert_allclose(user_output, reference_output, precision=precision) 23 | 24 | 25 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 26 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 27 | def test_attention(stream: mx.Stream, precision: np.dtype): 28 | with mx.stream(stream): 29 | BATCH_SIZE = 3 30 | DIM_N = 4 31 | DIM_M = 5 32 | for _ in range(100): 33 | query = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision) 34 | key = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision) 35 | value = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision) 36 | reference_output = torch.nn.functional.scaled_dot_product_attention( 37 | torch.tensor(query, device=TORCH_DEVICE), 38 | torch.tensor(key, device=TORCH_DEVICE), 39 | torch.tensor(value, device=TORCH_DEVICE), 40 | ) 41 | user_output = scaled_dot_product_attention( 42 | mx.array(query), 43 | mx.array(key), 44 | mx.array(value), 45 | ) 46 | assert_allclose(user_output, reference_output, precision=precision) 47 | 48 | 49 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 50 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 51 | @pytest.mark.parametrize( 52 | "qkv_shape", [True, False], ids=["with_seq_len", "without_seq_len"] 53 | ) 54 | def test_attention_with_mask(stream: mx.Stream, precision: np.dtype, qkv_shape: bool): 55 | with mx.stream(stream): 56 | BATCH_SIZE = 3 57 | SEQ_LEN = 10 58 | DIM_N = 4 59 | DIM_M = 5 60 | if qkv_shape: 61 | qkv_shape = (BATCH_SIZE, SEQ_LEN, DIM_N, DIM_M) 62 | mask_shape = (BATCH_SIZE, SEQ_LEN, DIM_N, DIM_N) 63 | else: 64 | qkv_shape = (BATCH_SIZE, DIM_N, DIM_M) 65 | mask_shape = (BATCH_SIZE, DIM_N, DIM_N) 66 | for _ in range(100): 67 | query = np.random.rand(*qkv_shape).astype(precision) 68 | key = np.random.rand(*qkv_shape).astype(precision) 69 | value = np.random.rand(*qkv_shape).astype(precision) 70 | scale = 0.8 71 | mask = np.random.rand(*mask_shape).astype(precision) 72 | reference_output = torch.nn.functional.scaled_dot_product_attention( 73 | torch.tensor(query, device=TORCH_DEVICE), 74 | torch.tensor(key, device=TORCH_DEVICE), 75 | torch.tensor(value, device=TORCH_DEVICE), 76 | scale=scale, 77 | attn_mask=torch.tensor(mask, device=TORCH_DEVICE), 78 | ) 79 | user_output = scaled_dot_product_attention( 80 | mx.array(query), 81 | mx.array(key), 82 | mx.array(value), 83 | scale=scale, 84 | mask=mx.array(mask), 85 | ) 86 | assert_allclose(user_output, reference_output, precision=precision) 87 | 88 | 89 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 90 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 91 | def test_multi_head_attention(stream: mx.Stream, precision: np.dtype): 92 | with mx.stream(stream): 93 | BATCH_SIZE = 7 94 | DIM_N = 11 95 | DIM_M = 9 96 | NUM_HEADS = 3 97 | for _ in range(100): 98 | query = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision) 99 | key = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision) 100 | value = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision) 101 | q_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision) 102 | k_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision) 103 | v_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision) 104 | out_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision) 105 | mask = np.random.rand(DIM_N * NUM_HEADS, BATCH_SIZE, BATCH_SIZE).astype( 106 | precision 107 | ) 108 | reference_output, _ = torch.nn.functional.multi_head_attention_forward( 109 | torch.tensor(query, device=TORCH_DEVICE), 110 | torch.tensor(key, device=TORCH_DEVICE), 111 | torch.tensor(value, device=TORCH_DEVICE), 112 | num_heads=NUM_HEADS, 113 | q_proj_weight=torch.tensor(q_proj_weight, device=TORCH_DEVICE), 114 | k_proj_weight=torch.tensor(k_proj_weight, device=TORCH_DEVICE), 115 | v_proj_weight=torch.tensor(v_proj_weight, device=TORCH_DEVICE), 116 | out_proj_weight=torch.tensor(out_proj_weight, device=TORCH_DEVICE), 117 | embed_dim_to_check=DIM_M, 118 | in_proj_weight=None, 119 | in_proj_bias=None, 120 | bias_k=None, 121 | bias_v=None, 122 | add_zero_attn=False, 123 | dropout_p=0.0, 124 | out_proj_bias=None, 125 | use_separate_proj_weight=True, 126 | attn_mask=torch.tensor(mask, device=TORCH_DEVICE), 127 | ) 128 | user_output = MultiHeadAttention( 129 | DIM_M, 130 | NUM_HEADS, 131 | mx.array(q_proj_weight), 132 | mx.array(k_proj_weight), 133 | mx.array(v_proj_weight), 134 | mx.array(out_proj_weight), 135 | )( 136 | mx.array(query), 137 | mx.array(key), 138 | mx.array(value), 139 | mask=mx.array(mask), 140 | ) 141 | assert_allclose(user_output, reference_output, precision=precision) 142 | 143 | 144 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 145 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 146 | @pytest.mark.parametrize( 147 | "batch_dimension", [0, 1, 2], ids=["batch_0", "batch_1", "batch_2"] 148 | ) 149 | @pytest.mark.parametrize("scale", [None, 0.8]) 150 | def test_attention_grouped( 151 | stream: mx.Stream, precision: np.dtype, batch_dimension: int, scale: float | None 152 | ): 153 | with mx.stream(stream): 154 | H_q = 18 155 | H = 6 156 | L = 7 157 | E = 5 158 | S = 3 159 | BATCH = 10 160 | BATCH_2 = 2 161 | if batch_dimension == 0: 162 | q_shape = (H_q, L, E) 163 | kv_shape = (H, S, E) 164 | mask_shape = (H_q, L, S) 165 | elif batch_dimension == 1: 166 | q_shape = (BATCH, H_q, L, E) 167 | kv_shape = (BATCH, H, S, E) 168 | mask_shape = (BATCH, H_q, L, S) 169 | elif batch_dimension == 2: 170 | q_shape = (BATCH_2, BATCH, H_q, L, E) 171 | kv_shape = (BATCH_2, BATCH, H, S, E) 172 | mask_shape = (BATCH_2, BATCH, H_q, L, S) 173 | for _ in range(100): 174 | query = np.random.rand(*q_shape).astype(precision) 175 | key = np.random.rand(*kv_shape).astype(precision) 176 | value = np.random.rand(*kv_shape).astype(precision) 177 | mask = np.random.rand(*mask_shape).astype(precision) 178 | reference_output = torch.nn.functional.scaled_dot_product_attention( 179 | torch.tensor(query, device=TORCH_DEVICE), 180 | torch.tensor(key, device=TORCH_DEVICE), 181 | torch.tensor(value, device=TORCH_DEVICE), 182 | scale=scale, 183 | attn_mask=torch.tensor(mask, device=TORCH_DEVICE), 184 | enable_gqa=True, 185 | ) 186 | user_output = scaled_dot_product_attention_grouped( 187 | mx.array(query), 188 | mx.array(key), 189 | mx.array(value), 190 | scale=scale, 191 | mask=mx.array(mask), 192 | ) 193 | assert_allclose(user_output, reference_output, precision=precision) 194 | -------------------------------------------------------------------------------- /tests/test_funcs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import mlx.core as mx 3 | import torch 4 | from tiny_llm.funcs import * 5 | from tiny_llm.layers import * 6 | import numpy as np 7 | from .utils import * 8 | import torchtune 9 | 10 | 11 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 12 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 13 | def test_silu(stream: mx.Stream, precision: np.dtype): 14 | SIZE = 100 15 | 16 | with mx.stream(stream): 17 | for _ in range(100): 18 | data = np.random.rand(SIZE).astype(precision) 19 | reference_output = torch.nn.functional.silu( 20 | torch.tensor(data, device=TORCH_DEVICE) 21 | ) 22 | user_output = silu(mx.array(data)) 23 | assert_allclose(user_output, reference_output, precision) 24 | 25 | 26 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 27 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) 28 | def test_rms_norm(stream: mx.Stream, precision: np.dtype): 29 | SIZE = 100 30 | SIZE_Y = 111 31 | with mx.stream(stream): 32 | for _ in range(100): 33 | data = np.random.rand(SIZE, SIZE_Y).astype(precision) 34 | weight = np.random.rand(SIZE_Y).astype(precision) 35 | eps = np.finfo(precision).eps 36 | reference_output = torch.nn.functional.rms_norm( 37 | torch.tensor(data, device=TORCH_DEVICE), 38 | (SIZE_Y,), 39 | torch.tensor(weight, device=TORCH_DEVICE), 40 | eps=eps, 41 | ) 42 | user_output = RMSNorm(SIZE_Y, mx.array(weight), eps=eps)(mx.array(data)) 43 | assert_allclose(user_output, reference_output, precision) 44 | -------------------------------------------------------------------------------- /tests/test_rope.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import mlx.core as mx 3 | import torch 4 | from tiny_llm.funcs import * 5 | from tiny_llm.layers import * 6 | import numpy as np 7 | from .utils import * 8 | import torchtune 9 | import random 10 | 11 | 12 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) 13 | @pytest.mark.parametrize("with_input", [True, False]) 14 | def test_rope(stream: mx.Stream, with_input: bool): 15 | BATCH_SIZE = 1 16 | NUM_HEADS = 8 17 | NUM_KV_HEADS = 6 18 | HEAD_DIM = 4 19 | MAX_SEQ_LEN = 20 20 | SEQ_LEN = 10 21 | BASE = 10000.0 22 | 23 | with mx.stream(stream): 24 | for _ in range(100): 25 | reference_layer = ( 26 | torchtune.modules.position_embeddings.RotaryPositionalEmbeddings( 27 | HEAD_DIM, 28 | MAX_SEQ_LEN, 29 | BASE, 30 | ) 31 | ) 32 | user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=True) 33 | x = np.random.rand(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM) 34 | 35 | if with_input: 36 | input_pos = np.random.randint(0, MAX_SEQ_LEN - SEQ_LEN) 37 | input_pos_mx = input_pos 38 | input_pos_user = slice(input_pos, input_pos + SEQ_LEN) 39 | input_pos_torch = torch.tensor([i for i in range(input_pos, input_pos + SEQ_LEN)], device=TORCH_DEVICE, dtype=torch.int32) 40 | else: 41 | input_pos = None 42 | input_pos_mx = None 43 | input_pos_user = None 44 | input_pos_torch = None 45 | 46 | reference_output = reference_layer.forward( 47 | torch.tensor(x, device=TORCH_DEVICE), input_pos=input_pos_torch 48 | ) 49 | user_output = user_layer(mx.array(x), input_pos_user) 50 | assert_allclose(user_output, reference_output, np.float32, atol=1e-6) 51 | 52 | user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=False) 53 | reference_output = mx.fast.rope( 54 | mx.array(x).transpose(0, 2, 1, 3), 55 | dims=HEAD_DIM, 56 | traditional=False, 57 | base=BASE, 58 | scale=1.0, 59 | offset=input_pos_mx or 0, 60 | ).transpose(0, 2, 1, 3) 61 | user_output = user_layer(mx.array(x), input_pos_user) 62 | assert_allclose(user_output, reference_output, np.float32, atol=1e-6) -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mlx.core as mx 3 | import torch 4 | 5 | AVAILABLE_STREAMS = [mx.cpu, mx.gpu] 6 | AVAILABLE_STREAMS_IDS = ["cpu", "gpu"] 7 | PRECISIONS = [np.float32, np.float16] 8 | PRECISION_IDS = ["f32", "f16"] 9 | TORCH_DEVICE = torch.device("cpu") 10 | 11 | 12 | def assert_allclose( 13 | a: mx.array, 14 | b: torch.Tensor | mx.array, 15 | precision: np.dtype, 16 | rtol: float | None = None, 17 | atol: float | None = None, 18 | ): 19 | a = np.array(a) 20 | if isinstance(b, torch.Tensor): 21 | b = b.cpu().numpy() 22 | elif isinstance(b, mx.array): 23 | b = np.array(b) 24 | else: 25 | raise ValueError(f"Unsupported type: {type(b)}") 26 | if precision == np.float32: 27 | rtol = rtol or 1.0e-5 28 | atol = atol or 1.0e-8 29 | elif precision == np.float16: 30 | rtol = rtol or 1.0e-2 31 | atol = atol or 1.0e-7 32 | assert a.shape == b.shape 33 | if not np.allclose(a, b, rtol=rtol, atol=atol): 34 | print("a=", a) 35 | print("b=", b) 36 | diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol)) 37 | print("diff_a=", a * diff) 38 | print("diff_b=", b * diff) 39 | assert False, f"result mismatch" 40 | 41 | 42 | def np_type_to_mx_type(np_type: np.dtype) -> mx.Dtype: 43 | if np_type == np.float32: 44 | return mx.float32 45 | elif np_type == np.float16: 46 | return mx.float16 47 | else: 48 | raise ValueError(f"Unsupported numpy type: {np_type}") 49 | --------------------------------------------------------------------------------