├── .github └── workflows │ ├── build.yml │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── tests ├── test_activations.py ├── test_ff.py ├── test_grouped_ff.py ├── test_grouped_mlp.py └── test_mlp.py └── x_mlps_pytorch ├── __init__.py ├── activations.py ├── ff.py ├── grouped_ff.py ├── grouped_mlp.py ├── mlp.py └── nff.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build with Rye 2 | on: push 3 | 4 | jobs: 5 | build: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | - name: Install Python 10 | uses: actions/setup-python@v4 11 | - name: Install the latest version of rye 12 | uses: eifinger/setup-rye@v2 13 | - name: Build with Rye 14 | run: rye build 15 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests the examples in README 2 | on: push 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | - name: Install Python 10 | uses: actions/setup-python@v4 11 | - name: Install the latest version of rye 12 | uses: eifinger/setup-rye@v2 13 | - name: Use UV instead of pip 14 | run: rye config --set-bool behavior.use-uv=true 15 | - name: Install dependencies 16 | run: | 17 | pip install .[test] 18 | - name: Run pytest 19 | run: pytest tests 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## x-mlps-pytorch 2 | 3 | Just a repository that will house MLPs for Pytorch, from garden variety to the exotic, so as to avoid having to reimplement them again and again for different projects (especially RL) 4 | 5 | Will also be the repository I use for testing out [Jules](https://jules.google.com/) and other AI assisted tools. 6 | 7 | 8 | ## Install 9 | 10 | ```bash 11 | $ pip install x-mlps-pytorch 12 | ``` 13 | 14 | ## Usage 15 | 16 | ```python 17 | import torch 18 | from x_mlps_pytorch import MLP 19 | 20 | actor = MLP(10, 16, 5) 21 | 22 | critic = MLP(10, 32, 16, 1) 23 | 24 | state = torch.randn(10) 25 | 26 | action_logits = actor(state) # (5,) 27 | 28 | values = critic(state) # (1,) 29 | ``` 30 | 31 | ## Citations 32 | 33 | ```bibtex 34 | @article{So2021PrimerSF, 35 | title = {Primer: Searching for Efficient Transformers for Language Modeling}, 36 | author = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le}, 37 | journal = {ArXiv}, 38 | year = {2021}, 39 | volume = {abs/2109.08668}, 40 | url = {https://api.semanticscholar.org/CorpusID:237563187} 41 | } 42 | ``` 43 | 44 | ```bibtex 45 | @article{Zhang2024ReLU2WD, 46 | title = {ReLU2 Wins: Discovering Efficient Activation Functions for Sparse LLMs}, 47 | author = {Zhengyan Zhang and Yixin Song and Guanghui Yu and Xu Han and Yankai Lin and Chaojun Xiao and Chenyang Song and Zhiyuan Liu and Zeyu Mi and Maosong Sun}, 48 | journal = {ArXiv}, 49 | year = {2024}, 50 | volume = {abs/2402.03804}, 51 | url = {https://api.semanticscholar.org/CorpusID:267499856} 52 | } 53 | ``` 54 | 55 | ```bibtex 56 | @inproceedings{Horuz2025TheRO, 57 | title = {The Resurrection of the ReLU}, 58 | author = {Cocsku Can Horuz and Geoffrey Kasenbacher and Saya Higuchi and Sebastian Kairat and Jendrik Stoltz and Moritz Pesl and Bernhard A. Moser and Christoph Linse and Thomas Martinetz and Sebastian Otte}, 59 | year = {2025}, 60 | url = {https://api.semanticscholar.org/CorpusID:278959515} 61 | } 62 | ``` 63 | 64 | ```bibtex 65 | @article{Loshchilov2024nGPTNT, 66 | title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere}, 67 | author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg}, 68 | journal = {ArXiv}, 69 | year = {2024}, 70 | volume = {abs/2410.01131}, 71 | url = {https://api.semanticscholar.org/CorpusID:273026160} 72 | } 73 | ``` 74 | 75 | ```bibtex 76 | @article{Lee2025HypersphericalNF, 77 | title = {Hyperspherical Normalization for Scalable Deep Reinforcement Learning}, 78 | author = {Hojoon Lee and Youngdo Lee and Takuma Seno and Donghu Kim and Peter Stone and Jaegul Choo}, 79 | journal = {ArXiv}, 80 | year = {2025}, 81 | volume = {abs/2502.15280}, 82 | url = {https://api.semanticscholar.org/CorpusID:276558261} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "x-mlps-pytorch" 3 | version = "0.0.20" 4 | description = "A collection of MLPs / Feedforwards for Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'mlps', 15 | 'feedforwards' 16 | ] 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.8', 23 | ] 24 | 25 | dependencies = [ 26 | 'einops>=0.8.0', 27 | 'torch>=2.4', 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://pypi.org/project/x-mlps-pytorch/" 32 | Repository = "https://github.com/lucidrains/x-mlps" 33 | 34 | [build-system] 35 | requires = ["hatchling"] 36 | build-backend = "hatchling.build" 37 | 38 | [project.optional-dependencies] 39 | test = [ 40 | "pytest", 41 | ] 42 | 43 | [tool.pytest.ini_options] 44 | pythonpath = [ 45 | "." 46 | ] 47 | 48 | [tool.hatch.metadata] 49 | allow-direct-references = true 50 | 51 | [tool.hatch.build.targets.wheel] 52 | packages = ["x_mlps_pytorch"] 53 | -------------------------------------------------------------------------------- /tests/test_activations.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | 4 | import torch 5 | from x_mlps_pytorch.activations import ReluNelu 6 | 7 | def test_relu_nelu(): 8 | inp = torch.randn(3) 9 | out = ReluNelu(0.01)(inp) 10 | 11 | assert inp.shape == out.shape 12 | -------------------------------------------------------------------------------- /tests/test_ff.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | def test_ff(): 5 | from x_mlps_pytorch.ff import Feedforwards 6 | 7 | ff = Feedforwards(256, 4, dim_in = 128, dim_out = 128) 8 | 9 | x = torch.randn(7, 3, 128) 10 | 11 | assert ff(x).shape == x.shape 12 | 13 | @pytest.mark.parametrize('preserve_magnitude', (False, True)) 14 | def test_nff( 15 | preserve_magnitude 16 | ): 17 | from x_mlps_pytorch.nff import nFeedforwards, norm_weights_ 18 | 19 | ff = nFeedforwards(256, 4, input_preserve_magnitude = preserve_magnitude) 20 | 21 | x = torch.randn(7, 3, 256) 22 | 23 | assert ff(x).shape == x.shape 24 | 25 | norm_weights_(ff) 26 | -------------------------------------------------------------------------------- /tests/test_grouped_ff.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | def test_grouped_ff_one_group(): 5 | from x_mlps_pytorch.grouped_ff import GroupedFeedforwards 6 | 7 | ff = GroupedFeedforwards(256, 4, dim_in = 128, dim_out = 128, squeeze_if_one_group = True) 8 | 9 | x = torch.randn(7, 3, 128) 10 | 11 | assert ff(x).shape == x.shape 12 | 13 | 14 | def test_grouped_ff_one_group(): 15 | from x_mlps_pytorch.grouped_ff import GroupedFeedforwards 16 | 17 | ff = GroupedFeedforwards(256, 4, dim_in = 128, dim_out = 128, squeeze_if_one_group = True, groups = 2) 18 | 19 | x = torch.randn(7, 3, 128) 20 | 21 | assert ff(x).shape == (7, 3, 2, 128) 22 | -------------------------------------------------------------------------------- /tests/test_grouped_mlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | @pytest.mark.parametrize('groups', (1, 4)) 5 | def test_mlp(groups): 6 | from x_mlps_pytorch.grouped_mlp import GroupedMLP 7 | 8 | mlp = GroupedMLP(256, 128, 64, groups = groups) 9 | 10 | x = torch.randn(7, 3, 256) 11 | 12 | assert mlp(x).shape == (7, 3, groups, 64) 13 | 14 | # with depth 15 | 16 | @pytest.mark.parametrize('groups', (1, 4)) 17 | def test_create_mlp(groups): 18 | from x_mlps_pytorch.grouped_mlp import create_grouped_mlp 19 | 20 | mlp = create_grouped_mlp( 21 | dim = 128, 22 | dim_in = 256, 23 | dim_out = 64, 24 | depth = 4, 25 | groups = groups 26 | ) 27 | 28 | # same as GroupedMLP(256, 128, 128, 128, 128, 64) 29 | 30 | x = torch.randn(7, 3, 256) 31 | 32 | assert mlp(x).shape == (7, 3, groups, 64) 33 | 34 | # test auto squeeze 1 group, so it can act as regular MLP 35 | 36 | def test_squeeze(): 37 | from x_mlps_pytorch.grouped_mlp import GroupedMLP 38 | 39 | mlp = GroupedMLP(256, 128, 64, squeeze_if_one_group = True) 40 | 41 | x = torch.randn(7, 3, 256) 42 | 43 | assert mlp(x).shape == (7, 3, 64) 44 | -------------------------------------------------------------------------------- /tests/test_mlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | def test_mlp(): 5 | from x_mlps_pytorch.mlp import MLP 6 | 7 | mlp = MLP(256, 128, 64) 8 | 9 | x = torch.randn(7, 3, 256) 10 | 11 | assert mlp(x).shape == (7, 3, 64) 12 | 13 | # with depth 14 | 15 | def test_create_mlp(): 16 | from x_mlps_pytorch.mlp import create_mlp 17 | 18 | mlp = create_mlp( 19 | dim = 128, 20 | dim_in = 256, 21 | dim_out = 64, 22 | depth = 4 23 | ) 24 | 25 | # same as MLP(256, 128, 128, 128, 128, 64) 26 | 27 | x = torch.randn(7, 3, 256) 28 | 29 | assert mlp(x).shape == (7, 3, 64) 30 | -------------------------------------------------------------------------------- /x_mlps_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from x_mlps_pytorch.mlp import ( 2 | MLP 3 | ) 4 | 5 | from x_mlps_pytorch.ff import ( 6 | Feedforwards 7 | ) 8 | 9 | from x_mlps_pytorch.nff import ( 10 | nFeedforwards 11 | ) -------------------------------------------------------------------------------- /x_mlps_pytorch/activations.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch.nn import Module, ReLU 5 | 6 | # relu squared with optional signing 7 | 8 | class ReluSquared(Module): 9 | def __init__(self, signed = False): 10 | super().__init__() 11 | self.signed = signed 12 | 13 | def forward(self, x): 14 | out = x.relu().square() 15 | 16 | if not self.signed: 17 | return out 18 | 19 | return out * x.sign() 20 | 21 | # sugar-(bsilu | nelu) 22 | 23 | class BSiLU(Module): 24 | # eq (7) in paper 25 | 26 | def __init__(self, alpha = 1.67): 27 | super().__init__() 28 | self.alpha = alpha 29 | 30 | def forward(self, x): 31 | α = self.alpha 32 | return (x + α) * x.sigmoid() - α / 2 33 | 34 | class NeLU(Module): 35 | def __init__(self, alpha = 0.05): 36 | super().__init__() 37 | self.alpha = alpha 38 | 39 | def forward(self, x): 40 | α = self.alpha 41 | return -α / (1. + x.square()) 42 | 43 | class StraightThrough(Module): 44 | def __init__( 45 | self, 46 | forward_fn: Module, 47 | backward_fn: Module 48 | ): 49 | super().__init__() 50 | self.forward_fn = forward_fn 51 | self.backward_fn = backward_fn 52 | 53 | def forward(self, x): 54 | hard = self.forward_fn(x) 55 | 56 | if not x.requires_grad: 57 | return hard 58 | 59 | soft = self.backward_fn(x) 60 | 61 | # straight-through during training 62 | 63 | return soft + (hard - soft).detach() 64 | 65 | class Sugar(Module): 66 | def __init__( 67 | self, 68 | forward_fn: Module, 69 | backward_fn: Module 70 | ): 71 | super().__init__() 72 | self.forward_fn = forward_fn 73 | self.backward_fn = backward_fn 74 | 75 | def forward(self, x): 76 | forward_out = self.forward_fn(x) 77 | 78 | if not x.requires_grad: 79 | return forward_out 80 | 81 | backward_out = self.backward_fn(x) 82 | 83 | # only neg region for backward function gradients 84 | 85 | soft = torch.where(x > 0, forward_out, backward_out) 86 | 87 | # straight-through during training 88 | 89 | return soft + (forward_out - soft).detach() 90 | 91 | # the one that beat gelu in transformer setting for me 92 | 93 | def ReluNelu(alpha = 0.05): 94 | return Sugar(ReLU(), NeLU(alpha)) 95 | -------------------------------------------------------------------------------- /x_mlps_pytorch/ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module, ModuleList 4 | 5 | # functions 6 | 7 | def exists(v): 8 | return v is not None 9 | 10 | # main class 11 | 12 | class Feedforwards(Module): 13 | 14 | def __init__( 15 | self, 16 | dim, 17 | depth, 18 | *, 19 | dim_in = None, 20 | dim_out = None, 21 | activation = nn.GELU(), 22 | bias = True, 23 | expansion_factor = 4., 24 | final_norm = False 25 | ): 26 | super().__init__() 27 | 28 | layers = [] 29 | 30 | dim_hidden = int(dim * expansion_factor) 31 | 32 | # layers 33 | 34 | for _ in range(depth): 35 | 36 | layer = nn.Sequential( 37 | nn.RMSNorm(dim), 38 | nn.Linear(dim, dim_hidden, bias = bias), 39 | activation, 40 | nn.Linear(dim_hidden, dim, bias = bias) 41 | ) 42 | 43 | layers.append(layer) 44 | 45 | self.layers = ModuleList(layers) 46 | 47 | # maybe final norm 48 | 49 | self.norm = nn.RMSNorm(dim) if final_norm else nn.Identity() 50 | 51 | # proj in and out 52 | 53 | self.proj_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() 54 | self.proj_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() 55 | 56 | def forward( 57 | self, 58 | x 59 | ): 60 | 61 | x = self.proj_in(x) 62 | 63 | for layer in self.layers: 64 | x = layer(x) + x 65 | 66 | x = self.norm(x) 67 | 68 | return self.proj_out(x) 69 | -------------------------------------------------------------------------------- /x_mlps_pytorch/grouped_ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Module, ModuleList 5 | 6 | from einops import rearrange, repeat, pack, unpack 7 | 8 | # functions 9 | 10 | def exists(v): 11 | return v is not None 12 | 13 | def first(arr): 14 | return arr[0] 15 | 16 | def pack_with_inverse(t, pattern): 17 | packed, shape = pack([t], pattern) 18 | 19 | def inverse(out): 20 | return first(unpack(out, shape, pattern)) 21 | 22 | return packed, inverse 23 | 24 | # modules 25 | 26 | class GroupedRMSNorm(Module): 27 | def __init__( 28 | self, 29 | dim, 30 | groups = 1 31 | ): 32 | super().__init__() 33 | self.groups = groups 34 | self.scale = dim ** 0.5 35 | self.gamma = nn.Parameter(torch.ones(dim)) 36 | 37 | def forward(self, x): 38 | # grouped l2norm 39 | 40 | x = rearrange(x, '... (g d) n -> ... g d n ', g = self.groups) 41 | x = F.normalize(x, dim = -2, p = 2) 42 | x = rearrange(x, '... g d n -> ... (g d) n') 43 | 44 | gamma = rearrange(self.gamma, 'd -> d 1') # channel first 45 | 46 | return gamma * x * self.scale 47 | 48 | # main class 49 | 50 | class GroupedFeedforwards(Module): 51 | 52 | def __init__( 53 | self, 54 | dim, 55 | depth, 56 | *, 57 | dim_in = None, 58 | dim_out = None, 59 | activation = nn.GELU(), 60 | bias = True, 61 | expansion_factor = 4., 62 | final_norm = False, 63 | groups = 1, 64 | squeeze_if_one_group = False 65 | ): 66 | super().__init__() 67 | 68 | layers = [] 69 | 70 | # take care of groups 71 | 72 | self.groups = groups 73 | self.squeeze_if_one_group = squeeze_if_one_group 74 | 75 | dim = dim * groups 76 | first_dim = dim * groups 77 | 78 | if exists(dim_in): 79 | dim_in *= groups 80 | first_dim = dim_in 81 | 82 | if exists(dim_out): 83 | dim_out *= groups 84 | 85 | dim_hidden = int(dim * expansion_factor) 86 | 87 | self.first_dim = first_dim 88 | 89 | # layers 90 | 91 | for _ in range(depth): 92 | 93 | layer = nn.Sequential( 94 | GroupedRMSNorm(dim, groups = groups), 95 | nn.Conv1d(dim, dim_hidden, 1, bias = bias, groups = groups), 96 | activation, 97 | nn.Conv1d(dim_hidden, dim, 1, bias = bias, groups = groups) 98 | ) 99 | 100 | layers.append(layer) 101 | 102 | self.layers = ModuleList(layers) 103 | 104 | # maybe final norm 105 | 106 | self.norm = GroupedRMSNorm(dim, groups = groups) if final_norm else nn.Identity() 107 | 108 | # proj in and out 109 | 110 | self.proj_in = nn.Conv1d(dim_in, dim, 1, groups = groups) if exists(dim_in) else nn.Identity() 111 | self.proj_out = nn.Conv1d(dim, dim_out, 1, groups = groups) if exists(dim_out) else nn.Identity() 112 | 113 | def forward( 114 | self, 115 | x 116 | ): 117 | 118 | dim = x.shape[-1] 119 | 120 | # channel first 121 | 122 | x = rearrange(x, 'b ... d -> b d ...') 123 | 124 | # repeat for groups if needed 125 | 126 | if dim != self.first_dim: 127 | x = repeat(x, 'b d ... -> b (g d) ...', g = self.groups) 128 | 129 | # pack 130 | 131 | x, inv_pack = pack_with_inverse(x, 'b d *') 132 | 133 | # project in 134 | 135 | x = self.proj_in(x) 136 | 137 | for layer in self.layers: 138 | x = layer(x) + x 139 | 140 | x = self.norm(x) 141 | 142 | x = self.proj_out(x) 143 | 144 | # get back the spatial dimensions 145 | 146 | x = inv_pack(x) 147 | x = rearrange(x, 'b d ... -> b ... d') 148 | 149 | x = rearrange(x, 'b ... (g d) -> b ... g d', g = self.groups) 150 | 151 | if self.squeeze_if_one_group and self.groups == 1: 152 | x = rearrange(x, 'b ... 1 d -> b ... d') 153 | 154 | return x 155 | -------------------------------------------------------------------------------- /x_mlps_pytorch/grouped_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module, ModuleList 4 | 5 | from einops import pack, unpack, rearrange, repeat 6 | 7 | # functions 8 | 9 | def exists(v): 10 | return v is not None 11 | 12 | def first(arr): 13 | return arr[0] 14 | 15 | def pack_with_inverse(t, pattern): 16 | packed, shape = pack([t], pattern) 17 | 18 | def inverse(out): 19 | return first(unpack(out, shape, pattern)) 20 | 21 | return packed, inverse 22 | 23 | # main class 24 | 25 | class GroupedMLP(Module): 26 | def __init__( 27 | self, 28 | *dims, 29 | activation = nn.ReLU(), 30 | bias = True, 31 | activate_last = False, 32 | groups = 1, 33 | squeeze_if_one_group = False 34 | ): 35 | super().__init__() 36 | assert len(dims) > 1, f'must have more than 1 layer' 37 | 38 | layers = [] 39 | 40 | # input output dimension pairs 41 | 42 | dims = tuple(dim * groups for dim in dims) 43 | first_dim = first(dims) 44 | dim_in_out = tuple(zip(dims[:-1], dims[1:])) 45 | 46 | # layers 47 | 48 | for i, (dim_in, dim_out) in enumerate(dim_in_out, start = 1): 49 | is_last = i == len(dim_in_out) 50 | 51 | layer = nn.Conv1d(dim_in, dim_out, 1, groups = groups, bias = bias) 52 | 53 | # if not last, add an activation after each linear layer 54 | 55 | if not is_last or activate_last: 56 | layer = nn.Sequential(layer, activation) 57 | 58 | layers.append(layer) 59 | 60 | self.layers = ModuleList(layers) 61 | 62 | # groups 63 | 64 | self.groups = groups 65 | self.first_dim = first_dim 66 | self.squeeze_if_one_group = squeeze_if_one_group and groups == 1 67 | 68 | def forward( 69 | self, 70 | x 71 | ): 72 | dim = x.shape[-1] 73 | 74 | # channel first 75 | 76 | x = rearrange(x, 'b ... d -> b d ...') 77 | 78 | # repeat for groups if needed 79 | 80 | if dim != self.first_dim: 81 | x = repeat(x, 'b d ... -> b (g d) ...', g = self.groups) 82 | 83 | # pack 84 | 85 | x, inv_pack = pack_with_inverse(x, 'b d *') 86 | 87 | # layers 88 | 89 | for layer in self.layers: 90 | x = layer(x) 91 | 92 | # get back the spatial dimensions 93 | 94 | x = inv_pack(x) 95 | x = rearrange(x, 'b d ... -> b ... d') 96 | 97 | x = rearrange(x, 'b ... (g d) -> b ... g d', g = self.groups) 98 | 99 | if self.squeeze_if_one_group and self.groups == 1: 100 | x = rearrange(x, 'b ... 1 d -> b ... d') 101 | 102 | return x 103 | 104 | # factory function 105 | 106 | def create_grouped_mlp( 107 | dim, 108 | depth, 109 | *, 110 | dim_in = None, 111 | dim_out = None, 112 | **mlp_kwargs 113 | ): 114 | dims = (dim,) * depth 115 | 116 | if exists(dim_in): 117 | dims = (dim_in, *dims) 118 | 119 | if exists(dim_out): 120 | dims = (*dims, dim_out) 121 | 122 | return GroupedMLP(*dims, **mlp_kwargs) 123 | -------------------------------------------------------------------------------- /x_mlps_pytorch/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module, ModuleList 4 | 5 | # functions 6 | 7 | def exists(v): 8 | return v is not None 9 | 10 | # main class 11 | 12 | class MLP(Module): 13 | def __init__( 14 | self, 15 | *dims, 16 | activation = nn.ReLU(), 17 | bias = True, 18 | activate_last = False 19 | ): 20 | super().__init__() 21 | assert len(dims) > 1, f'must have more than 1 layer' 22 | 23 | layers = [] 24 | 25 | # input output dimension pairs 26 | 27 | dim_in_out = tuple(zip(dims[:-1], dims[1:])) 28 | 29 | # layers 30 | 31 | for i, (dim_in, dim_out) in enumerate(dim_in_out, start = 1): 32 | is_last = i == len(dim_in_out) 33 | 34 | layer = nn.Linear(dim_in, dim_out, bias = bias) 35 | 36 | # if not last, add an activation after each linear layer 37 | 38 | if not is_last or activate_last: 39 | layer = nn.Sequential(layer, activation) 40 | 41 | layers.append(layer) 42 | 43 | self.layers = ModuleList(layers) 44 | 45 | def forward( 46 | self, 47 | x 48 | ): 49 | 50 | for layer in self.layers: 51 | x = layer(x) 52 | 53 | return x 54 | 55 | # factory function 56 | 57 | def create_mlp( 58 | dim, 59 | depth, 60 | *, 61 | dim_in = None, 62 | dim_out = None, 63 | **mlp_kwargs 64 | ): 65 | dims = (dim,) * depth 66 | 67 | if exists(dim_in): 68 | dims = (dim_in, *dims) 69 | 70 | if exists(dim_out): 71 | dims = (*dims, dim_out) 72 | 73 | return MLP(*dims, **mlp_kwargs) 74 | -------------------------------------------------------------------------------- /x_mlps_pytorch/nff.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/2410.01131 2 | 3 | from __future__ import annotations 4 | from functools import partial 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from torch.nn import Module, ModuleList 10 | from torch.nn.utils.parametrize import register_parametrization 11 | 12 | # functions 13 | 14 | def exists(v): 15 | return v is not None 16 | 17 | def default(v, d): 18 | return v if exists(v) else d 19 | 20 | def cast_tuple(t, length = 1): 21 | out = t if isinstance(t, tuple) else ((t,) * length) 22 | assert len(out) == length 23 | return out 24 | 25 | def l2norm(t, dim = -1): 26 | return F.normalize(t, dim = dim) 27 | 28 | # norming of the weights 29 | 30 | def norm_weights_(parent_module: Module): 31 | for module in parent_module.modules(): 32 | if not isinstance(module, NormLinear): 33 | continue 34 | 35 | module.norm_weights_() 36 | 37 | # scale 38 | 39 | class Scale(Module): 40 | def __init__( 41 | self, 42 | dim, 43 | init = 1., 44 | scale = 1. 45 | ): 46 | super().__init__() 47 | self.dim = dim 48 | self.scale = nn.Parameter(torch.ones(dim) * scale) 49 | self.forward_scale = init / scale 50 | 51 | def forward(self): 52 | return self.scale * self.forward_scale 53 | 54 | # residual slerp update with learned scale 55 | 56 | class Residual(Module): 57 | def __init__( 58 | self, 59 | fn: Module, 60 | dim: int, 61 | init: float, 62 | scale: float | None = None, 63 | ): 64 | super().__init__() 65 | self.fn = fn 66 | self.branch_scale = Scale(dim, init, default(scale, dim ** -0.5)) 67 | 68 | def forward(self, x, **kwargs): 69 | residual = x 70 | 71 | out = self.fn(x, **kwargs) 72 | 73 | tuple_output = isinstance(out, tuple) 74 | 75 | if tuple_output: 76 | out, *rest = out 77 | 78 | out = l2norm(out) 79 | out = l2norm(residual.lerp(out, self.branch_scale())) 80 | 81 | if tuple_output: 82 | out = (out, *rest) 83 | 84 | return out 85 | 86 | # for use with parametrize 87 | 88 | class L2Norm(Module): 89 | def __init__(self, dim = -1): 90 | super().__init__() 91 | self.dim = dim 92 | 93 | def forward(self, t): 94 | return l2norm(t, dim = self.dim) 95 | 96 | class NormLinear(Module): 97 | def __init__( 98 | self, 99 | dim, 100 | dim_out, 101 | norm_dim_in = True, 102 | parametrize = True 103 | ): 104 | super().__init__() 105 | self.dim = dim 106 | self.dim_out = dim_out 107 | 108 | self.linear = nn.Linear(dim, dim_out, bias = False) 109 | 110 | self.parametrize = parametrize 111 | self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0) 112 | 113 | if parametrize: 114 | register_parametrization( 115 | self.linear, 116 | 'weight', 117 | self.l2norm 118 | ) 119 | 120 | self.norm_weights_() 121 | 122 | @torch.no_grad() 123 | def norm_weights_(self): 124 | if self.parametrize: 125 | normed = self.weight 126 | original = self.linear.parametrizations.weight.original 127 | 128 | original.copy_(normed) 129 | else: 130 | self.weight.copy_(self.l2norm(self.weight)) 131 | 132 | @property 133 | def weight(self): 134 | return self.linear.weight 135 | 136 | def forward(self, x): 137 | return self.linear(x) 138 | 139 | # feedforward 140 | 141 | class nFeedforward(Module): 142 | def __init__( 143 | self, 144 | dim, 145 | *, 146 | expand_factor = 4, 147 | manual_norm_weights = False, 148 | s_hidden_init = 1., 149 | s_hidden_scale = 1., 150 | s_gate_init = 1., 151 | s_gate_scale = 1., 152 | ): 153 | super().__init__() 154 | NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights) 155 | 156 | self.dim = dim 157 | self.expand_factor = expand_factor 158 | 159 | dim_inner = int(dim * expand_factor * 2 / 3) 160 | 161 | self.dim_inner = dim_inner 162 | 163 | self.to_hidden = NormLinear_(dim, dim_inner) 164 | self.to_gate = NormLinear_(dim, dim_inner) 165 | 166 | self.hidden_scale = Scale(dim_inner, s_hidden_init, s_hidden_scale) 167 | self.gate_scale = Scale(dim_inner, s_gate_init, s_gate_scale) 168 | 169 | self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False) 170 | 171 | def forward(self, x): 172 | hidden, gate = self.to_hidden(x), self.to_gate(x) 173 | 174 | hidden = hidden * self.hidden_scale() 175 | gate = gate * self.gate_scale() * (self.dim ** 0.5) 176 | 177 | hidden = F.silu(gate) * hidden 178 | return self.to_out(hidden) 179 | 180 | # classes 181 | 182 | class nFeedforwards(Module): 183 | def __init__( 184 | self, 185 | dim, 186 | depth, 187 | *, 188 | dim_in = None, 189 | dim_out = None, 190 | ff_expand_factor = 4., 191 | input_preserve_magnitude = False, 192 | constant_shift = 3., # simbav2 concatted a constant of 3. before l2norm of the input to preserve magnitude information 193 | manual_norm_weights = False, 194 | # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network 195 | alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_ff_init if they are specified 196 | alpha_attn_init: float | tuple[float, ...] | None = None, 197 | alpha_attn_scale: float | tuple[float, ...] | None = None, 198 | alpha_ff_init: float | tuple[float, ...] | None = None, 199 | alpha_ff_scale: float | tuple[float, ...] | None = None, 200 | s_ff_hidden_init: float | tuple[float, ...] = 1., 201 | s_ff_hidden_scale: float | tuple[float, ...] = 1., 202 | s_ff_gate_init: float | tuple[float, ...] = 1., 203 | s_ff_gate_scale: float | tuple[float, ...] = 1., 204 | 205 | ): 206 | super().__init__() 207 | NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights) 208 | 209 | self.dim = dim 210 | self.depth = depth 211 | self.ff_expand_factor = ff_expand_factor 212 | 213 | alpha_init = default(alpha_init, 1. / depth) 214 | 215 | self.layers = ModuleList([]) 216 | 217 | scale_hparams = ( 218 | alpha_attn_init, 219 | alpha_attn_scale, 220 | alpha_ff_init, 221 | alpha_ff_scale, 222 | s_ff_hidden_init, 223 | s_ff_hidden_scale, 224 | s_ff_gate_init, 225 | s_ff_gate_scale 226 | ) 227 | 228 | scale_hparams = tuple(cast_tuple(hparam, depth) for hparam in scale_hparams) 229 | 230 | for ( 231 | alpha_attn_init_, 232 | alpha_attn_scale_, 233 | alpha_ff_init_, 234 | alpha_ff_scale_, 235 | s_ff_hidden_init_, 236 | s_ff_hidden_scale_, 237 | s_ff_gate_init_, 238 | s_ff_gate_scale_ 239 | ) in zip(*scale_hparams): 240 | 241 | ff = nFeedforward( 242 | dim, 243 | expand_factor = ff_expand_factor, 244 | manual_norm_weights = manual_norm_weights, 245 | s_hidden_init = s_ff_hidden_init_, 246 | s_hidden_scale = s_ff_hidden_scale_, 247 | s_gate_init = s_ff_gate_init_, 248 | s_gate_scale = s_ff_gate_scale_, 249 | ) 250 | 251 | ff_with_residual = Residual( 252 | ff, 253 | dim, 254 | default(alpha_ff_init_, alpha_init), 255 | default(alpha_ff_scale_, dim ** -0.5) 256 | ) 257 | 258 | self.layers.append(ff_with_residual) 259 | 260 | # appending the magnitude 261 | 262 | self.input_preserve_magnitude = input_preserve_magnitude 263 | self.constant_shift = constant_shift 264 | 265 | # projecting in 266 | 267 | self.need_proj_in = exists(dim_in) or input_preserve_magnitude 268 | 269 | if self.need_proj_in: 270 | dim_in = default(dim_in, dim) 271 | dim_constant_shift = int(input_preserve_magnitude) 272 | 273 | self.proj_in = NormLinear(dim_in + dim_constant_shift, dim, norm_dim_in = False) 274 | self.proj_in_scale = Scale(dim) 275 | 276 | # projecting out 277 | 278 | self.need_proj_out = exists(dim_out) 279 | 280 | if self.need_proj_out: 281 | self.proj_out = NormLinear_(dim, dim_out) 282 | self.proj_out_scale = Scale(dim_out, 1., dim ** -0.5) 283 | 284 | @torch.no_grad() 285 | def norm_weights_(self): 286 | norm_weights_(self) 287 | 288 | def forward( 289 | self, 290 | x 291 | ): 292 | 293 | if self.input_preserve_magnitude: 294 | x = F.pad(x, (0, 1), value = self.constant_shift) 295 | x = l2norm(x) 296 | 297 | if self.need_proj_in: 298 | x = self.proj_in(x) * self.proj_in_scale() 299 | x = l2norm(x) 300 | 301 | for ff in self.layers: 302 | x = ff(x) 303 | 304 | if self.need_proj_out: 305 | x = self.proj_out(x) * self.proj_out_scale() 306 | 307 | return x 308 | 309 | # copy-pasteable file 310 | 311 | if __name__ == '__main__': 312 | 313 | nff = nFeedforwards(512, 4, dim_in = 128, dim_out = 128, preserve_magnitude = True) 314 | x = torch.randn((2, 128)) 315 | 316 | assert nff(x).shape == x.shape 317 | --------------------------------------------------------------------------------