├── .gitignore
├── LICENSE
├── README.md
├── book
├── .gitignore
├── book.toml
└── src
│ ├── SUMMARY.md
│ ├── glossary.md
│ ├── preface.md
│ ├── week1-overview.md
│ ├── week2-overview.md
│ └── week3-overview.md
├── main.py
├── poetry.lock
├── pyproject.toml
├── src
└── tiny_llm
│ ├── __init__.py
│ ├── funcs.py
│ └── layers.py
└── tests
├── test_attention.py
├── test_funcs.py
├── test_rope.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tiny-llm
2 |
3 | Still WIP and in very early stage. A tutorial on LLM serving using MLX for system engineers. The codebase
4 | is solely (almost!) based on MLX array/matrix APIs without any high-level neural network APIs, so that we
5 | can build the model serving infrastructure from scratch and dig into the optimizations.
6 |
7 | We test the implementation against PyTorch's CPU implementation to ensure correctness. The main codebase uses MLX
8 | instead of PyTorch because nowadays it's easier to get an Apple Silicon MacBook than an NVIDIA GPU. In theory you can
9 | implement everything using PyTorch tensor APIs, but we didn't have the test infra to support that.
10 |
11 | (TODO: maybe we should test against MLX? PyTorch APIs sometimes don't align with MLX; but I also want to ensure the computation
12 | precision is enough to load any model directly from PyTorch tensors without converting to MLX format.)
13 |
14 | The goal is to learn the techniques behind efficiently serving an LLM model (i.e., Qwen2 models).
15 |
16 | * Week 1: serve Qwen2 with purely Python APIs. No fancy optimizations, just Python.
17 | * Week 2: optimizations, implement C++/Metal custom kernels to make the model run faster.
18 | * Week 3: more optimizations, batch the requests to serve the model with high throughput.
19 |
20 | TBD: implement a leaderboard service?
21 |
22 | ## Usage
23 |
24 | ```bash
25 | poetry install
26 | poetry run pytest
27 | poetry run python main.py
28 | ```
29 |
30 | ## Week 1: LLM from Scratch
31 |
32 | Instead of me explaining everything in this tutorial, it works in a different way: I collect all the materials, blog posts, and source
33 | code I've read when implementing the components. There's API specification and test infra to compare the result with PyTorch/MLX in the
34 | repo. You may read the materials and implement the components :)
35 |
36 | ### Day 1: Attention is All You Need
37 |
38 | Implement `scaled_dot_product_attention`. The function takes key, value, and query of the same dimensions.
39 |
40 | ```
41 | K: N.. x H x L x E
42 | V: N.. x H x L x E
43 | Q: N.. x H x L x E
44 | ```
45 |
46 | Where `N..` is zero or some number of dimensions for batches. Within each of the batch, `H` is the number of heads,
47 | `L` is the sequence length, and `E` is the embedding/hidden size.
48 |
49 | You may use `softmax` provided by mlx and implement it later in week 2.
50 |
51 | **References**
52 |
53 | * Annotated Transformer https://nlp.seas.harvard.edu/annotated-transformer/
54 | * PyTorch API (the case where enable_gqa=False) https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
55 | * MLX API https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html
56 | * https://arxiv.org/abs/1706.03762
57 |
58 | Implement `MultiHeadAttention`. The layer takes a batch of vectors `x`, maps it through the K,V,Q weight matrixes, and
59 | use the attention function we implemented in day 1 to compute the result. The output needs to be mapped using the O
60 | weight matrix. You will also need to implement the `linear` function.
61 |
62 | ```
63 | x: N x L x D
64 | D = num_heads x head_dim
65 | ```
66 |
67 | **References**
68 |
69 | * Annotated Transformer https://nlp.seas.harvard.edu/annotated-transformer/
70 | * PyTorch API https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
71 | * MLX API https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html
72 |
73 | ### Day 2: RoPE Embedding
74 |
75 | Note there are traditional and non-traditional ropes.
76 |
77 | **References**
78 |
79 | * https://pytorch.org/torchtune/stable/generated/torchtune.modules.RotaryPositionalEmbeddings.html
80 | * https://github.com/pytorch/torchtune/blob/main/torchtune/modules/position_embeddings.py
81 | * https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
82 | * https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.RoPE.html
83 | * https://arxiv.org/abs/2104.09864
84 |
85 | ### Day 3: Grouped Query Attention
86 |
87 | The Qwen2 models use Grouped Query Attention (GQA). GQA allows different dimensions for query and key/value.
88 |
89 | **References**
90 |
91 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py
92 | * PyTorch API (the case where enable_gqa=True) https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
93 | * torchtune.modules.MultiHeadAttention https://pytorch.org/torchtune/0.3/generated/torchtune.modules.MultiHeadAttention.html
94 | * https://arxiv.org/abs/2305.13245v1
95 |
96 | ### Day 4: RMSNorm and MLP
97 |
98 | RMSNorm needs to be accumulated over float32
99 |
100 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py
101 | * SiLU https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
102 | * RMSNorm (note that it needs to accumulate at float32)
103 |
104 | ### Day 5: Transformer Block
105 |
106 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py
107 |
108 | ### Day 6: Load the Model
109 |
110 | We will use mlx-lm's loader to load the model. We will _steal_ the loaded parameters from the mlx model and
111 | plug it into our own operators.
112 |
113 | ### Day 7: Generate Responses
114 |
115 | * Qwen layers implementation in mlx-lm https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py
116 |
117 | Run `python main.py` and it should give you a reasonable response.
118 |
119 | On my M4 Pro Mac Mini, my implementation gives 17 tokens per sec on Metal, versus 50 tokens per sec from the mlx-lm
120 | Qwen2 implementation. Sadly, it also takes 4x memory than using the mlx-lm components as it does not support computation
121 | over quantized parameters.
122 |
123 | ## Week 2
124 |
125 | Quantization, implement softmax/linear/silu kernels, implement attention kernels, key-value cache and compression, attention masks, prompt cache.
126 |
127 | ## Week 3
128 |
129 | Continuous batching, OpenAPI HTTP endpoint, integrate with other services.
130 |
--------------------------------------------------------------------------------
/book/.gitignore:
--------------------------------------------------------------------------------
1 | book
2 |
--------------------------------------------------------------------------------
/book/book.toml:
--------------------------------------------------------------------------------
1 | [book]
2 | authors = ["Alex Chi"]
3 | language = "en"
4 | multilingual = false
5 | src = "src"
6 | title = "Tiny LLM - LLM Serving in a Week"
7 |
8 | [preprocessor.toc]
9 | command = "mdbook-toc"
10 | renderer = ["html"]
11 |
12 | [output.html]
13 | git-repository-url = "https://github.com/skyzh/tiny-llm"
14 |
--------------------------------------------------------------------------------
/book/src/SUMMARY.md:
--------------------------------------------------------------------------------
1 | # LLM Serving in a Week
2 |
3 | [Preface](./preface.md)
4 | [Setting Up the Environment]()
5 |
6 | ---
7 |
8 | - [Week 1: From Matmul to Text]()
9 | - [Attention and Multi-Head Attention, Linear]()
10 | - [Positional Embeddings and RoPE]()
11 | - [Grouped/Multi Query Attention, Embedding, Silu]()
12 | - [Multilayer Perceptron Layer and Transformer, RMSNorm]()
13 | - [Wiring and Loading the Model, Dequantize]()
14 | - [Tokenize and Generating Response]()
15 |
16 | - [Week 2: Optimizing]()
17 |
18 | - [Week 3: Serving]()
19 |
20 | ---
21 |
22 | [Glossary Index](./glossary.md)
23 |
--------------------------------------------------------------------------------
/book/src/glossary.md:
--------------------------------------------------------------------------------
1 | # Glossary Index
2 |
3 | The functionality is covered in which days?
4 |
--------------------------------------------------------------------------------
/book/src/preface.md:
--------------------------------------------------------------------------------
1 | # Preface
2 |
3 | LLM serving in a week for systems engineers!
4 |
--------------------------------------------------------------------------------
/book/src/week1-overview.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/book/src/week1-overview.md
--------------------------------------------------------------------------------
/book/src/week2-overview.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/book/src/week2-overview.md
--------------------------------------------------------------------------------
/book/src/week3-overview.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/book/src/week3-overview.md
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from mlx_lm import load, generate
2 | from mlx_lm.sample_utils import make_sampler, make_logits_processors
3 | from tiny_llm.layers import Qwen2Model
4 | import mlx.core as mx
5 |
6 | with mx.stream(mx.gpu):
7 | mlx_model, tokenizer = load(
8 | "Qwen/Qwen2-7B-Instruct-MLX",
9 | tokenizer_config={"eos_token": "<|im_end|>"},
10 | model_config={"tie_word_embeddings": False, "rope_traditional": True},
11 | )
12 | tiny_llm_model = Qwen2Model(mlx_model)
13 |
14 | prompt = "Give me a short introduction to large language model."
15 | messages = [
16 | {"role": "system", "content": "You are a helpful assistant."},
17 | {"role": "user", "content": prompt},
18 | ]
19 | prompt = tokenizer.apply_chat_template(
20 | messages, tokenize=False, add_generation_prompt=True
21 | )
22 | def _step(model, y, offset):
23 | logits = model(y[None], offset)
24 | logits = logits[:, -1, :]
25 | logprobs = logits - mx.logsumexp(logits, keepdims=True)
26 | sampler = lambda x: mx.argmax(x, axis=-1)
27 | y = sampler(logprobs)
28 | return y, logprobs.squeeze(0)
29 | # prefill with the prompt
30 | tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False))
31 | offset = tokens.size
32 | detokenizer = tokenizer.detokenizer
33 | detokenizer.reset()
34 | # generate
35 | while True:
36 | token, _ = _step(tiny_llm_model, tokens, offset)
37 | tokens = mx.concat([tokens, token])
38 | if token.item() == tokenizer.eos_token_id:
39 | break
40 | detokenizer.add_token(token.item())
41 | print(detokenizer.last_segment, end="", flush=True)
42 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "tiny-llm"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["Alex Chi Z "]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "^3.12"
10 | mlx = "^0.25.0"
11 | torch = "^2.6.0"
12 | mlx-lm = "^0.23.0"
13 | torchtune = "^0.6.1"
14 | torchao = "^0.10.0"
15 |
16 |
17 | [tool.poetry.group.dev.dependencies]
18 | pytest = "^8.3.5"
19 | numpy = "^2.2.4"
20 | ruff = "^0.11.6"
21 |
22 | [build-system]
23 | requires = ["poetry-core"]
24 | build-backend = "poetry.core.masonry.api"
25 |
26 | [project]
27 | name = "tiny-llm"
28 | version = "0.1.0"
29 |
30 | [tool.pytest.ini_options]
31 | addopts = [
32 | "--import-mode=importlib",
33 | ]
34 | pythonpath = "src"
35 |
--------------------------------------------------------------------------------
/src/tiny_llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skyzh/tiny-llm/c9be20df58b98dc6d2e7846364091439539ed712/src/tiny_llm/__init__.py
--------------------------------------------------------------------------------
/src/tiny_llm/funcs.py:
--------------------------------------------------------------------------------
1 | import mlx.core as mx
2 | import math
3 |
4 |
5 | def softmax(x: mx.array, axis: int) -> mx.array:
6 | # TODO: manual implementation
7 | return mx.softmax(x, axis=axis)
8 |
9 |
10 | def scaled_dot_product_attention(
11 | query: mx.array,
12 | key: mx.array,
13 | value: mx.array,
14 | scale: float | None = None,
15 | mask: mx.array | None = None,
16 | stream: mx.Stream | mx.Device | None = None,
17 | ) -> mx.array:
18 | """
19 | Compute scaled dot-product attention.
20 |
21 | query: batch_size x
22 | """
23 | factor = mx.rsqrt(query.shape[-1]) if scale is None else scale
24 | scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor
25 | if mask is not None:
26 | scores = scores + mask
27 | return mx.matmul(softmax(scores, axis=-1), value)
28 |
29 |
30 | def scaled_dot_product_attention_grouped(
31 | query: mx.array,
32 | key: mx.array,
33 | value: mx.array,
34 | scale: float | None = None,
35 | mask: mx.array | None = None,
36 | ) -> mx.array:
37 | """
38 | Compute scaled dot-product attention.
39 |
40 | query: batch_size x
41 | """
42 | factor = mx.rsqrt(query.shape[-1]) if scale is None else scale
43 | expected_shape = query.shape
44 | query = query.reshape(-1, query.shape[-3], query.shape[-2], query.shape[-1])
45 | key = key.reshape(-1, key.shape[-3], key.shape[-2], key.shape[-1])
46 | value = value.reshape(-1, value.shape[-3], value.shape[-2], value.shape[-1])
47 | B, H_q, L, E = query.shape
48 | _, H, S, _ = key.shape
49 | assert H_q % H == 0
50 | n_repeats = H_q // H
51 | query = query.reshape((B, H, n_repeats, L, E))
52 | key = key.reshape((B, H, 1, S, E))
53 | value = value.reshape((B, H, 1, S, E))
54 | scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor
55 | if mask is not None:
56 | mask = mask.reshape(-1, H, n_repeats, mask.shape[-2], mask.shape[-1])
57 | scores = scores + mask
58 | result = mx.matmul(
59 | softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype), value
60 | )
61 | return result.reshape(expected_shape)
62 |
63 |
64 | def linear(
65 | x: mx.array,
66 | w: mx.array,
67 | bias: mx.array | None = None,
68 | ) -> mx.array:
69 | if bias is not None:
70 | return mx.matmul(x, w.T) + bias
71 | else:
72 | return mx.matmul(x, w.T)
73 |
74 |
75 | def silu(x: mx.array) -> mx.array:
76 | return x / (1 + mx.exp(-x))
77 |
--------------------------------------------------------------------------------
/src/tiny_llm/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Any, Optional
3 |
4 | import mlx.core as mx
5 | from mlx_lm.models.cache import KVCache
6 | from .funcs import (
7 | linear,
8 | scaled_dot_product_attention,
9 | scaled_dot_product_attention_grouped,
10 | silu,
11 | )
12 |
13 | # TODO: add license for those heavily based on mlx-lm/PyTorch
14 |
15 |
16 | class MultiHeadAttention:
17 | def __init__(
18 | self,
19 | hidden_size: int,
20 | num_heads: int,
21 | wq: mx.array,
22 | wk: mx.array,
23 | wv: mx.array,
24 | wo: mx.array,
25 | ):
26 | self.hidden_size = hidden_size
27 | self.num_heads = num_heads
28 | assert hidden_size % num_heads == 0
29 | self.head_dim = hidden_size // num_heads
30 | self.scale = mx.rsqrt(self.head_dim)
31 | assert wq.shape == (hidden_size, num_heads * self.head_dim)
32 | assert wk.shape == (hidden_size, num_heads * self.head_dim)
33 | assert wv.shape == (hidden_size, num_heads * self.head_dim)
34 | assert wo.shape == (num_heads * self.head_dim, hidden_size)
35 | self.wq = wq
36 | self.wk = wk
37 | self.wv = wv
38 | self.wo = wo
39 |
40 | def __call__(
41 | self,
42 | query: mx.array,
43 | key: mx.array,
44 | value: mx.array,
45 | mask: mx.array | None = None,
46 | ) -> mx.array:
47 | n_batches = query.shape[0]
48 | batch_size = query.shape[1]
49 | projection_q = (
50 | linear(query, self.wq)
51 | .reshape(n_batches, self.num_heads * batch_size, self.head_dim)
52 | .transpose(1, 0, 2)
53 | )
54 | projection_k = (
55 | linear(key, self.wk)
56 | .reshape(n_batches, self.num_heads * batch_size, self.head_dim)
57 | .transpose(1, 0, 2)
58 | )
59 | projection_v = (
60 | linear(value, self.wv)
61 | .reshape(n_batches, self.num_heads * batch_size, self.head_dim)
62 | .transpose(1, 0, 2)
63 | )
64 | x = scaled_dot_product_attention(
65 | projection_q,
66 | projection_k,
67 | projection_v,
68 | scale=self.scale,
69 | mask=mask,
70 | )
71 | x = x.transpose(1, 0, 2).reshape(n_batches, batch_size, self.hidden_size)
72 | return linear(x, self.wo)
73 |
74 |
75 | class Qwen2MultiHeadAttention:
76 | def __init__(
77 | self,
78 | hidden_size: int,
79 | num_heads: int,
80 | num_kv_heads: int,
81 | wq: mx.array,
82 | wk: mx.array,
83 | wv: mx.array,
84 | wo: mx.array,
85 | bq: mx.array,
86 | bk: mx.array,
87 | bv: mx.array,
88 | max_seq_len: int = 32768,
89 | theta: int = 1000000,
90 | ):
91 | self.hidden_size = hidden_size
92 | self.num_heads = num_heads
93 | self.num_kv_heads = num_kv_heads
94 | assert hidden_size % num_heads == 0, (
95 | f"hidden_size {hidden_size} must be divisible by num_heads {num_heads}"
96 | )
97 | assert num_heads % num_kv_heads == 0, (
98 | f"num_heads {num_heads} must be divisible by num_kv_heads {num_kv_heads}"
99 | )
100 | self.head_dim = hidden_size // num_heads
101 | self.scale = mx.rsqrt(self.head_dim)
102 | self.wq = wq
103 | self.wk = wk
104 | self.wv = wv
105 | self.wo = wo
106 | self.bq = bq
107 | self.bk = bk
108 | self.bv = bv
109 | self.rope = RoPE(self.head_dim, max_seq_len, theta)
110 |
111 | def __call__(
112 | self,
113 | x: mx.array,
114 | offset: int,
115 | mask: mx.array | None = None,
116 | cache: KVCache | None = None,
117 | ) -> mx.array:
118 | B, L, _ = x.shape
119 | orig_dtype = x.dtype
120 | projection_q = (
121 | linear(x, self.wq, bias=self.bq)
122 | .reshape(B, L, self.num_heads, self.head_dim)
123 | .astype(mx.float32)
124 | )
125 | projection_k = (
126 | linear(x, self.wk, bias=self.bk)
127 | .reshape(B, L, self.num_kv_heads, self.head_dim)
128 | .astype(mx.float32)
129 | )
130 | projection_v = (
131 | linear(x, self.wv, bias=self.bv)
132 | .reshape(B, L, self.num_kv_heads, self.head_dim)
133 | .astype(mx.float32)
134 | )
135 | # offset = cache.offset
136 | projection_q = self.rope(projection_q, offset=slice(offset, offset + L))
137 | projection_k = self.rope(projection_k, offset=slice(offset, offset + L))
138 | projection_q = projection_q.transpose(0, 2, 1, 3)
139 | projection_k = projection_k.transpose(0, 2, 1, 3)
140 | projection_v = projection_v.transpose(0, 2, 1, 3)
141 | # TODO: it is possible to get a sensible result without using a kv-cache? Otherwise we have to include kv-cache in week 1.
142 | # mlx-lm's KvCache seems to do more than just caching, we could extract something out of it.
143 | # projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v)
144 | assert (
145 | projection_k.dtype == mx.float32
146 | ) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly.
147 | assert projection_v.dtype == mx.float32
148 | x = scaled_dot_product_attention_grouped(
149 | projection_q,
150 | projection_k,
151 | projection_v,
152 | scale=self.scale,
153 | mask=mask,
154 | ).astype(orig_dtype)
155 | x = x.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size)
156 | return linear(x, self.wo)
157 |
158 |
159 | class RoPE:
160 | def __init__(
161 | self,
162 | dims: int,
163 | seq_len: int,
164 | base: int = 10000,
165 | traditional: bool = False,
166 | ):
167 | self.dims = dims
168 | self.seq_len = seq_len
169 | half_dims = dims // 2
170 | inner = mx.arange(0, half_dims, dtype=mx.float32) / half_dims
171 | freqs = mx.power(base, -inner)
172 | t = mx.arange(seq_len)
173 | freqs = mx.outer(t, freqs)
174 | self.cos_freqs = mx.cos(freqs)
175 | self.sin_freqs = mx.sin(freqs)
176 | self.base = base
177 | self.half_dims = half_dims
178 | self.traditional = traditional
179 | def __call__(
180 | self, x: mx.array, offset: slice | None = None
181 | ) -> tuple[mx.array, mx.array]:
182 | # input x: (b, s, n_heads, head_dim)
183 | *N, S, H, D = x.shape
184 | # if offset is not None:
185 | # assert len(offset) == S, f"offset {len(offset)} must be of length {s}"
186 | cos_basis = (
187 | self.cos_freqs[:S, :] if offset is None else self.cos_freqs[offset, :]
188 | )
189 | sin_basis = (
190 | self.sin_freqs[:S, :] if offset is None else self.sin_freqs[offset, :]
191 | )
192 | # reshape x: (b, s, n_heads, head_dim // 2, 2)
193 | if self.traditional:
194 | x = x.reshape(*N, S, H, self.half_dims, 2)
195 | x1 = x[..., 0]
196 | x2 = x[..., 1]
197 | else:
198 | x1 = x[..., 0:self.half_dims]
199 | x2 = x[..., self.half_dims:self.dims]
200 | # reshape basis: (1, s, 1, dims // 2, 2)
201 | cos_basis = cos_basis.reshape(S, 1, self.half_dims)
202 | sin_basis = sin_basis.reshape(S, 1, self.half_dims)
203 | # manually doing complex number multiplication..
204 | real = mx.multiply(x1, cos_basis) - mx.multiply(x2, sin_basis)
205 | imag = mx.multiply(x2, cos_basis) + mx.multiply(x1, sin_basis)
206 | if self.traditional:
207 | y = mx.stack([real, imag], axis=-1)
208 | y = y.reshape(*N, S, H, D)
209 | else:
210 | y = mx.concat([real, imag], axis=-1)
211 | y = y.reshape(*N, S, H, D)
212 | return y
213 |
214 |
215 | class Qwen2MLP:
216 | def __init__(
217 | self,
218 | dim: int,
219 | hidden_dim: int,
220 | w_gate: mx.array,
221 | w_up: mx.array,
222 | w_down: mx.array,
223 | ):
224 | self.dim = dim
225 | self.hidden_dim = hidden_dim
226 | self.w_gate = w_gate
227 | self.w_up = w_up
228 | self.w_down = w_down
229 |
230 | def __call__(self, x: mx.array) -> mx.array:
231 | return linear(silu(linear(x, self.w_gate)) * linear(x, self.w_up), self.w_down)
232 |
233 |
234 | class RMSNorm:
235 | def __init__(self, dim: int, weight: mx.array, eps: float = 1e-5):
236 | self.dim = dim
237 | self.eps = eps
238 | self.weight = weight.astype(mx.float32)
239 |
240 | def __call__(self, x: mx.array) -> mx.array:
241 | # TODO: tests to ensure the precision of this function
242 | orig_dtype = x.dtype
243 | x = x.astype(mx.float32)
244 | return (
245 | self.weight
246 | * x
247 | * mx.rsqrt(mx.mean(mx.square(x), axis=-1, keepdims=True) + self.eps)
248 | ).astype(orig_dtype)
249 |
250 |
251 | class Qwen2TransformerBlock:
252 | def __init__(
253 | self,
254 | num_attention_heads: int,
255 | num_kv_heads: int,
256 | hidden_size: int,
257 | intermediate_size: int,
258 | rms_norm_eps: float,
259 | wq: mx.array,
260 | wk: mx.array,
261 | wv: mx.array,
262 | wo: mx.array,
263 | bq: mx.array,
264 | bk: mx.array,
265 | bv: mx.array,
266 | w_gate: mx.array,
267 | w_up: mx.array,
268 | w_down: mx.array,
269 | w_input_layernorm: mx.array,
270 | w_post_attention_layernorm: mx.array,
271 | max_seq_len: int = 32768,
272 | theta: int = 1000000,
273 | ):
274 | self.num_attention_heads = num_attention_heads
275 | self.hidden_size = hidden_size
276 | self.mlp = Qwen2MLP(hidden_size, intermediate_size, w_gate, w_up, w_down)
277 | self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps)
278 | self.post_attention_layernorm = RMSNorm(
279 | hidden_size, w_post_attention_layernorm, eps=rms_norm_eps
280 | )
281 | self.self_attn = Qwen2MultiHeadAttention(
282 | num_heads=num_attention_heads,
283 | hidden_size=hidden_size,
284 | num_kv_heads=num_kv_heads,
285 | wq=wq,
286 | wk=wk,
287 | wv=wv,
288 | wo=wo,
289 | bq=bq,
290 | bk=bk,
291 | bv=bv,
292 | max_seq_len=max_seq_len,
293 | theta=theta,
294 | )
295 |
296 | def __call__(
297 | self,
298 | x: mx.array,
299 | offset: int,
300 | mask: mx.array | None = None,
301 | cache: KVCache | None = None,
302 | ) -> mx.array:
303 | r = self.self_attn(self.input_layernorm(x), offset, mask, cache)
304 | h = x + r
305 | r = self.mlp(self.post_attention_layernorm(h))
306 | out = h + r
307 | return out
308 |
309 |
310 | def dequantize_linear(mx_layer: Any) -> mx.array:
311 | w = mx.dequantize(
312 | mx_layer.weight,
313 | mx_layer.scales,
314 | mx_layer.biases,
315 | mx_layer.group_size,
316 | mx_layer.bits,
317 | )
318 | return w
319 |
320 | class Embedding:
321 | def __init__(self, vocab_size: int, embedding_dim: int, weight: mx.array):
322 | self.vocab_size = vocab_size
323 | self.embedding_dim = embedding_dim
324 | self.weight = weight
325 |
326 | def __call__(self, x: mx.array) -> mx.array:
327 | return self.weight[x, :]
328 |
329 | class Qwen2Model:
330 | def __init__(
331 | self,
332 | mlx_model: Any,
333 | ):
334 | self.num_hidden_layers = mlx_model.args.num_hidden_layers
335 | self.hidden_size = mlx_model.args.hidden_size
336 | self.vocab_size = mlx_model.args.vocab_size
337 | precision = mx.float16
338 | self.precision = precision
339 |
340 | self.embedding = Embedding(
341 | vocab_size=self.vocab_size,
342 | embedding_dim=self.hidden_size,
343 | weight=dequantize_linear(mlx_model.model.embed_tokens).astype(precision),
344 | )
345 | self.layers_inner = []
346 |
347 |
348 | for i in range(mlx_model.args.num_hidden_layers):
349 | wq = dequantize_linear(mlx_model.model.layers[i].self_attn.q_proj)
350 | wk = dequantize_linear(mlx_model.model.layers[i].self_attn.k_proj)
351 | wv = dequantize_linear(mlx_model.model.layers[i].self_attn.v_proj)
352 | wo = dequantize_linear(mlx_model.model.layers[i].self_attn.o_proj)
353 | w_gate = dequantize_linear(mlx_model.model.layers[i].mlp.gate_proj)
354 | w_up = dequantize_linear(mlx_model.model.layers[i].mlp.up_proj)
355 | w_down = dequantize_linear(mlx_model.model.layers[i].mlp.down_proj)
356 |
357 | layer = Qwen2TransformerBlock(
358 | num_attention_heads=mlx_model.args.num_attention_heads,
359 | num_kv_heads=mlx_model.args.num_key_value_heads,
360 | hidden_size=mlx_model.args.hidden_size,
361 | intermediate_size=mlx_model.args.intermediate_size,
362 | rms_norm_eps=mlx_model.args.rms_norm_eps,
363 | wq=wq.astype(precision),
364 | wk=wk.astype(precision),
365 | wv=wv.astype(precision),
366 | wo=wo.astype(precision),
367 | bq=mlx_model.model.layers[i].self_attn.q_proj.bias.astype(precision),
368 | bk=mlx_model.model.layers[i].self_attn.k_proj.bias.astype(precision),
369 | bv=mlx_model.model.layers[i].self_attn.v_proj.bias.astype(precision),
370 | w_gate=w_gate.astype(precision),
371 | w_up=w_up.astype(precision),
372 | w_down=w_down.astype(precision),
373 | w_input_layernorm=mlx_model.model.layers[
374 | i
375 | ].input_layernorm.weight.astype(precision),
376 | w_post_attention_layernorm=mlx_model.model.layers[
377 | i
378 | ].post_attention_layernorm.weight.astype(precision),
379 | max_seq_len=mlx_model.args.max_position_embeddings,
380 | theta=mlx_model.args.rope_theta,
381 | )
382 | self.layers_inner.append(layer)
383 | self.norm = RMSNorm(
384 | mlx_model.args.hidden_size,
385 | weight=mlx_model.model.norm.weight.astype(precision),
386 | eps=mlx_model.args.rms_norm_eps,
387 | )
388 | self.w_lm_head = dequantize_linear(mlx_model.lm_head)
389 | self.mlx_model = mlx_model
390 |
391 | def __call__(
392 | self,
393 | inputs: mx.array,
394 | offset: int,
395 | mask: mx.array | None = None,
396 | cache: KVCache | None = None,
397 | ) -> mx.array:
398 | h = self.embedding(inputs)
399 | for layer in range(self.num_hidden_layers):
400 | h = self.layers_inner[layer](h, offset, None, cache[layer] if cache else None)
401 | h = self.norm(h)
402 | return linear(h, self.w_lm_head)
403 |
404 | def sanitize(self, weights: dict):
405 | assert False, "not implemented"
406 |
407 | @property
408 | def layers(self):
409 | return self.layers_inner
410 |
--------------------------------------------------------------------------------
/tests/test_attention.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import mlx.core as mx
3 | import torch
4 | from tiny_llm.funcs import *
5 | from tiny_llm.layers import *
6 | import numpy as np
7 | from .utils import *
8 |
9 |
10 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
11 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
12 | def test_softmax(stream: mx.Stream, precision: np.dtype):
13 | with mx.stream(stream):
14 | BATCH_SIZE = 10
15 | DIM = 10
16 | for _ in range(100):
17 | x = np.random.rand(BATCH_SIZE, DIM).astype(precision)
18 | user_output = softmax(mx.array(x), axis=-1)
19 | reference_output = torch.nn.functional.softmax(
20 | torch.tensor(x, device=TORCH_DEVICE), dim=-1
21 | )
22 | assert_allclose(user_output, reference_output, precision=precision)
23 |
24 |
25 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
26 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
27 | def test_attention(stream: mx.Stream, precision: np.dtype):
28 | with mx.stream(stream):
29 | BATCH_SIZE = 3
30 | DIM_N = 4
31 | DIM_M = 5
32 | for _ in range(100):
33 | query = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
34 | key = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
35 | value = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
36 | reference_output = torch.nn.functional.scaled_dot_product_attention(
37 | torch.tensor(query, device=TORCH_DEVICE),
38 | torch.tensor(key, device=TORCH_DEVICE),
39 | torch.tensor(value, device=TORCH_DEVICE),
40 | )
41 | user_output = scaled_dot_product_attention(
42 | mx.array(query),
43 | mx.array(key),
44 | mx.array(value),
45 | )
46 | assert_allclose(user_output, reference_output, precision=precision)
47 |
48 |
49 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
50 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
51 | @pytest.mark.parametrize(
52 | "qkv_shape", [True, False], ids=["with_seq_len", "without_seq_len"]
53 | )
54 | def test_attention_with_mask(stream: mx.Stream, precision: np.dtype, qkv_shape: bool):
55 | with mx.stream(stream):
56 | BATCH_SIZE = 3
57 | SEQ_LEN = 10
58 | DIM_N = 4
59 | DIM_M = 5
60 | if qkv_shape:
61 | qkv_shape = (BATCH_SIZE, SEQ_LEN, DIM_N, DIM_M)
62 | mask_shape = (BATCH_SIZE, SEQ_LEN, DIM_N, DIM_N)
63 | else:
64 | qkv_shape = (BATCH_SIZE, DIM_N, DIM_M)
65 | mask_shape = (BATCH_SIZE, DIM_N, DIM_N)
66 | for _ in range(100):
67 | query = np.random.rand(*qkv_shape).astype(precision)
68 | key = np.random.rand(*qkv_shape).astype(precision)
69 | value = np.random.rand(*qkv_shape).astype(precision)
70 | scale = 0.8
71 | mask = np.random.rand(*mask_shape).astype(precision)
72 | reference_output = torch.nn.functional.scaled_dot_product_attention(
73 | torch.tensor(query, device=TORCH_DEVICE),
74 | torch.tensor(key, device=TORCH_DEVICE),
75 | torch.tensor(value, device=TORCH_DEVICE),
76 | scale=scale,
77 | attn_mask=torch.tensor(mask, device=TORCH_DEVICE),
78 | )
79 | user_output = scaled_dot_product_attention(
80 | mx.array(query),
81 | mx.array(key),
82 | mx.array(value),
83 | scale=scale,
84 | mask=mx.array(mask),
85 | )
86 | assert_allclose(user_output, reference_output, precision=precision)
87 |
88 |
89 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
90 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
91 | def test_multi_head_attention(stream: mx.Stream, precision: np.dtype):
92 | with mx.stream(stream):
93 | BATCH_SIZE = 7
94 | DIM_N = 11
95 | DIM_M = 9
96 | NUM_HEADS = 3
97 | for _ in range(100):
98 | query = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
99 | key = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
100 | value = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
101 | q_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
102 | k_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
103 | v_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
104 | out_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
105 | mask = np.random.rand(DIM_N * NUM_HEADS, BATCH_SIZE, BATCH_SIZE).astype(
106 | precision
107 | )
108 | reference_output, _ = torch.nn.functional.multi_head_attention_forward(
109 | torch.tensor(query, device=TORCH_DEVICE),
110 | torch.tensor(key, device=TORCH_DEVICE),
111 | torch.tensor(value, device=TORCH_DEVICE),
112 | num_heads=NUM_HEADS,
113 | q_proj_weight=torch.tensor(q_proj_weight, device=TORCH_DEVICE),
114 | k_proj_weight=torch.tensor(k_proj_weight, device=TORCH_DEVICE),
115 | v_proj_weight=torch.tensor(v_proj_weight, device=TORCH_DEVICE),
116 | out_proj_weight=torch.tensor(out_proj_weight, device=TORCH_DEVICE),
117 | embed_dim_to_check=DIM_M,
118 | in_proj_weight=None,
119 | in_proj_bias=None,
120 | bias_k=None,
121 | bias_v=None,
122 | add_zero_attn=False,
123 | dropout_p=0.0,
124 | out_proj_bias=None,
125 | use_separate_proj_weight=True,
126 | attn_mask=torch.tensor(mask, device=TORCH_DEVICE),
127 | )
128 | user_output = MultiHeadAttention(
129 | DIM_M,
130 | NUM_HEADS,
131 | mx.array(q_proj_weight),
132 | mx.array(k_proj_weight),
133 | mx.array(v_proj_weight),
134 | mx.array(out_proj_weight),
135 | )(
136 | mx.array(query),
137 | mx.array(key),
138 | mx.array(value),
139 | mask=mx.array(mask),
140 | )
141 | assert_allclose(user_output, reference_output, precision=precision)
142 |
143 |
144 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
145 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
146 | @pytest.mark.parametrize(
147 | "batch_dimension", [0, 1, 2], ids=["batch_0", "batch_1", "batch_2"]
148 | )
149 | @pytest.mark.parametrize("scale", [None, 0.8])
150 | def test_attention_grouped(
151 | stream: mx.Stream, precision: np.dtype, batch_dimension: int, scale: float | None
152 | ):
153 | with mx.stream(stream):
154 | H_q = 18
155 | H = 6
156 | L = 7
157 | E = 5
158 | S = 3
159 | BATCH = 10
160 | BATCH_2 = 2
161 | if batch_dimension == 0:
162 | q_shape = (H_q, L, E)
163 | kv_shape = (H, S, E)
164 | mask_shape = (H_q, L, S)
165 | elif batch_dimension == 1:
166 | q_shape = (BATCH, H_q, L, E)
167 | kv_shape = (BATCH, H, S, E)
168 | mask_shape = (BATCH, H_q, L, S)
169 | elif batch_dimension == 2:
170 | q_shape = (BATCH_2, BATCH, H_q, L, E)
171 | kv_shape = (BATCH_2, BATCH, H, S, E)
172 | mask_shape = (BATCH_2, BATCH, H_q, L, S)
173 | for _ in range(100):
174 | query = np.random.rand(*q_shape).astype(precision)
175 | key = np.random.rand(*kv_shape).astype(precision)
176 | value = np.random.rand(*kv_shape).astype(precision)
177 | mask = np.random.rand(*mask_shape).astype(precision)
178 | reference_output = torch.nn.functional.scaled_dot_product_attention(
179 | torch.tensor(query, device=TORCH_DEVICE),
180 | torch.tensor(key, device=TORCH_DEVICE),
181 | torch.tensor(value, device=TORCH_DEVICE),
182 | scale=scale,
183 | attn_mask=torch.tensor(mask, device=TORCH_DEVICE),
184 | enable_gqa=True,
185 | )
186 | user_output = scaled_dot_product_attention_grouped(
187 | mx.array(query),
188 | mx.array(key),
189 | mx.array(value),
190 | scale=scale,
191 | mask=mx.array(mask),
192 | )
193 | assert_allclose(user_output, reference_output, precision=precision)
194 |
--------------------------------------------------------------------------------
/tests/test_funcs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import mlx.core as mx
3 | import torch
4 | from tiny_llm.funcs import *
5 | from tiny_llm.layers import *
6 | import numpy as np
7 | from .utils import *
8 | import torchtune
9 |
10 |
11 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
12 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
13 | def test_silu(stream: mx.Stream, precision: np.dtype):
14 | SIZE = 100
15 |
16 | with mx.stream(stream):
17 | for _ in range(100):
18 | data = np.random.rand(SIZE).astype(precision)
19 | reference_output = torch.nn.functional.silu(
20 | torch.tensor(data, device=TORCH_DEVICE)
21 | )
22 | user_output = silu(mx.array(data))
23 | assert_allclose(user_output, reference_output, precision)
24 |
25 |
26 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
27 | @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
28 | def test_rms_norm(stream: mx.Stream, precision: np.dtype):
29 | SIZE = 100
30 | SIZE_Y = 111
31 | with mx.stream(stream):
32 | for _ in range(100):
33 | data = np.random.rand(SIZE, SIZE_Y).astype(precision)
34 | weight = np.random.rand(SIZE_Y).astype(precision)
35 | eps = np.finfo(precision).eps
36 | reference_output = torch.nn.functional.rms_norm(
37 | torch.tensor(data, device=TORCH_DEVICE),
38 | (SIZE_Y,),
39 | torch.tensor(weight, device=TORCH_DEVICE),
40 | eps=eps,
41 | )
42 | user_output = RMSNorm(SIZE_Y, mx.array(weight), eps=eps)(mx.array(data))
43 | assert_allclose(user_output, reference_output, precision)
44 |
--------------------------------------------------------------------------------
/tests/test_rope.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import mlx.core as mx
3 | import torch
4 | from tiny_llm.funcs import *
5 | from tiny_llm.layers import *
6 | import numpy as np
7 | from .utils import *
8 | import torchtune
9 | import random
10 |
11 |
12 | @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
13 | @pytest.mark.parametrize("with_input", [True, False])
14 | def test_rope(stream: mx.Stream, with_input: bool):
15 | BATCH_SIZE = 1
16 | NUM_HEADS = 8
17 | NUM_KV_HEADS = 6
18 | HEAD_DIM = 4
19 | MAX_SEQ_LEN = 20
20 | SEQ_LEN = 10
21 | BASE = 10000.0
22 |
23 | with mx.stream(stream):
24 | for _ in range(100):
25 | reference_layer = (
26 | torchtune.modules.position_embeddings.RotaryPositionalEmbeddings(
27 | HEAD_DIM,
28 | MAX_SEQ_LEN,
29 | BASE,
30 | )
31 | )
32 | user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=True)
33 | x = np.random.rand(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM)
34 |
35 | if with_input:
36 | input_pos = np.random.randint(0, MAX_SEQ_LEN - SEQ_LEN)
37 | input_pos_mx = input_pos
38 | input_pos_user = slice(input_pos, input_pos + SEQ_LEN)
39 | input_pos_torch = torch.tensor([i for i in range(input_pos, input_pos + SEQ_LEN)], device=TORCH_DEVICE, dtype=torch.int32)
40 | else:
41 | input_pos = None
42 | input_pos_mx = None
43 | input_pos_user = None
44 | input_pos_torch = None
45 |
46 | reference_output = reference_layer.forward(
47 | torch.tensor(x, device=TORCH_DEVICE), input_pos=input_pos_torch
48 | )
49 | user_output = user_layer(mx.array(x), input_pos_user)
50 | assert_allclose(user_output, reference_output, np.float32, atol=1e-6)
51 |
52 | user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=False)
53 | reference_output = mx.fast.rope(
54 | mx.array(x).transpose(0, 2, 1, 3),
55 | dims=HEAD_DIM,
56 | traditional=False,
57 | base=BASE,
58 | scale=1.0,
59 | offset=input_pos_mx or 0,
60 | ).transpose(0, 2, 1, 3)
61 | user_output = user_layer(mx.array(x), input_pos_user)
62 | assert_allclose(user_output, reference_output, np.float32, atol=1e-6)
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mlx.core as mx
3 | import torch
4 |
5 | AVAILABLE_STREAMS = [mx.cpu, mx.gpu]
6 | AVAILABLE_STREAMS_IDS = ["cpu", "gpu"]
7 | PRECISIONS = [np.float32, np.float16]
8 | PRECISION_IDS = ["f32", "f16"]
9 | TORCH_DEVICE = torch.device("cpu")
10 |
11 |
12 | def assert_allclose(
13 | a: mx.array,
14 | b: torch.Tensor | mx.array,
15 | precision: np.dtype,
16 | rtol: float | None = None,
17 | atol: float | None = None,
18 | ):
19 | a = np.array(a)
20 | if isinstance(b, torch.Tensor):
21 | b = b.cpu().numpy()
22 | elif isinstance(b, mx.array):
23 | b = np.array(b)
24 | else:
25 | raise ValueError(f"Unsupported type: {type(b)}")
26 | if precision == np.float32:
27 | rtol = rtol or 1.0e-5
28 | atol = atol or 1.0e-8
29 | elif precision == np.float16:
30 | rtol = rtol or 1.0e-2
31 | atol = atol or 1.0e-7
32 | assert a.shape == b.shape
33 | if not np.allclose(a, b, rtol=rtol, atol=atol):
34 | print("a=", a)
35 | print("b=", b)
36 | diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol))
37 | print("diff_a=", a * diff)
38 | print("diff_b=", b * diff)
39 | assert False, f"result mismatch"
40 |
41 |
42 | def np_type_to_mx_type(np_type: np.dtype) -> mx.Dtype:
43 | if np_type == np.float32:
44 | return mx.float32
45 | elif np_type == np.float16:
46 | return mx.float16
47 | else:
48 | raise ValueError(f"Unsupported numpy type: {np_type}")
49 |
--------------------------------------------------------------------------------