├── LICENSE
├── README.md
├── doc
├── CONTRIBUTING.md
├── README.md
├── fig
│ ├── flashNorm_fig1.pdf
│ ├── flashNorm_fig1.svg
│ ├── flashNorm_fig2.pdf
│ ├── flashNorm_fig3.pdf
│ ├── flashNorm_fig4.pdf
│ ├── flashNorm_fig5.pdf
│ ├── flashNorm_fig6.pdf
│ ├── flashNorm_fig7.pdf
│ ├── flashNorm_fig8.pdf
│ ├── flashNorm_figA.pdf
│ ├── flashNorm_figB.pdf
│ ├── matShrink_fig1.pdf
│ ├── matShrink_fig2.pdf
│ ├── matShrink_fig3.pdf
│ ├── precomp1stLayer_fig1.pdf
│ ├── precomp1stLayer_fig2.pdf
│ ├── removeWeights_fig1.pdf
│ ├── removeWeights_fig2.pdf
│ ├── removeWeights_fig3.pdf
│ ├── removeWeights_fig4.pdf
│ ├── slimAttn_fig1.pdf
│ ├── slimAttn_fig1.svg
│ ├── slimAttn_fig2.pdf
│ ├── slimAttn_fig3.pdf
│ ├── slimAttn_fig4.pdf
│ ├── slimAttn_fig5.pdf
│ ├── slimAttn_fig6.pdf
│ └── slimAttn_fig7.pdf
├── flashNorm.md
├── flashNorm.pdf
├── matShrink.pdf
├── precomp1stLayer.pdf
├── removeWeights.pdf
├── slimAttn.md
└── slimAttn.pdf
├── flashNorm_example.py
├── flashNorm_modeling_llama.py
├── flashNorm_test.py
├── notebooks
├── README.md
├── flashNorm_example.ipynb
├── flashNorm_paper.ipynb
├── removeWeights_paper.ipynb
├── slimAttn_paper.ipynb
└── update_packages.ipynb
├── pyproject.toml
├── requirements.txt
├── slimAttn_paper.py
├── tex
├── README.md
├── arxiv.sty
├── clean
├── flashNorm.tex
├── matShrink.tex
├── neurips_2025.sty
├── neurips_2025.tex
├── precomp1stLayer.tex
├── references.bib
├── removeWeights.tex
├── run
├── slimAttn.tex
└── submit
├── transformer_tricks.py
└── util
├── .aspell
├── clean_all
├── gen_notebooks
├── gen_pdf
├── push_pypi
└── spell_check
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 OpenMachine
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
7 |
8 | A collection of tricks to simplify and speed up transformer models:
9 | - Slim attention: [[paper]](https://arxiv.org/abs/2503.05840), [[video]](https://youtu.be/uVtk3B6YO4Y), [[podcast]](https://notebooklm.google.com/notebook/ac47a53c-866b-4271-ab79-bc48d1b41722/audio), [[notebook]](https://colab.research.google.com/github/OpenMachine-ai/transformer-tricks/blob/main/notebooks/slimAttn_paper.ipynb), [[code-readme]](doc/slimAttn.md), :hugs: [[article]](https://huggingface.co/blog/Kseniase/attentions), [[reddit]](https://www.reddit.com/r/LocalLLaMA/comments/1j9wkc2/slim_attention_cut_your_context_memory_in_half)
10 | - Flash normalization: [[paper]](https://arxiv.org/abs/2407.09577), [[video]](https://youtu.be/GEuJv34_XgU), [[podcast]](https://notebooklm.google.com/notebook/0877599c-720c-49b5-b451-8a41af592dd1/audio), [[notebook]](https://colab.research.google.com/github/OpenMachine-ai/transformer-tricks/blob/main/notebooks/flashNorm_paper.ipynb), [[code-readme]](doc/flashNorm.md)
11 | - Matrix-shrink \[work in progress\]: [[paper]](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/doc/matShrink.pdf)
12 | - Precomputing the first layer: [[paper]](https://arxiv.org/abs/2402.13388), [[podcast]](https://notebooklm.google.com/notebook/7794278e-de6a-40fc-ab1c-3240a40e55d5/audio)
13 | - Removing weights from skipless transformers: [[paper]](https://arxiv.org/abs/2404.12362), [[podcast]](https://notebooklm.google.com/notebook/0875eef7-094e-4c30-bc13-90a1a074c949/audio), [[notebook]](https://colab.research.google.com/github/OpenMachine-ai/transformer-tricks/blob/main/notebooks/removeWeights_paper.ipynb)
14 |
15 | Many of these tricks follow a recent trend of removing parts from neural networks such as [RMSNorm’s](https://arxiv.org/abs/1910.07467) removal of mean centering from LayerNorm, [PaLM's](https://arxiv.org/abs/2204.02311) removal of bias-parameters, [decoder-only transformer's](https://arxiv.org/abs/1801.10198) removal of the encoder stack, and of course [transformer’s](https://arxiv.org/abs/1706.03762) revolutionary removal of recurrent layers.
16 |
17 | For example, our FlashNorm removes the weights from RMSNorm and merges them with the next linear layer. And slim attention removes the entire V-cache from the context memory for MHA transformers.
18 |
19 | ---
20 |
21 | ## Explainer videos
22 |
23 | [](https://www.youtube.com/watch?v=uVtk3B6YO4Y "Slim attention")
24 | [](https://www.youtube.com/watch?v=GEuJv34_XgU "Flash normalization")
25 |
26 | ---
27 |
28 | ## Installation
29 |
30 | Install the transformer tricks package:
31 | ```bash
32 | pip install transformer-tricks
33 | ```
34 |
35 | Alternatively, to run from latest repo:
36 | ```bash
37 | git clone https://github.com/OpenMachine-ai/transformer-tricks.git
38 | python3 -m venv .venv
39 | source .venv/bin/activate
40 | pip3 install --quiet -r requirements.txt
41 | ```
42 |
43 | ---
44 |
45 | ## Documentation
46 | Follow the links below for documentation of the python code in this directory:
47 | - [Slim attention](doc/slimAttn.md)
48 | - [Flash normalization](doc/flashNorm.md)
49 |
50 | ---
51 |
52 | ## Notebooks
53 | The papers are accompanied by the following Jupyter notebooks:
54 | - Slim attention:
55 | - Flash normalization:
56 | - Removing weights from skipless transformers:
57 |
58 | ---
59 | ## Newsletter
60 | Please subscribe to our [newsletter](https://transformertricks.substack.com) on substack to get the latest news about this project. We will never send you more than one email per month.
61 |
62 | [](https://transformertricks.substack.com)
63 |
64 | ---
65 |
66 | ## Contributing
67 | We pay cash for high-impact contributions. Please check out [CONTRIBUTING](doc/CONTRIBUTING.md) for how to get involved.
68 |
69 | ---
70 |
71 | ## Sponsors
72 | The Transformer Tricks project is currently sponsored by [OpenMachine](https://openmachine.ai). We'd love to hear from you if you'd like to join us in supporting this project.
73 |
74 | ---
75 |
76 | ### Please give us a ⭐ if you like this repo, and check out [TinyFive](https://github.com/OpenMachine-ai/tinyfive)
77 |
78 | ---
79 |
--------------------------------------------------------------------------------
/doc/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | We pay cash for your high-impact contributions, please contact us for details.
2 |
3 | Before submitting a PR, please do the following:
4 | - Make sure to minimize the number of files and lines of code, we strive for simple and readable code.
5 | - Format your code by typing `autopep8 *.py`. It's using the config in `pyproject.toml`
6 | - Generate notebooks from python by typing `util/gen_notebooks`
7 | - Whenever you change `transformer_tricks.py`, we will publish a new version of the package as follows:
8 | - First, update the version number in `pyproject.toml` and in `requirements.txt`
9 | - Then, push the package to PyPi by typing `./push_pypi.sh`
10 | - Links for python package: [pypi](https://pypi.org/project/transformer-tricks/), [stats](https://www.pepy.tech/projects/transformer-tricks)
11 |
--------------------------------------------------------------------------------
/doc/README.md:
--------------------------------------------------------------------------------
1 | Click on the links below for better PDF viewing:
2 | - [[flashNorm.pdf]](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/doc/flashNorm.pdf)
3 | - [[matShrink.pdf]](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/doc/matShrink.pdf)
4 | - [[precomp1stLayer.pdf]](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/doc/precomp1stLayer.pdf)
5 | - [[removeWeights.pdf]](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/doc/removeWeights.pdf)
6 | - [[slimAttn.pdf]](https://docs.google.com/viewer?url=https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/doc/slimAttn.pdf)
7 |
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig1.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig2.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig3.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig4.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig5.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig6.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig6.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig7.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig7.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_fig8.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_fig8.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_figA.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_figA.pdf
--------------------------------------------------------------------------------
/doc/fig/flashNorm_figB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/flashNorm_figB.pdf
--------------------------------------------------------------------------------
/doc/fig/matShrink_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/matShrink_fig1.pdf
--------------------------------------------------------------------------------
/doc/fig/matShrink_fig2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/matShrink_fig2.pdf
--------------------------------------------------------------------------------
/doc/fig/matShrink_fig3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/matShrink_fig3.pdf
--------------------------------------------------------------------------------
/doc/fig/precomp1stLayer_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/precomp1stLayer_fig1.pdf
--------------------------------------------------------------------------------
/doc/fig/precomp1stLayer_fig2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/precomp1stLayer_fig2.pdf
--------------------------------------------------------------------------------
/doc/fig/removeWeights_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/removeWeights_fig1.pdf
--------------------------------------------------------------------------------
/doc/fig/removeWeights_fig2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/removeWeights_fig2.pdf
--------------------------------------------------------------------------------
/doc/fig/removeWeights_fig3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/removeWeights_fig3.pdf
--------------------------------------------------------------------------------
/doc/fig/removeWeights_fig4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/removeWeights_fig4.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig1.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig1.svg:
--------------------------------------------------------------------------------
1 |
2 |
215 |
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig2.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig3.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig4.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig5.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig6.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig6.pdf
--------------------------------------------------------------------------------
/doc/fig/slimAttn_fig7.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/fig/slimAttn_fig7.pdf
--------------------------------------------------------------------------------
/doc/flashNorm.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Setup
4 | ```
5 | pip3 install transformer-tricks
6 | ```
7 |
8 | ## Example
9 | The example below converts SmolLM-135M to [FlashNorm](https://arxiv.org/pdf/2407.09577) and measures perplexity of the original and the modified model.
10 | ```python
11 | import transformer_tricks as tt
12 |
13 | # convert model and store the new model in ./SmolLM-135M_flashNorm_test
14 | tt.flashify_repo('HuggingFaceTB/SmolLM-135M')
15 |
16 | # run example inference of original and modified model
17 | tt.hello_world('HuggingFaceTB/SmolLM-135M')
18 | tt.hello_world('./SmolLM-135M_flashNorm_test')
19 |
20 | # measure perplexity of original and modified model
21 | tt.perplexity('HuggingFaceTB/SmolLM-135M', speedup=16)
22 | tt.perplexity('./SmolLM-135M_flashNorm_test', speedup=16)
23 | ```
24 | Results:
25 | ```
26 | Once upon a time there was a curious little girl
27 | Once upon a time there was a curious little girl
28 | perplexity = 16.083
29 | perplexity = 16.083
30 | ```
31 |
32 | You can run the example in your browser by clicking on this notebook: . Hit "cancel" when it says "Notebook does not have secret access", because we don't need an HF_TOKEN for SmolLM.
33 |
34 | TODO: [our HuggingFace repo](https://huggingface.co/open-machine/FlashNorm)
35 |
36 | ## Test FlashNorm
37 | ```shell
38 | # setup
39 | git clone https://github.com/OpenMachine-ai/transformer-tricks.git
40 | pip3 install --quiet -r requirements.txt
41 |
42 | # run tests
43 | python3 flashNorm_test.py
44 | ```
45 | Results:
46 | ```
47 | Once upon a time there was a curious little girl
48 | Once upon a time there was a curious little girl
49 | Once upon a time there was a little girl named
50 | Once upon a time there was a little girl named
51 | perplexity = 16.083
52 | perplexity = 16.083
53 | perplexity = 12.086
54 | perplexity = 12.086
55 | ```
56 | To run llama and other LLMs that need an agreement (not SmolLM), you first have to type the following, which will ask for your `hf_token`:
57 | ```
58 | huggingface-cli login
59 | ```
60 |
61 | ## Please give us a ⭐ if you like this repo, thanks!
62 |
--------------------------------------------------------------------------------
/doc/flashNorm.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/flashNorm.pdf
--------------------------------------------------------------------------------
/doc/matShrink.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/matShrink.pdf
--------------------------------------------------------------------------------
/doc/precomp1stLayer.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/precomp1stLayer.pdf
--------------------------------------------------------------------------------
/doc/removeWeights.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/removeWeights.pdf
--------------------------------------------------------------------------------
/doc/slimAttn.md:
--------------------------------------------------------------------------------
1 | coming soon
2 |
3 | For now, see the [[notebook]](https://colab.research.google.com/github/OpenMachine-ai/transformer-tricks/blob/main/notebooks/slimAttn_paper.ipynb)
4 |
5 | Feature requests for frameworks to implement slim attention:
6 | - [llama.cpp and whisper.cpp](https://github.com/ggml-org/llama.cpp/issues/12359)
7 | - [vLLM](https://github.com/vllm-project/vllm/issues/14937)
8 | - [SGLang](https://github.com/sgl-project/sglang/issues/4496)
9 |
--------------------------------------------------------------------------------
/doc/slimAttn.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/doc/slimAttn.pdf
--------------------------------------------------------------------------------
/flashNorm_example.py:
--------------------------------------------------------------------------------
1 | # This example converts SmolLM-135M to FlashNorm and measures its perplexity
2 | # before and after the conversion.
3 | # Usage: python3 flashNorm_example.py
4 |
5 | # !wget -q https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/flashNorm_modeling_llama.py
6 | # %pip install --quiet transformer_tricks
7 | import transformer_tricks as tt
8 |
9 | tt.quiet_hf() # calm down HuggingFace
10 |
11 | # %%
12 | #-------------------------------------------------------------------------------
13 | # Example 1
14 | #-------------------------------------------------------------------------------
15 |
16 | # convert model and store the new model in ./SmolLM-135M_flashNorm_test
17 | tt.flashify_repo('HuggingFaceTB/SmolLM-135M')
18 |
19 | # run example inference of original and modified model
20 | tt.hello_world('HuggingFaceTB/SmolLM-135M')
21 | tt.hello_world('./SmolLM-135M_flashNorm_test')
22 |
23 | # measure perplexity of original and modified model
24 | tt.perplexity('HuggingFaceTB/SmolLM-135M', speedup=16)
25 | tt.perplexity('./SmolLM-135M_flashNorm_test', speedup=16)
26 |
27 | # %%
28 | #-------------------------------------------------------------------------------
29 | # Example 2
30 | #-------------------------------------------------------------------------------
31 |
32 | # convert model and store the new model in ./SmolLM-135M_flashNorm
33 | tt.flashify_repo('HuggingFaceTB/SmolLM-135M')
34 |
35 | # run example inference of original and modified model
36 | tt.hello_world('HuggingFaceTB/SmolLM-135M')
37 | tt.hello_world('./SmolLM-135M_flashNorm', arch='LlamaFlashNorm')
38 |
39 | # measure perplexity of original and modified model
40 | tt.perplexity('HuggingFaceTB/SmolLM-135M', speedup=16)
41 | tt.perplexity('./SmolLM-135M_flashNorm', speedup=16, arch='LlamaFlashNorm')
42 |
43 | # %% [markdown]
44 | # Whenever you change this file, make sure to regenerate the jupyter notebook by typing:
45 | # `util/gen_notebooks`
46 |
--------------------------------------------------------------------------------
/flashNorm_test.py:
--------------------------------------------------------------------------------
1 | # flashify LLMs and run inference and perplexity to make sure that
2 | # the flashified models are equivalent to the original ones
3 | # Usage: python3 test_flashNorm.py
4 |
5 | import transformer_tricks as tt
6 |
7 | tt.quiet_hf() # calm down HuggingFace
8 |
9 | # convert models to flashNorm
10 | tt.flashify_repo('HuggingFaceTB/SmolLM-135M')
11 | tt.flashify_repo('HuggingFaceTB/SmolLM-360M')
12 | #tt.flashify_repo('HuggingFaceTB/SmolLM-1.7B', bars=True)
13 | #tt.flashify_repo('microsoft/Phi-3-mini-4k-instruct', bars=True)
14 |
15 | # run models
16 | tt.hello_world('HuggingFaceTB/SmolLM-135M')
17 | tt.hello_world( 'SmolLM-135M_flashNorm_test')
18 | tt.hello_world( 'SmolLM-135M_flashNorm', arch='LlamaFlashNorm')
19 | tt.hello_world('HuggingFaceTB/SmolLM-360M')
20 | tt.hello_world( 'SmolLM-360M_flashNorm_test')
21 | tt.hello_world( 'SmolLM-360M_flashNorm', arch='LlamaFlashNorm')
22 | #tt.hello_world('HuggingFaceTB/SmolLM-1.7B')
23 | #tt.hello_world( 'SmolLM-1.7B_flashNorm')
24 | #tt.hello_world('microsoft/Phi-3-mini-4k-instruct')
25 | #tt.hello_world( 'Phi-3-mini-4k-instruct_flashNorm')
26 |
27 | # measure perplexity
28 | tt.perplexity('HuggingFaceTB/SmolLM-135M', speedup=16)
29 | tt.perplexity( 'SmolLM-135M_flashNorm_test', speedup=16)
30 | tt.perplexity( 'SmolLM-135M_flashNorm', arch='LlamaFlashNorm', speedup=16)
31 | tt.perplexity('HuggingFaceTB/SmolLM-360M', speedup=16)
32 | tt.perplexity( 'SmolLM-360M_flashNorm_test', speedup=16)
33 | tt.perplexity( 'SmolLM-360M_flashNorm', arch='LlamaFlashNorm', speedup=16)
34 | #tt.perplexity('HuggingFaceTB/SmolLM-1.7B', speedup=64)
35 | #tt.perplexity( 'SmolLM-1.7B_flashNorm', speedup=64)
36 | #tt.perplexity('microsoft/Phi-3-mini-4k-instruct', speedup=64, bars=True)
37 | #tt.perplexity( 'Phi-3-mini-4k-instruct_flashNorm', speedup=64, bars=True)
38 |
39 | # TODO: add more LLMs
40 | #python3 gen.py stabilityai/stablelm-2-1_6b # doesn't use RMSNorm, but LayerNorm
41 | #python3 gen.py meta-llama/Meta-Llama-3.1-8B
42 | #python3 gen.py mistralai/Mistral-7B-v0.3
43 |
44 | # Notes for running larger models:
45 | # - To run llama and other semi-secret LLMs, you first have to type the following:
46 | # huggingface-cli login
47 | # above will ask you for the hf_token, which is the same you use e.g. in colab
48 | #
49 | # - On MacBook, open the 'Activity Monitor' and check your memory usage. If your
50 | # MacBook has only 8GB of DRAM, then you have only about 6GB available. Many LLMs
51 | # use float32, so a 1.5B model needs at least 6GB of DRAM.
52 | #
53 | # - Running gen.py is limited by DRAM bandwidth, not compute. Running ppl.py is
54 | # usually limited by compute (rather than by memory bandwidth), so only having
55 | # 8GB of DRAM is likely not an issue for running ppl.py on larger LLMs. That's
56 | # because ppl.py doesn't do the auto-regressive generation phase but only the
57 | # prompt phase (where all input tokens are batched).
58 | #
59 | # - The models get cached on your system at ~/.cache/huggingface, which can grow
60 | # very big, see du -h -d 3 ~/.cache/huggingface
61 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | Click on the icons below to run the notebooks in your browser. You can hit 'cancel' when it says 'Notebook does not have secret access' because we don't need an HF_TOKEN:
2 | - flashNorm_example.ipynb:
3 | - flashNorm_paper.ipynb:
4 | - removeWeights_paper.ipynb:
5 | - slimAttn_paper.ipynb:
6 | - update_packages.ipynb:
7 |
--------------------------------------------------------------------------------
/notebooks/flashNorm_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "456ef5b0",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# This example converts SmolLM-135M to FlashNorm and measures its perplexity\n",
11 | "# before and after the conversion.\n",
12 | "# Usage: python3 flashNorm_example.py\n",
13 | "\n",
14 | "!wget -q https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/flashNorm_modeling_llama.py\n",
15 | "%pip install --quiet transformer_tricks\n",
16 | "import transformer_tricks as tt\n",
17 | "\n",
18 | "tt.quiet_hf() # calm down HuggingFace"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "a8cd3fd7",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "#-------------------------------------------------------------------------------\n",
29 | "# Example 1\n",
30 | "#-------------------------------------------------------------------------------\n",
31 | "\n",
32 | "# convert model and store the new model in ./SmolLM-135M_flashNorm_test\n",
33 | "tt.flashify_repo('HuggingFaceTB/SmolLM-135M')\n",
34 | "\n",
35 | "# run example inference of original and modified model\n",
36 | "tt.hello_world('HuggingFaceTB/SmolLM-135M')\n",
37 | "tt.hello_world('./SmolLM-135M_flashNorm_test')\n",
38 | "\n",
39 | "# measure perplexity of original and modified model\n",
40 | "tt.perplexity('HuggingFaceTB/SmolLM-135M', speedup=16)\n",
41 | "tt.perplexity('./SmolLM-135M_flashNorm_test', speedup=16)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "31381dd7",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "#-------------------------------------------------------------------------------\n",
52 | "# Example 2\n",
53 | "#-------------------------------------------------------------------------------\n",
54 | "\n",
55 | "# convert model and store the new model in ./SmolLM-135M_flashNorm\n",
56 | "tt.flashify_repo('HuggingFaceTB/SmolLM-135M')\n",
57 | "\n",
58 | "# run example inference of original and modified model\n",
59 | "tt.hello_world('HuggingFaceTB/SmolLM-135M')\n",
60 | "tt.hello_world('./SmolLM-135M_flashNorm', arch='LlamaFlashNorm')\n",
61 | "\n",
62 | "# measure perplexity of original and modified model\n",
63 | "tt.perplexity('HuggingFaceTB/SmolLM-135M', speedup=16)\n",
64 | "tt.perplexity('./SmolLM-135M_flashNorm', speedup=16, arch='LlamaFlashNorm')"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "id": "46c7d85f",
70 | "metadata": {},
71 | "source": [
72 | "Whenever you change this file, make sure to regenerate the jupyter notebook by typing:\n",
73 | " `util/gen_notebooks`"
74 | ]
75 | }
76 | ],
77 | "metadata": {
78 | "jupytext": {
79 | "cell_metadata_filter": "-all",
80 | "main_language": "python",
81 | "notebook_metadata_filter": "-all"
82 | }
83 | },
84 | "nbformat": 4,
85 | "nbformat_minor": 5
86 | }
87 |
--------------------------------------------------------------------------------
/notebooks/flashNorm_paper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "id": "P9KAn1NGtAX3"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# code for paper 'Transformer tricks: flash normalization'\n",
12 | "\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "# reciprocal of RMS and activation functions\n",
16 | "def r_rms(x): return 1 / np.sqrt(np.mean(x**2))\n",
17 | "def r_ms(x): return 1 / np.mean(x**2)\n",
18 | "def relu(x): return np.maximum(0, x)\n",
19 | "def sigmoid(x): return 1 / (1 + np.exp(-x))\n",
20 | "def silu(x): return x * sigmoid(x) # often known as swish\n",
21 | "\n",
22 | "# merge normalization weights g into weight matrix W\n",
23 | "def flashify(g, W):\n",
24 | " Wnew = np.empty(W.shape)\n",
25 | " for i in range(g.shape[0]):\n",
26 | " Wnew[i, :] = g[i] * W[i, :]\n",
27 | " return Wnew\n",
28 | "\n",
29 | "# alternative flashify (same as above but fewer lines)\n",
30 | "#def flashify_alt(g, W):\n",
31 | "# G = np.repeat(g, W.shape[1]).reshape(W.shape)\n",
32 | "# return G * W # elementwise multiply"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "source": [
38 | "# variables\n",
39 | "n = 32\n",
40 | "f = 128\n",
41 | "a = np.random.rand(n) # row-vector\n",
42 | "g = np.random.rand(n) # row-vector\n",
43 | "W = np.random.rand(n, n)\n",
44 | "UP = np.random.rand(n, f)\n",
45 | "GATE = np.random.rand(n, f)\n",
46 | "DOWN = np.random.rand(f, n)\n",
47 | "\n",
48 | "# derived variables\n",
49 | "s = r_rms(a) # scaling factor\n",
50 | "Wstar = flashify(g, W)\n",
51 | "UPstar = flashify(g, UP)\n",
52 | "GATEstar = flashify(g, GATE)"
53 | ],
54 | "metadata": {
55 | "id": "yLty-2szRrDM"
56 | },
57 | "execution_count": 2,
58 | "outputs": []
59 | },
60 | {
61 | "cell_type": "code",
62 | "source": [
63 | "# code for section 1 of paper\n",
64 | "\n",
65 | "# figures 1(a), 1(b), and 1(c) of paper\n",
66 | "z_fig1a = (r_rms(a) * a * g) @ W\n",
67 | "z_fig1b = (r_rms(a) * a) @ Wstar\n",
68 | "z_fig1c = (a @ Wstar) * r_rms(a)\n",
69 | "\n",
70 | "# compare against z_fig1a\n",
71 | "print(np.allclose(z_fig1b, z_fig1a), ' (fig1b is close to fig1a if True)')\n",
72 | "print(np.allclose(z_fig1c, z_fig1a), ' (fig1c is close to fig1a if True)')"
73 | ],
74 | "metadata": {
75 | "colab": {
76 | "base_uri": "https://localhost:8080/"
77 | },
78 | "id": "4YSv5p16ScE2",
79 | "outputId": "e9e6ec90-04e7-4d80-98dc-374e8202f60e"
80 | },
81 | "execution_count": 3,
82 | "outputs": [
83 | {
84 | "output_type": "stream",
85 | "name": "stdout",
86 | "text": [
87 | "True (fig1b is close to fig1a if True)\n",
88 | "True (fig1c is close to fig1a if True)\n"
89 | ]
90 | }
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "source": [
96 | "# code for section 2.1 of paper\n",
97 | "\n",
98 | "# reference and figures 2(a) and 2(b) of paper\n",
99 | "y_ref2 = relu((s * a * g) @ UP) @ DOWN\n",
100 | "y_fig2a = relu((a @ UPstar) * s) @ DOWN\n",
101 | "y_fig2b = (relu(a @ UPstar) @ DOWN) * s\n",
102 | "\n",
103 | "# compare against y_ref\n",
104 | "print(np.allclose(y_fig2a, y_ref2), ' (fig2a is close to reference if True)')\n",
105 | "print(np.allclose(y_fig2b, y_ref2), ' (fig2b is close to reference if True)')"
106 | ],
107 | "metadata": {
108 | "colab": {
109 | "base_uri": "https://localhost:8080/"
110 | },
111 | "id": "HlUzIzR-VXlg",
112 | "outputId": "6f90bd82-b003-4431-e071-a251b4906d8a"
113 | },
114 | "execution_count": 4,
115 | "outputs": [
116 | {
117 | "output_type": "stream",
118 | "name": "stdout",
119 | "text": [
120 | "True (fig2a is close to reference if True)\n",
121 | "True (fig2b is close to reference if True)\n"
122 | ]
123 | }
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "source": [
129 | "# code for section 2.2 of paper\n",
130 | "\n",
131 | "# shortcuts\n",
132 | "a_norm = s * a * g\n",
133 | "a_gate, a_up = (a @ GATEstar), (a @ UPstar)\n",
134 | "\n",
135 | "# figure 3: reference and figures 3(a) and 3(b) of paper\n",
136 | "y_ref3 = ((a_norm @ GATE) * silu(a_norm @ UP)) @ DOWN\n",
137 | "y_fig3a = (a_gate * s * silu(a_up * s)) @ DOWN\n",
138 | "y_fig3b = ((a_gate * silu(a_up * s)) @ DOWN) * s\n",
139 | "\n",
140 | "# compare against y_ref3\n",
141 | "print(np.allclose(y_fig3a, y_ref3), ' (fig3a is close to reference if True)')\n",
142 | "print(np.allclose(y_fig3b, y_ref3), ' (fig3b is close to reference if True)')\n",
143 | "\n",
144 | "# figure 4: reference and figures 4(a) and 4(b) of paper\n",
145 | "y_ref4 = ((a_norm @ GATE) * relu(a_norm @ UP)) @ DOWN\n",
146 | "y_fig4a = (a_gate * s * relu(a_up * s)) @ DOWN\n",
147 | "y_fig4b = ((a_gate * relu(a_up)) @ DOWN) * r_ms(a)\n",
148 | "\n",
149 | "# compare against y_ref4\n",
150 | "print(np.allclose(y_fig4a, y_ref4), ' (fig4a is close to reference if True)')\n",
151 | "print(np.allclose(y_fig4b, y_ref4), ' (fig4b is close to reference if True)')"
152 | ],
153 | "metadata": {
154 | "colab": {
155 | "base_uri": "https://localhost:8080/"
156 | },
157 | "id": "3sVxcH8Q0MHt",
158 | "outputId": "5a7a7d7c-f935-4e80-fc34-9b59662e39d7"
159 | },
160 | "execution_count": 5,
161 | "outputs": [
162 | {
163 | "output_type": "stream",
164 | "name": "stdout",
165 | "text": [
166 | "True (fig3a is close to reference if True)\n",
167 | "True (fig3b is close to reference if True)\n",
168 | "True (fig4a is close to reference if True)\n",
169 | "True (fig4b is close to reference if True)\n"
170 | ]
171 | }
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "source": [
177 | "# code for section 3 of paper\n",
178 | "\n",
179 | "# TODO"
180 | ],
181 | "metadata": {
182 | "id": "xW9Sz-Dy5H0e"
183 | },
184 | "execution_count": null,
185 | "outputs": []
186 | }
187 | ],
188 | "metadata": {
189 | "colab": {
190 | "provenance": []
191 | },
192 | "kernelspec": {
193 | "display_name": "Python 3",
194 | "name": "python3"
195 | },
196 | "language_info": {
197 | "name": "python"
198 | }
199 | },
200 | "nbformat": 4,
201 | "nbformat_minor": 0
202 | }
--------------------------------------------------------------------------------
/notebooks/removeWeights_paper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "P9KAn1NGtAX3"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import numpy as np\n",
13 | "from huggingface_hub import hf_hub_download\n",
14 | "import gc # garbage collection needed for low RAM footprint\n",
15 | "\n",
16 | "# download Mistral-7B from https://huggingface.co/mistralai/Mistral-7B-v0.1\n",
17 | "hf_hub_download(repo_id='mistralai/Mistral-7B-v0.1', filename='pytorch_model-00001-of-00002.bin', local_dir='.')\n",
18 | "hf_hub_download(repo_id='mistralai/Mistral-7B-v0.1', filename='pytorch_model-00002-of-00002.bin', local_dir='.')\n",
19 | "\n",
20 | "# load model files, use mmap to keep RAM footprint low\n",
21 | "m1 = torch.load('pytorch_model-00001-of-00002.bin', weights_only=True, mmap=True)\n",
22 | "m2 = torch.load('pytorch_model-00002-of-00002.bin', weights_only=True, mmap=True)\n",
23 | "\n",
24 | "def get_weights(model, layer, name):\n",
25 | " \"\"\"returns weight matrix of specific layer and name (such as Q, K, V)\"\"\"\n",
26 | " layer_str = 'layers.' + str(layer)\n",
27 | " match name:\n",
28 | " case 'Q': suffix = layer_str + '.self_attn.q_proj.weight'\n",
29 | " case 'K': suffix = layer_str + '.self_attn.k_proj.weight'\n",
30 | " case 'V': suffix = layer_str + '.self_attn.v_proj.weight'\n",
31 | " case 'P': suffix = layer_str + '.self_attn.o_proj.weight'\n",
32 | " case 'O': suffix = layer_str + '.mlp.down_proj.weight'\n",
33 | " case 'E': suffix = 'embed_tokens.weight'\n",
34 | " W = model['model.' + suffix].to(torch.float64).numpy() # convert to float64\n",
35 | " return W if name == 'E' else W.T # transpose weights, except for 'E'\n",
36 | "\n",
37 | "for layer in range(0, 32):\n",
38 | " print('layer', layer)\n",
39 | "\n",
40 | " # get weights Q, K, V, P, O\n",
41 | " model = m1 if layer < 23 else m2 # use m1 for layers 0 to 22\n",
42 | " Q = get_weights(model, layer, 'Q')\n",
43 | " K = get_weights(model, layer, 'K')\n",
44 | " V = get_weights(model, layer, 'V')\n",
45 | " P = get_weights(model, layer, 'P')\n",
46 | " O = get_weights(model, layer - 1, 'E' if layer == 0 else 'O') # use embedding for 1st layer\n",
47 | "\n",
48 | " # check if weight elimination is numerically identical\n",
49 | " Q_inv = np.linalg.inv(Q) # errors out if matrix is not invertible\n",
50 | " K_star = Q_inv @ K\n",
51 | " V_star = Q_inv @ V\n",
52 | " O_star = O @ Q\n",
53 | " print(' is O* @ K* close to O @ K ? ', np.allclose(O_star @ K_star, O @ K))\n",
54 | " print(' is O* @ V* close to O @ V ? ', np.allclose(O_star @ V_star, O @ V))\n",
55 | "\n",
56 | " # also check if P is invertible\n",
57 | " P_inv = np.linalg.inv(P) # errors out if matrix is not invertible\n",
58 | "\n",
59 | "# garbage collection (to avoid colab's RAM limit)\n",
60 | "del m1, m2, model, Q, K, V, P, O, Q_inv, P_inv, K_star, V_star, O_star\n",
61 | "gc.collect()"
62 | ]
63 | }
64 | ],
65 | "metadata": {
66 | "colab": {
67 | "provenance": []
68 | },
69 | "kernelspec": {
70 | "display_name": "Python 3",
71 | "name": "python3"
72 | },
73 | "language_info": {
74 | "name": "python"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 0
79 | }
--------------------------------------------------------------------------------
/notebooks/slimAttn_paper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "9e3a70d2",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# Proof of concept for paper \"Slim Attention: cut your context memory in half\"\n",
11 | "# Usage: python3 slimAttn_paper.py\n",
12 | "\n",
13 | "%pip install --quiet transformer_tricks\n",
14 | "import transformer_tricks as tt\n",
15 | "import numpy as np\n",
16 | "import torch\n",
17 | "from transformers import AutoConfig\n",
18 | "\n",
19 | "#-------------------------------------------------------------------------------\n",
20 | "# defs\n",
21 | "#-------------------------------------------------------------------------------\n",
22 | "def softmax(x, axis=-1):\n",
23 | " \"\"\"softmax along 'axis', default is the last axis\"\"\"\n",
24 | " e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n",
25 | " return e_x / np.sum(e_x, axis=axis, keepdims=True)\n",
26 | "\n",
27 | "def msplit(M, h):\n",
28 | " \"\"\"shortcut to split matrix M into h chunks\"\"\"\n",
29 | " return np.array_split(M, h, axis=-1)\n",
30 | "\n",
31 | "def ops(A, B):\n",
32 | " \"\"\"number of OPs (operations) for matmul of A and B:\n",
33 | " - A and B must be 2D arrays, and their inner dimensions must agree!\n",
34 | " - A is an m × n matrix, and B is an n × p matrix, then the resulting product\n",
35 | " of A and B is an m × p matrix.\n",
36 | " - Each element (i,j) of the m x p result matrix is computed by the dotproduct\n",
37 | " of the i-th row of A and the j-th column of B.\n",
38 | " - Each dotproduct takes n multiplications and n - 1 additions, so total\n",
39 | " number of OPs is 2n - 1 per dotproduct.\n",
40 | " - There are m * p elements in the result matrix, so m * p dotproducts, so in\n",
41 | " total we need m * p * (2n - 1) OPs, which is approximately 2*m*p*n OPs\n",
42 | " - For simplicity, let's just use the simple approximation of OPs = 2*m*p*n\"\"\"\n",
43 | " m, n = A.shape\n",
44 | " p = B.shape[1]\n",
45 | " return 2 * m * n * p\n",
46 | "\n",
47 | "#-------------------------------------------------------------------------------\n",
48 | "# setup for model SmolLM2-1.7B\n",
49 | "#-------------------------------------------------------------------------------\n",
50 | "tt.quiet_hf() # calm down HuggingFace\n",
51 | "\n",
52 | "repo = 'HuggingFaceTB/SmolLM2-1.7B'\n",
53 | "param = tt.get_param(repo)\n",
54 | "config = AutoConfig.from_pretrained(repo)\n",
55 | "\n",
56 | "h = config.num_attention_heads\n",
57 | "d = config.hidden_size\n",
58 | "dk = config.head_dim"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "6498fa85",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "#-------------------------------------------------------------------------------\n",
69 | "# check if we can accurately compute V from K for each layer\n",
70 | "#-------------------------------------------------------------------------------\n",
71 | "for layer in range(config.num_hidden_layers):\n",
72 | " # convert to float64 for better accuracy of matrix inversion\n",
73 | " # note that all weights are transposed in tensorfile (per pytorch convention)\n",
74 | " Wk = param[tt.weight('K', layer)].to(torch.float64).numpy().T\n",
75 | " Wv = param[tt.weight('V', layer)].to(torch.float64).numpy().T\n",
76 | " Wkv = np.linalg.inv(Wk) @ Wv\n",
77 | " print(layer, ':', np.allclose(Wk @ Wkv, Wv)) # check if Wk @ Wkv close to Wv"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "a0f6735d",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "#-------------------------------------------------------------------------------\n",
88 | "# compare options 1 and 2 for calculating equation (5) of paper\n",
89 | "#-------------------------------------------------------------------------------\n",
90 | "\n",
91 | "# get weights for Q, K, V and convert to float64\n",
92 | "# note that all weights are transposed in tensorfile (per pytorch convention)\n",
93 | "Wq = param[tt.weight('Q', 0)].to(torch.float64).numpy().T\n",
94 | "Wk = param[tt.weight('K', 0)].to(torch.float64).numpy().T\n",
95 | "Wv = param[tt.weight('V', 0)].to(torch.float64).numpy().T\n",
96 | "Wkv = np.linalg.inv(Wk) @ Wv # calculate Wkv (aka W_KV)\n",
97 | "# print('Is Wk @ Wkv close to Wv?', np.allclose(Wk @ Wkv, Wv))\n",
98 | "\n",
99 | "# generate random input X\n",
100 | "n = 100 # number of tokens\n",
101 | "X = np.random.rand(n, d).astype(np.float64) # range [0,1]\n",
102 | "Xn = np.expand_dims(X[n-1, :], axis=0) # n-th row of X; make it a 1 x d matrix\n",
103 | "\n",
104 | "Q = Xn @ Wq # only for the last row of X (for the generate-phase)\n",
105 | "K = X @ Wk\n",
106 | "V = X @ Wv\n",
107 | "\n",
108 | "# only consider the first head\n",
109 | "Q0, K0, V0 = msplit(Q, h)[0], msplit(K, h)[0], msplit(V, h)[0]\n",
110 | "Wkv0 = msplit(Wkv, h)[0]\n",
111 | "\n",
112 | "# baseline reference\n",
113 | "scores = softmax((Q0 @ K0.T) / np.sqrt(dk))\n",
114 | "head_ref = scores @ V0\n",
115 | "\n",
116 | "# head option1 and option2\n",
117 | "head_o1 = scores @ (K @ Wkv0) # option 1\n",
118 | "head_o2 = (scores @ K) @ Wkv0 # option 2\n",
119 | "\n",
120 | "# compare\n",
121 | "print('Is head_o1 close to head_ref?', np.allclose(head_o1, head_ref))\n",
122 | "print('Is head_o2 close to head_ref?', np.allclose(head_o2, head_ref))\n",
123 | "\n",
124 | "# computational complexity for both options\n",
125 | "o1_step1, o1_step2 = ops(K, Wkv0), ops(scores, (K @ Wkv0))\n",
126 | "o2_step1, o2_step2 = ops(scores, K), ops(scores @ K, Wkv0)\n",
127 | "\n",
128 | "print(f'Option 1 OPs: step 1 = {o1_step1:,}; step 2 = {o1_step2:,}; total = {(o1_step1 + o1_step2):,}')\n",
129 | "print(f'Option 2 OPs: step 1 = {o2_step1:,}; step 2 = {o2_step2:,}; total = {(o2_step1 + o2_step2):,}')\n",
130 | "print(f'speedup of option 2 over option 1: {((o1_step1 + o1_step2) / (o2_step1 + o2_step2)):.1f}')"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "id": "fe352e46",
136 | "metadata": {},
137 | "source": [
138 | "Whenever you change this file, make sure to regenerate the jupyter notebook by typing:\n",
139 | " `util/gen_notebooks`"
140 | ]
141 | }
142 | ],
143 | "metadata": {
144 | "jupytext": {
145 | "cell_metadata_filter": "-all",
146 | "main_language": "python",
147 | "notebook_metadata_filter": "-all"
148 | }
149 | },
150 | "nbformat": 4,
151 | "nbformat_minor": 5
152 | }
153 |
--------------------------------------------------------------------------------
/notebooks/update_packages.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {
21 | "id": "lOsSebCDGVTh"
22 | },
23 | "outputs": [],
24 | "source": [
25 | "# This colab helps you to fix python package issues:\n",
26 | "# - The huggingface (HF) packages are updated very often\n",
27 | "# - We want the transformer-tricks package to work with the (almost) latest\n",
28 | "# HF packages\n",
29 | "# - We only need 3 HF packages: transformers, accelerate, datasets\n",
30 | "# - These 3 HF packages will load many other HF packages (such as safetensors,\n",
31 | "# hub), numpy and torch\n",
32 | "\n",
33 | "# remove all pip packages, except for pip, dateutil, certifi, etc.\n",
34 | "!pip list --format=freeze | grep -v -E \"pip|dateutil|certifi|psutil|_distutils_hack|pkg_resources\" | xargs pip uninstall -y --quiet\n",
35 | "\n",
36 | "# above takes about 6 minutes !!!"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "source": [
42 | "# install latest versions of the 3 HF packages\n",
43 | "!pip install transformers --quiet\n",
44 | "!pip install accelerate --quiet\n",
45 | "!pip install datasets --quiet\n",
46 | "\n",
47 | "!pip list | grep -E \"transformers|accelerate|datasets\"\n",
48 | "!pip list | grep -E \"torch|numpy|safetensors|huggingface\"\n",
49 | "!python -V"
50 | ],
51 | "metadata": {
52 | "id": "kWz9vRphZL2A"
53 | },
54 | "execution_count": null,
55 | "outputs": []
56 | },
57 | {
58 | "cell_type": "code",
59 | "source": [
60 | "# download files transformer_tricks.py and flashNorm_test.py and test if it\n",
61 | "# works with the latest version of the HF packages\n",
62 | "!wget -q https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/transformer_tricks.py\n",
63 | "!wget -q https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/flashNorm_example.py\n",
64 | "!wget -q https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/refs/heads/main/flashNorm_modeling_llama.py\n",
65 | "\n",
66 | "!python flashNorm_example.py"
67 | ],
68 | "metadata": {
69 | "id": "FCh28_1fauTX"
70 | },
71 | "execution_count": null,
72 | "outputs": []
73 | }
74 | ]
75 | }
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # config file for creating the PyPi package. See
2 | # https://packaging.python.org/en/latest/tutorials/packaging-projects/
3 |
4 | [build-system]
5 | requires = ["hatchling"]
6 | build-backend = "hatchling.build"
7 |
8 | [project]
9 | name = "transformer-tricks"
10 | # version numbering A.B.C: A is major version, B is minor version, and C is patch
11 | version = "0.3.4"
12 | authors = [
13 | {name="Open Machine", email="info@openmachine.ai"},
14 | ]
15 | description = "A collection of tricks to speed up LLMs, see our transformer-tricks papers on arXiv"
16 | readme = "README.md"
17 | requires-python = ">=3.11"
18 | classifiers = [
19 | "Programming Language :: Python :: 3",
20 | "License :: OSI Approved :: MIT License",
21 | "Operating System :: OS Independent",
22 | ]
23 | dependencies = [
24 | "transformers>=4.52.3",
25 | "accelerate>=1.7.0",
26 | "datasets>=3.6.0",
27 | ]
28 |
29 | [project.urls]
30 | "Homepage" = "https://github.com/OpenMachine-ai/transformer-tricks"
31 | "Bug Tracker" = "https://github.com/OpenMachine-ai/transformer-tricks/issues"
32 |
33 | [tool.autopep8]
34 | indent-size = 2
35 | in-place = true
36 | max-line-length = 120
37 | ignore = "E265, E401, E70, E20, E241, E11"
38 | # type 'autopep8 --list-fixes' to see a list of all rules
39 | # for debug, type 'autopep8 --verbose *.py'
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformer-tricks>=0.3.4
2 | jupytext>=1.16.4
3 | autopep8>=2.3.1
4 | twine>=6.1.0
5 | build>=1.2.2
6 |
7 | # pip list # see all versions
8 | #
9 | # Phi-3 needs flash-attn, but this requires CUDA
10 | # flash-attn==2.5.8
11 |
--------------------------------------------------------------------------------
/slimAttn_paper.py:
--------------------------------------------------------------------------------
1 | # Proof of concept for paper "Slim Attention: cut your context memory in half"
2 | # Usage: python3 slimAttn_paper.py
3 |
4 | # %pip install --quiet transformer_tricks
5 | import transformer_tricks as tt
6 | import numpy as np
7 | import torch
8 | from transformers import AutoConfig
9 |
10 | #-------------------------------------------------------------------------------
11 | # defs
12 | #-------------------------------------------------------------------------------
13 | def softmax(x, axis=-1):
14 | """softmax along 'axis', default is the last axis"""
15 | e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
16 | return e_x / np.sum(e_x, axis=axis, keepdims=True)
17 |
18 | def msplit(M, h):
19 | """shortcut to split matrix M into h chunks"""
20 | return np.array_split(M, h, axis=-1)
21 |
22 | def ops(A, B):
23 | """number of OPs (operations) for matmul of A and B:
24 | - A and B must be 2D arrays, and their inner dimensions must agree!
25 | - A is an m × n matrix, and B is an n × p matrix, then the resulting product
26 | of A and B is an m × p matrix.
27 | - Each element (i,j) of the m x p result matrix is computed by the dotproduct
28 | of the i-th row of A and the j-th column of B.
29 | - Each dotproduct takes n multiplications and n - 1 additions, so total
30 | number of OPs is 2n - 1 per dotproduct.
31 | - There are m * p elements in the result matrix, so m * p dotproducts, so in
32 | total we need m * p * (2n - 1) OPs, which is approximately 2*m*p*n OPs
33 | - For simplicity, let's just use the simple approximation of OPs = 2*m*p*n"""
34 | m, n = A.shape
35 | p = B.shape[1]
36 | return 2 * m * n * p
37 |
38 | #-------------------------------------------------------------------------------
39 | # setup for model SmolLM2-1.7B
40 | #-------------------------------------------------------------------------------
41 | tt.quiet_hf() # calm down HuggingFace
42 |
43 | repo = 'HuggingFaceTB/SmolLM2-1.7B'
44 | param = tt.get_param(repo)
45 | config = AutoConfig.from_pretrained(repo)
46 |
47 | h = config.num_attention_heads
48 | d = config.hidden_size
49 | dk = config.head_dim
50 |
51 | # %%
52 | #-------------------------------------------------------------------------------
53 | # check if we can accurately compute V from K for each layer
54 | #-------------------------------------------------------------------------------
55 | for layer in range(config.num_hidden_layers):
56 | # convert to float64 for better accuracy of matrix inversion
57 | # note that all weights are transposed in tensorfile (per pytorch convention)
58 | Wk = param[tt.weight('K', layer)].to(torch.float64).numpy().T
59 | Wv = param[tt.weight('V', layer)].to(torch.float64).numpy().T
60 | Wkv = np.linalg.inv(Wk) @ Wv
61 | print(layer, ':', np.allclose(Wk @ Wkv, Wv)) # check if Wk @ Wkv close to Wv
62 |
63 | # %%
64 | #-------------------------------------------------------------------------------
65 | # compare options 1 and 2 for calculating equation (5) of paper
66 | #-------------------------------------------------------------------------------
67 |
68 | # get weights for Q, K, V and convert to float64
69 | # note that all weights are transposed in tensorfile (per pytorch convention)
70 | Wq = param[tt.weight('Q', 0)].to(torch.float64).numpy().T
71 | Wk = param[tt.weight('K', 0)].to(torch.float64).numpy().T
72 | Wv = param[tt.weight('V', 0)].to(torch.float64).numpy().T
73 | Wkv = np.linalg.inv(Wk) @ Wv # calculate Wkv (aka W_KV)
74 | # print('Is Wk @ Wkv close to Wv?', np.allclose(Wk @ Wkv, Wv))
75 |
76 | # generate random input X
77 | n = 100 # number of tokens
78 | X = np.random.rand(n, d).astype(np.float64) # range [0,1]
79 | Xn = np.expand_dims(X[n-1, :], axis=0) # n-th row of X; make it a 1 x d matrix
80 |
81 | Q = Xn @ Wq # only for the last row of X (for the generate-phase)
82 | K = X @ Wk
83 | V = X @ Wv
84 |
85 | # only consider the first head
86 | Q0, K0, V0 = msplit(Q, h)[0], msplit(K, h)[0], msplit(V, h)[0]
87 | Wkv0 = msplit(Wkv, h)[0]
88 |
89 | # baseline reference
90 | scores = softmax((Q0 @ K0.T) / np.sqrt(dk))
91 | head_ref = scores @ V0
92 |
93 | # head option1 and option2
94 | head_o1 = scores @ (K @ Wkv0) # option 1
95 | head_o2 = (scores @ K) @ Wkv0 # option 2
96 |
97 | # compare
98 | print('Is head_o1 close to head_ref?', np.allclose(head_o1, head_ref))
99 | print('Is head_o2 close to head_ref?', np.allclose(head_o2, head_ref))
100 |
101 | # computational complexity for both options
102 | o1_step1, o1_step2 = ops(K, Wkv0), ops(scores, (K @ Wkv0))
103 | o2_step1, o2_step2 = ops(scores, K), ops(scores @ K, Wkv0)
104 |
105 | print(f'Option 1 OPs: step 1 = {o1_step1:,}; step 2 = {o1_step2:,}; total = {(o1_step1 + o1_step2):,}')
106 | print(f'Option 2 OPs: step 1 = {o2_step1:,}; step 2 = {o2_step2:,}; total = {(o2_step1 + o2_step2):,}')
107 | print(f'speedup of option 2 over option 1: {((o1_step1 + o1_step2) / (o2_step1 + o2_step2)):.1f}')
108 |
109 | # %% [markdown]
110 | # Whenever you change this file, make sure to regenerate the jupyter notebook by typing:
111 | # `util/gen_notebooks`
112 |
--------------------------------------------------------------------------------
/tex/README.md:
--------------------------------------------------------------------------------
1 | This folder contains the latex files for the Transformer Tricks papers. The flow is as follows:
2 | 1) Write first draft and drawings in Google docs.
3 | 2) Create file `foo.tex` and copy text from the Google doc.
4 | - Copy each drawing into a separate google drawing file and adjust the bounding box and "download" as PDF. This PDF is then used by latex.
5 | - For references, see the comments in file `references.bib`
6 | 3) Type `./run foo.tex` to create PDF.
7 | 4) Use spell checker as follows: `cd ..; util/spell_check`
8 | 5) Note: I converted some figures from PDF to SVG (so that I can use them in markdown) as follows `pdftocairo -svg foo.pdf foo.svg` TODO: maybe only use SVG drawings even for tex.
9 | 6) Submit to arXiv:
10 | - To submit `foo.tex`, type: `./submit foo.tex`
11 | - To double-check if everything works, run `pdflatex foo` two times (or sometimes three times) as follows:
12 | `cd foo_submit` and `pdflatex foo && pdflatex foo`
13 | - Then upload the generated `*.tar.gz` file to arXiv.
14 | - Notes for filling out the abstract field in the online form:
15 | - Make sure to remove citations or replace them by `arXiv:YYMM.NNNNN`
16 | - You can add hyperlinks to the abstract as follows: `See https://github.com/blabla for code`
17 | - You can force a new paragraph in the abstract by typing a carriage return followed by one white space in the new line (i.e. indent the new line after the carriage return)
18 | - Keep in mind: papers that are short (e.g. 6 pages or less) are automatically put on `hold` and need to be reviewed by a moderator, which can take several weeks.
19 |
--------------------------------------------------------------------------------
/tex/arxiv.sty:
--------------------------------------------------------------------------------
1 | % This file is copied from
2 | % https://github.com/kourgeorge/arxiv-style/blob/master/arxiv.sty
3 | % my changes are marked by "OM-mods"
4 |
5 | \NeedsTeXFormat{LaTeX2e}
6 |
7 | \ProcessOptions\relax
8 |
9 | % fonts
10 | \renewcommand{\rmdefault}{ptm}
11 | \renewcommand{\sfdefault}{phv}
12 |
13 | % set page geometry
14 | \usepackage[verbose=true,letterpaper]{geometry}
15 | \AtBeginDocument{
16 | \newgeometry{
17 | textheight=9in,
18 | textwidth=6.5in,
19 | top=1in,
20 | % OM-mods: headheight=14pt,
21 | % OM-mods: headsep=25pt,
22 | footskip=30pt
23 | }
24 | }
25 |
26 | \widowpenalty=10000
27 | \clubpenalty=10000
28 | \flushbottom
29 | \sloppy
30 |
31 | % OM-mods: \newcommand{\headeright}{A Preprint}
32 | % OM-mods: \newcommand{\undertitle}{A Preprint}
33 | % OM-mods: \newcommand{\shorttitle}{\@title}
34 |
35 | \usepackage{fancyhdr}
36 | \usepackage{lastpage}
37 | \fancyhf{}
38 | \pagestyle{fancy}
39 | % OM-mods: \renewcommand{\headrulewidth}{0.4pt}
40 | % OM-mods: below line removes the header
41 | \renewcommand{\headrulewidth}{0pt}
42 | % OM-mods: \fancyheadoffset{0pt}
43 | % OM-mods: \rhead{\scshape \footnotesize \headeright}
44 | % OM-mods: \chead{\shorttitle}
45 | \cfoot{{\thepage} of \pageref*{LastPage}}
46 |
47 | % OM-mods: %Handling Keywords
48 | % OM-mods: \def\keywordname{{\bfseries \emph{Keywords}}}%
49 | % OM-mods: \def\keywords#1{\par\addvspace\medskipamount{\rightskip=0pt plus1cm
50 | % OM-mods: \def\and{\ifhmode\unskip\nobreak\fi\ $\cdot$
51 | % OM-mods: }\noindent\keywordname\enspace\ignorespaces#1\par}}
52 |
53 | % font sizes with reduced leading
54 | \renewcommand{\normalsize}{%
55 | \@setfontsize\normalsize\@xpt\@xipt
56 | \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@
57 | \abovedisplayshortskip \z@ \@plus 3\p@
58 | \belowdisplayskip \abovedisplayskip
59 | \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@
60 | }
61 | \normalsize
62 | \renewcommand{\small}{%
63 | \@setfontsize\small\@ixpt\@xpt
64 | \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@
65 | \abovedisplayshortskip \z@ \@plus 2\p@
66 | \belowdisplayskip \abovedisplayskip
67 | \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@
68 | }
69 | \renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt}
70 | \renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt}
71 | \renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt}
72 | \renewcommand{\large}{\@setfontsize\large\@xiipt{14}}
73 | \renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}}
74 | \renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}}
75 | \renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}}
76 | \renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}}
77 |
78 | % sections with less space
79 | \providecommand{\section}{}
80 | \renewcommand{\section}{%
81 | \@startsection{section}{1}{\z@}%
82 | {-2.0ex \@plus -0.5ex \@minus -0.2ex}%
83 | { 1.5ex \@plus 0.3ex \@minus 0.2ex}%
84 | {\large\bf\raggedright}%
85 | }
86 | \providecommand{\subsection}{}
87 | \renewcommand{\subsection}{%
88 | \@startsection{subsection}{2}{\z@}%
89 | {-1.8ex \@plus -0.5ex \@minus -0.2ex}%
90 | { 0.8ex \@plus 0.2ex}%
91 | {\normalsize\bf\raggedright}%
92 | }
93 | \providecommand{\subsubsection}{}
94 | \renewcommand{\subsubsection}{%
95 | \@startsection{subsubsection}{3}{\z@}%
96 | {-1.5ex \@plus -0.5ex \@minus -0.2ex}%
97 | { 0.5ex \@plus 0.2ex}%
98 | {\normalsize\bf\raggedright}%
99 | }
100 | \providecommand{\paragraph}{}
101 | \renewcommand{\paragraph}{%
102 | \@startsection{paragraph}{4}{\z@}%
103 | {1.5ex \@plus 0.5ex \@minus 0.2ex}%
104 | {-1em}%
105 | {\normalsize\bf}%
106 | }
107 | \providecommand{\subparagraph}{}
108 | \renewcommand{\subparagraph}{%
109 | \@startsection{subparagraph}{5}{\z@}%
110 | {1.5ex \@plus 0.5ex \@minus 0.2ex}%
111 | {-1em}%
112 | {\normalsize\bf}%
113 | }
114 | \providecommand{\subsubsubsection}{}
115 | \renewcommand{\subsubsubsection}{%
116 | \vskip5pt{\noindent\normalsize\rm\raggedright}%
117 | }
118 |
119 | % float placement
120 | \renewcommand{\topfraction }{0.85}
121 | \renewcommand{\bottomfraction }{0.4}
122 | \renewcommand{\textfraction }{0.1}
123 | \renewcommand{\floatpagefraction}{0.7}
124 |
125 | \newlength{\@abovecaptionskip}\setlength{\@abovecaptionskip}{7\p@}
126 | \newlength{\@belowcaptionskip}\setlength{\@belowcaptionskip}{\z@}
127 |
128 | \setlength{\abovecaptionskip}{\@abovecaptionskip}
129 | \setlength{\belowcaptionskip}{\@belowcaptionskip}
130 |
131 | % swap above/belowcaptionskip lengths for tables
132 | \renewenvironment{table}
133 | {\setlength{\abovecaptionskip}{\@belowcaptionskip}%
134 | \setlength{\belowcaptionskip}{\@abovecaptionskip}%
135 | \@float{table}}
136 | {\end@float}
137 |
138 | % footnote formatting
139 | \setlength{\footnotesep }{6.65\p@}
140 | \setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@}
141 | \renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@}
142 | \setcounter{footnote}{0}
143 |
144 | % paragraph formatting
145 | \setlength{\parindent}{\z@}
146 | \setlength{\parskip }{5.5\p@}
147 |
148 | % list formatting
149 | % OM-mods: commented out below stuff
150 | %\setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@}
151 | %\setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@}
152 | %\setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
153 | %\setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
154 | %\setlength{\topsep }{-1pt}
155 | %\setlength{\partopsep }{-1pt}
156 | %\setlength{\parsep }{-1pt}
157 | %\setlength{\itemsep }{-1pt}
158 | \setlength{\leftmargin }{3pc}
159 | \setlength{\leftmargini }{\leftmargin}
160 | \setlength{\leftmarginii }{2em}
161 | \setlength{\leftmarginiii}{1.5em}
162 | \setlength{\leftmarginiv }{1.0em}
163 | \setlength{\leftmarginv }{0.5em}
164 | % OM-mods: commented out below stuff
165 | %\def\@listi {\leftmargin\leftmargini}
166 | %\def\@listii {\leftmargin\leftmarginii
167 | % \labelwidth\leftmarginii
168 | % \advance\labelwidth-\labelsep}
169 | %\topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@
170 | %\parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
171 | %\itemsep \parsep}
172 | %\def\@listiii{\leftmargin\leftmarginiii
173 | % \labelwidth\leftmarginiii
174 | % \advance\labelwidth-\labelsep}
175 | %\topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
176 | %\parsep \z@
177 | %\partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@
178 | %\itemsep \topsep}
179 | %\def\@listiv {\leftmargin\leftmarginiv
180 | % \labelwidth\leftmarginiv
181 | % \advance\labelwidth-\labelsep}
182 | %\def\@listv {\leftmargin\leftmarginv
183 | % \labelwidth\leftmarginv
184 | % \advance\labelwidth-\labelsep}
185 | %\def\@listvi {\leftmargin\leftmarginvi
186 | % \labelwidth\leftmarginvi
187 | % \advance\labelwidth-\labelsep}
188 |
189 | % create title
190 | \providecommand{\maketitle}{}
191 | \renewcommand{\maketitle}{%
192 | \par
193 | \begingroup
194 | \renewcommand{\thefootnote}{\fnsymbol{footnote}}
195 | % for perfect author name centering
196 | %\renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}}
197 | % The footnote-mark was overlapping the footnote-text,
198 | % added the following to fix this problem (MK)
199 | \long\def\@makefntext##1{%
200 | \parindent 1em\noindent
201 | \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1
202 | }
203 | \thispagestyle{empty}
204 | \@maketitle
205 | \@thanks
206 | %\@notice
207 | \endgroup
208 | \let\maketitle\relax
209 | \let\thanks\relax
210 | }
211 |
212 | % rules for title box at top of first page
213 | \newcommand{\@toptitlebar}{
214 | \hrule height 2\p@
215 | \vskip 0.25in
216 | \vskip -\parskip%
217 | }
218 | \newcommand{\@bottomtitlebar}{
219 | \vskip 0.29in
220 | \vskip -\parskip
221 | \hrule height 2\p@
222 | \vskip 0.09in%
223 | }
224 |
225 | % create title (includes both anonymized and non-anonymized versions)
226 | \providecommand{\@maketitle}{}
227 | \renewcommand{\@maketitle}{%
228 | \vbox{%
229 | \hsize\textwidth
230 | \linewidth\hsize
231 | \vskip 0.1in
232 | \@toptitlebar
233 | \centering
234 | % OM-mods: {\LARGE\sc \@title\par}
235 | {\LARGE\bf \@title\par}
236 | \@bottomtitlebar
237 | % OM-mods: \textsc{\undertitle}\\
238 | % \vskip 0.1in
239 | \vskip -0.15in
240 | \def\And{%
241 | \end{tabular}\hfil\linebreak[0]\hfil%
242 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
243 | }
244 | \def\AND{%
245 | \end{tabular}\hfil\linebreak[4]\hfil%
246 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
247 | }
248 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}%
249 | % OM-mods: \vskip 0.4in \@minus 0.1in \center{\@date} \vskip 0.2in
250 | % \vskip 0.4in
251 | \vskip 0.2in
252 | }
253 | }
254 |
255 | % add conference notice to bottom of first page
256 | \newcommand{\ftype@noticebox}{8}
257 | \newcommand{\@notice}{%
258 | % give a bit of extra room back to authors on first page
259 | \enlargethispage{2\baselineskip}%
260 | \@float{noticebox}[b]%
261 | \footnotesize\@noticestring%
262 | \end@float%
263 | }
264 |
265 | % abstract styling
266 | \renewenvironment{abstract}
267 | {
268 | \centerline
269 | % OM-mods: {\large \bfseries \scshape Abstract}
270 | {\large \bfseries \upshape Abstract}
271 | \begin{quote}
272 | }
273 | {
274 | \end{quote}
275 | }
276 |
277 | % OM-mods: moved below lines from *.tex file here
278 | \usepackage[utf8]{inputenc}
279 | \usepackage[T1]{fontenc} % use 8-bit T1 fonts
280 | % \usepackage{hyperref}
281 | \usepackage[hidelinks,colorlinks=true,linkcolor=blue,citecolor=blue,urlcolor=blue]{hyperref}
282 | \usepackage{url}
283 | \usepackage{booktabs} % professional-quality tables
284 | \usepackage{amsfonts} % blackboard math symbols
285 | \usepackage{amsmath}
286 | \usepackage{amssymb}
287 | \usepackage{nicefrac} % compact symbols for 1/2, etc.
288 | \usepackage{microtype} % microtypography
289 | \usepackage{cleveref} % smart cross-referencing
290 | \usepackage{graphicx}
291 | \usepackage{doi}
292 | \usepackage{enumitem}
293 | \usepackage{makecell}
294 | \usepackage{multirow}
295 |
296 | % set bold font for tabular heads
297 | \renewcommand\theadfont{\bfseries}
298 |
299 | \endinput
300 |
--------------------------------------------------------------------------------
/tex/clean:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # script to clean up this directory
4 | # usage: ./clean
5 |
6 | \rm -f *.bbl *.aux *.blg *.log *.out
7 |
--------------------------------------------------------------------------------
/tex/flashNorm.tex:
--------------------------------------------------------------------------------
1 | % To generate PDF, type ./run flashNorm
2 |
3 | \documentclass{article}
4 |
5 | \usepackage[preprint, nonatbib]{neurips_2025}
6 | % for submission to neurips:
7 | % \usepackage[nonatbib]{neurips_2025}
8 | % to compile a preprint version, e.g., for submission to arXiv:
9 | % \usepackage[preprint, nonatbib]{neurips_2025}
10 | % to compile a camera-ready version, add the [final] option, e.g.:
11 | % \usepackage[final, nonatbib]{neurips_2025}
12 |
13 | \usepackage[utf8]{inputenc} % allow utf-8 input
14 | \usepackage[T1]{fontenc} % use 8-bit T1 fonts
15 | %% I removed this: \usepackage{hyperref} % hyperlinks
16 | \usepackage{url} % simple URL typesetting
17 | \usepackage{booktabs} % professional-quality tables
18 | \usepackage{amsfonts} % blackboard math symbols
19 | \usepackage{nicefrac} % compact symbols for 1/2, etc.
20 | \usepackage{microtype} % microtypography
21 | \usepackage{xcolor} % colors
22 |
23 | %% I added the following packages
24 | \usepackage[hidelinks,colorlinks=true,linkcolor=blue,citecolor=blue,urlcolor=blue]{hyperref}
25 | \usepackage{amsmath}
26 | \usepackage{amssymb}
27 | \usepackage{graphicx}
28 | \usepackage{makecell}
29 | \usepackage{multirow}
30 | %\usepackage{tablefootnote}
31 | \usepackage{enumitem}
32 | \usepackage{pythonhighlight} % for python listings
33 | \usepackage[numbers]{natbib}
34 | \usepackage{caption}
35 | \captionsetup[figure]{skip=5pt} % reduce the space between figure and caption
36 | %\captionsetup[table]{skip=10pt}
37 |
38 | % shortcuts
39 | \newcommand{\mat}[1]{\mathbf{#1}} % shortcut for matrix
40 | \newcommand{\RMS}[1]{\text{RMS}(#1)} % shortcut for RMS(x)
41 | \def\rms{\text{RMS}(\vec{a})} % RMS(a)
42 | \def\f1n{\frac{1}{n}} % 1/n
43 | \def\sas{\sum_{i=1}^n a_i^2} % sum over a_i squared
44 | \def\W*{\mat{W}^\ast} % matrix W*
45 | \def\V*{\mat{V}^\ast} % matrix V*
46 | \def\mW{\mat{W}} % matrix W
47 | \def\mV{\mat{V}} % matrix V
48 | \def\a{\vec{a}} % vector a
49 | \def\b{\vec{b}} % vector b
50 | \def\c{\vec{c}} % vector c
51 | \def\vb{\vec{\beta}} % vector beta
52 | \def\vx{\vec{x}} % vector x
53 | \def\vy{\vec{y}} % vector y
54 | \def\vz{\vec{z}} % vector z
55 | \def\vg{\vec{g}} % vector g
56 | \def\vs{\vec{s}} % vector s
57 | \def\cosi{\cos{(\cdot)}} % cos(.)
58 | \def\sini{\sin{(\cdot)}} % sin(.)
59 |
60 | \title{FlashNorm: fast normalization for LLMs}
61 | %\title{Flash normalization: fast normalization for LLMs}
62 | %\title{Flash normalization: fast RMSNorm for LLMs}
63 |
64 | \author{Nils Graef\thanks{\texttt{info@openmachine.ai}}, \, Andrew Wasielewski, \, Matthew Clapp \\
65 | \href{https://openmachine.ai}{OpenMachine}}
66 |
67 | \begin{document} \maketitle
68 |
69 | \begin{abstract}
70 | This paper presents FlashNorm, which is an exact but faster implementation of RMSNorm followed by linear layers. RMSNorm \citep{rms} is used by many LLMs such as Llama, Mistral, and OpenELM \citep{LLaMA, mistral, openelm}. FlashNorm also speeds up Layer Normalization \citep{layerNorm} and its recently proposed replacement Dynamic Tanh (DyT) \citep{DyT}. FlashNorm also reduces the number of parameter tensors by simply merging the normalization weights with the weights of the next linear layer. See \citep{slimAttn, tricks, remove, precompute} for code and more transformer tricks.
71 | \end{abstract}
72 |
73 | \section{Flash normalization}
74 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
75 | \includegraphics[scale=1.0]{../doc/fig/flashNorm_fig1.pdf}
76 | \caption{Mathematically identical implementations of RMSNorm followed by a linear layer: (a) unoptimized version with weight matrix $\mat{W}$; (b) optimized version with normalization weights $g_i$ merged into the linear layer with new weights $\W*$; (c) optimized version with deferred normalization. The $\triangleq$ symbol denotes mathematical identity.}
77 | \label{fig1} \end{figure}
78 |
79 | RMSNorm \citep{rms} normalizes the elements $a_i$ of vector $\a$ as $y_i = \frac{a_i}{\rms} \cdot g_i$ with $\rms = \sqrt{\f1n \sas}$ and normalization weights $g_i$. In transformer \citep{vanilla} and other neural networks, RMSNorm is often followed by a linear layer as illustrated in Fig. \ref{fig1}(a), which we optimize as follows:
80 | \begin{itemize}[topsep=-1pt]
81 | \item \textbf{Weightless normalization (aka non-parametric normalization)}: We merge the normalization weights $g_i$ into the linear layer with weights $\mat{W}$, resulting in a modified weight matrix $\W*$ with $W_{i,j}^\ast = g_i \cdot W_{i,j}$ as illustrated in Fig. \ref{fig1}(b). This works for linear layers with and without bias.
82 | \item \textbf{Deferred normalization}: Instead of normalizing before the linear layer, we normalize after the linear layer, as shown in Fig. \ref{fig1}(c). This only works if the linear layer is bias-free, which is the case for many LLMs such as Llama, Mistral, and OpenELM. Specifically, the output of the linear layer in Fig. \ref{fig1}(b) is $\vz = \left( \a \cdot \frac{1}{\rms} \right) \W*$, which is identical to $\vz = \left( \a \, \W* \right) \cdot \frac{1}{\rms}$ because matrix multiplication by a scalar is commutative. If the linear layer has a bias at its output, then the normalization (i.e. scaling by $\frac{1}{\rms}$) must be done before adding the bias.
83 | \end{itemize}
84 |
85 | In summary, FlashNorm eliminates the normalization weights and defers the normalization to the output of the linear layer, which removes a compute bottleneck described at the end of this paper. Deferring the normalization is similar to Flash Attention \citep{flash-attention}, where the normalization by the softmax denominator is done after the multiplication of softmax arguments with value projections (V) (so that keys and values can be processed in \emph{parallel}). Therefore, we call our implementation \emph{flash} normalization (or FlashNorm), which allows us to compute the linear layer and $\rms$ in \emph{parallel} (instead of sequentially).
86 |
87 | \citeauthor{openelm} report significant changes in the overall tokens-per-second throughput when they modify the layer normalization implementation, which they attribute to a lack of kernel fusion for the underlying GPU. The simplifications presented here reduce the number of operations and thus the number of the individual kernel launches mentioned in \citep{openelm}.
88 |
89 | \subsection{Support for normalization bias and DyT bias}
90 | Layer normalization (LayerNorm) \citep{layerNorm} and DyT \citep{DyT} can have a bias vector $\vb$ right after scaling by weights $g_i$. Figure \ref{figA} illustrates how the bias vector $\vb$ can be moved to the output of the linear layer and then be added to the bias vector $\c$ of the linear layer, resulting in the new bias term $\c^{\, \ast} = \c + \vb \, \mW$, see Fig. \ref{figA}(b). After this elimination of $\vb$, the normalization weights $g_i$ can be merged into the linear layer as described in the previous section and illustrated in Fig. \ref{fig1}(b).
91 |
92 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
93 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_figA.pdf}
94 | \caption{Elimination of bias vector $\vb$: (a) Before elimination with $\vb$ between normalization weights $\vg$ and linear layer. (b) Optimized version with new bias term $\c^{\, \ast} = \c + \vb \, \mW$ at the output.}
95 | \label{figA} \end{figure}
96 |
97 | \subsection{Merging mean centering into a preceding linear layer}
98 | Note that LayerNorm consists of mean centering followed by RMSNorm. If the mean centering is preceded by a linear layer with weight matrix $\mV$, then we can eliminate the entire mean centering by modifying the weight matrix as explained in this section. Fig. \ref{figB}(a) shows the weight matrix $\mV$ followed by the mean centering, which is followed by RMSNorm.
99 |
100 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
101 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_figB.pdf}
102 | \caption{Elimination of mean centering: (a) Original weight matrix $\mV$ followed by mean centering. (b) Optimized version where the mean centering is merged into the modified weight matrix $\V*$.}
103 | \label{figB} \end{figure}
104 |
105 | The mean $\mu$ is calculated from the linear layer outputs $y_j$ as $\mu = \frac{1}{n} \sum_{j=1}^n y_j$. Note that $\vy = \vx \, \mV$, i.e. $y_j = \sum_{i=1}^n x_i v_{i, j}$ where $v_{i, j}$ are the weights of matrix $\mV$. Plugging the last equation into the $\mu$ expression lets us calculate $\mu$ directly from the input $\vx$ as
106 | \begin{equation*}
107 | \mu = \frac{1}{n} \sum_{j=1}^n \sum_{i=1}^n x_i v_{i, j} = \frac{1}{n} \sum_{i=1}^n x_i \left[ \sum_{j=1}^n v_{i, j} \right] = \frac{1}{n} \sum_{i=1}^n x_i s_i
108 | \end{equation*}
109 | where we define vector $\vs$ with $s_i = \sum_{j=1}^n v_{i, j}$ the sum of row $i$ of weight matrix $\mV$. In other words, $\mu$ is the inner-product of vectors $\vx$ and $\vs$ divided by $n$. The outputs $a_j$ of the mean centering are
110 | \begin{equation*}
111 | \a_j = y_j - \mu = \sum_{i=1}^n x_i v_{i, j} - \mu = \sum_{i=1}^n x_i v_{i, j} - \frac{1}{n} \sum_{i=1}^n x_i s_i = \sum_{i=1}^n x_i \left( v_{i, j} - \frac{1}{n} s_i \right)
112 | = \sum_{i=1}^n x_i v^{\, \ast}_{i, j}
113 | \end{equation*}
114 | From the last identity follows that the new weights $v^{\, \ast}_{i, j}$ of matrix $\V*$ of Fig. \ref{figB}(b) are computed as $v^{\, \ast}_{i, j} = v_{i, j} - \frac{1}{n} s_i$. This trick can be used to retrofit existing LayerNorm models with RMSNorm without any retraining.
115 |
116 | \section{Flash normalization for FFN}
117 | For the feed-forward networks (FFN) of LLMs, the linear layers at the FFN input usually have more output channels than input channels. In this case, deferring the normalization requires more scaling operations (i.e. more multiplications). This section details ways to reduce the number of scaling operations for bias-free FFNs.
118 |
119 | \subsection{Flash normalization for FFNs with ReLU}
120 | \begin{figure}[h!] \centering
121 | \includegraphics[scale=1.0]{../doc/fig/flashNorm_fig2.pdf}
122 | \caption{FFN with ReLU and preceding flash normalization: (a) unoptimized version; (b) optimized version where the normalization is deferred to the output of the FFN. Up and Down denote the linear layers for up and down projections.}
123 | \label{fig2} \end{figure}
124 |
125 | Even though ReLU is a nonlinear function, multiplying its argument by a non-negative scaling factor $s$ is the same as scaling its output by $s$, i.e. $\text{ReLU}(s \cdot \a) = s \cdot \text{ReLU}(\a)$ for $s \ge 0$ \citep{ReLU}. Because of this scale-invariance, we can defer the normalization to the output of the FFN as illustrated in Fig. \ref{fig2}(b), which saves $f - n$ multipliers.
126 |
127 | \subsection{Flash normalization for FFNs with GLU variant}
128 | Fig. \ref{fig3}(a) shows an FFN with a GLU variant \citep{GLU} and flash normalization at its input. The flash normalization requires two sets of $f$ multipliers at the outputs of the Gate and Up linear layers in Fig. \ref{fig3}(a). One set can be deferred to the FFN output in Fig. \ref{fig3}(b), which saves $f - n$ multipliers.
129 | \begin{figure}[h!] \centering
130 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_fig3.pdf}
131 | \caption{FFN with GLU variant and preceding flash normalization: (a) unoptimized version; (b) optimized version with fewer scaling multipliers. Gate, Up, and Down denote the linear layers for gate, up, and down projections.}
132 | \label{fig3} \end{figure}
133 |
134 | \textbf{Special case for ReGLU and Bilinear GLU}: If the activation function is ReLU (aka ReGLU \citep{GLU}) or just linear (aka bilinear GLU \citep{GLU}), then we can also eliminate the scaling before the activation function and combine it with the scaling at the output as illustrated in Fig. \ref{fig4}(b), which saves $2f - n$ multipliers. Now the output scaling is using the reciprocal of the squared RMS as scaling value, which is the same as the reciprocal of the mean-square (MS):
135 | \begin{equation*}
136 | \frac{1}{(\rms)^2} = \frac{1}{\text{MS}(\a)}
137 | = \frac{1}{\f1n \sas} = \frac{n}{\sas}
138 | \end{equation*}
139 |
140 | \begin{figure}[h!] \centering
141 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_fig4.pdf}
142 | \caption{FFN with ReGLU (or bilinear GLU) and preceding flash normalization: (a) unoptimized version; (b) optimized version with fewer scaling multipliers.}
143 | \label{fig4} \end{figure}
144 |
145 | \section{Flash normalization for attention with RoPE}
146 | Fig. \ref{fig5}(a) shows the Q and K linear layers with flash normalization followed by RoPE \citep{RoPE} and scaled dot-product attention \citep{vanilla}. More details on Figure \ref{fig5}:
147 | \begin{itemize}[topsep=-1pt]
148 | \item Q* and K* are the linear layers for Q (queries) and K (keys) fused with the normalization weights of the activation vector $\a$ (according to flash normalization).
149 | \item $h$ is the dimension of the attention heads.
150 | \item The boxes labeled cos, sin, and RoPE perform $\vy = \vx \cdot \cosi + \text{permute}(\vx) \cdot \sini$, where
151 | \begin{itemize}[topsep=-1pt]
152 | \item $\text{permute}(\vx) = (-x_2, x_1, -x_4, x_3, \dots, -x_h, x_{h-1})$, see equation (34) of \citep{RoPE} for more details.
153 | \item $\cosi = (\cos m \theta_1, \cos m \theta_1, \cos m \theta_2, \cos m \theta_2, \dots, \cos m \theta_{h/2}, \cos m \theta_{h/2})$ for position $m$.
154 | \item $\sini = (\sin m \theta_1, \sin m \theta_1, \sin m \theta_2, \sin m \theta_2, \dots, \sin m \theta_{h/2}, \sin m \theta_{h/2})$ for position $m$.
155 | \end{itemize}
156 | \item Note that $\cosi$ and $\sini$ only depend on the position of activation vector $\a$ and are shared among all attention heads. Therefore, it’s more efficient to first scale $\cosi$ and $\sini$ by $1/ \rms$ as illustrated in Fig. \ref{fig5}(b). This saves $2hH - h$ multipliers, where $H$ is the number of attention heads.
157 | \item Furthermore, we can fuse the scaling factor $1/ \sqrt{h}$ of the scaled dot-product with the $1/ \rms$ factor (note that we need to use $\sqrt{1/ \sqrt{h}}$ as a scaling factor for this).
158 | \item Unfortunately, the V linear layer (value projection) still needs the normalization at its output.
159 | \end{itemize}
160 | \begin{figure}[h!] \centering
161 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_fig5.pdf}
162 | \caption{Flash normalization for scaled dot-product attention with RoPE: (a) unoptimized version; (b) optimized version where the normalization is fused with $\cosi$ and $\sini$.}
163 | \label{fig5} \end{figure}
164 |
165 | \section{Optimizations for QK-normalization with RoPE}
166 | Some LLMs use query-key normalization \citep{QKnorm}. For example, each layer of OpenELM \citep{openelm} has the following two sets of normalization weights:
167 | \begin{itemize}[topsep=-1pt]
168 | \item \verb+q_norm_weight+: query normalization weights for all heads of this layer
169 | \item \verb+k_norm_weight+: key normalization weights for all heads of this layer
170 | \end{itemize}
171 | Unfortunately, FlashNorm can't be applied for QK-normalization. But for the type of QK-normalization used in OpenELM, we can apply the following two optimizations detailed in the next sections:
172 | \begin{enumerate}[topsep=-1pt]
173 | \item Eliminate the RMS calculation before the Q and K linear layers.
174 | \item Fuse the normalization weights with RoPE.
175 | \end{enumerate}
176 |
177 | \subsection{Eliminate RMS calculation before QK linear layers}
178 | Fig. \ref{fig6}(a) shows a linear layer with flash normalization followed by an additional normalization. The weights of the first normalization are already merged into the linear layer weights $\W*$. Note that $\RMS{s \cdot \a} = s \cdot \rms$ where $s$ is scalar and $\a$ is a vector. Due to this scale-invariance of the RMS function, the second multiplier (scaler $s_c$) in the pipeline of Fig. \ref{fig6}(a) cancels out the first multiplier (scaler $s_a$). Fig. \ref{fig6}(b) takes advantage of this property. We can express this by using the vectors $\a, \b, \c$ along the datapath in Fig. \ref{fig6} as follows:
179 | \begin{itemize}[topsep=-1pt]
180 | \item Note that $s_c = \frac{1}{\RMS{\c}} = \frac{1}{\RMS{\b \cdot s_a}} = \frac{1}{s_a \cdot \RMS{\b}} = \frac{s_b}{s_a}$.
181 | \item With above, we can show that the $y$ outputs of figures \ref{fig6}(a) and \ref{fig6}(b) are identical:
182 | \begin{equation*}
183 | y = \a \cdot \W* \cdot s_a \cdot s_c \cdot \vg = \a \cdot \W* \cdot s_a \cdot \frac{s_b}{s_a} \cdot \vg
184 | = \a \cdot \W* \cdot s_b \cdot \vg
185 | \end{equation*}
186 | \end{itemize}
187 |
188 | \begin{figure}[h!] \centering
189 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_fig6.pdf}
190 | \caption{Linear layer with flash normalization followed by a second normalization: (a) unoptimized version; (b) optimized version.}
191 | \label{fig6} \end{figure}
192 |
193 | The scale-invariance property of $\rms$ doesn’t hold exactly true for RMS with epsilon (see appendix). This should not matter because the epsilon only makes an impact if the RMS (or energy) of the activation vector is very small, in which case the epsilon limits the up-scaling of this low-energy activation vector.
194 |
195 | \begin{figure}[h!] \centering
196 | \includegraphics[scale=0.9]{../doc/fig/flashNorm_fig7.pdf}
197 | \caption{QK-normalization with RoPE: (a) unoptimized version; (b) optimized version.}
198 | \label{fig7} \end{figure}
199 |
200 | \subsection{Fuse normalization weights with RoPE}
201 | Fig. \ref{fig7}(a) illustrates QK-normalization with RoPE. If the QK-normalization weights are the same for all heads of a layer, as is the case for OpenELM \citep{openelm}, then we can fuse them with RoPE's $\cosi$ and $\sini$ as follows: multiply $\cosi$ and $\sini$ with the normalization weights and then share the fused $\cosi$ and $\sini$ vectors across all heads of the LLM layer as shown in Fig. \ref{fig7}(b). This requires permutation of the normalization weights $\vg$ so that the boxes labeled cos, sin, and RoPE in Fig. \ref{fig7}(b) perform $\vy = \vx \cdot \left( \cosi \cdot \vg \right) + \text{permute}(\vx) \cdot \left( \sini \cdot \text{permuteg}(\vg) \right)$, where $\text{permuteg}(\vg) = (g_2, g_1, g_4, g_3, \dots, g_h, g_{h-1})$. For simplicity, Fig. \ref{fig7}(b) doesn't show the permutation of the normalization weights.
202 |
203 | \section{Bottleneck of RMS normalization for batch 1}
204 | This section describes the compute bottleneck of RMS normalization that exists for batch size 1. This bottleneck is different from the bottleneck detailed in \citep{openelm}. Let’s consider a processor with one vector unit and one matrix unit:
205 | \begin{itemize}[topsep=-1pt]
206 | \item The matrix multiplications of the linear layers are performed by the matrix unit, while the vector unit performs vector-wise operations such as RMSNorm and FlashNorm.
207 | \item Let’s assume that the vector unit can perform $m$ operations per cycle and the matrix unit can perform $m^2$ operations per cycle, where $m$ is the processor width. Specifically:
208 | \begin{itemize}[topsep=-1pt]
209 | \item Multiplying an $n$-element vector with an $n \times n$ matrix takes $n^2$ MAD (multiply-add) operations, which takes $n^2/m^2$ cycles with our matrix unit.
210 | \item Calculating $1/\rms$ takes $n$ MAD operations (for squaring and adding) plus 2 scalar operations (for $\sqrt{n/x}$), which takes $n/m$ cycles with our vector unit if we ignore the 2 scalar operations.
211 | \item Scaling an $n$-element vector by a scaling factor takes $n$ multiply operations, which takes $n/m$ cycles.
212 | \end{itemize}
213 | \end{itemize}
214 |
215 | For the example $n = 512, m = 128$ and batch 1, Fig. \ref{fig8} shows timing diagrams without and with deferred normalization:
216 | \begin{itemize}[topsep=-1pt]
217 | \item Without deferred normalization, the matrix unit has to wait for 8 cycles until the vector unit has calculated the RMS value and completed the scaling by $1/ \rms$ as illustrated in Fig. \ref{fig8}(a).
218 | \item As shown in Fig. \ref{fig8}(b), it is possible to start the matrix unit 3 cycles earlier if the weight matrix $\mat{W}$ is processed in row-major order for example. But the RMS calculation still presents a bottleneck.
219 | \item FlashNorm eliminates this bottleneck: With deferred normalization, the matrix unit computes the vector-matrix multiplication in parallel to the vector unit's RMS calculation as shown in Fig. \ref{fig8}(c). The scaling at the end can be performed in parallel to the matrix unit if $\mat{W}$ is processed in column-major order for example.
220 | \end{itemize}
221 |
222 | \begin{figure}[h!] \centering
223 | \includegraphics[scale=1.0]{../doc/fig/flashNorm_fig8.pdf}
224 | \caption{Timing diagrams for $n = 512, m = 128$: (a) without deferred normalization; (b) with interleaved scaling and vector-matrix multiplication; (c) with deferred normalization.}
225 | \label{fig8} \end{figure}
226 |
227 | \section{Experiments and conclusions}
228 | Refer to \citep{hfFlashNorm, tricks} for Python code that demonstrates the mathematical equivalency of the optimizations presented in this paper. The overall speedup of FlashNorm is modest: We measured a throughput of 204 tokens per second for OpenELM-270M with 4-bit weight quantization using the MLX framework on an M1 MacBook Air. This throughput increases to only 225 tokens per second when we remove RMSNorm entirely. Therefore, the maximum possible speedup of any RMSNorm optimization is $\leq$ 10\% for this model.
229 |
230 | For many applications, the main advantage of FlashNorm is simplification. This is similar to the simplifications we get from using RMSNorm over Layer Normalization (LayerNorm \citep{layerNorm}), and from PaLM's removal of bias-parameters from all linear layers \citep{PaLM}.
231 |
232 | Future work includes integrating FlashNorm into popular frameworks such as HuggingFace Transformers \citep{HFtransformers}, whisper.cpp \citep{whisper-cpp}, llama.cpp \citep{llama-cpp}, vLLM \citep{vLLM}, llamafile \citep{llamafile}, LM Studio \citep{lmstudio}, Ollama \citep{ollama}, SGLang \citep{sglang}, and combining it with parameter quantization.
233 |
234 | \section*{Acknowledgments}
235 | We would like to thank Dmitry Belenko for helpful feedback on this work.
236 |
237 | \appendix
238 |
239 | \section{RMS with epsilon}
240 | Many implementations add a small epsilon $\epsilon$ to the RMS value to limit the resulting scaling factor $1/\rms$ and to avoid division by zero as follows:
241 | \begin{equation*}
242 | \text{RMSe}(\a) = \sqrt{\epsilon + \f1n \sas} = \sqrt{\epsilon + \left( \rms \right)^2}
243 | \end{equation*}
244 |
245 | $\text{RMSe}(\a)$ can be used as a drop-in-replacement for RMS. The popular HuggingFace transformer library calls this epsilon \verb+rms_norm_eps+, which is set to $10^{-5}$ for Llama3.
246 |
247 | \section{Eliminating $1/n$}
248 | This section details a small optimization that eliminates the constant term $1/n$ from the RMS calculation. First, we factor out $1/n$ as follows:
249 | \begin{equation*}
250 | \rms = \sqrt{\f1n \sas} = \sqrt{\f1n} \sqrt{\sas} = \sqrt{\f1n} \cdot \text{RSS}(\a)
251 | \end{equation*}
252 | where $\text{RSS}(\a) = \sqrt{\sas}$. We can now merge the constant term into the normalization weights $g_i$ as follows:
253 | \begin{equation*}
254 | y_i = \frac{a_i}{\rms} \cdot g_i =
255 | \frac{a_i}{\text{RSS}(\a)} \sqrt{n} \cdot g_i =
256 | \frac{a_i}{\text{RSS}(\a)} \cdot g_i^\ast
257 | \end{equation*}
258 | with new normalization weights $g_i^\ast = \sqrt{n} \cdot g_i$ . These new normalization weights can now be merged with the weights $\mat{W}$ of the following linear layer as shown in the previous sections. This optimization also applies for the case where we add an epsilon as detailed in the previous section. In this case, we factor out $1/n$ as follows:
259 | \begin{equation*}
260 | \text{RMSe}(\a) = \sqrt{\epsilon + \f1n \sas}
261 | = \sqrt{\f1n \left( n \epsilon + \sas \right)}
262 | %= \sqrt{\f1n} \sqrt{n \epsilon + \sas}
263 | = \sqrt{\f1n} \cdot \text{RSSe}(\a)
264 | \end{equation*}
265 | where $\text{RSSe}(\a) = \sqrt{n \epsilon + \sas}$.
266 |
267 | \bibliographystyle{unsrtnat}
268 | \bibliography{references}
269 |
270 | \end{document}
271 |
--------------------------------------------------------------------------------
/tex/matShrink.tex:
--------------------------------------------------------------------------------
1 | % To generate PDF, type ./run matShrink.tex
2 |
3 | \documentclass{article}
4 | \usepackage{arxiv}
5 | \usepackage[numbers]{natbib} % for author-year citation style: \usepackage{natbib}
6 |
7 | \usepackage{tablefootnote}
8 |
9 | % shortcuts
10 | \newcommand{\WW}[1]{W_\text{#1}} % for W_\text{...}
11 | \newcommand{\eR}[2]{$\in \mathbb{R}^{#1 \times #2}$} % element of R^{1x2}
12 | \newcommand{\mc}[2]{\multicolumn{#1}{c}{#2}} % table multicolumn
13 | \def\fline{\Xhline{2\arrayrulewidth}} % fat-line for table
14 |
15 | \title{[work in progress]: \\ Matrix-shrink for transformers without loss of accuracy}
16 |
17 | \author{Nils Graef\thanks{\texttt{info@openmachine.ai}}, TBD \\
18 | \href{https://openmachine.ai}{OpenMachine}}
19 |
20 | \begin{document} \maketitle
21 |
22 | \begin{abstract}
23 | Matrix-shrink reduces the number of weights for back-to-back matrices. It uses matrix inversion to reduce weights in a mathematically equivalent way and thus without compromising model accuracy. Matrix-shrink is applicable to both inference and training. It can be used for inference of existing models without fine-tuning or re-training. We also propose a simplified MLA (multi-head latent attention) scheme. See \citep{tricks} for code and more transformer tricks.
24 | \end{abstract}
25 |
26 | For two back-to-back weight matrices $W_A$ and $W_B$, Fig. \ref{fig1} illustrates how we can reduce the size of $W_B$ in a mathematically equivalent way by using matrix inversion.
27 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
28 | \includegraphics[scale=0.88]{../doc/fig/matShrink_fig1.pdf}
29 | \caption{Mathematically equivalent implementations of two back-to-back weight matrices $W_A$ and $W_B$ with rank $r$, where $d > r$ and $e > r$. We can split $W_B$ into two submatrices $W_{B1}$ \eR{r}{r} and $W_{B2}$. We can eliminate $W_{B1}$ if it is invertible by merging it into $W_A$ as $W_A^\ast = W_A W_{B1}$ and by changing $W_{B2}$ to $W_{B2}^\ast = W_{B1}^{-1} W_{B2}$. This saves $r^2$ weights and $r^2$ multiply operations per token $x$.}
30 | \label{fig1} \end{figure}
31 |
32 | Matrix-shrink reduces the number of weights for the following back-to-back weight matrices:
33 |
34 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
35 | \item The V and O projections for each attention-head
36 | \item The Q and K projections for each attention-head (without the RoPE portion)
37 | \item The latent projections of MLA (multi-head latent attention)
38 | \end{itemize}
39 | \textbf{Related work.} Matrix-shrink is similar to slim attention \citep{slimAttn} in its use of matrix inversion to compute projections from each other. TODO: also mention papers about matrix approximation / compression schemes such as SVD and others.
40 |
41 | \textbf{Alternative way.} Alternatively, we can split matrix $W_A$ into two submatrices $W_{A1}$ \eR{r}{r} and $W_{A2}$ such that $W_A = [W_{A1}; W_{A2}]$. We can then eliminate $W_{A1}$ if it is invertible as $W = [W_{A1}; W_{A2}] W_B = [I; W_{A2}^\ast] W_B^\ast$ with identity matrix $I$ \eR{r}{r} and where $W_B^\ast = W_{A1} W_B$ and $W_{A2}^\ast = W_{A2} W_{A1}^{-1}$, see Fig. \ref{fig2}.
42 | \begin{figure}[h!] \centering
43 | \includegraphics[scale=0.88]{../doc/fig/matShrink_fig2.pdf}
44 | \caption{Alternative way of shrinking $W_A$ instead of $W_B$}
45 | \label{fig2} \end{figure}
46 |
47 | \section{Matrix-shrink for MHA}
48 | Note that the value (V) and output (O) projections for head $i$ of multi-head attention (MHA) are two back-to-back weight matrices $W_{V,i}$ and $W_{O,i}$. Therefore, we can apply the matrix-shrink scheme to each head. Specifically:
49 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
50 | \item For the vanilla MHA with $h$ heads, each head has dimension $d_k = d / h$, and $d = d_\text{model}$.
51 | \item So for the dimensions $r$ and $e$ of Fig. \ref{fig1}, we have $r = d / h$ and $e = d$.
52 | \item This saves $r^2 = d^2 / h^2$ weights for each head, so $d^2 / h$ weights in total.
53 | \item Note: for single-head attention (where $h = 1$), we can save $2 d^2$ weights (i.e. we can merge the V and O weight matrices into a single $d \times d$ matrix; and the Q and K weight matrices into a single $d \times d$ matrix if there is no RoPE).
54 | \end{itemize}
55 |
56 | For models that don’t use RoPE (such as Whisper or T5 models), the query (Q) and key (K) projections for each head $i$ of MHA are two back-to-back weight matrices $W_{Q,i}$ and $W_{K,i}$. As with V-O weight matrices, this saves $d^2 / h$ weights.
57 |
58 | For many models that use RoPE, we can also apply this trick as follows: Many implementations apply RoPE to only a portion of the head-dimension $d_k = d / h$, usually only to one half of $d_k$. So in this case $r = d_k / 2 = d / (2h)$, which saves only $r^2 = d^2 / (4h^2)$ weights for each head, so $d^2 / (4h)$ weights in total.
59 |
60 | \begingroup \renewcommand{\arraystretch}{1.3} % increase table row height by 1.3x
61 | \begin{table}[h!] \centering
62 | \begin{tabular}{lcccccc} \fline
63 | \thead[l]{Model} & \thead{$d$} & \thead{$d_k$} & \thead{$h$} & \thead{weights \\ $d \times (d_k h)$} & \thead{savings \\ $d_k^2 h$} & \thead{savings \\ \%} \\ \hline
64 | Whisper-tiny & 384 & 64 & 6 & 147K & 25K & 17\% \\
65 | CodeGemma-7B & 3,072 & 256 & 16 & 12.6M & 1.0M & 8\% \\
66 | T5-3B & 1,024 & 128 & 32 & 4.2M & 0.5M & 12\% \\
67 | T5-11B & 1,024 & 128 & 128 & 16.8M & 2.1M & 13\% \\ \fline
68 | \end{tabular} \end{table} \endgroup
69 |
70 | \section{Matrix-shrink for MLA}
71 | DeepSeek's MLA (multi-head latent attention) scheme \citep{deepseek-v2} has two latent projections, one for Q (queries) and one for KV (keys and values). We can apply matrix-shrink to each of them:
72 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
73 | \item The Q-latent projection and query (Q) projections are two back-to-back weight matrices $\WW{DQ}$ and $\WW{UQ}$.
74 | \item The KV-latent projection and key/value (KV) projections are two back-to-back weight matrices $\WW{DKV}$ and the union of $\WW{UK}$ and $\WW{UV}$.
75 | \end{itemize}
76 | We can also apply matrix-shrink to each V-O head and the non-RoPE portion of the Q-K heads. Specifically, we can apply the matrix-shrink to the MLA weight matrices in the following order:
77 | \begin{enumerate}[topsep=-1pt, itemsep=-1pt]
78 | \item Apply matrix-shrink to the V-O weight matrices.
79 | \item Apply matrix-shrink to the NoPE portion (i.e. the non-RoPE portion) of the Q-K weight matrices.
80 | \item Apply matrix-shrink to the Q-latent projections. This step must be done after applying matrix-shrink to the Q-K weights.
81 | \item Apply matrix-shrink to the KV-latent projections. This step must be done after applying matrix-shrink to the V-O weights.
82 | \end{enumerate}
83 |
84 | Applying matrix-shrink to the KV-latent projections not only reduces weight matrices and corresponding compute, it can also reduce the compute complexity as follows, where $r_\text{KV}$ is the rank of the KV-latent projections.
85 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
86 | \item Option 1: Use the $r_\text{KV}$ neurons that don’t require a weight matrix as keys. The number of those keys is $r_\text{KV} / d_\text{NOPE}$. Then these keys can be directly used for the softmax arguments, which saves some computation complexity.
87 | \item Option 2: Use the $r_\text{KV}$ neurons as values (instead of keys). Then these values can be directly multiplied with the softmax scores, which saves some compute complexity.
88 | \end{itemize}
89 |
90 | We are using the following parameter names similar to \citep{deepseek-v2}:
91 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
92 | \item For Q (query):
93 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
94 | \item $r_\text{Q}$: rank of Q-latent projection
95 | \item $\WW{DQ}$: down-projection for Q
96 | \item $\WW{UQ}$: up-projection for Q-part without RoPE (aka NoPE)
97 | \item $\WW{QR}$: up-projection for Q-part with RoPE
98 | \end{itemize}
99 | \item For KV (key-value):
100 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
101 | \item $r_\text{KV}$: rank of KV-latent projection
102 | \item $\WW{KR}$: projection for K-part with RoPE (has its own cache, used for all queries as MQA)
103 | \item $\WW{DKV}$: down-projection for KV
104 | \item $\WW{UK}$: up-projection for K-part without RoPE (aka NoPE)
105 | \item $\WW{UV}$: up-projection for V
106 | \end{itemize}
107 | \end{itemize}
108 |
109 | % shortcuts (only letters are allowed in macro names, no numbers and dashes)
110 | \def\dsRone {\href{https://huggingface.co/deepseek-ai/DeepSeek-R1} {DeepSeek-R1}}
111 | \def\pplRone {\href{https://huggingface.co/perplexity-ai/r1-1776} {R1-1776}}
112 | \def\dsVthree {\href{https://huggingface.co/deepseek-ai/DeepSeek-V3} {V3}}
113 | \def\dsVtwoFive {\href{https://huggingface.co/deepseek-ai/DeepSeek-V2.5} {DeepSeek-V2.5}}
114 | \def\dsVtwoL {\href{https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite} {DeepSeek-V2-lite}}
115 | \def\dsVLtwoS {\href{https://huggingface.co/deepseek-ai/deepseek-vl2-small} {DeepSeek-VL2-small}}
116 | \def\MiniCPM {\href{https://huggingface.co/openbmb/MiniCPM3-4B} {MiniCPM3-4B}}
117 |
118 | \begingroup \renewcommand{\arraystretch}{1.3} % increase table row height by 1.3x
119 | \begin{table}[h!] \centering
120 | \begin{tabular}{lcccccccc} \fline
121 | \thead[l]{Model} & \thead{Params} & $d$ & $r_\text{Q}$ & $r_\text{KV}$ & $h$ & $d_\text{NOPE}$ & $h \cdot d_\text{NOPE}$ & $d_\text{ROPE}$ \\ \hline
122 | Perplexity \pplRone, \dsRone, and \dsVthree & 685B & 7,168 & 1,536 & 512 & 128 & 128 & 16,384 & 64 \\
123 | \dsVtwoFive & 236B & 5,120 & 1,536 & 512 & 128 & 128 & 16,384 & 64 \\
124 | \dsVtwoL, \dsVLtwoS & 16B & 2,048 & N/A & 512 & 16 & 128 & 2,048 & 64 \\
125 | OpenBMB \MiniCPM & 4B & 2,560 & 768 & 256 & 40 & 64 & 2,560 & 32 \\ \fline
126 | \end{tabular} \end{table} \endgroup
127 |
128 | TODO: add savings to the table above (or a new table)
129 |
130 | \section{Simplified MLA}
131 | In this section we propose a simplification for DeepSeek’s MLA (multi-head latent attention).
132 | \begin{figure}[h!] \centering
133 | \includegraphics[scale=0.88]{../doc/fig/matShrink_fig3.pdf}
134 | \caption{K and V projections for MLA. (a) original version; (b) equivalent version optimized by matrix-shrink; (c) proposed simplification}
135 | \label{fig3} \end{figure}
136 |
137 | Fig. \ref{fig3} shows the K and V projections of MLA and the proposed simplification:
138 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
139 | \item Fig. \ref{fig3}(a) shows the MLA projections for K (keys) and V (values). Note that a single $d_\text{ROPE}$ head is shared among all query-heads, where $d_\text{ROPE} = 64$ or $32$ usually.
140 | \item Fig. \ref{fig3}(b) shows the mathematically equivalent version with matrix-shrink applied to the weight matrices $\WW{DKV}$ and $\WW{UK}$.
141 | \item Fig. \ref{fig3}(c) shows the proposed simplified MLA scheme where the $d_\text{ROPE}$ units (or channels) are sourced directly from the latent cache, instead of having a separate cache and $\WW{KR}$:
142 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
143 | \item Note that this simplified scheme is not mathematically identical to the standard MLA scheme shown in Fig. \ref{fig3}(a).
144 | \item The rank $s$ of the simplified scheme could be larger than $r$ (e.g. $s = r + d_\text{ROPE}$) or slightly lower than this (e.g. $s = r$).
145 | \item Advantages include: If $s > r$, then there is more usable rank for the keys and values. So the cached latent space is better utilized. And if $s < r + d_\text{ROPE}$ then the total cache size is reduced.
146 | \end{itemize}
147 | \end{itemize}
148 |
149 | \section{Matrix-shrink for GQA and MQA}
150 | Matrix-shrink is not limited to MHA and MLA only. It’s also applicable to GQA (grouped query attention) and MQA (multi-query attention). However, the savings are smaller than for MHA and MLA. Specifically, the savings are reduced by a factor $g$, where $g$ is the number of queries that are shared among a single KV-pair, or in other words $g = n_\text{heads} / n_\text{KV-heads}$ (where $n_\text{heads}$ is the number of query-heads, and $n_\text{KV-heads}$ is the number of KV-heads).
151 |
152 | \section{Matrix-shrink for SVD}
153 | In some cases, we can first use SVD (singular value decomposition) to compress the rank of any weight matrix $W$ by a certain percentage. This is applicable for example for the large weight matrices of the transformer’s FFN (feedforward networks). The SVD decomposition factorizes the original matrix $W$ \eR{d}{e} into two matrices $W_A$ and $W_B$ where $r$ is the compressed rank. After performing SVD and compressing the rank by a certain percentage, we can then eliminate $r^2$ weights using our matrix-shrink scheme. Note that reducing the rank by a certain percentage is not an exact implementation of the original matrix $W$ but an approximation.
154 |
155 | %\section{Conclusion}
156 | %Slim attention offers a simple trick for halving the context memory of existing MHA transformer models without sacrificing accuracy. Future work includes integrating slim attention into popular frameworks such as HuggingFace Transformers \citep{HFtransformers}, llama.cpp \citep{llama-cpp}, vLLM \citep{vLLM}, llamafile \citep{llamafile}, Ollama \citep{ollama}, SGLang \citep{sglang}, and combining it with existing context memory management schemes such as PagedAttention \citep{pagedAttn} and other compression schemes such as Dynamic Memory Compression DMC \citep{DMC} and VL-cache \citep{VL-cache}.
157 |
158 | %\section*{Acknowledgments}
159 | %We would like to thank TBD for helpful feedback on this work.
160 |
161 | \bibliographystyle{unsrtnat}
162 | \bibliography{references}
163 |
164 | \end{document}
165 |
--------------------------------------------------------------------------------
/tex/neurips_2025.sty:
--------------------------------------------------------------------------------
1 | % I downloaded this file from NeurIPS website:
2 | % https://media.neurips.cc/Conferences/NeurIPS2025/Styles.zip
3 |
4 | % partial rewrite of the LaTeX2e package for submissions to the
5 | % Conference on Neural Information Processing Systems (NeurIPS):
6 | %
7 | % - uses more LaTeX conventions
8 | % - line numbers at submission time replaced with aligned numbers from
9 | % lineno package
10 | % - \nipsfinalcopy replaced with [final] package option
11 | % - automatically loads times package for authors
12 | % - loads natbib automatically; this can be suppressed with the
13 | % [nonatbib] package option
14 | % - adds foot line to first page identifying the conference
15 | % - adds preprint option for submission to e.g. arXiv
16 | % - conference acronym modified
17 | %
18 | % Roman Garnett (garnett@wustl.edu) and the many authors of
19 | % nips15submit_e.sty, including MK and drstrip@sandia
20 | %
21 | % last revision: April 2025
22 |
23 | \NeedsTeXFormat{LaTeX2e}
24 | \ProvidesPackage{neurips_2025}[2025/04/02 NeurIPS 2025 submission/camera-ready style file]
25 |
26 | % declare final option, which creates camera-ready copy
27 | \newif\if@neuripsfinal\@neuripsfinalfalse
28 | \DeclareOption{final}{
29 | \@neuripsfinaltrue
30 | }
31 |
32 | % declare nonatbib option, which does not load natbib in case of
33 | % package clash (users can pass options to natbib via
34 | % \PassOptionsToPackage)
35 | \newif\if@natbib\@natbibtrue
36 | \DeclareOption{nonatbib}{
37 | \@natbibfalse
38 | }
39 |
40 | % declare preprint option, which creates a preprint version ready for
41 | % upload to, e.g., arXiv
42 | \newif\if@preprint\@preprintfalse
43 | \DeclareOption{preprint}{
44 | \@preprinttrue
45 | }
46 |
47 | \ProcessOptions\relax
48 |
49 | % determine whether this is an anonymized submission
50 | \newif\if@submission\@submissiontrue
51 | \if@neuripsfinal\@submissionfalse\fi
52 | \if@preprint\@submissionfalse\fi
53 |
54 | % fonts
55 | \renewcommand{\rmdefault}{ptm}
56 | \renewcommand{\sfdefault}{phv}
57 |
58 | % change this every year for notice string at bottom
59 | \newcommand{\@neuripsordinal}{39th}
60 | \newcommand{\@neuripsyear}{2025}
61 | \newcommand{\@neuripslocation}{San Diego}
62 |
63 | % acknowledgments
64 | \usepackage{environ}
65 | \newcommand{\acksection}{\section*{Acknowledgments and Disclosure of Funding}}
66 | \NewEnviron{ack}{%
67 | \acksection
68 | \BODY
69 | }
70 |
71 |
72 | % load natbib unless told otherwise
73 | \if@natbib
74 | \RequirePackage{natbib}
75 | \fi
76 |
77 | % set page geometry
78 | \usepackage[verbose=true,letterpaper]{geometry}
79 | \AtBeginDocument{
80 | \newgeometry{
81 | textheight=9in,
82 | textwidth=5.5in,
83 | top=1in,
84 | headheight=12pt,
85 | headsep=25pt,
86 | footskip=30pt
87 | }
88 | \@ifpackageloaded{fullpage}
89 | {\PackageWarning{neurips_2025}{fullpage package not allowed! Overwriting formatting.}}
90 | {}
91 | }
92 |
93 | \widowpenalty=10000
94 | \clubpenalty=10000
95 | \flushbottom
96 | \sloppy
97 |
98 |
99 | % font sizes with reduced leading
100 | \renewcommand{\normalsize}{%
101 | \@setfontsize\normalsize\@xpt\@xipt
102 | \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@
103 | \abovedisplayshortskip \z@ \@plus 3\p@
104 | \belowdisplayskip \abovedisplayskip
105 | \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@
106 | }
107 | \normalsize
108 | \renewcommand{\small}{%
109 | \@setfontsize\small\@ixpt\@xpt
110 | \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@
111 | \abovedisplayshortskip \z@ \@plus 2\p@
112 | \belowdisplayskip \abovedisplayskip
113 | \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@
114 | }
115 | \renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt}
116 | \renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt}
117 | \renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt}
118 | \renewcommand{\large}{\@setfontsize\large\@xiipt{14}}
119 | \renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}}
120 | \renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}}
121 | \renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}}
122 | \renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}}
123 |
124 | % sections with less space
125 | \providecommand{\section}{}
126 | \renewcommand{\section}{%
127 | \@startsection{section}{1}{\z@}%
128 | {-2.0ex \@plus -0.5ex \@minus -0.2ex}%
129 | { 1.5ex \@plus 0.3ex \@minus 0.2ex}%
130 | {\large\bf\raggedright}%
131 | }
132 | \providecommand{\subsection}{}
133 | \renewcommand{\subsection}{%
134 | \@startsection{subsection}{2}{\z@}%
135 | {-1.8ex \@plus -0.5ex \@minus -0.2ex}%
136 | { 0.8ex \@plus 0.2ex}%
137 | {\normalsize\bf\raggedright}%
138 | }
139 | \providecommand{\subsubsection}{}
140 | \renewcommand{\subsubsection}{%
141 | \@startsection{subsubsection}{3}{\z@}%
142 | {-1.5ex \@plus -0.5ex \@minus -0.2ex}%
143 | { 0.5ex \@plus 0.2ex}%
144 | {\normalsize\bf\raggedright}%
145 | }
146 | \providecommand{\paragraph}{}
147 | \renewcommand{\paragraph}{%
148 | \@startsection{paragraph}{4}{\z@}%
149 | {1.5ex \@plus 0.5ex \@minus 0.2ex}%
150 | {-1em}%
151 | {\normalsize\bf}%
152 | }
153 | \providecommand{\subparagraph}{}
154 | \renewcommand{\subparagraph}{%
155 | \@startsection{subparagraph}{5}{\z@}%
156 | {1.5ex \@plus 0.5ex \@minus 0.2ex}%
157 | {-1em}%
158 | {\normalsize\bf}%
159 | }
160 | \providecommand{\subsubsubsection}{}
161 | \renewcommand{\subsubsubsection}{%
162 | \vskip5pt{\noindent\normalsize\rm\raggedright}%
163 | }
164 |
165 | % float placement
166 | \renewcommand{\topfraction }{0.85}
167 | \renewcommand{\bottomfraction }{0.4}
168 | \renewcommand{\textfraction }{0.1}
169 | \renewcommand{\floatpagefraction}{0.7}
170 |
171 | \newlength{\@neuripsabovecaptionskip}\setlength{\@neuripsabovecaptionskip}{7\p@}
172 | \newlength{\@neuripsbelowcaptionskip}\setlength{\@neuripsbelowcaptionskip}{\z@}
173 |
174 | \setlength{\abovecaptionskip}{\@neuripsabovecaptionskip}
175 | \setlength{\belowcaptionskip}{\@neuripsbelowcaptionskip}
176 |
177 | % swap above/belowcaptionskip lengths for tables
178 | \renewenvironment{table}
179 | {\setlength{\abovecaptionskip}{\@neuripsbelowcaptionskip}%
180 | \setlength{\belowcaptionskip}{\@neuripsabovecaptionskip}%
181 | \@float{table}}
182 | {\end@float}
183 |
184 | % footnote formatting
185 | \setlength{\footnotesep }{6.65\p@}
186 | \setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@}
187 | \renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@}
188 | \setcounter{footnote}{0}
189 |
190 | % paragraph formatting
191 | \setlength{\parindent}{\z@}
192 | \setlength{\parskip }{5.5\p@}
193 |
194 | % list formatting
195 | \setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@}
196 | \setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@}
197 | \setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
198 | \setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
199 | \setlength{\leftmargin }{3pc}
200 | \setlength{\leftmargini }{\leftmargin}
201 | \setlength{\leftmarginii }{2em}
202 | \setlength{\leftmarginiii}{1.5em}
203 | \setlength{\leftmarginiv }{1.0em}
204 | \setlength{\leftmarginv }{0.5em}
205 | \def\@listi {\leftmargin\leftmargini}
206 | \def\@listii {\leftmargin\leftmarginii
207 | \labelwidth\leftmarginii
208 | \advance\labelwidth-\labelsep
209 | \topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@
210 | \parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
211 | \itemsep \parsep}
212 | \def\@listiii{\leftmargin\leftmarginiii
213 | \labelwidth\leftmarginiii
214 | \advance\labelwidth-\labelsep
215 | \topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
216 | \parsep \z@
217 | \partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@
218 | \itemsep \topsep}
219 | \def\@listiv {\leftmargin\leftmarginiv
220 | \labelwidth\leftmarginiv
221 | \advance\labelwidth-\labelsep}
222 | \def\@listv {\leftmargin\leftmarginv
223 | \labelwidth\leftmarginv
224 | \advance\labelwidth-\labelsep}
225 | \def\@listvi {\leftmargin\leftmarginvi
226 | \labelwidth\leftmarginvi
227 | \advance\labelwidth-\labelsep}
228 |
229 | % create title
230 | \providecommand{\maketitle}{}
231 | \renewcommand{\maketitle}{%
232 | \par
233 | \begingroup
234 | \renewcommand{\thefootnote}{\fnsymbol{footnote}}
235 | % for perfect author name centering
236 | \renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}}
237 | % The footnote-mark was overlapping the footnote-text,
238 | % added the following to fix this problem (MK)
239 | \long\def\@makefntext##1{%
240 | \parindent 1em\noindent
241 | \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1
242 | }
243 | \thispagestyle{empty}
244 | \@maketitle
245 | \@thanks
246 | \@notice
247 | \endgroup
248 | \let\maketitle\relax
249 | \let\thanks\relax
250 | }
251 |
252 | % rules for title box at top of first page
253 | \newcommand{\@toptitlebar}{
254 | \hrule height 4\p@
255 | \vskip 0.25in
256 | \vskip -\parskip%
257 | }
258 | \newcommand{\@bottomtitlebar}{
259 | \vskip 0.29in
260 | \vskip -\parskip
261 | \hrule height 1\p@
262 | \vskip 0.09in%
263 | }
264 |
265 | % create title (includes both anonymized and non-anonymized versions)
266 | \providecommand{\@maketitle}{}
267 | \renewcommand{\@maketitle}{%
268 | \vbox{%
269 | \hsize\textwidth
270 | \linewidth\hsize
271 | \vskip 0.1in
272 | \@toptitlebar
273 | \centering
274 | {\LARGE\bf \@title\par}
275 | \@bottomtitlebar
276 | \if@submission
277 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}
278 | Anonymous Author(s) \\
279 | Affiliation \\
280 | Address \\
281 | \texttt{email} \\
282 | \end{tabular}%
283 | \else
284 | \def\And{%
285 | \end{tabular}\hfil\linebreak[0]\hfil%
286 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
287 | }
288 | \def\AND{%
289 | \end{tabular}\hfil\linebreak[4]\hfil%
290 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
291 | }
292 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}%
293 | \fi
294 | \vskip 0.3in \@minus 0.1in
295 | }
296 | }
297 |
298 | % add conference notice to bottom of first page
299 | \newcommand{\ftype@noticebox}{8}
300 | \newcommand{\@notice}{%
301 | % give a bit of extra room back to authors on first page
302 | \enlargethispage{2\baselineskip}%
303 | \@float{noticebox}[b]%
304 | \footnotesize\@noticestring%
305 | \end@float%
306 | }
307 |
308 | % abstract styling
309 | \renewenvironment{abstract}%
310 | {%
311 | \vskip 0.075in%
312 | \centerline%
313 | {\large\bf Abstract}%
314 | \vspace{0.5ex}%
315 | \begin{quote}%
316 | }
317 | {
318 | \par%
319 | \end{quote}%
320 | \vskip 1ex%
321 | }
322 |
323 | % For the paper checklist
324 | \newcommand{\answerYes}[1][]{\textcolor{blue}{[Yes] #1}}
325 | \newcommand{\answerNo}[1][]{\textcolor{orange}{[No] #1}}
326 | \newcommand{\answerNA}[1][]{\textcolor{gray}{[NA] #1}}
327 | \newcommand{\answerTODO}[1][]{\textcolor{red}{\bf [TODO]}}
328 | \newcommand{\justificationTODO}[1][]{\textcolor{red}{\bf [TODO]}}
329 |
330 | % handle tweaks for camera-ready copy vs. submission copy
331 | \if@preprint
332 | \newcommand{\@noticestring}{%
333 | Preprint. Under review.%
334 | }
335 | \else
336 | \if@neuripsfinal
337 | \newcommand{\@noticestring}{%
338 | \@neuripsordinal\/ Conference on Neural Information Processing Systems
339 | (NeurIPS \@neuripsyear).%, \@neuripslocation.%
340 | }
341 | \else
342 | \newcommand{\@noticestring}{%
343 | Submitted to \@neuripsordinal\/ Conference on Neural Information
344 | Processing Systems (NeurIPS \@neuripsyear). Do not distribute.%
345 | }
346 |
347 | % hide the acknowledgements
348 | \NewEnviron{hide}{}
349 | \let\ack\hide
350 | \let\endack\endhide
351 |
352 | % line numbers for submission
353 | \RequirePackage{lineno}
354 | \linenumbers
355 |
356 | % fix incompatibilities between lineno and amsmath, if required, by
357 | % transparently wrapping linenomath environments around amsmath
358 | % environments
359 | \AtBeginDocument{%
360 | \@ifpackageloaded{amsmath}{%
361 | \newcommand*\patchAmsMathEnvironmentForLineno[1]{%
362 | \expandafter\let\csname old#1\expandafter\endcsname\csname #1\endcsname
363 | \expandafter\let\csname oldend#1\expandafter\endcsname\csname end#1\endcsname
364 | \renewenvironment{#1}%
365 | {\linenomath\csname old#1\endcsname}%
366 | {\csname oldend#1\endcsname\endlinenomath}%
367 | }%
368 | \newcommand*\patchBothAmsMathEnvironmentsForLineno[1]{%
369 | \patchAmsMathEnvironmentForLineno{#1}%
370 | \patchAmsMathEnvironmentForLineno{#1*}%
371 | }%
372 | \patchBothAmsMathEnvironmentsForLineno{equation}%
373 | \patchBothAmsMathEnvironmentsForLineno{align}%
374 | \patchBothAmsMathEnvironmentsForLineno{flalign}%
375 | \patchBothAmsMathEnvironmentsForLineno{alignat}%
376 | \patchBothAmsMathEnvironmentsForLineno{gather}%
377 | \patchBothAmsMathEnvironmentsForLineno{multline}%
378 | }
379 | {}
380 | }
381 | \fi
382 | \fi
383 |
384 |
385 | \endinput
386 |
--------------------------------------------------------------------------------
/tex/precomp1stLayer.tex:
--------------------------------------------------------------------------------
1 | % To generate PDF, type ./run precomp1stLayer
2 |
3 | \documentclass{article}
4 | \usepackage{arxiv}
5 | \usepackage[numbers]{natbib} % for author-year citation style: \usepackage{natbib}
6 |
7 | \title{Transformer tricks: Precomputing the first layer}
8 |
9 | \author{Nils Graef \\ \href{https://openmachine.ai}{OpenMachine},
10 | South San Francisco, CA 94080, \texttt{info@openmachine.ai}}
11 |
12 | \begin{document} \maketitle
13 |
14 | \begin{abstract}
15 | This micro-paper \cite{micro-paper} describes a trick to speed up inference of transformers with RoPE \citep{RoPE} (such as LLaMA, Mistral, PaLM, and Gemma \citep{gemma}). For these models, a large portion of the first transformer layer can be precomputed, which results in slightly lower latency and lower cost-per-token.
16 | Because this trick optimizes only one layer, the relative savings depend on the total number of layers. For example, the maximum savings for a model with only 4 layers (such as Whisper tiny \citep{Whisper}) is limited to 25\%, while a 32-layer model is limited to 3\% savings. See \citep{tricks} for code and more transformer tricks. \\
17 | The next two sections detail the precompute for transformers with parallel attention/FFN \citep{parallel} (such as GPT-J, Pythia, and PaLM \citep{parallel, Pythia, PaLM}) and without (such as Llama 2, Mistral, and Mixtral \citep{LLaMA, Llama2, Mistral, Mixtral}).
18 | \end{abstract}
19 |
20 | \section{Precompute for parallel transformers}
21 | Figure \ref{fig1}(a) shows the first layer of a transformer with RoPE and parallel attention/FFN. Because the inputs of Q, K, V, and FFN only depend on the embedding, we can precompute their outputs and store them in memory instead of the input embeddings, see Figure \ref{fig1}(b).
22 |
23 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
24 | \includegraphics[scale=0.86]{../doc/fig/precomp1stLayer_fig1.pdf}
25 | \caption{First layer of parallel transformer (a) without precompute; and (b) with precompute of FFN and linear layers Q, K, and V.}
26 | \label{fig1} \end{figure}
27 |
28 | Figure \ref{fig1} uses the following dimensions, based on the type of attention such as multi-head attention (MHA) \citep{vanilla}, multi-query attention (MQA) \citep{MQA}, and grouped-query attention (GQA) \citep{GQA}:
29 |
30 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
31 | \item $d$: embedding dimension.
32 | \item $e$: $e = d$ for MHA. For MQA, $e = d / n_{heads}$. And for GQA, $e = d \cdot n_{kv\_heads} / n_{heads}$.
33 | \item Q, K, V, P are the linear layers for query, keys, values, and post-attention projection.
34 | \item FFN (feedforward network) is usually a two-layer MLP (multi-layer perceptron). Mistral and Llama2 use a two-layer MLP with a GLU variant \citep{GLU} for the first layer. And MoE models (mixture-of-experts) \citep{MoE} such as Mixtral use a switch FFN.
35 | \item The embedding layer is implemented by a simple memory read operation, where the token-ID provides the read-address to read $d$ values from memory.
36 | \end{itemize}
37 |
38 | The precompute is done as follows: For each token stored in the embedding table, perform the calculations needed for the first layer normalization, FFN, skip-connection, and linear layers Q, K, V, and store the results in memory instead of the original input-embeddings. This precompute is done offline only once and stored in the parameter memory (along with weights, biases, and output-embeddings).
39 |
40 | The benefits of precompute include:
41 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
42 | \item \textbf{Lower computational complexity per token}: For each token, we save the operations needed for FFN and the linear layers Q, K, V. This can speed up inference if the system is limited by compute.
43 | \item \textbf{Fewer memory reads for low batch sizes}: This can speed up inference for systems that are memory bandwidth limited, especially during the autoregressive next-token-prediction phase, see the table below and section \ref{sec:examples} for examples.
44 | \end{itemize}
45 |
46 | \begingroup \renewcommand{\arraystretch}{1.3} % increase table row height by 1.3x
47 | \begin{center} \begin{tabular}{r|l|l}
48 | & \textbf{Without precompute} & \textbf{With precompute} \\ \hline
49 | & \makecell[l]{1) For each token, read $d$ embedding values \\
50 | 2) Plus, for each batch, read weights for Q, K, V, FFN}
51 | & \makecell[l]{For each token, read $2(d+e)$ \\ precomputed values} \\ \hline
52 | \makecell[l]{Reads per batch: \\ ($B$ is batch-size)} & $B \cdot d + \verb+num_weights_Q_K_V_FFN+$ & $B \cdot 2(d+e)$
53 | \end{tabular} \end{center} \endgroup
54 | % TODO: https://stackoverflow.com/questions/56197968/how-can-i-make-a-list-with-itemize-in-a-cell-of-table
55 |
56 | Notes on batch size:
57 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
58 | \item During the prefill phase, many implementations use a batch size larger than 1, because the input tokens can be processed in parallel.
59 | \item During the autoregressive next-token-generation phase, single-user implementations often use a batch size of \verb+num_beams+ (i.e. the width of the beam search, such as \verb+num_beams+ = 4), while multi-user implementations use larger batch sizes. However, the maximum batch size for multi-user applications can be limited by the total memory capacity as the number of KV-caches increases linearly with the batch size.
60 | \end{itemize}
61 |
62 | However, precomputing the first layer can increase (or decrease) the total memory size, which depends on the vocabulary size and the number of eliminated weights as shown in the table below. For example, the total memory size of Mistral-7B only increases by 2\%, see section \ref{sec:examples} for more details.
63 |
64 | \begingroup \renewcommand{\arraystretch}{1.3} % increase table row height by 1.3x
65 | \begin{center} \begin{tabular}{l|l}
66 | \textbf{Without precompute} & \textbf{With precompute} \\ \hline
67 | 1) Store embeddings: $d \cdot \verb+vocab_size+$ & Store precomputed values: $2(d+e) \cdot \verb+vocab_size+$ \\
68 | 2) Store weights for Q, K, V, and FFN & \\ \hline
69 | \end{tabular} \end{center} \endgroup
70 |
71 | \section{Precompute for serial transformers}
72 | Transformers without the parallel attention/FFN scheme can also benefit from precompute, but the savings are smaller: As shown in Figure \ref{fig2}(c), we can only precompute Q, K, and V, but not the FFN. For reference, Figure \ref{fig2}(a) shows the vanilla transformer with absolute positional encoding (PE) instead of RoPE and with pre-normalization \citep{pre-norm}. The PE is located right after the embedding layer, which prevents us from precomputing the first layer. But replacing the PE by RoPE, as done in Figure \ref{fig2}(b), allows us to precompute the linear layers Q, K, and V and store the precomputed values along the embeddings in memory as illustrated in Figure \ref{fig2}(c).
73 |
74 | \begin{figure} \centering
75 | \includegraphics[scale=0.86]{../doc/fig/precomp1stLayer_fig2.pdf}
76 | \caption{First transformer layer. (a) Vanilla with pre-normalization and vanilla PE; (b) Vanilla with RoPE; (c) Precomputing linear layers Q, K, V.}
77 | \label{fig2} \end{figure}
78 |
79 | \section{Examples} \label{sec:examples}
80 |
81 | \begingroup
82 | \renewcommand{\arraystretch}{1.3} % increase table row height by 1.3x
83 | \begin{center} \begin{tabular}{|l|c|c|c|l|} \hline
84 | \textbf{Parameter} & \textbf{Pythia-6.9B} & \textbf{Mistral-7B} & \textbf{Mixtral-8x7B} & \textbf{Notes} \\ \hline
85 | Parallel attention/FFN? & parallel & \multicolumn{2}{c|}{serial} & \citep{parallel} \\ \hline
86 | MHA, MQA, or GQA? & MHA & \multicolumn{2}{c|}{GQA} & \citep{vanilla, MQA, GQA} \\ \hline
87 | % Positional encoding & \multicolumn{3}{c|}{RoPE} & \citep{RoPE} \\ \hline % this line didn't fit on the page
88 | \verb+dim+ (aka $d$) & \multicolumn{3}{c|}{4,096} & embedding dimension \\ \hline
89 | \verb+n_layers+ & \multicolumn{3}{c|}{32} & number of layers \\ \hline
90 | \verb+n_heads+, \verb+n_kv_heads+ & 32, 32 & \multicolumn{2}{c|}{32, 8} & number of heads, KV-heads \\ \hline
91 | \verb+e+ (output dim. of K, V) & 4,096 & \multicolumn{2}{c|}{1,024} & \verb+e = d * n_kv_heads / n_heads+ \\ \hline
92 | FFN type & 2-layer MLP & SwiGLU *) & SwiGLU MoE & *) MLP with SwiGLU (GLU variant) \citep{GLU, MoE} \\ \hline
93 | FFN \verb+hidden_dim+ & 16,384 & \multicolumn{2}{c|}{14,336} & FFN hidden dimension \\ \hline
94 | FFN \verb+n_experts+ & \multicolumn{2}{c|}{1} & 8 & FFN number of experts \\ \hline
95 | \verb+vocab_size+ & 50,400 & \multicolumn{2}{c|}{32,000} & vocabulary size \\ \hline
96 |
97 | \multicolumn{5}{|l|}{\textbf{Number of weights (calculated from above parameters):}} \\ \hline
98 | Q+P weights per layer & \multicolumn{3}{c|}{33,554,432} & \verb+2 * dim * dim+ \\ \hline
99 | K+V weights per layer & 33,554,432 & \multicolumn{2}{c|}{8,388,608} & \verb+2 * dim * dim / n_heads * n_kv_heads+ \\ \hline
100 | FFN weights per layer & 134,217,728 & 176,160,768 & 1,409,286,144 & \verb+(2 or 3) * dim * hidden_dim * n_exp.+ \\ \hline
101 | Input+output embed. & 412,876,800 & \multicolumn{2}{c|}{262,144,000} & \verb+2 * dim * vocab_size+ \\ \hline
102 | \multicolumn{1}{|r|}{\textbf{Total weights:}} & 6.9B & 7.2B & 46.7B & \\ \hline
103 | \end{tabular} \end{center}
104 | \endgroup
105 |
106 | The table above compares the configurations and number of weights of Pythia-6.9B, Mistral-7B, and Mixtral-8x7B. The next table shows the memory read savings and memory size increases for Pythia-6.9B, Mistral-7B, and a hypothetical Mixtral-8x7B with parallel attention/FFN layers.
107 |
108 | \begingroup
109 | \renewcommand{\arraystretch}{1.4} % increase table row height by 1.4x
110 | \begin{center} \begin{tabular}{|l|c|c|>{\centering\arraybackslash}m{7.8em}|} \hline
111 | & \textbf{Pythia-6.9B} & \textbf{Mistral-7B} & \textbf{Hypothetical Mixtral-8x7B with parallel attn./FFN} \\ \hline
112 | Number of weights that can be eliminated & 184,549,376 & 25,165,824 & 1,434,451,968 \\ \hline
113 | Number of reads w/o precompute for batch 1 & 184,553,472 & 25,169,920 & 1,434,456,064 \\ \hline
114 | Number of reads with precompute for batch 1 & 16,384 & 10,240 & 10,240 \\ \hline
115 | \multicolumn{1}{|r|}{\textbf{First layer reduction factor for batch size 1:}} & \textbf{11,264x} & \textbf{2,458x} & \textbf{140,084x} \\ \hline
116 | \multicolumn{1}{|r|}{\textbf{First layer reduction factor for batch size 16:}} & 704x & 154x & 8,756x \\ \hline
117 | \multicolumn{1}{|r|}{\textbf{First layer reduction factor for batch size 256:}} & 44x & 10x & 548x \\ \hline
118 | \multicolumn{1}{|r|}{\textbf{First layer reduction factor for batch size 1,024:}} & 11x & 3x & 137x \\ \hline
119 |
120 | \multicolumn{4}{|l|}{\textbf{Increase (or decrease) of total weight memory size:}} \\ \hline
121 | Increase embedding memory by $(2e + d) \cdot \verb+vocab_size+$ & 619,315,200 & \multicolumn{2}{c|}{196,608,000} \\ \hline
122 | Memory decrease due to elimination of weights & –184,549,376 & –25,165,824 & -1,434,451,968 \\ \hline
123 | \multicolumn{1}{|r|}{\textbf{Total absolute memory increase (or decrease):}} & 434,765,824 & 171,442,176 & \textbf{-1,237,843,968} \\ \hline
124 | \multicolumn{1}{|r|}{\textbf{Total relative memory increase (or decrease):}} & 6\% & \textbf{2\%} & \textbf{–3\%} \\ \hline
125 | \end{tabular} \end{center}
126 | \endgroup
127 |
128 | \section*{Acknowledgments}
129 | We would like to thank \href{https://scholar.google.com/citations?user=LlK_saMAAAAJ&hl=en}{James Martens (DeepMind)} for his generous support and endorsement for the arXiv submission process.
130 |
131 | \bibliographystyle{unsrtnat}
132 | \bibliography{references}
133 |
134 | \end{document}
135 |
--------------------------------------------------------------------------------
/tex/removeWeights.tex:
--------------------------------------------------------------------------------
1 | % To generate PDF, type ./run removeWeights
2 |
3 | \documentclass{article}
4 | \usepackage{arxiv} % see file arxiv.sty
5 | \usepackage[numbers]{natbib} % number citation style (remove "numbers" for author-year style)
6 |
7 | % shortcuts for matrices Q, K, V, P, O, M and vectors u, x, y, z
8 | \newcommand{\mat}[1]{\mathbf{#1}} % shortcut for matrix
9 | \def\Q{\mat{Q}_i}
10 | \def\K{\mat{K}_i}
11 | \def\V{\mat{V}_i}
12 | \def\P{\mat{P}_i}
13 | \def\O{\mat{O}_{i-1}}
14 | \def\M{\mat{M}_i}
15 | \def\u{\vec{u}}
16 | \def\x{\vec{x}}
17 | \def\y{\vec{y}}
18 | \def\z{\vec{z}}
19 |
20 | \title{Transformer tricks: Removing weights for skipless transformers}
21 | %\title{Transformer tricks: Removing weights from skipless transformers}
22 | %\title{Transformer tricks: Reducing weights for skipless transformers}
23 | %\title{Transformer tricks: Merging linear layers for skipless transformers}
24 | %\title{Transformer tricks: Eliminating linear layers}
25 |
26 | \author{Nils Graef \\ \href{https://openmachine.ai}{OpenMachine},
27 | South San Francisco, CA 94080, \texttt{info@openmachine.ai}}
28 |
29 | \begin{document} \maketitle
30 |
31 | \begin{abstract}
32 | \citet{simplified} detailed a skipless transformer without the V and P (post-attention projection) linear layers, which reduces the total number of weights. However, this scheme is only applicable to MHA (multi-head attention) \cite{vanilla}, but not for MQA (multi-query attention) \cite{MQA} and GQA (grouped-query attention) \cite{GQA}. The latter schemes are used by many popular LLMs such as Llama 2, Mistral, Mixtral, PaLM, and Gemma \cite{Llama2, mistral, mixtral, PaLM, gemma}. Therefore, this micro-paper \cite{micro-paper} proposes mathematically equivalent versions that are suitable for MQA and GQA. For example, removing Q and P from a skipless version of Mistral-7B would remove 15\% of its weights (and thus reduce its compute and memory complexity). See \cite{tricks, precompute} for code and more transformer tricks.
33 | \end{abstract}
34 |
35 | \section{Vanilla transformer without skip connections}
36 |
37 | \citet{skipless} have shown how transformers without skip connections and normalization (see Figure \ref{fig1}(a)) can be trained successfully.
38 |
39 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
40 | \includegraphics[scale=0.90]{../doc/fig/removeWeights_fig1.pdf}
41 | \caption{(a) Skipless vanilla transformer; equivalent versions with (b) Q and P merged into the FFN (feedforward network); (c) K and P merged into FFN; (d) V and P merged into FFN. $\M^*, \Q^*, \K^*, \V^*, \O^*$ are defined in table \ref{tab1}.}
42 | \label{fig1} \end{figure}
43 |
44 | Removing skip connections and normalization allows us to merge linear layers in a mathematically identical way as shown in Figures \ref{fig1}(b) to (d). This reduces the number of weights without changing the functionality as follows:
45 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
46 | \item Figure \ref{fig1}(b) is mathematically identical to Figure \ref{fig1}(a) and eliminates $2d^2$ weights per transformer block by merging $\P$ into $\M^*$ and $\Q$ into $\O^*$.
47 | \item For MHA where $e = d$, Figures \ref{fig1}(c) and (d) are mathematically identical to Figure \ref{fig1}(a) and eliminate $2d^2$ weights per transformer block by merging $\P$ into $\M^*$ and $\K$ or $\V$ into $\O^*$.
48 | \item This requires that $\Q, \K, \V$ are invertible (i.e. nonsingular). It is extremely rare that a square matrix with random values is not invertible \cite{invertible} (which requires its determinant to be exactly 0).
49 | \end{itemize}
50 |
51 | Figure \ref{fig1} uses the following dimensions and weight matrices, based on the type of attention:
52 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
53 | \item $d$: embedding dimension
54 | \item $e$: $e = d$ for MHA. For MQA, $e = d / n_{heads}$. And for GQA, $e = d \cdot n_{kv\_heads} / n_{heads}$.
55 | \item $f$: hidden dimension of the FFN. $f = 4d$ in the vanilla transformer; \citet{MQA} uses $f > 4d$. For models that use a GLU variant \cite{GLU} (such as Llama and Mistral), the effective $f'$ for the first linear layer M is $f' = 2f$, because the GLU variant uses two linear layers that are combined (via pointwise multiplication) with a non-linear activation function.
56 | \item $\Q, \K, \V, \P$: The weight matrices of the linear layers for query, keys, values, and the post-attention projection of transformer block $i$.
57 | \item $\M, \mat{O}_i$: The weight matrices of the FFN input and output linear layers.
58 | \end{itemize}
59 |
60 | \begin{figure}[h!] \centering % the [h!] tries to place the picture right here
61 | \includegraphics[scale=0.92]{../doc/fig/removeWeights_fig2.pdf}
62 | \caption{(a) Merging P and M; (b) eliminating Q; (c) eliminating K; (d) eliminating V.}
63 | \label{fig2} \end{figure}
64 |
65 | Figure \ref{fig2} details how the linear layers are merged:
66 | \begin{itemize}[topsep=-1pt, itemsep=-1pt]
67 | \item Figure \ref{fig2}(a) shows how the two linear layers with weight matrices $\P$ and $\M$ are collapsed and replaced by a single linear layer with weight matrix $\M^* = \P \M$, which eliminates $d^2$ weights.
68 | \item Figure \ref{fig2}(b) illustrates how to merge $\Q$ into the preceding $\O$-matrix, which eliminates $d^2$ weights and requires $\Q$ to be invertible. Note that $\y = \u \O (\Q \Q^{-1}) \K = \u \O \K$ and $\z = \u \O (\Q \Q^{-1}) \V = \u \O \V$.
69 | \item For MHA where $e = d$, $\K$ can be removed as shown in Figure \ref{fig2}(c), which eliminates $d^2$ weights. Note that $\x = \u \O (\K \K^{-1}) \Q = \u \O \Q$ and $\z = \u \O (\K \K^{-1}) \V = \u \O \V$. This requires that $\K$ is invertible.
70 | \item For MHA where $e = d$, $\V$ can be removed as shown in Figure \ref{fig2}(d), which eliminates $d^2$ weights. Note that $\x = \u \O (\V \V^{-1}) \Q = \u \O \Q$ and $\y = \u \O (\V \V^{-1}) \K = \u \O \K$. This requires that $\V$ is invertible.
71 | \end{itemize}
72 |
73 | Table \ref{tab1} specifies how the new weight matrices ($\M^*, \Q^*, \K^*, \V^*, \O^*$) of Figure \ref{fig1} are calculated from the original ones. For the first transformer block ($i = 1$), we use the input embedding instead of $\O$ (because there is no $\O$ for $i = 1$).
74 |
75 | \begingroup \begin{table} [h!] \centering % the [h!] tries to place the picture right here
76 | \renewcommand{\arraystretch}{1.3} % increase table row height by 1.3x
77 | \begin{tabular}{c|c|c|c}
78 | & Figure 1(b) & Figure 1(c) & Figure 1(d) \\ \hline
79 | $\O^*$ & $\O \Q$ & $\O \K$ & $\O \V$ \\ \hline
80 | $\Q^*$ & 1 (eliminated) & $\K^{-1} \Q$ & $\V^{-1} \Q$ \\ \hline
81 | $\K^*$ & $\Q^{-1} \K$ & 1 (eliminated) & $\V^{-1} \K$ \\ \hline
82 | $\V^*$ & $\Q^{-1} \V$ & $\K^{-1} \V$ & 1 (eliminated) \\ \hline
83 | $\M^*$ & \multicolumn{3}{c}{$\P \M$}
84 | \end{tabular}
85 | \caption{How to calculate the new weight matrices from the original ones for Figure \ref{fig1}.}
86 | \label{tab1} \end{table} \endgroup
87 |
88 | \section{Parallel transformer without skip connections}
89 | Similar to the parallel transformer \cite{parallel}, Figure \ref{fig3} shows parallel versions of Figures \ref{fig1}(b) to (d). Here, “parallel” refers to having the attention (including its linear layers) in parallel to the FFN.
90 |
91 | \begin{figure} [h!] \centering % the [h!] tries to place the picture right here
92 | \includegraphics[scale=0.92]{../doc/fig/removeWeights_fig3.pdf}
93 | \caption{Parallel skipless transformers (a) without Q and P; (b) without K and P; (c) without V and P.}
94 | \label{fig3} \end{figure}
95 |
96 | Figures \ref{fig3}(b) and (c) require that $e = d$, so they are only suitable for MHA, but not for MQA and GQA. Figure \ref{fig3}(a) is suitable for MHA, MQA, and GQA. Figure \ref{fig3}(c) is identical to the simplified transformer proposed in \cite{simplified}.
97 |
98 | \section{Examples}
99 | The table below lists the configurations and weight counts for Pythia-6.9B and Mistral-7B. For a skipless version of Mistral-7B we would save 15\% of weights after merging the Q and P linear layers into the FFN layers. For a batch 1 system that is limited by memory bandwidth, these 15\% weight savings can speed up inference by 1.17x during the autoregressive next-token-generation phase, see the table below.
100 |
101 | \begingroup
102 | \renewcommand{\arraystretch}{1.5} % increase table row height by 1.5x
103 | \begin{center} \begin{tabular}{|l|c|c|l|} \hline
104 | \textbf{Parameter} & \textbf{Pythia-6.9B} & \textbf{Mistral-7B} & \textbf{Notes} \\ \hline
105 | Parallel attention/FFN? & parallel & serial & \cite{parallel} \\ \hline
106 | MHA, MQA, or GQA? & MHA & GQA & \cite{vanilla, MQA, GQA} \\ \hline
107 | \verb+dim+ (aka $d$) & \multicolumn{2}{c|}{4,096} & embedding dimension \\ \hline
108 | \verb+n_layers+ & \multicolumn{2}{c|}{32} & number of layers \\ \hline
109 | \verb+n_heads+ & \multicolumn{2}{c|}{32} & number of heads \\ \hline
110 | \verb+n_kv_heads+ & 32 & 8 & number of KV-heads \\ \hline
111 | \verb+e+ (output dim. of K, V) & 4,096 & 1,024 & \verb+e = d * n_kv_heads / n_heads+ \\ \hline
112 | FFN type & MLP & MLP with SwiGLU & \cite{GLU} \\ \hline
113 | FFN \verb+hidden_dim+ & 16,384 & 14,336 & FFN hidden dimension \\ \hline
114 | \verb+vocab_size+ & 50,400 & 32,000 & vocabulary size \\ \hline
115 |
116 | \multicolumn{4}{|l|}{\textbf{Number of weights (calculated from above parameters):}} \\ \hline
117 | Q+P weights per layer & \multicolumn{2}{c|}{33,554,432} & \verb+2 * dim * dim+ \\ \hline
118 | K+V weights per layer & 33,554,432 & 8,388,608 & \verb+2 * dim * dim / n_heads * n_kv_heads+ \\ \hline
119 | FFN weights per layer & 134,217,728 & 176,160,768 & \verb+(2 or 3) * dim * hidden_dim+ \\ \hline
120 | Input+output embed. & 412,876,800 & 262,144,000 & \verb+2 * dim * vocab_size+ \\ \hline
121 | \multicolumn{1}{|r|}{\textbf{Total weights:}} & 6.9B & 7.2B & \\ \hline
122 |
123 | \multicolumn{4}{|l|}{\textbf{Weight savings and speedup after removing Q and P:}} \\ \hline
124 | Total w/o Q+P weights: & 5.8B & 6.2B & total after removing Q and P \\ \hline
125 | \multicolumn{1}{|r|}{\textbf{Weight savings:}} & \textbf{16\%} & \textbf{15\%} & \\ \hline
126 | \multicolumn{1}{|r|}{\textbf{Possible speedup:}} & \textbf{1.19x} & \textbf{1.17x} & assumes batch size 1 \\ \hline
127 | \end{tabular} \end{center}
128 | \endgroup
129 |
130 | \section{Experiments}
131 | Refer to \cite{tricks} for Python code that demonstrates the numerical equivalency of the weight reduction illustrated in Figures \ref{fig1}(b) and \ref{fig2}(b). The code also confirms that all square matrices of Mistral-7B are invertible.
132 |
133 | \section{Future work}
134 | Because skipless transformers are not very popular right now, future work should investigate whether removing P and Q (or K or V) is also beneficial for transformers with normalization and skip connections as illustrated in Figure \ref{fig4}. Adding normalization and skip connections again could simplify and speed up training relative to skipless transformers.
135 |
136 | \begin{figure} \centering
137 | \includegraphics[scale=0.92]{../doc/fig/removeWeights_fig4.pdf}
138 | \caption{(a) Transformer block without Q and P; (b) version with parallel attention / FFN.}
139 | \label{fig4} \end{figure}
140 |
141 | \section*{Acknowledgments}
142 | We would like to thank \href{https://scholar.google.com/citations?user=HKft_LAAAAAJ&hl=en}{Bobby He (ETH Zürich)} and \href{https://scholar.google.com/citations?user=LlK_saMAAAAJ&hl=en}{James Martens (DeepMind)} for helpful discussions on this work.
143 |
144 | \bibliographystyle{unsrtnat}
145 | \bibliography{references}
146 |
147 | \end{document}
148 |
--------------------------------------------------------------------------------
/tex/run:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # script to convert foo.tex to PDF
4 | # usage: ./run foo.tex or ./run foo
5 | # above generates foo.pdf and other files
6 |
7 | # remove filename extension from arg
8 | file="${1%.*}"
9 |
10 | ./clean
11 | pdflatex "$file"
12 | bibtex "$file"
13 | pdflatex "$file"
14 | pdflatex "$file"
15 | pdflatex "$file" # we sometimes need to run pdflatex 3 times
16 |
17 | # in case you want to diff your changes visually
18 | diff-pdf --view "$file".pdf ../doc/"$file".pdf
19 |
20 | mv "$file".pdf ../doc
21 |
22 | echo "--------------------------------------------------------------------------------"
23 | grep Warning "$file".log
24 |
25 | #./clean
26 |
27 | # note: to diff 2 PDF files visually, type the following:
28 | # diff-pdf --view file1.pdf file2.pdf
29 | # alternatively, use pdftotext to convert each PDF to text and then diff the text files:
30 | # pdftotext file1.pdf # this generates file1.txt
31 | # pdftotext file2.pdf # this generates file2.txt
32 | # diff file1.txt file2.txt
33 |
--------------------------------------------------------------------------------
/tex/submit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # script to submit foo.tex to arXiv
4 | # usage: ./submit foo.tex
5 | # above generates a directory foo_submit and a tar gz file foo_submit.tar.gz
6 | # only upload this tar file to arXiv and see the notes in README on how to submit
7 |
8 | # note: to double-check if everything works, run pdflatex foo two times
9 | # (or sometimes three times) as follows:
10 | # cd foo_submit
11 | # pdflatex foo && pdflatex foo
12 |
13 | # remove filename extension from arg
14 | file="${1%.*}"
15 |
16 | DIR="$file"_submit
17 |
18 | rm -Rf "$DIR" "$DIR".tar.gz
19 | mkdir "$DIR"
20 |
21 | ./run "$file"
22 |
23 | cp *.sty references.bib "$file".bbl "$file".tex "$DIR"
24 | cp ../doc/fig/"$file"_fig*.pdf "$DIR"
25 |
26 | # modify the figure-paths in the tex file: ../doc/fig/ -> ./
27 | sed -i "" "s,../doc/fig/,./,g" "$DIR"/"$file".tex
28 | # the two quotes (empty string) at the beginning of sed are for running on mac
29 |
30 | # TODO: it might also work without the sed command, but only if the tar.gz
31 | # archive is flat and has only one directory to the tex file (right now the
32 | # archive includes the directory 'foo_submit')
33 |
34 | tar -czvf "$DIR".tar.gz "$DIR"
35 |
--------------------------------------------------------------------------------
/transformer_tricks.py:
--------------------------------------------------------------------------------
1 | # tricks and tools for speeding up LLMs
2 |
3 | import gc, os, time, torch, datasets, glob
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | from huggingface_hub import snapshot_download, repo_exists
7 | from safetensors.torch import load_file, save_file, safe_open
8 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, logging, utils
9 | try:
10 | from flashNorm_modeling_llama import * # import local file if it exists
11 | except ImportError:
12 | pass
13 |
14 |
15 | #-------------------------------------------------------------------------------------
16 | # tools for working with safetensors and HuggingFace repos
17 | #-------------------------------------------------------------------------------------
18 | def quiet_hf():
19 | """reduce verbosity of HuggingFace"""
20 | logging.set_verbosity_error()
21 | utils.logging.disable_progress_bar()
22 | datasets.logging.disable_progress_bar()
23 | os.environ['TOKENIZERS_PARALLELISM'] = 'true'
24 | os.environ['HF_HUB_VERBOSITY'] = 'error'
25 | # for more env variables, see link below
26 | # https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables
27 |
28 |
29 | def weight(name, layer=0):
30 | """get dictionary key of specific weight (such as Q from layer 0)"""
31 | layer_str = 'model.layers.' + str(layer) + '.'
32 | match name:
33 | # weights of each layer
34 | case 'Inorm': key = layer_str + 'input_layernorm.weight'
35 | case 'Anorm': key = layer_str + 'post_attention_layernorm.weight'
36 | case 'QKV' : key = layer_str + 'self_attn.qkv_proj.weight'
37 | case 'Q' : key = layer_str + 'self_attn.q_proj.weight'
38 | case 'K' : key = layer_str + 'self_attn.k_proj.weight'
39 | case 'V' : key = layer_str + 'self_attn.v_proj.weight'
40 | case 'O' : key = layer_str + 'self_attn.o_proj.weight'
41 | case 'GU' : key = layer_str + 'mlp.gate_up_proj.weight'
42 | case 'G' : key = layer_str + 'mlp.gate_proj.weight'
43 | case 'U' : key = layer_str + 'mlp.up_proj.weight'
44 | case 'D' : key = layer_str + 'mlp.down_proj.weight'
45 | # embedding weights
46 | case 'Hnorm': key = 'model.norm.weight' # normalization of lm_head
47 | case 'H' : key = 'lm_head.weight' # output embeddings
48 | case 'E' : key = 'model.embed_tokens.weight' # input embeddings
49 | return key
50 |
51 |
52 | def get_param(repo, get_meta=False):
53 | """download all *.safetensors files from repo (or local dir) and return a single
54 | param dict, and optionally also return the metadata"""
55 |
56 | # download and get list of files
57 | if repo_exists(repo):
58 | dir = 'get_param_tmp'
59 | snapshot_download(repo_id=repo, allow_patterns='*.safetensors', local_dir=dir)
60 | else: # if repo doesn't exist on HuggingFace, then 'repo' specifies local dir
61 | dir = repo
62 | files = glob.glob(dir + '/*.safetensors')
63 |
64 | # get parameters
65 | param = {}
66 | for file in files:
67 | param.update(load_file(file)) # concatenate all parameters into a single dict
68 |
69 | # return param only, or param and metadata
70 | if get_meta == False:
71 | return param
72 | else:
73 | with safe_open(files[0], framework='pt') as f: # use the first file
74 | return param, f.metadata()
75 |
76 |
77 | def save_repo(repo, param, config, dir):
78 | """save tokenizer, config, and param in local dir"""
79 | tok = AutoTokenizer.from_pretrained(repo)
80 | tok.save_pretrained(dir, from_pt=True)
81 | config.save_pretrained(dir, from_pt=True)
82 | save_file(param, dir + '/model.safetensors', metadata={'format': 'pt'})
83 |
84 |
85 | #-------------------------------------------------------------------------------------
86 | # functions for flashNorm, see paper https://arxiv.org/abs/2407.09577
87 | #-------------------------------------------------------------------------------------
88 | def merge_norm_proj(param, norm, proj, layer=0):
89 | """merge norm weights into projection weights"""
90 | n_key = weight(norm, layer)
91 | p_key = weight(proj, layer)
92 | param[p_key] = nn.Parameter(param[p_key] @ torch.diag(param[n_key])).detach() # flipped order
93 | # TODO: consider first converting to float64, then merge norm into projections,
94 | # and then convert back to float32. Example: torch.ones(4, dtype=torch.float32)
95 |
96 |
97 | def set_norm_one(param, norm, layer=0):
98 | """set all norm weights to 1.0"""
99 | n_key = weight(norm, layer)
100 | len = list(param[n_key].shape)[0]
101 | param[n_key] = nn.Parameter(torch.ones(len)).detach()
102 |
103 |
104 | def flashify(param, config, bars):
105 | """merge norm weights into projection weights as per flashNorm"""
106 | with torch.no_grad(): # prevent autograd from tracking changes
107 |
108 | # check if model uses fused projections (such as in Phi-3)
109 | fused_proj = weight('QKV') in param
110 |
111 | # perform flashNorm merging for each layer
112 | for layer in tqdm(range(config.num_hidden_layers), disable=not bars):
113 |
114 | # merge input-layernorm into QKV projections
115 | if fused_proj:
116 | merge_norm_proj(param, 'Inorm', 'QKV', layer)
117 | else:
118 | merge_norm_proj(param, 'Inorm', 'Q', layer)
119 | merge_norm_proj(param, 'Inorm', 'K', layer)
120 | merge_norm_proj(param, 'Inorm', 'V', layer)
121 | set_norm_one(param, 'Inorm', layer)
122 |
123 | # merge post-attention layernorm 'Anorm' into Gate and Up projections
124 | if fused_proj:
125 | merge_norm_proj(param, 'Anorm', 'GU', layer)
126 | else:
127 | merge_norm_proj(param, 'Anorm', 'G', layer)
128 | merge_norm_proj(param, 'Anorm', 'U', layer)
129 | set_norm_one(param, 'Anorm', layer)
130 |
131 | # if the model has untied embeddings, then merge 'Hnorm' into 'lm_head'
132 | # see also https://huggingface.co/HuggingFaceTB/SmolLM-135M/discussions/15
133 | if config.tie_word_embeddings == False:
134 | merge_norm_proj(param, 'Hnorm', 'H')
135 | set_norm_one(param, 'Hnorm')
136 |
137 |
138 | def flashify_repo(repo, dir=None, bars=False, test=True):
139 | """convert LLM repo to flashNorm, store the new model in local dir"""
140 | with torch.no_grad(): # prevent autograd from tracking changes
141 |
142 | if dir == None: # append '_flashNorm' if no output dir is defined
143 | dir = os.path.basename(repo) + '_flashNorm'
144 |
145 | # get config, download safetensors, and flashify params
146 | config = AutoConfig.from_pretrained(repo)
147 | param = get_param(repo)
148 | flashify(param, config, bars)
149 | if test: # optionally, save a test-repo in directory *_test
150 | save_repo(repo, param, config, dir + '_test')
151 |
152 | # delete norm weights from param
153 | for layer in range(config.num_hidden_layers):
154 | del param[weight('Inorm', layer)]
155 | del param[weight('Anorm', layer)]
156 | if config.tie_word_embeddings == False:
157 | del param[weight('Hnorm')]
158 |
159 | # TODO:
160 | #config.architectures = ['LlamaForCausalLM_flashNorm']
161 | #config.auto_map = {'AutoModelForCausalLM': 'flashNorm_modeling_llama.LlamaForCausalLM_flashNorm'}
162 | #config.model_type = 'flashNorm'
163 | save_repo(repo, param, config, dir)
164 |
165 | del param; gc.collect() # run garbage collection
166 |
167 |
168 | #-------------------------------------------------------------------------------------
169 | # functions for testing
170 | #-------------------------------------------------------------------------------------
171 | def hello_world(repo, max_new_tok=4, arch='AutoModelForCausalLM', perf=False):
172 | """run example inference of an LLM from HuggingFace repo or local directory"""
173 | tok = AutoTokenizer.from_pretrained(repo)
174 | model = eval(f'{arch}.from_pretrained(repo, low_cpu_mem_usage=True)')
175 | # to use FP16 or bfloaf: torch_dtype=torch.float16, torch_dtype=torch.bfloat
176 | # note: FP16 is 30x slower than FP32 on my Mac M1, not sure why
177 |
178 | prompt = 'Once upon a time there was'
179 | start_time = time.perf_counter()
180 | inp = tok.encode(prompt, return_tensors='pt').to('cpu')
181 | out = model.generate(inp, pad_token_id=0, max_new_tokens=max_new_tok).ravel()
182 | print(tok.decode(out),
183 | f' (time: {time.perf_counter() - start_time:.2f}s)' if perf else '')
184 | del tok, model; gc.collect() # run garbage collection
185 | # TODO: especially for Phi-3, set verbosity to quiet as follows
186 | # transformers.logging.set_verbosity_error()
187 |
188 |
189 | def perplexity(repo, speedup=1, arch='AutoModelForCausalLM', bars=False, perf=False):
190 | """calculate perplexity of an LLM with wikitext2
191 | this def is copied from https://huggingface.co/docs/transformers/perplexity
192 | I made the following changes to adapt it for SmolLM (was GPT2 before):
193 | - changed model and tokenizer
194 | - changed 'from transformers import' to point to 'Auto*' (was 'GTP2*' before)
195 | - changed 'max_length' to 'config.max_position_embeddings'
196 | - changed 'device' from 'cuda' to 'cpu'
197 | - changed 'stride' to be 'max_length' (was 512 or 'max_length//2' before)
198 | - removed 'with torch.no_grad()' and added global 'torch.set_grad_enabled(False)'
199 | Perhaps a simpler and cleaner way is given here:
200 | https://huggingface.co/spaces/evaluate-metric/perplexity"""
201 |
202 | torch.set_grad_enabled(False) # speed up torch
203 | # TODO: consider using instead 'with torch.no_grad():'
204 |
205 | tok = AutoTokenizer.from_pretrained(repo)
206 | model = eval(f'{arch}.from_pretrained(repo, low_cpu_mem_usage=True)')
207 |
208 | # tokenize wikitext2
209 | test = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
210 | encodings = tok('\n\n'.join(test['text']), return_tensors='pt')
211 | del tok; gc.collect() # run garbage collection
212 |
213 | max_length = model.config.max_position_embeddings
214 | stride = max_length # before it was 512 or max_length // 2
215 | seq_len = encodings.input_ids.size(1) // speedup
216 |
217 | start_time = time.perf_counter()
218 | nlls = []
219 | prev_end_loc = 0
220 | for begin_loc in tqdm(range(0, seq_len, stride), disable=not bars):
221 | end_loc = min(begin_loc + max_length, seq_len)
222 | trg_len = end_loc - prev_end_loc # may be different from stride on last loop
223 | input_ids = encodings.input_ids[:, begin_loc:end_loc].to('cpu')
224 | target_ids = input_ids.clone()
225 | target_ids[:, :-trg_len] = -100
226 | outputs = model(input_ids, labels=target_ids)
227 |
228 | # loss is calculated using CrossEntropyLoss which averages over valid labels
229 | # N.B. the model only calculates loss over trg_len - 1 labels, because it
230 | # internally shifts the labels to the left by 1.
231 | neg_log_likelihood = outputs.loss
232 | nlls.append(neg_log_likelihood)
233 |
234 | prev_end_loc = end_loc
235 | if end_loc == seq_len:
236 | break
237 |
238 | ppl = torch.exp(torch.stack(nlls).mean())
239 | print(f'perplexity = {ppl:.3f}',
240 | f' (time: {time.perf_counter() - start_time:.2f}s)' if perf else '')
241 | # print('nlls:', nlls)
242 | del model; gc.collect() # run garbage collection
243 |
244 |
245 | #-------------------------------------------------------------------------------------
246 | # debug tools
247 | #-------------------------------------------------------------------------------------
248 | def diff_safetensors(repo1, repo2):
249 | """compare differences of safetensor file(s) between repo1 and repo2"""
250 | param1, meta1 = get_param(repo1, get_meta=True)
251 | param2, meta2 = get_param(repo2, get_meta=True)
252 | set1, set2 = set(param1.keys()), set(param2.keys())
253 |
254 | # diff keys
255 | if set1 == set2:
256 | print('>>> SAFE-DIFF: both repos have the same safetensor keys')
257 | else:
258 | if set1 - set2:
259 | print(f'>>> SAFE-DIFF: these keys are only in repo {repo1}: {set1 - set2}')
260 | if set2 - set1:
261 | print(f'>>> SAFE-DIFF: these keys are only in repo {repo2}: {set2 - set1}')
262 |
263 | # diff tensors
264 | found_diff = False
265 | for key in set1.intersection(set2):
266 | if not torch.equal(param1[key], param2[key]):
267 | found_diff = True
268 | print(f'>>> SAFE-DIFF: tensors {key} are not equal')
269 | if not found_diff:
270 | print('>>> SAFE-DIFF: all intersecting tensors are equal')
271 |
272 | # diff metadata
273 | if meta1 == meta2:
274 | print('>>> SAFE-DIFF: both repos have the same safetensor metadata')
275 | else:
276 | print(f'>>> SAFE-DIFF: metadata of repo {repo1}: {meta1}')
277 | print(f'>>> SAFE-DIFF: metadata of repo {repo2}: {meta2}')
278 |
279 |
280 | # misc TODOs:
281 | # - do we really need 'with torch.no_grad():' everywhere?
282 | # - do we really need garbage collection 'gc'?
283 | # - would 'torch.set_grad_enabled(False)' speed up things?
284 |
--------------------------------------------------------------------------------
/util/.aspell:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMachine-ai/transformer-tricks/490f7a07f6b11f0b15105acadfa26be927150cad/util/.aspell
--------------------------------------------------------------------------------
/util/clean_all:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # cleans up the entire repo by removing generated files
4 | # usage: util/clean_all (run from the root dir of this repo)
5 |
6 | # clean up ./util
7 | cd util
8 | rm -rf pypi
9 | cd -
10 |
11 | # clean up ./tex
12 | cd tex
13 | ./clean
14 | \rm -rf *_submit *_submit.tar.gz
15 | cd -
16 |
--------------------------------------------------------------------------------
/util/gen_notebooks:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # generate Jupyter notebooks from python
4 | # usage: util/gen_notebooks (run from the root dir of this repo)
5 |
6 | for fname in slimAttn_paper flashNorm_example; do
7 | jupytext "$fname".py -o notebooks/"$fname".ipynb
8 | done
9 |
--------------------------------------------------------------------------------
/util/gen_pdf:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # generate PDF from tex for all files
4 | # usage: util/gen_pdf (run from the root dir of this repo)
5 |
6 | cd tex
7 | for fname in *.tex; do
8 | ./run "$fname"
9 | done
10 |
11 | ./clean
12 | cd -
13 |
--------------------------------------------------------------------------------
/util/push_pypi:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Make sure to increment the version number in pyproject.toml and
4 | # requirements.txt before running this script! See below link for details:
5 | # https://packaging.python.org/en/latest/tutorials/packaging-projects/
6 | #
7 | # Setup: install build and twine via 'pip3 install build twine'
8 | # To upgrade: pip3 install --upgrade build twine pkginfo packaging
9 | # Usage: ./push_pypi
10 |
11 | # create folder 'pypi' and copy all relevant files
12 | rm -rf pypi
13 | mkdir pypi
14 | cp ../LICENSE ../README.md ../pyproject.toml ../transformer_tricks.py pypi
15 |
16 | # build and upload
17 | cd pypi
18 | python3 -m build
19 | python3 -m twine upload dist/*
20 |
21 | #rm -rf pypi
22 |
--------------------------------------------------------------------------------
/util/spell_check:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # spell check all tex and markdown files
4 | # usage: util/spell_check (run from the root dir of this repo)
5 |
6 | # notes:
7 | # - option -M is for markdown; -t is for tex
8 | # - file util/.aspell is our personal dictionary
9 | # - however, aspell seems to have a bug, the personal dictionary file
10 | # must be located in the same dir from which you call aspell, that's
11 | # why we first do 'cd util'
12 |
13 | cd util
14 |
15 | # all markdown files
16 | for file in ../*.md ../*/*.md; do
17 | aspell -d en_US -l en_US --personal=./.aspell -M -c "$file"
18 | done
19 |
20 | # all tex files
21 | for file in ../tex/*.tex; do
22 | aspell -d en_US -l en_US --personal=./.aspell -t -c "$file"
23 | done
24 |
25 | cd -
26 |
--------------------------------------------------------------------------------