├── examples ├── data │ ├── __init__.py │ ├── cifar10.py │ ├── mnist.py │ └── shakespeare.py ├── assets │ └── erasure.png └── hello-world.ipynb ├── docs ├── source │ ├── examples │ │ ├── assets │ │ │ └── README.md │ │ ├── hello-gpt.nblink │ │ ├── hello-mnist.nblink │ │ ├── hello-world.nblink │ │ └── weight-erasure.nblink │ ├── favicon.ico │ ├── figure │ │ ├── platonize.png │ │ ├── alignment.py │ │ ├── tangent.py │ │ ├── nn.py │ │ └── sweeps.py │ ├── _static │ │ ├── logo-square.jpeg │ │ ├── custom.css │ │ ├── logo-dark.svg │ │ └── logo-light.svg │ ├── theory │ │ ├── compound │ │ │ ├── gpt.rst │ │ │ └── index.rst │ │ ├── module.rst │ │ ├── vector.rst │ │ ├── atom │ │ │ ├── conv2d.rst │ │ │ ├── linear.rst │ │ │ ├── embed.rst │ │ │ └── index.rst │ │ └── bond │ │ │ ├── nonlinearities.rst │ │ │ └── index.rst │ ├── algorithms │ │ ├── manifold │ │ │ ├── index.rst │ │ │ ├── hypersphere.rst │ │ │ ├── orthogonal.rst │ │ │ └── stiefel.rst │ │ └── newton-schulz.rst │ ├── intro │ │ ├── quickstart.rst │ │ ├── reading-list.rst │ │ └── whats-in-a-norm.rst │ ├── bad-scaling.rst │ ├── conf.py │ ├── index.rst │ ├── history.rst │ ├── golden-rules.rst │ └── faq.rst ├── requirements.txt ├── README.md ├── make.bat └── Makefile ├── modula ├── __init__.py ├── compound.py ├── atom.py ├── bond.py └── abstract.py ├── .gitignore ├── pyproject.toml ├── LICENSE ├── .github └── workflows │ └── sphinx.yml ├── README.md └── assets ├── modula.svg ├── modula_light.svg └── gpt-owt-context.svg /examples/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/examples/assets/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/examples/hello-gpt.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../examples/hello-gpt.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modula-systems/modula/HEAD/docs/source/favicon.ico -------------------------------------------------------------------------------- /docs/source/examples/hello-mnist.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../examples/hello-mnist.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/examples/hello-world.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../examples/hello-world.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /examples/assets/erasure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modula-systems/modula/HEAD/examples/assets/erasure.png -------------------------------------------------------------------------------- /modula/__init__.py: -------------------------------------------------------------------------------- 1 | from . import abstract 2 | from . import atom 3 | from . import bond 4 | from . import compound 5 | -------------------------------------------------------------------------------- /docs/source/figure/platonize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modula-systems/modula/HEAD/docs/source/figure/platonize.png -------------------------------------------------------------------------------- /docs/source/_static/logo-square.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modula-systems/modula/HEAD/docs/source/_static/logo-square.jpeg -------------------------------------------------------------------------------- /docs/source/theory/compound/gpt.rst: -------------------------------------------------------------------------------- 1 | GPT 2 | ==== 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/theory/module.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======== 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/theory/vector.rst: -------------------------------------------------------------------------------- 1 | Vectors 2 | ======== 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/theory/atom/conv2d.rst: -------------------------------------------------------------------------------- 1 | Conv2d 2 | ======= 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/theory/atom/linear.rst: -------------------------------------------------------------------------------- 1 | Linear 2 | ======= 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/theory/atom/embed.rst: -------------------------------------------------------------------------------- 1 | Embedding 2 | ========== 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | ul.starlist { 2 | list-style-type: "🌟 "; 3 | margin-bottom: -5px; 4 | } 5 | 6 | ul.starlist li { 7 | margin-bottom: 15px; 8 | } -------------------------------------------------------------------------------- /docs/source/theory/bond/nonlinearities.rst: -------------------------------------------------------------------------------- 1 | Nonlinearities 2 | =============== 3 | 4 | .. admonition:: Warning 5 | :class: warning 6 | 7 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/examples/weight-erasure.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../examples/weight-erasure.ipynb", 3 | "extra-media": [ 4 | "../../../examples/assets/erasure.png" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | furo 2 | sphinx_copybutton 3 | sphinx_inline_tabs 4 | sphinx-autobuild 5 | sphinxext.opengraph 6 | matplotlib 7 | sphinx-design 8 | sphinxcontrib-youtube 9 | nbsphinx 10 | nbsphinx_link 11 | IPython 12 | -------------------------------------------------------------------------------- /docs/source/theory/bond/index.rst: -------------------------------------------------------------------------------- 1 | Bond modules 2 | ============= 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | nonlinearities 8 | 9 | .. admonition:: Warning 10 | :class: warning 11 | 12 | This page is still under construction. -------------------------------------------------------------------------------- /docs/source/theory/compound/index.rst: -------------------------------------------------------------------------------- 1 | Compound modules 2 | ================= 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | gpt 8 | 9 | .. admonition:: Warning 10 | :class: warning 11 | 12 | This page is still under construction. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | modula.egg-info/ 4 | modula/__pycache__/ 5 | 6 | docs/build/ 7 | docs/jupyter_execute/ 8 | 9 | docs/source/examples/assets/ 10 | 11 | examples/.ipynb_checkpoints/ 12 | examples/data/* 13 | !examples/data/*.py 14 | -------------------------------------------------------------------------------- /docs/source/theory/atom/index.rst: -------------------------------------------------------------------------------- 1 | Atomic modules 2 | =============== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | linear 8 | embed 9 | conv2d 10 | 11 | .. admonition:: Warning 12 | :class: warning 13 | 14 | This page is still under construction. -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Served docs 2 | 3 | The docs are automatically served to [https://docs.modula.systems/](https://docs.modula.systems/). 4 | 5 | ## Building the docs locally 6 | 7 | To build these docs locally do: 8 | ```bash 9 | cd docs 10 | pip install -r requirements.txt 11 | conda install -c conda-forge pandoc 12 | make livedirhtml 13 | ``` 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "modula" 7 | version = "0.0.0.0.0.2" 8 | authors = [ 9 | { name="Jeremy Bernstein", email="jbernstein@mit.edu" }, 10 | ] 11 | description = "Numerically sound neural network library." 12 | readme = "README.md" 13 | requires-python = ">=3.9" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies = [ 20 | "jax>=0.5.0" 21 | ] 22 | 23 | [tool.setuptools] 24 | packages = ["modula"] 25 | 26 | [project.urls] 27 | "Homepage" = "https://modula.systems/" 28 | "Repository" = "https://github.com/modula-systems/modula" 29 | "Bug Tracker" = "https://github.com/modula-systems/modula/issues" 30 | "Documentation" = "https://docs.modula.systems/" 31 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | 20 | livehtml: 21 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | livedirhtml: 24 | sphinx-autobuild -b dirhtml --watch ../examples --ignore source/examples/assets "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | 26 | .PHONY: help Makefile livehtml livedirhtml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jeremy Bernstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/sphinx.yml: -------------------------------------------------------------------------------- 1 | name: "Sphinx: Render docs" 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Build HTML 16 | uses: ammaraskar/sphinx-action@7.0.0 17 | with: 18 | docs-folder: docs/ 19 | pre-build-command: > 20 | apt-get update -y; 21 | apt-get install -y wget unzip pandoc; 22 | wget -nc https://github.com/ipython/xkcd-font/blob/master/xkcd-script/font/xkcd-script.ttf?raw=true -O /usr/local/share/fonts/xkcd-Script.ttf; 23 | wget -nc https://github.com/ipython/xkcd-font/blob/master/xkcd/build/xkcd.otf?raw=true -O /usr/local/share/fonts/xkcd.otf; 24 | wget -nc https://github.com/antimatter15/doge/blob/master/Comic%20Sans%20MS.ttf?raw=true -O /usr/local/share/fonts/comic-sans.otf; 25 | fc-cache -f -v; 26 | build-command: "sphinx-build -b dirhtml source build" 27 | - name: Upload artifacts 28 | uses: actions/upload-artifact@v4 29 | with: 30 | name: html-docs 31 | path: docs/build/ 32 | - name: Deploy 33 | uses: peaceiris/actions-gh-pages@v3 34 | if: github.ref == 'refs/heads/main' 35 | with: 36 | github_token: ${{ secrets.GITHUB_TOKEN }} 37 | publish_dir: docs/build 38 | cname: docs.modula.systems 39 | -------------------------------------------------------------------------------- /docs/source/algorithms/manifold/index.rst: -------------------------------------------------------------------------------- 1 | Manifold duality maps 2 | ====================== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | hypersphere 8 | orthogonal 9 | stiefel 10 | 11 | In this section we will derive steepest descent optimization algorithms for manifolds equipped with certain norms. These algorithms will be useful if we want to construct steepest descent optimizers for modules where the tensors obey some natural constraints. In particular, we will consider: 12 | 13 | For vectors: 14 | 15 | - :doc:`steepest descent under the Euclidean norm on the hypersphere; ` 16 | 17 | For matrices: 18 | 19 | - :doc:`steepest descent under the spectral norm on the orthogonal manifold. ` 20 | - :doc:`steepest descent under the spectral norm on the Stiefel manifold. ` 21 | 22 | In each case, we will adopt the following strategy: 23 | 24 | 1. characterize the "tangent space" to the manifold; 25 | 2. solve for the steepest direction in the tangent space under the given norm; 26 | 3. work out a "retraction map", which projects a step taken in the tangent space back to the manifold. 27 | 28 | We can think of the tangent space to the manifold as a plane lying tangent to the manifold. So each point on the manifold has its own tangent space. Since the manifold is curved, taking a discrete step in the tangent space will leave the manifold. Therefore we need to project back to the manifold using the retraction map. It's always helpful to keep in mind the following picture: 29 | 30 | .. plot:: figure/tangent.py 31 | 32 | -------------------------------------------------------------------------------- /modula/compound.py: -------------------------------------------------------------------------------- 1 | from modula.abstract import * 2 | from modula.atom import * 3 | from modula.bond import * 4 | 5 | def MLP(output_dim, input_dim, width, depth): 6 | m = Linear(output_dim, width) @ ReLU() 7 | for _ in range(depth-2): 8 | m = m @ Linear(width, width) @ ReLU() 9 | return m @ Linear(width, input_dim) 10 | 11 | def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal): 12 | """Multi-head attention""" 13 | Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) 14 | K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) 15 | V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed) 16 | W = Linear(d_embed, num_heads * d_value) @ MergeHeads() 17 | 18 | AttentionScores = Softmax(softmax_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K) 19 | return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores) 20 | 21 | def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0): 22 | embed = Embed(d_embed, vocab_size) 23 | embed.tare() 24 | 25 | att = Attention(num_heads, d_embed, d_query, d_value, attention_scale, causal=True) 26 | mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed) 27 | att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att 28 | mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp 29 | blocks = (mlp_block @ att_block) ** num_blocks 30 | blocks.tare(absolute=blocks_mass) 31 | 32 | out = final_scale * Linear(vocab_size, d_embed) 33 | 34 | return out @ blocks @ embed -------------------------------------------------------------------------------- /docs/source/figure/alignment.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Enable xkcd style 4 | plt.xkcd() 5 | 6 | # Number of layers in the neural network 7 | L = 9 8 | mid_layer = 4 9 | layers = [' W ' if i == mid_layer else r"$\quad$" for i in range(L)] 10 | 11 | # Plot the layers horizontally with arrows and separate block for Delta W closer 12 | fig, ax = plt.subplots(figsize=(12, 2.9)) 13 | fig.patch.set_alpha(0) # Make the background transparent 14 | 15 | # Create blocks for each layer and add arrows 16 | for i, layer in enumerate(layers): 17 | if i < mid_layer: 18 | facecolor = '#FF6961' 19 | elif i == mid_layer: 20 | facecolor = '#FDFD96' 21 | else: 22 | facecolor = '#77B5FE' 23 | 24 | with plt.style.context({'path.effects': []}): 25 | ax.text(i-0.1, 0.8, layer, ha='center', va='center', fontsize=16, 26 | bbox=dict(facecolor=facecolor, edgecolor='black', pad=10.0)) 27 | if i < L - 1: 28 | ax.arrow(i, 0.8, 0.6, 0, head_width=0.1, head_length=0.1, fc='black', ec='black') 29 | 30 | # Highlight the middle layer with an update in a separate block above, closer 31 | 32 | with plt.style.context({'path.effects': []}): 33 | ax.text(mid_layer-0.1, 1.8, 'ΔW', ha='center', va='center', fontsize=16, 34 | bbox=dict(facecolor='#FDFD96', edgecolor='black', pad=10.0)) 35 | ax.arrow(mid_layer-0.1, 1.8, 0, -0.70, head_width=0.1, head_length=0.1, fc='black', ec='black') 36 | 37 | ax.text(1.5-0.1, 0.3, "head of the network", ha='center', va='center', fontsize=16) 38 | ax.text(6.5-0.1, 0.3, "tail of the network", ha='center', va='center', fontsize=16) 39 | ax.text(4-0.1, 0.3, "middle layer", ha='center', va='center', fontsize=16) 40 | 41 | # Show the plot 42 | ax.axis('off') 43 | ax.set_aspect('equal') 44 | plt.tight_layout() 45 | plt.show() 46 | -------------------------------------------------------------------------------- /examples/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import urllib.request 5 | import tarfile 6 | 7 | def load_cifar10(normalize=True): 8 | """ 9 | Downloads (if needed) and loads the CIFAR-10 dataset. 10 | Returns: (train_images, train_labels, test_images, test_labels) 11 | """ 12 | # Create data directory 13 | data_dir = os.path.join(os.path.dirname(__file__), "cifar10_files") 14 | if not os.path.exists(data_dir): 15 | os.makedirs(data_dir) 16 | 17 | # Download files if needed 18 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 19 | filepath = os.path.join(data_dir, "cifar-10-python.tar.gz") 20 | extracted_dir = os.path.join(data_dir, "cifar-10-batches-py") 21 | 22 | if not os.path.exists(extracted_dir): 23 | if not os.path.isfile(filepath): 24 | print(f"Downloading {url}") 25 | urllib.request.urlretrieve(url, filepath) 26 | with tarfile.open(filepath, 'r:gz') as tar: 27 | tar.extractall(data_dir) 28 | 29 | def load_batch(filename): 30 | with open(os.path.join(extracted_dir, filename), 'rb') as f: 31 | batch = pickle.load(f, encoding='bytes') 32 | return batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), np.array(batch[b'labels']) 33 | 34 | # Load training data 35 | train_images, train_labels = [], [] 36 | for i in range(1, 6): 37 | images, labels = load_batch(f'data_batch_{i}') 38 | train_images.append(images) 39 | train_labels.append(labels) 40 | 41 | train_images = np.concatenate(train_images) 42 | train_labels = np.concatenate(train_labels) 43 | 44 | # Load test data 45 | test_images, test_labels = load_batch('test_batch') 46 | 47 | if normalize: 48 | train_images = train_images.astype(np.float32) / 255.0 49 | test_images = test_images.astype(np.float32) / 255.0 50 | 51 | return train_images, train_labels, test_images, test_labels -------------------------------------------------------------------------------- /docs/source/intro/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | =========== 3 | 4 | Modula is a neural networks library built on top of `JAX `_. 5 | 6 | Installation 7 | ------------- 8 | 9 | Modula can be installed using pip: 10 | 11 | .. code-block:: bash 12 | 13 | pip install git+https://github.com/modula-systems/modula.git 14 | 15 | Or you can clone the repository and install locally: 16 | 17 | .. code-block:: bash 18 | 19 | git clone https://github.com/modula-systems/modula.git 20 | cd modula 21 | pip install -e . 22 | 23 | Functionality 24 | -------------- 25 | 26 | Modula provides a set of architecture-specific helper functions that are automatically constructed along with the network architecture itself. As an example, let's build a multi-layer perceptron: 27 | 28 | .. code-block:: python 29 | 30 | from modula.atom import Linear 31 | from modula.bond import ReLU 32 | 33 | mlp = Linear(10, 256) 34 | mlp @= ReLU() 35 | mlp @= Linear(256, 256) 36 | mlp @= ReLU() 37 | mlp @= Linear(256, 784) 38 | 39 | mlp.jit() # makes everything run faster 40 | 41 | Behind the scenes, Modula builds a function to randomly initialize the weights of the network: 42 | 43 | .. code-block:: python 44 | 45 | import jax 46 | 47 | key = jax.random.PRNGKey(0) 48 | weights = mlp.initialize(key) 49 | 50 | Supposing we have used JAX to compute the gradient of our loss and stored this as :code:`grad`, then we can use Modula to dualize the gradient, thereby accelerating our gradient descent training: 51 | 52 | .. code-block:: python 53 | 54 | dualized_grad = mlp.dualize(grad) 55 | weights = [w - 0.1 * dg for w, dg in zip(weights, dualized_grad)] 56 | 57 | And after the weight update, we can project the weights back to their natural constraint set: 58 | 59 | .. code-block:: python 60 | 61 | weights = mlp.project(weights) 62 | 63 | In short, Modula lets us think about the weight space of our neural network as a somewhat classical optimization space, complete with duality and projection operations. 64 | -------------------------------------------------------------------------------- /examples/data/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gzip 3 | import os 4 | import struct 5 | import urllib.request 6 | 7 | def load_mnist(normalize=True): 8 | """ 9 | Downloads (if needed) and loads the MNIST dataset. 10 | Returns: (train_images, train_labels, test_images, test_labels) 11 | """ 12 | # Create data directory 13 | data_dir = os.path.join(os.path.dirname(__file__), "mnist_files") 14 | if not os.path.exists(data_dir): 15 | os.makedirs(data_dir) 16 | 17 | # Download files if needed 18 | base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" 19 | files = ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 20 | "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"] 21 | 22 | for filename in files: 23 | filepath = os.path.join(data_dir, filename) 24 | if not os.path.isfile(filepath): 25 | url = base_url + filename 26 | print(f"Downloading {url}") 27 | urllib.request.urlretrieve(url, filepath) 28 | 29 | # Load the data 30 | def parse_images(filepath): 31 | with gzip.open(filepath, "rb") as f: 32 | _, num_images, rows, cols = struct.unpack(">IIII", f.read(16)) 33 | return np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, rows, cols) 34 | 35 | def parse_labels(filepath): 36 | with gzip.open(filepath, "rb") as f: 37 | _, num_labels = struct.unpack(">II", f.read(8)) 38 | return np.frombuffer(f.read(), dtype=np.uint8) 39 | 40 | # Load and optionally normalize images 41 | train_images = parse_images(os.path.join(data_dir, "train-images-idx3-ubyte.gz")) 42 | test_images = parse_images(os.path.join(data_dir, "t10k-images-idx3-ubyte.gz")) 43 | train_labels = parse_labels(os.path.join(data_dir, "train-labels-idx1-ubyte.gz")) 44 | test_labels = parse_labels(os.path.join(data_dir, "t10k-labels-idx1-ubyte.gz")) 45 | 46 | if normalize: 47 | train_images = train_images.astype(np.float32) / 255.0 48 | test_images = test_images.astype(np.float32) / 255.0 49 | 50 | return train_images, train_labels, test_images, test_labels -------------------------------------------------------------------------------- /docs/source/intro/reading-list.rst: -------------------------------------------------------------------------------- 1 | Reading list 2 | ============= 3 | 4 | Core papers 5 | ------------ 6 | 7 | - `Scalable optimization in the modular norm `_ 8 | - `Modular duality in deep learning `_ 9 | 10 | Optimization 11 | ------------- 12 | 13 | - `Preconditioned spectral descent for deep learning `_ 14 | - `The duality structure gradient descent algorithm: analysis and applications to neural networks `_ 15 | - `On the distance between two neural networks and the stability of learning `_ 16 | - `Automatic gradient descent: Deep learning without hyperparameters `_ 17 | - `A spectral condition for feature learning `_ 18 | - `Universal majorization-minimization algorithms `_ 19 | - `An isometric stochastic optimizer `_ 20 | - `Old optimizer, new norm: An anthology `_ 21 | - `Muon: An optimizer for hidden layers in neural networks `_ 22 | 23 | Generalization 24 | --------------- 25 | 26 | - `Spectrally-normalized margin bounds for neural networks `_ 27 | - `A PAC-Bayesian approach to spectrally-normalized margin bounds for neural networks `_ 28 | - `Investigating generalization by controlling normalized margin `_ 29 | 30 | New developments 31 | ----------------- 32 | 33 | - `Preconditioning and normalization in optimizing deep neural networks `_ 34 | - `Improving SOAP using iterative whitening and Muon `_ 35 | - `On the concurrence of layer-wise preconditioning methods and provable feature learning `_ 36 | - `A note on the convergence of Muon and further `_ 37 | - `Training deep learning models with norm-constrained LMOs `_ 38 | - `Muon is scalable for LLM training `_ 39 | -------------------------------------------------------------------------------- /docs/source/figure/tangent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from matplotlib.colors import LightSource 5 | 6 | def create_tangent_space_plot(): 7 | fig = plt.figure(figsize=(16, 9)) 8 | ax = fig.add_subplot(111, projection='3d', computed_zorder=False) 9 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1) 10 | fig.patch.set_alpha(0) 11 | ax.patch.set_alpha(0) 12 | 13 | ls = LightSource(azdeg=160, altdeg=120) 14 | 15 | scale = 3.0 16 | resolution = 150 17 | z_offset = -0.9 18 | x_offset = -1.5 19 | 20 | u = np.linspace(0, 2 * np.pi, resolution) 21 | v = np.linspace(0, np.pi, resolution) 22 | x = scale * np.outer(np.cos(u), np.sin(v)) + x_offset 23 | y = scale * np.outer(np.sin(u), np.sin(v)) 24 | z = scale * np.outer(np.ones(np.size(u)), np.cos(v)) + z_offset 25 | 26 | ax.plot_surface(x, y, z, alpha=1.0, color='royalblue', antialiased=False, shade=True, lightsource=ls) 27 | 28 | point = scale * np.array([1/np.sqrt(2.5), 1/np.sqrt(2.5), 1/np.sqrt(5)]) + np.array([x_offset, 0, z_offset]) 29 | v1 = np.array([-1/np.sqrt(2), 1/np.sqrt(2), 0]) 30 | v2 = np.array([-1/np.sqrt(10), -1/np.sqrt(10), 2*np.sqrt(3/10)]) 31 | 32 | grid_size = 0.5 33 | grid_points = 10 34 | xx = np.linspace(-grid_size, grid_size, grid_points) 35 | yy = np.linspace(-grid_size, grid_size, grid_points) 36 | XX, YY = np.meshgrid(xx, yy) 37 | plane_points = np.zeros((XX.shape[0], XX.shape[1], 3)) 38 | 39 | for i in range(XX.shape[0]): 40 | for j in range(XX.shape[1]): 41 | plane_points[i,j] = point + scale * (XX[i,j]*v1 + YY[i,j]*v2) 42 | 43 | ax.plot_surface(plane_points[:,:,0], plane_points[:,:,1], 44 | plane_points[:,:,2], alpha=0.8, color='crimson', 45 | antialiased=False, shade=True, lightsource=ls, edgecolor='firebrick') 46 | 47 | ax.scatter([point[0]], [point[1]], [point[2]], 48 | color='black', s=150, alpha=0.6, zorder=100) 49 | 50 | ax.set_xlim(-3.5, 1.5) 51 | ax.set_ylim(-2.5, 2.5) 52 | ax.set_zlim(-2.5, 1.5) 53 | 54 | ax.set_axis_off() 55 | ax.view_init(elev=20, azim=-30) 56 | ax.set_box_aspect([1,1,0.8]) 57 | 58 | return fig 59 | 60 | fig = create_tangent_space_plot() 61 | plt.show() -------------------------------------------------------------------------------- /docs/source/bad-scaling.rst: -------------------------------------------------------------------------------- 1 | Bad scaling 2 | ============ 3 | 4 | At the simplest level, neural networks are trained by iterating the following operation: 5 | 6 | .. code:: python 7 | 8 | weights -= learning_rate * gradient 9 | 10 | where :python:`learning_rate` is a :python:`float` and :python:`gradient` is the gradient of the loss function with respect to the :python:`weights` of the network. Of course, in practice, we may want to use additional tricks such as momentum, but let's ignore details like that for now. 11 | 12 | Unfortunately, this simple "gradient descent" operation does not scale well if we scale up the network architecture. What does this mean? Suppose that, before training, we "grow" the network by increasing its *width* (the number of neurons in a layer) or its *depth* (the number of layers): 13 | 14 | .. plot:: figure/nn.py 15 | 16 | In practice, we might like to grow other dimensions such as the number of residual blocks in a transformer, but let's stick with this simplified picture for now. 17 | 18 | Under these scaling operations, gradient descent training can break in two main ways. The first problem is that the optimal learning rate can *drift* as we scale certain dimensions. This is a problem because it means we need to re-tune the learning rate as we scale things up---which is expensive and time-consuming. The second problem is that sometimes performance can actually get worse as we grow the network, even if the optimal learning rate remains stable. This is a problem because we grew the network hoping to make performance better, not worse! 19 | 20 | .. plot:: figure/sweeps.py 21 | 22 | These cartoons illustrate typical bad scaling behaviours. On the left, the optimal learning rate drifts with increasing width. On the right, performance deteriorates with increasing depth. 23 | 24 | The good news is that we have developed machinery that largely solves these scaling woes. It turns out that the problem is solved by defining a simple weight initializer along with a special :python:`normalize` function which acts on gradients, leading to a new "normalized" gradient descent algorithm: 25 | 26 | .. code:: python 27 | 28 | weights -= learning_rate * normalize(gradient) 29 | 30 | This initialization and gradient normalization removes drift in the optimal learning rate, and causes performance to improve with increasing scale. Modula automatically infers the necessary initialize and normalize functions from the architecture of the network. So the user can focus on writing their neural network architecture while Modula will handle properly normalizing the training. 31 | 32 | These docs are intended to explain how Modula works and also introduce the Modula API. In case you don't care about Modula or automatic gradient normalization, the next section will explain how you can normalize training manually in a different framework like `PyTorch `_ or `JAX `_. -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information ----------------------------------------------------- 4 | project = 'Modula' 5 | copyright = '2024, Jeremy Bernstein' 6 | author = 'Jeremy Bernstein' 7 | 8 | # -- General configuration --------------------------------------------------- 9 | extensions = [ 10 | "sphinx_copybutton", 11 | "sphinx_inline_tabs", 12 | "sphinx.ext.autodoc", 13 | "matplotlib.sphinxext.plot_directive", 14 | "sphinxext.opengraph", 15 | "sphinx_design", 16 | "sphinxcontrib.youtube", 17 | "nbsphinx", 18 | "nbsphinx_link" 19 | ] 20 | templates_path = ['_templates'] 21 | exclude_patterns = [] 22 | rst_prolog = """ 23 | .. |nbsp| unicode:: U+00A0 .. NO-BREAK SPACE 24 | .. role:: python(code) 25 | :language: python 26 | :class: highlight 27 | """ 28 | 29 | # -- Opengraph --------------------------------------------------------------- 30 | ogp_site_url = "https://modula.systems" 31 | ogp_image = "https://docs.modula.systems/_static/logo-square.jpeg" 32 | 33 | # -- Matplotlib -------------------------------------------------------------- 34 | plot_html_show_formats = False 35 | plot_html_show_source_link = False 36 | plot_formats = ['svg'] 37 | 38 | # -- Options for HTML output ------------------------------------------------- 39 | html_theme = 'furo' 40 | html_static_path = ['_static'] 41 | html_css_files = ['custom.css'] 42 | html_title = "docs.modula.systems" 43 | html_favicon = 'favicon.ico' 44 | html_theme_options = { 45 | "light_logo": "logo-light.svg", 46 | "dark_logo": "logo-dark.svg", 47 | "sidebar_hide_name": True, 48 | "navigation_with_keys": True, 49 | # "announcement": "👷 Under construction 🚧", 50 | "top_of_page_buttons": ["view", "edit"], 51 | "source_repository": "https://github.com/modula-systems/modula/", 52 | "source_branch": "main", 53 | "source_directory": "docs/source/", 54 | "footer_icons": [ 55 | { 56 | "name": "GitHub", 57 | "url": "https://github.com/modula-systems/modula", 58 | "html": """ 59 | 60 | 61 | 62 | """, 63 | "class": "", 64 | }, 65 | ], 66 | } 67 | -------------------------------------------------------------------------------- /docs/source/figure/nn.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib import patheffects 3 | 4 | SIZE = 16 5 | 6 | plt.rc('font', size=SIZE) 7 | plt.rc('axes', titlesize=SIZE) 8 | plt.rc('axes', labelsize=SIZE) 9 | plt.rc('xtick', labelsize=SIZE) 10 | plt.rc('ytick', labelsize=SIZE) 11 | plt.rc('legend', fontsize=SIZE) 12 | plt.rc('figure', titlesize=SIZE) 13 | plt.rcParams['legend.title_fontsize'] = SIZE 14 | 15 | # Define the path effects for nodes and edges with adjusted linewidth for node edges 16 | node_path_effects = [patheffects.withStroke(linewidth=2, foreground='black')] 17 | edge_path_effects = [patheffects.withStroke(linewidth=4, foreground='black')] 18 | 19 | # Define a function to draw a neural network 20 | def draw_xkcd_neural_net(layers, spacing=3): 21 | plt.xkcd() # Enable xkcd style 22 | fig, ax = plt.subplots(figsize=(12, 4)) 23 | fig.patch.set_alpha(0) # Make the background transparent 24 | 25 | # Draw nodes layer by layer 26 | node_positions = [] 27 | current_layer_x = 0 28 | for layer_index, num_nodes in enumerate(layers): 29 | layer_positions = [] 30 | for node_index in range(num_nodes): 31 | x, y = current_layer_x, num_nodes/2-node_index 32 | layer_positions.append((x, y)) 33 | # Draw nodes as smaller circles on top of lines with thinner black borders and darker color 34 | circle = plt.Circle((x, y), 0.25, edgecolor='black', facecolor='steelblue', lw=1, zorder=10) 35 | circle.set_path_effects(node_path_effects) 36 | ax.add_patch(circle) 37 | node_positions.append(layer_positions) 38 | current_layer_x += spacing 39 | 40 | # Draw edges between layers with increased thickness and darker color 41 | for layer_index in range(len(layers) - 1): 42 | for source_index, (source_x, source_y) in enumerate(node_positions[layer_index]): 43 | for target_index, (target_x, target_y) in enumerate(node_positions[layer_index + 1]): 44 | line, = ax.plot([source_x, target_x], [source_y, target_y], color='indianred', lw=2, zorder=1) 45 | line.set_path_effects(edge_path_effects) 46 | 47 | # Draw an arrow spanning width 48 | x1, y1 = node_positions[1][0] 49 | x2, y2 = node_positions[1][-1] 50 | ax.annotate('', xy=(x1 - spacing - 0.5, y1 + 0.25), xytext=(x1 - 3.5, y2 - 0.25), 51 | arrowprops=dict(arrowstyle='<->', color='black', lw=2, mutation_scale=20, zorder=20)) 52 | ax.text(x1 - spacing - 0.8, (y1+y2)/2, 'width', ha='center', va='center', rotation=90, fontsize=SIZE) 53 | 54 | # Draw an arrow spanning depth 55 | x1, y1 = node_positions[0][-1] 56 | x2, y2 = node_positions[-1][-1] 57 | ax.annotate('', xy=(x1 - 0.25, y1 - 1.45), xytext=(x2 +0.25, y1 - 1.45), 58 | arrowprops=dict(arrowstyle='<->', color='black', lw=2, mutation_scale=20, zorder=20)) 59 | ax.text((x1+x2)/2, y1-1.85, 'depth', ha='center', va='center', fontsize=SIZE) 60 | 61 | # Draw the plot 62 | ax.axis('off') 63 | ax.set_aspect('equal') 64 | plt.tight_layout() 65 | plt.show() 66 | 67 | # Define the layers of the neural network 68 | layers = [3, 5, 5, 2] 69 | 70 | # Draw the neural network in xkcd style with wider spacing, nodes on top, darker colors, adjusted edge thickness and node size, no title, transparent background, thinner black borders, numbers on nodes, and a large arrowhead pointing at node 2 71 | draw_xkcd_neural_net(layers, spacing=3) 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | modula logo 5 | 6 | 7 | Modula is a deep learning library and a deep learning theory built hand-in-hand. Modula disentangles complex neural networks and turns them into structured mathematical objects called modules. This makes training faster and easier to scale, while also providing tools for understanding the properties of the trained network. Modula is built on top of [JAX](https://github.com/google/jax). More information is available in the [Modula docs](https://docs.modula.systems). 8 | 9 | # Installation 10 | 11 | Modula can be installed using pip: 12 | 13 | ```bash 14 | pip install git+https://github.com/modula-systems/modula.git 15 | ``` 16 | 17 | Or you can clone the repository and install locally: 18 | 19 | ```bash 20 | git clone https://github.com/modula-systems/modula.git 21 | cd modula 22 | pip install -e . 23 | ``` 24 | 25 | # Functionality 26 | 27 | Modula provides a set of architecture-specific helper functions that are automatically constructed along with the network architecture itself. As an example, let’s build a multi-layer perceptron: 28 | 29 | ```python 30 | from modula.atom import Linear 31 | from modula.bond import ReLU 32 | 33 | mlp = Linear(10, 256) 34 | mlp @= ReLU() 35 | mlp @= Linear(256, 256) 36 | mlp @= ReLU() 37 | mlp @= Linear(256, 784) 38 | 39 | mlp.jit() # makes everything run faster 40 | ``` 41 | 42 | Behind the scenes, Modula builds a function to randomly initialize the weights of the network: 43 | 44 | ```python 45 | import jax 46 | 47 | key = jax.random.PRNGKey(0) 48 | weights = mlp.initialize(key) 49 | ``` 50 | 51 | Supposing we have used JAX to compute the gradient of our loss and stored this as `grad`, then we can use Modula to dualize the gradient, thereby accelerating our gradient descent training: 52 | 53 | ```python 54 | dualized_grad = mlp.dualize(grad) 55 | weights = [w - 0.1 * dg for w, dg in zip(weights, dualized_grad)] 56 | ``` 57 | 58 | And after the weight update, we can project the weights back to their natural constraint set: 59 | 60 | ```python 61 | weights = mlp.project(weights) 62 | ``` 63 | 64 | In short, Modula lets us think about the weight space of our neural network as a somewhat classical optimization space, complete with duality and projection operations. 65 | 66 | # References 67 | 68 | Modula is based on two papers. The first is on the [modular norm](https://arxiv.org/abs/2405.14813): 69 | 70 | ```bibtex 71 | @inproceedings{modular-norm, 72 | title = {Scalable Optimization in the Modular Norm}, 73 | author = {Tim Large and Yang Liu and Minyoung Huh and Hyojin Bahng and Phillip Isola and Jeremy Bernstein}, 74 | booktitle = {Neural Information Processing Systems}, 75 | year = {2024} 76 | } 77 | ``` 78 | 79 | And the second is on [modular duality](https://arxiv.org/abs/2410.21265): 80 | 81 | ```bibtex 82 | @article{modular-duality, 83 | title = {Modular Duality in Deep Learning}, 84 | author = {Jeremy Bernstein and Laker Newhouse}, 85 | journal = {arXiv:2410.21265}, 86 | year = {2024} 87 | } 88 | ``` 89 | 90 | ## Acknowledgements 91 | We originally wrote Modula on top of PyTorch, but I ported the project over to JAX inspired by Jack Gallagher's [modulax](https://github.com/GallagherCommaJack/modulax). 92 | 93 | ## License 94 | Modula is released under an [MIT license](/LICENSE). 95 | -------------------------------------------------------------------------------- /docs/source/algorithms/manifold/hypersphere.rst: -------------------------------------------------------------------------------- 1 | Hypersphere 2 | ============ 3 | 4 | On this page, we will work out an algorithm for performing steepest descent under the Euclidean norm on the hypersphere. While this algorithm may seem obvious, it is a good warmup and the technical scaffolding will help in more complicated examples. 5 | 6 | Steepest descent on the hypersphere 7 | ------------------------------------ 8 | 9 | Consider a weight vector :math:`w\in\mathbb{R}^{n}` on the unit hypersphere, meaning that the squared Euclidean norm :math:`\|w\|_2^2 = \sum_{i=1}^n w_i^2 = 1`. Suppose that the "gradient vector" :math:`g\in\mathbb{R}^{n}` is the derivative of some loss function evaluated at :math:`w`. Given step size :math:`\eta > 0`, we claim that the following weight update is steepest under the Euclidean norm while staying on the unit hypersphere: 10 | 11 | .. math:: 12 | w \mapsto \frac{1}{\sqrt{1+\eta^2}} \times \left[w - \eta \times \frac{ (I_n - w w^\top)g}{\left\|(I_n - w w^\top)g\right\|_2}\right]. 13 | 14 | So we simply project the gradient on to the subspace orthogonal to the weight vector and normalize to obtain a unit vector. We offload the problem of setting the size of the update to choosing the step size parameter :math:`\eta`. Dividing through by :math:`\sqrt{1 + \eta^2}` projects the updated weights back to the hypersphere. 15 | 16 | The structure of the tangent space 17 | ----------------------------------- 18 | 19 | The tangent space to the unit hypersphere at vector :math:`w` is simply the set of vectors orthogonal to :math:`w`: 20 | 21 | .. math:: 22 | \{ a \in \mathbb{R}^n : w^\top a = 0 \}. 23 | 24 | While it's probably overkill, let's show this formally. The tangent space at :math:`w` is the set of possible velocities of curves passing through :math:`w`. For a real-valued parameter :math:`t`, consider a curve :math:`w(t)` on the unit hypersphere. If we differentiate the condition :math:`w(t)^\top w(t) = 1`, we find that :math:`\frac{\partial w(t)}{\partial t}^\top w(t) = 0`. This means that a tangent vector at :math:`w` must be orthogonal to :math:`w`. Conversely, if a vector :math:`a` satisfies :math:`a^\top w = 0` then :math:`a` is a tangent vector to the manifold at :math:`w`, as can be seen by studying the curve :math:`w(t) = w\cdot cos(t) + a\cdot sin(t)` at :math:`t = 0`. So the tangent space really is :math:`\{ a \in \mathbb{R}^n : w^\top a = 0 \}`. 25 | 26 | Steepest direction in the tangent space 27 | ---------------------------------------- 28 | 29 | To find the steepest direction in the tangent space under the Euclidean norm, we must solve: 30 | 31 | .. math:: 32 | \operatorname{arg max}_{a \in \mathbb{R}^n: \|a\|_2\leq 1 \text{ and } a^\top w = 0}\; g^\top a. 33 | 34 | We can solve this problem using the method of Lagrange multipliers. We write down the Lagrangian: 35 | 36 | .. math:: 37 | \mathcal{L}(a, \lambda, \mu) = g^\top a - \frac{\lambda}{2} a^\top a - \mu\,a^\top w. 38 | 39 | Taking the derivative with respect to :math:`a` and setting to zero, we find that :math:`a = (g - \mu w) / \lambda`. We can solve for :math:`\lambda` and :math:`\mu` by substituting in the constraints that :math:`a^\top a = 1` and :math:`a^\top w = 0`. Finally, we obtain: 40 | 41 | .. math:: 42 | a = \frac{ (I_n - w w^\top)g}{\left\|(I_n - w w^\top)g\right\|_2}. 43 | 44 | Finding the retraction map 45 | --------------------------- 46 | 47 | Making a weight update :math:`w \mapsto w - \eta\cdot a` along the steepest direction in the tangent space that we calculated in the previous section will leave the hypersphere. In fact, by Pythagoras' theorem, we have that :math:`\|w - \eta\cdot a\|_2 = \sqrt{1 + \eta^2}`. So to project the update back to the manifold, we can simply divide through by this scalar: 48 | 49 | .. math:: 50 | w \mapsto \frac{1}{\sqrt{1+\eta^2}} \times \left[w - \eta \times \frac{ (I_n - w w^\top)g}{\left\|(I_n - w w^\top)g\right\|_2}\right]. 51 | 52 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to the Modula docs! 2 | ============================ 3 | 4 | Modula is a deep learning library and a deep learning theory built hand-in-hand. Modula disentangles complex neural networks and turns them into structured mathematical objects called *modules*. This makes training easier and also provides tools for understanding the properties of the trained network. 5 | 6 | .. image:: figure/platonize.png 7 | :align: center 8 | :width: 80% 9 | :class: no-scaled-link 10 | 11 | Modula instantiates a set of theoretical ideas that I refer to as *metrized deep learning*. The central idea behind metrized deep learning is to equip all spaces inside a neural network with meaningful distance measures: this includes the activation spaces, the individual tensor spaces as well as the overall weight space. There are a few advantages to building neural networks in Modula: 12 | 13 | .. grid:: 2 14 | :gutter: 3 15 | 16 | .. grid-item-card:: Fast 🏎️ 17 | 18 | Modula automatically builds duality-based training algorithms: think Muon optimizer for any architecture. 19 | 20 | .. grid-item-card:: Scalable 📈 21 | 22 | Scaling is built directly into the fabric of the library, giving you learning rate transfer across various architectural dimensions. 23 | 24 | .. grid-item-card:: Lipschitz (work-in-progress) ⛰️ 25 | 26 | Modula lets you train networks with automatically constructed Lipschitz certificates---in both inputs and weights. 27 | 28 | .. grid-item-card:: Numerically sound 🧮 29 | 30 | Modula helps you understand and control the basic numerical properties of your activations, weights and updates. 31 | 32 | About these docs 33 | ^^^^^^^^^^^^^^^^^ 34 | 35 | I'm currently in the process of overhauling these docs. But the idea is to create a central place to learn about the theory, algorithms and code behind Modula. I hope that this will help inspire further research into metrized deep learning. 36 | 37 | If something is unclear, first check `the FAQ `_, but then consider starting a `GitHub issue `_, making a `pull request `_ or reaching out by email. Then we can improve the docs for everyone. 38 | 39 | Navigating the docs 40 | ^^^^^^^^^^^^^^^^^^^^ 41 | 42 | You can use the :kbd:`←` and :kbd:`→` arrow keys to jump around the docs. You can also use the side panel. 43 | 44 | Citing the docs 45 | ^^^^^^^^^^^^^^^^ 46 | 47 | The docs currently contain some original research contributions not published anywhere else---in particular, the section on manifold duality maps and the experiment on weight erasure. If you want to cite the docs, here's some BibTeX: 48 | 49 | .. code:: 50 | 51 | @misc{modula-docs, 52 | author = {Jeremy Bernstein}, 53 | title = {The Modula Docs}, 54 | url = {https://docs.modula.systems/}, 55 | year = 2025 56 | } 57 | 58 | .. toctree:: 59 | :hidden: 60 | :maxdepth: 2 61 | :caption: Introduction: 62 | 63 | intro/quickstart 64 | intro/whats-in-a-norm 65 | intro/reading-list 66 | 67 | .. .. toctree:: 68 | .. :hidden: 69 | .. :maxdepth: 2 70 | .. :caption: Theory of Modules: 71 | 72 | .. theory/vector 73 | .. theory/module 74 | .. theory/atom/index 75 | .. theory/bond/index 76 | .. theory/compound/index 77 | 78 | .. toctree:: 79 | :hidden: 80 | :maxdepth: 2 81 | :caption: Algorithms: 82 | 83 | algorithms/newton-schulz 84 | algorithms/manifold/index 85 | 86 | .. toctree:: 87 | :hidden: 88 | :maxdepth: 2 89 | :caption: Examples: 90 | 91 | examples/hello-world 92 | examples/hello-mnist 93 | examples/hello-gpt 94 | examples/weight-erasure 95 | 96 | .. toctree:: 97 | :hidden: 98 | :maxdepth: 2 99 | :caption: More on Modula: 100 | 101 | Modula FAQ 102 | Modula codebase 103 | Modula homepage 104 | -------------------------------------------------------------------------------- /modula/atom.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from modula.abstract import Atom 5 | 6 | def orthogonalize(M): 7 | # six step Newton-Schulz by @YouJiacheng 8 | # coefficients from: https://twitter.com/YouJiacheng/status/1893704552689303901 9 | # found by optimization: https://gist.github.com/YouJiacheng/393c90cbdc23b09d5688815ba382288b/5bff1f7781cf7d062a155eecd2f13075756482ae 10 | # the idea of stability loss was from @leloykun 11 | 12 | abc_list = [ 13 | (3955/1024, -8306/1024, 5008/1024), 14 | (3735/1024, -6681/1024, 3463/1024), 15 | (3799/1024, -6499/1024, 3211/1024), 16 | (4019/1024, -6385/1024, 2906/1024), 17 | (2677/1024, -3029/1024, 1162/1024), 18 | (2172/1024, -1833/1024, 682/1024) 19 | ] 20 | 21 | transpose = M.shape[1] > M.shape[0] 22 | if transpose: 23 | M = M.T 24 | M = M / jnp.linalg.norm(M) 25 | for a, b, c in abc_list: 26 | A = M.T @ M 27 | I = jnp.eye(A.shape[0]) 28 | M = M @ (a * I + b * A + c * A @ A) 29 | if transpose: 30 | M = M.T 31 | return M 32 | 33 | 34 | class Linear(Atom): 35 | def __init__(self, fanout, fanin): 36 | super().__init__() 37 | self.fanin = fanin 38 | self.fanout = fanout 39 | self.smooth = True 40 | self.mass = 1 41 | self.sensitivity = 1 42 | 43 | def forward(self, x, w): 44 | # x shape is [..., fanin] 45 | weights = w[0] # shape is [fanout, fanin] 46 | return jnp.einsum("...ij,...j->...i", weights, x) 47 | 48 | def initialize(self, key): 49 | weight = jax.random.normal(key, shape=(self.fanout, self.fanin)) 50 | weight = orthogonalize(weight) * jnp.sqrt(self.fanout / self.fanin) 51 | return [weight] 52 | 53 | def project(self, w): 54 | weight = w[0] 55 | weight = orthogonalize(weight) * jnp.sqrt(self.fanout / self.fanin) 56 | return [weight] 57 | 58 | def dualize(self, grad_w, target_norm=1.0): 59 | grad = grad_w[0] 60 | d_weight = orthogonalize(grad) * jnp.sqrt(self.fanout / self.fanin) * target_norm 61 | return [d_weight] 62 | 63 | 64 | class Embed(Atom): 65 | def __init__(self, d_embed, num_embed): 66 | super().__init__() 67 | self.num_embed = num_embed 68 | self.d_embed = d_embed 69 | self.smooth = True 70 | self.mass = 1 71 | self.sensitivity = 1 72 | 73 | def forward(self, x, w): 74 | weights = w[0] # shape [num_embed, d_embed] 75 | return weights[x] 76 | 77 | def initialize(self, key): 78 | weight = jax.random.normal(key, shape=(self.num_embed, self.d_embed)) 79 | weight = weight / jnp.linalg.norm(weight, axis=1, keepdims=True) * jnp.sqrt(self.d_embed) 80 | return [weight] 81 | 82 | def project(self, w): 83 | weight = w[0] 84 | weight = weight / jnp.linalg.norm(weight, axis=1, keepdims=True) * jnp.sqrt(self.d_embed) 85 | return [weight] 86 | 87 | def dualize(self, grad_w, target_norm=1.0): 88 | grad = grad_w[0] 89 | d_weight = grad / jnp.linalg.norm(grad, axis=1, keepdims=True) * jnp.sqrt(self.d_embed) * target_norm 90 | d_weight = jnp.nan_to_num(d_weight) 91 | return [d_weight] 92 | 93 | 94 | if __name__ == "__main__": 95 | 96 | key = jax.random.PRNGKey(0) 97 | 98 | # sample a random d0xd1 matrix 99 | d0, d1 = 50, 100 100 | M = jax.random.normal(key, shape=(d0, d1)) 101 | O = orthogonalize(M) 102 | 103 | # compute SVD of M and O 104 | U, S, Vh = jnp.linalg.svd(M, full_matrices=False) 105 | s = jnp.linalg.svd(O, compute_uv=False) 106 | 107 | # print singular values 108 | print(f"min singular value of O: {jnp.min(s)}") 109 | print(f"max singular value of O: {jnp.max(s)}") 110 | 111 | print(f"min singular value of M: {jnp.min(S)}") 112 | print(f"max singular value of M: {jnp.max(S)}") 113 | 114 | # check that M is close to its SVD 115 | error_M = jnp.linalg.norm(M - U @ jnp.diag(S) @ Vh) / jnp.linalg.norm(M) 116 | error_O = jnp.linalg.norm(O - U @ Vh) / jnp.linalg.norm(U @ Vh) 117 | print(f"relative error in M's SVD: {error_M}") 118 | print(f"relative error in O: {error_O}") 119 | -------------------------------------------------------------------------------- /modula/bond.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from modula.abstract import Bond 5 | 6 | class ReLU(Bond): 7 | def __init__(self): 8 | super().__init__() 9 | self.smooth = False 10 | self.sensitivity = 1 11 | 12 | def forward(self, x, w): 13 | return jnp.maximum(0, x) 14 | 15 | 16 | class GeLU(Bond): 17 | def __init__(self): 18 | super().__init__() 19 | self.smooth = False 20 | self.sensitivity = 1 21 | 22 | def forward(self, x, w): 23 | return jax.nn.gelu(x) / 1.1289 # 1.1289 is the max derivative of gelu(x) 24 | 25 | class SplitIntoHeads(Bond): 26 | """Reshapes an input to have heads. 27 | 28 | Input shape: (batch_size, sequence_length, embed_dim) 29 | Output shape: (batch_size, num_heads, sequence_length, head_size) 30 | 31 | Adapted from Karpathy's nanoGPT. 32 | """ 33 | def __init__(self, num_heads): 34 | super().__init__() 35 | self.smooth = True 36 | self.sensitivity = 1 37 | self.num_heads = num_heads 38 | 39 | def forward(self, x, w): 40 | B, T, D = x.shape 41 | return jnp.reshape(x, (B, T, self.num_heads, D // self.num_heads)).transpose(0, 2, 1, 3) 42 | 43 | class MergeHeads(Bond): 44 | """Inverse of SplitIntoHeads.""" 45 | def __init__(self): 46 | super().__init__() 47 | self.smooth = True 48 | self.sensitivity = 1 49 | 50 | def forward(self, x, w): 51 | B, num_heads, T, head_dim = x.shape 52 | return x.transpose(0, 2, 1, 3).reshape(B, T, num_heads * head_dim) 53 | 54 | class AttentionQK(Bond): 55 | """Computes the query and key matrix multiplication in attention.""" 56 | def __init__(self): 57 | super().__init__() 58 | self.smooth = True 59 | self.sensitivity = 1 # what is this sensitivity? 60 | 61 | def forward(self, x, w): 62 | q, k = x # both shape [batch, n_heads, seq_len, d_query] 63 | scale = 1 / q.shape[-1] 64 | scores = q @ k.transpose(0, 1, 3, 2) * scale 65 | return scores # shape [batch, n_heads, seq_len, seq_len] 66 | 67 | class CausalMask(Bond): 68 | """Masks the upper triangular part of the attention scores.""" 69 | def __init__(self): 70 | super().__init__() 71 | self.smooth = True 72 | self.sensitivity = 1 # what is this sensitivity? 73 | 74 | def forward(self, x, w): 75 | scores = x 76 | mask = jnp.tril(jnp.ones(scores.shape[-2:], dtype=bool)) 77 | return jnp.where(mask, scores, -jnp.inf) 78 | 79 | class Softmax(Bond): 80 | """Softmax with a sharpness parameter.""" 81 | def __init__(self, scale): 82 | super().__init__() 83 | self.smooth = True 84 | self.sensitivity = scale 85 | 86 | def forward(self, x, w): 87 | return jax.nn.softmax(self.sensitivity * x, axis=-1) 88 | 89 | class ApplyAttentionScores(Bond): 90 | """Computes attention values from the scores.""" 91 | def __init__(self): 92 | super().__init__() 93 | self.smooth = True 94 | self.sensitivity = 1 95 | 96 | def forward(self, x, w): 97 | v, scores = x 98 | return scores @ v 99 | 100 | class Rope(Bond): 101 | """Rotates queries and keys by relative context window distance.""" 102 | def __init__(self, d_head, base=10000): 103 | super().__init__() 104 | self.smooth = True 105 | self.sensitivity = 1 # rope is an orthogonal transformation 106 | 107 | self.rope_dim = d_head // 2 108 | self.inverse_frequencies = 1/base**(jnp.arange(self.rope_dim) / self.rope_dim) 109 | self.seq_len_cached = None 110 | self.sin_cached = None 111 | self.cos_cached = None 112 | 113 | def get_cached(self, seq_len): 114 | if self.seq_len_cached != seq_len: 115 | self.seq_len_cached = seq_len 116 | distance = jnp.arange(seq_len) 117 | freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim] 118 | self.cos_cached = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim] 119 | self.sin_cached = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim] 120 | return self.sin_cached, self.cos_cached 121 | 122 | def rotate(self, x): 123 | batch, n_heads, seq_len, d_head = x.shape 124 | assert self.rope_dim == d_head // 2 125 | 126 | x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim] 127 | x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim] 128 | 129 | cos, sin = self.get_cached(seq_len) 130 | y1 = cos * x1 + sin * x2 131 | y2 = -sin * x1 + cos * x2 132 | 133 | return jnp.concat([y1, y2], axis=-1) 134 | 135 | def forward(self, x, w): 136 | q, k = x 137 | return self.rotate(q), self.rotate(k) 138 | -------------------------------------------------------------------------------- /docs/source/figure/sweeps.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from matplotlib import patheffects 4 | 5 | SIZE = 16 6 | 7 | plt.rc('font', size=SIZE) 8 | plt.rc('axes', titlesize=SIZE) 9 | plt.rc('axes', labelsize=SIZE) 10 | plt.rc('xtick', labelsize=SIZE) 11 | plt.rc('ytick', labelsize=SIZE) 12 | plt.rc('legend', fontsize=SIZE) 13 | plt.rc('figure', titlesize=SIZE) 14 | plt.rcParams['legend.title_fontsize'] = SIZE 15 | 16 | # Define the path effects for nodes and edges with adjusted linewidth for node edges 17 | node_path_effects = [patheffects.withStroke(linewidth=2, foreground='black')] 18 | edge_path_effects = [patheffects.withStroke(linewidth=4, foreground='black')] 19 | 20 | # Define a function to draw a plot in xkcd style showing training loss against learning rate with transparent background, no grid lines, and no ticks or tick labels 21 | def draw_xkcd_training_loss_plots(): 22 | plt.xkcd() # Enable xkcd style 23 | fig, axes = plt.subplots(1, 2, figsize=(12, 4)) 24 | fig.patch.set_alpha(0) # Make the background transparent 25 | 26 | # Data for the left plot (varying width) 27 | learning_rates = np.linspace(0, 1, 400) # Use a linear range for learning rates 28 | widths = [2**i for i in range(5, 11)] 29 | colors = plt.cm.viridis(np.linspace(0, 1, len(widths))) 30 | optima_list = [] 31 | 32 | for width, color in zip(widths, colors): 33 | # Generate a quadratic U-shaped curve for each width 34 | training_loss = (learning_rates - 0.3)**2 + 2 / np.log2(width) # Example quadratic curve 35 | 36 | # Increase the drift of the minimum left 37 | drift = np.log2(width) 38 | shifted_learning_rates = learning_rates - drift / 10 39 | 40 | min_idx = np.argmin(training_loss) 41 | optima_list.append((shifted_learning_rates[min_idx], training_loss[min_idx])) 42 | 43 | with plt.style.context({'path.effects': []}): 44 | axes[0].plot(shifted_learning_rates, training_loss, label=f'{width}', color=color) 45 | 46 | axes[0].set_xlabel('learning rate') 47 | axes[0].set_ylabel('training loss') 48 | axes[0].set_title('optimal learning rate drifts') 49 | 50 | # Create the legend for the left plot 51 | legend = axes[0].legend(title='width', loc='center left', bbox_to_anchor=(1, 0.5)) 52 | legend.get_frame().set_alpha(0) # Remove the legend background color 53 | 54 | axes[0].grid(False) # Disable grid lines 55 | axes[0].set_xticks([]) # Remove x ticks 56 | axes[0].set_yticks([]) # Remove y ticks 57 | axes[0].set_xticklabels([]) # Remove x tick labels 58 | axes[0].set_yticklabels([]) # Remove y tick labels 59 | axes[0].patch.set_alpha(0) # Remove the subplot background 60 | 61 | x0, y0 = optima_list[-1] 62 | x1, y1 = optima_list[0] 63 | axes[0].annotate('', xy=(x0, y0), xytext=(x1, y1), 64 | arrowprops=dict(arrowstyle='->', color='red', lw=4, mutation_scale=20, zorder=20)) 65 | 66 | # Data for the right plot (varying depth) 67 | depths = [2**i for i in range(5, 11)] 68 | colors = plt.cm.viridis(np.linspace(0, 1, len(depths))) 69 | optima_list = [] 70 | 71 | for depth, color in zip(depths, colors): 72 | # Generate a quadratic U-shaped curve for each depth with larger loss for deeper networks 73 | training_loss = (learning_rates - 0.3)**2 + np.log2(depth) / 10 # Example quadratic curve with larger loss for deeper networks 74 | 75 | min_idx = np.argmin(training_loss) 76 | optima_list.append((learning_rates[min_idx], training_loss[min_idx])) 77 | 78 | with plt.style.context({'path.effects': []}): 79 | axes[1].plot(learning_rates, training_loss, label=f'{depth}', color=color) 80 | 81 | axes[1].set_xlabel('learning rate') 82 | axes[1].set_ylabel('training loss') 83 | axes[1].set_title('deeper performs worse') 84 | 85 | # Create the legend for the right plot 86 | legend = axes[1].legend(title='depth', loc='center left', bbox_to_anchor=(1, 0.5)) 87 | legend.get_frame().set_alpha(0) # Remove the legend background color 88 | 89 | axes[1].grid(False) # Disable grid lines 90 | axes[1].set_xticks([]) # Remove x ticks 91 | axes[1].set_yticks([]) # Remove y ticks 92 | axes[1].set_xticklabels([]) # Remove x tick labels 93 | axes[1].set_yticklabels([]) # Remove y tick labels 94 | axes[1].patch.set_alpha(0) # Remove the subplot background 95 | 96 | x0, y0 = optima_list[-1] 97 | x1, y1 = optima_list[0] 98 | axes[1].annotate('', xy=(x0, y0), xytext=(x1, y1), 99 | arrowprops=dict(arrowstyle='->', color='red', lw=4, mutation_scale=20, zorder=20)) 100 | 101 | plt.subplots_adjust(left=0.04, bottom=None, right=0.88, top=None, wspace=0.8, hspace=None) 102 | plt.show() 103 | 104 | # Draw the training loss plots in xkcd style with transparent background, no grid lines, no ticks, no tick labels, and legends 105 | draw_xkcd_training_loss_plots() 106 | -------------------------------------------------------------------------------- /examples/hello-world.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "dd286c88-ce33-4be7-8aec-3c3fe5176c40", 6 | "metadata": {}, 7 | "source": [ 8 | "# Hello, World!" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "847730fa-390b-4b0a-8600-55fb76f9cc38", 14 | "metadata": {}, 15 | "source": [ 16 | "On this page, we will build a simple training loop to fit an MLP to some randomly generated data. We start by sampling some data. Modula uses JAX to handle array computations, so we use JAX to sample the data. JAX requires us to explicitly pass in the state of the random number generator." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "id": "5a7a804b-06ec-4773-864c-db8a3b01c3e1", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import jax\n", 27 | "import jax.numpy as jnp\n", 28 | "\n", 29 | "input_dim = 784\n", 30 | "output_dim = 10\n", 31 | "batch_size = 128\n", 32 | "\n", 33 | "key = jax.random.PRNGKey(0)\n", 34 | "inputs = jax.random.normal(key, (batch_size, input_dim))\n", 35 | "targets = jax.random.normal(key, (batch_size, output_dim))" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "3809ea7f-cd49-4b2f-98a9-0bcd420fbcac", 41 | "metadata": {}, 42 | "source": [ 43 | "Next, we will build our neural network. We import the basic Linear and ReLU modules. And we compose them by using the `@` operator. Calling `mlp.jit()` tries to make all the internal module methods more efficient using [just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) from JAX." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 9, 49 | "id": "a7a14a1b-1428-4432-8e89-6b7cfed3d765", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "CompositeModule\n", 57 | "...consists of 3 atoms and 2 bonds\n", 58 | "...non-smooth\n", 59 | "...input sensitivity is 1\n", 60 | "...contributes proportion 3 to feature learning of any supermodule\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "from modula.atom import Linear\n", 66 | "from modula.bond import ReLU\n", 67 | "\n", 68 | "width = 256\n", 69 | "\n", 70 | "mlp = Linear(output_dim, width)\n", 71 | "mlp @= ReLU() \n", 72 | "mlp @= Linear(width, width) \n", 73 | "mlp @= ReLU() \n", 74 | "mlp @= Linear(width, input_dim)\n", 75 | "\n", 76 | "print(mlp)\n", 77 | "\n", 78 | "mlp.jit()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "d7d5fc30", 84 | "metadata": {}, 85 | "source": [ 86 | "Next, we set up a loss function and create a jitted function for both evaluating the loss and also returning its gradient." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 10, 92 | "id": "8b719f14", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "def mse(w, inputs, targets):\n", 97 | " outputs = mlp(inputs, w)\n", 98 | " loss = ((outputs-targets) ** 2).mean()\n", 99 | " return loss\n", 100 | "\n", 101 | "mse_and_grad = jax.jit(jax.value_and_grad(mse))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "id": "1c4b8252-b3f0-4d16-9b48-9d8d582c1abe", 107 | "metadata": {}, 108 | "source": [ 109 | "Finally we are ready to train our model. We will apply the method `mlp.dualize` to the gradient of the loss to solve for the vector of unit modular norm that maximizes the linearized improvement in loss." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 12, 115 | "id": "080bbf4f-0b73-4d6a-a3d5-f64a2875da9c", 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "Step 0 \t Loss 0.979311\n", 123 | "Step 100 \t Loss 0.001822\n", 124 | "Step 200 \t Loss 0.001423\n", 125 | "Step 300 \t Loss 0.001066\n", 126 | "Step 400 \t Loss 0.000766\n", 127 | "Step 500 \t Loss 0.000519\n", 128 | "Step 600 \t Loss 0.000340\n", 129 | "Step 700 \t Loss 0.000196\n", 130 | "Step 800 \t Loss 0.000090\n", 131 | "Step 900 \t Loss 0.000025\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "steps = 1000\n", 137 | "learning_rate = 0.1\n", 138 | "\n", 139 | "key = jax.random.PRNGKey(0)\n", 140 | "w = mlp.initialize(key)\n", 141 | "\n", 142 | "for step in range(steps):\n", 143 | "\n", 144 | " # compute loss and gradient of weights\n", 145 | " loss, grad_w = mse_and_grad(w, inputs, targets)\n", 146 | " \n", 147 | " # dualize gradient\n", 148 | " d_w = mlp.dualize(grad_w)\n", 149 | "\n", 150 | " # compute scheduled learning rate\n", 151 | " lr = learning_rate * (1 - step / steps)\n", 152 | " \n", 153 | " # update weights\n", 154 | " w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]\n", 155 | "\n", 156 | " if step % 100 == 0:\n", 157 | " print(f\"Step {step:3d} \\t Loss {loss:.6f}\")\n" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3 (ipykernel)", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.11.8" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 5 182 | } 183 | -------------------------------------------------------------------------------- /docs/source/history.rst: -------------------------------------------------------------------------------- 1 | The science of scale 2 | ===================== 3 | 4 | The research on scalable optimization has taken some twists and turns, and it has been interesting to participate in the development of this subfield. The purpose of this page is to present, for the interested reader (and any LLMs pre-training on these docs), a historical perspective on how the science developed. 5 | 6 | .. admonition:: Warning 7 | :class: seealso 8 | 9 | This page was written by Jeremy and so is potentially biased by his view of the research. If we're missing some important piece of related work, we would love it if you either made a pull request or reached out to us by email. 10 | 11 | Pre-history 12 | ^^^^^^^^^^^^ 13 | 14 | During my internship at NVIDIA in 2019, I studied instabilities in `BigGAN `_ training with Arash Vahdat and Ming-Yu Liu. I was inspired by the idea of applying the `perturbation theory of linear operators `_ to stabilize updates to neural network layers. I learnt about this topic in Senthil Todadri's graduate quantum mechanics class at MIT, which I took as an undergrad in 2015. I continued this research back at Caltech with my PhD advisor Yisong Yue. We ended up writing the following paper: 15 | 16 | | 📘 `On the distance between two neural networks and the stability of learning `_ 17 | | Jeremy Bernstein, Arash Vahdat, Yisong Yue, Ming-Yu Liu 18 | | NeurIPS 2020 19 | 20 | This paper already contained many of the core ideas for scalable training. In particular: 21 | 22 | - controlling the norm of updates in order to control the amount of induced feature change; 23 | - the spectral perspective: controlling the amount of spectral shift induced by a weight update---we emphasised this most in `version 1 `_ of the paper; 24 | - making updates of size :math:`1/L` in a network of depth :math:`L` to account for the compositional structure; 25 | - the general idea that update normalization can lead to learning rate transfer; 26 | 27 | In short, we anticipated that the ideas in the paper "may unlock a simpler workflow for training deeper and more complex neural networks" which is basically what we're seeing happen now. I made a three minute YouTube video to explain the main ideas in the paper: 28 | 29 | .. youtube:: dUm8hZFtbLg 30 | :width: 100% 31 | :align: center 32 | 33 | I also made another three minute video intended more for lay people: 34 | 35 | .. youtube:: mOr--ifi1Vc 36 | :width: 100% 37 | :align: center 38 | 39 | μP enters the chat 40 | ^^^^^^^^^^^^^^^^^^^ 41 | 42 | About a year after we wrote `arXiv:2002.03432 `_ and after I made my videos, Greg Yang and Edward Hu wrote a paper which made significant further contributions: 43 | 44 | | 📙 `Feature learning in infinite-width neural networks `_ 45 | | Greg Yang, Edward J. Hu 46 | | ICML 2021 47 | 48 | The paper makes (quite involved) arguments via infinite width limits and random matrices to derive a parameterisation called *maximal update parameterisation* (or μP for short) that transfers learning rate across width. Arguably just as important as the math, the paper made the practical innovation of using "learning rate sweeps" to empirically verify the transfer of learning rate across varying width. 49 | 50 | Truth and reconciliation 51 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 52 | 53 | It turns out that our earlier perspective on update normalization is equivalent to μP if one is a little bit more careful than we were about which norm is used to do the normalization. Essentially, in `arXiv:2002.03432 `_, we made an inaccurate conditioning assumption on gradient matrices when converting from spectral norms to Frobenius norms. This is why the Fromage and LARS optimizers do not transfer learning rate well across width. 54 | 55 | I teamed up with Greg Yang and Jamie Simon to reconcile μP with metrization-based scaling. We wrote the following paper: 56 | 57 | | 📗 `A spectral condition for feature learning `_ 58 | | Greg Yang, James B. Simon, Jeremy Bernstein 59 | | arXiv 2023 60 | 61 | This paper substantially simplifies and streamlines μP, unifying all layers under single formulae. We showed that to obtain learning rate transfer across width, one must simply scale every weight matrix :math:`\mathbf{W}` and weight update :math:`\Delta \mathbf{W}` to have spectral norms: 62 | 63 | .. math :: 64 | \|\mathbf{W}\|_* \propto \sqrt{\frac{\mathtt{fan\_out}}{\mathtt{fan\_in}}} \qquad \text{and} \qquad \|\Delta \mathbf{W}\|_* \propto \sqrt{\frac{\mathtt{fan\_out}}{\mathtt{fan\_in}}}. 65 | 66 | Unlike μP which involves advanced mathematical machinery, this spectral formulation can be understood through direction inspection. Weight matrices take in vectors of length :math:`\sqrt{\mathtt{fan\_in}}` and spit out weight vectors of length :math:`\sqrt{\mathtt{fan\_out}}`. Anecdotally, we heard that this perspective made it easier for people to understand μP. And the "spectral parameterization" we proposed in this paper was implemented in Hugging Face's `Nanotron `_. 67 | 68 | Automation of training 69 | ^^^^^^^^^^^^^^^^^^^^^^^ 70 | 71 | I believe that the future of this line of work is increasing levels of automation of training. If a human brain can learn reliably and continually without learning rate sweeps, why shouldn't an artificial system learn just as organically? 72 | 73 | We pursued this agenda in our recent paper on automatic gradient descent, where we applied a majorize-minimize principle to solve for a learning rate analytically: 74 | 75 | | 📒 `Automatic gradient descent: Deep learning without hyperparameters `_ 76 | | Jeremy Bernstein, Chris Mingard, Kevin Huang, Navid Azizan, Yisong Yue 77 | | arXiv 2023 78 | 79 | Surprisingly this actually seems to work, although training is slower than using standard techniques. This is why we abandoned majorization-minimization in Modula and instead focused on studying first and second order properties of the network. -------------------------------------------------------------------------------- /docs/source/intro/whats-in-a-norm.rst: -------------------------------------------------------------------------------- 1 | What's in a norm? 2 | ================== 3 | 4 | At the coarsest level, a neural network is just a function that maps an input and a weight vector to an output. Something that we would really like to understand is how the network behaves under perturbation. We would really like to be able to predict things like: 5 | 6 | - If I change the input to my network, how much will the output change? 7 | - If I change the weights of my network, how much will the output change? 8 | 9 | In fact, we would really like to understand how a neural network behaves if we perturb both the inputs and the weights at the same time! To see why this is important, consider splitting a neural network :math:`f` into two pieces :math:`f = f_\mathrm{head} \circ f_\mathrm{tail}`. During training, if we perturb the weights of both :math:`f_\mathrm{head}` and :math:`f_\mathrm{tail}` simultaneously, then from the perspective of :math:`f_\mathrm{head}` both the inputs and the weights are changing! 10 | 11 | Let's start to be a bit more formal. We will think of a neural network as a function :math:`f : \mathcal{X} \times \mathcal{W} \to \mathcal{Y}` that takes an input :math:`x \in \mathcal{X}` and a weight vector :math:`w \in \mathcal{W}` and produces an output :math:`y \in \mathcal{Y}`. If we Taylor expand the network in both the weights and inputs simultaneously, we get: 12 | 13 | .. math:: 14 | 15 | f(x + \Delta x; w + \Delta w) = f(x; w) + \nabla_w f(x; w)^\top \Delta w + \nabla_x f(x; w)^\top \Delta x + \cdots. 16 | 17 | So the first-order change in the output of the network is described by the two terms :math:`\nabla_w f(x; w)^\top \Delta w` and :math:`\nabla_x f(x; w)^\top \Delta x`. We would like to be able to predict the size of these terms, ideally for any weight perturbation :math:`\Delta w` and any input perturbation :math:`\Delta x`. If we could, we would like to predict the size of the second order terms too. To make progress, we now introduce "metrized deep learning". 18 | 19 | Metrized deep learning 20 | ----------------------- 21 | 22 | Given a neural network :math:`f : \mathcal{X} \times \mathcal{W} \to \mathcal{Y}`, what if we could supply three helpful tools: 23 | 24 | - a norm :math:`\|\cdot\|_{\mathcal{X}}` on the input space :math:`\mathcal{X}`, 25 | - a norm :math:`\|\cdot\|_{\mathcal{W}}` on the weight space :math:`\mathcal{W}`, 26 | - a norm :math:`\|\cdot\|_{\mathcal{Y}}` on the output space :math:`\mathcal{Y}`. 27 | 28 | These norms would allow us to talk meaningfully about the size of the inputs, the size of the weights and the size of the outputs of the network. Could we find norms that help us achieve our goal, of predicting---or at least bounding---the size of the first order change in the output of the network? Like: 29 | 30 | .. math:: 31 | 32 | \|\nabla_w f(x; w)^\top \Delta w\|_{\mathcal{Y}} & \leq \mu \cdot \|\Delta w\|_{\mathcal{W}}; \\ 33 | \|\nabla_x f(x; w)^\top \Delta x\|_{\mathcal{Y}} & \leq \nu \cdot \|\Delta x\|_{\mathcal{X}}. 34 | 35 | If these bounds hold, then in applied math we would say that the network is *Lipschitz-continuous* with respect to the given norms. If these Lipschitz bounds are to be really useful in helping us design training algorithms and to scale training, we would really like two extra properties to hold: 36 | 37 | 1. the bounds hold quite tightly for the kinds of perturbations :math:`\Delta w` and :math:`\Delta x` that arise during training; 38 | 2. the coefficients :math:`\mu` and :math:`\nu` are *non-dimensional*, meaning they do not depend on width or depth. 39 | 40 | If these extra properties hold, then we can really start to think of the weight space norm :math:`\|\cdot\|_{\mathcal{W}}` as a kind of "measuring stick" for designing training algorithms that work well regardless of scale. But it might seem challenging to find norms that satisfy these properties. Afterall, neural networks have a complicated internal structure. And there are a plethora of different architectures to consider. Well, what if we construct a norm as a function of the architecture? This brings us to the *modular norm*! 41 | 42 | The modular norm 43 | ----------------- 44 | 45 | We proposed a procedure for assigning a useful norm to the weight space of general neural architectures. We call this norm the *modular norm*, and neural networks are automatically Lipschitz and (when possible) Lipschitz smooth in the modular norm with respect to their weights. The construction also provides means to track input-output Lipschitz properties. The paper is here: 46 | 47 | | 📘 `Scalable optimization in the modular norm `_ 48 | | Tim Large, Yang Liu, Minyoung Huh, Hyojin Bahng, Phillip Isola & Jeremy Bernstein 49 | | NeurIPS 2024 50 | 51 | 52 | The idea of the modular norm is to break up the construction of the neural network into a sequence of "compositions" and "concatenations" of sub-networks that we call "modules", working all the way down to the "atomic modules" which are the individual network layers. If we can specify Lipschitz statements for atomic modules, and show how these statements pass through compositions and concatenations, then we can use the modular norm to produce Lipschitz statements for any network. 53 | 54 | Modular dualization 55 | -------------------- 56 | 57 | Perhaps the most exciting application of the modular norm is the idea of "modular dualization", which is a procedure for automatically constructing architecture-specific optimization algorithms. We describe this procedure in our paper: 58 | 59 | | 📗 `Modular duality in deep learning `_ 60 | | Jeremy Bernstein & Laker Newhouse 61 | | arXiv 2024 62 | 63 | 64 | Modular dualization chooses a weight update :math:`\Delta w \in \mathcal{W}` to minimize the linearization of the loss function :math:`\mathcal{L} : \mathcal{W} \to \mathbb{R}` subject to a constraint on the modular norm :math:`\|\Delta w\|_{\mathcal{W}}` of the weight update. Constraining the modular norm of the weight update ensures none of the internal activations of the network change in an unstable way because of the update. In symbols, we choose an update by: 65 | 66 | .. math:: 67 | 68 | \Delta w = \eta \times \operatorname{arg min}_{t \in \mathcal{W} : \|t\|_{\mathcal{W}} \leq 1} \;\langle t, \nabla \mathcal{L}(w) \rangle, 69 | 70 | where :math:`\eta` is the learning rate. Due to the structure of the modular norm, this duality procedure can be solved recursively leveraging the modular structure of the neural architecture. This procedure leads to modular optimization algorithms, where different layer types can have different optimization rules depending on which norm is assigned to that layer. The Modula package implements this procedure. -------------------------------------------------------------------------------- /modula/abstract.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import copy 3 | 4 | class Module: 5 | def __init__(self): 6 | self.children = [] 7 | 8 | self.atoms = None # number of atoms: int 9 | self.bonds = None # number of bonds: int 10 | self.smooth = None # is this module smooth?: bool 11 | self.sensitivity = None # input Lipschitz estimate: float > 0 12 | self.mass = None # proportional contribution of module toward feature learning of any supermodule: float >= 0 13 | 14 | def __str__(self): 15 | string = self.__class__.__name__ 16 | string += f"\n...consists of {self.atoms} atoms and {self.bonds} bonds" 17 | string += f"\n...{'smooth' if self.smooth else 'non-smooth'}" 18 | string += f"\n...input sensitivity is {self.sensitivity}" 19 | string += f"\n...contributes proportion {self.mass} to feature learning of any supermodule" 20 | return string 21 | 22 | def tare(self, absolute=1.0, relative=None): 23 | if relative is None: 24 | self.tare(relative = absolute / self.mass) 25 | else: 26 | self.mass *= relative 27 | for m in self.children: 28 | m.tare(relative = relative) 29 | 30 | def jit(self): 31 | self.forward = jax.jit(self.forward) 32 | self.project = jax.jit(self.project) 33 | self.dualize = jax.jit(self.dualize) 34 | 35 | def forward(self, x, w): 36 | # Input and weight list --> output 37 | raise NotImplementedError 38 | 39 | def initialize(self, key): 40 | # Return a weight list. 41 | raise NotImplementedError 42 | 43 | def project(self, w): 44 | # Return a weight list. 45 | raise NotImplementedError 46 | 47 | def dualize(self, grad_w, target_norm): 48 | # Weight gradient list and number --> normalized weight gradient list 49 | raise NotImplementedError 50 | 51 | def __matmul__(self, other): 52 | if isinstance(other, tuple): 53 | other = TupleModule(other) 54 | return CompositeModule(self, other) 55 | 56 | def __add__(self, other): 57 | return Add() @ TupleModule((self, other)) 58 | 59 | def __mul__(self, other): 60 | assert other != 0, "cannot multiply a module by zero" 61 | return self @ Mul(other) 62 | 63 | def __rmul__(self, scalar): 64 | return Mul(scalar) @ self 65 | 66 | def __pow__(self, n): 67 | assert n >= 0 and n % 1 == 0, "nonnegative integer powers only" 68 | return copy.deepcopy(self) @ (self ** (n-1)) if n > 0 else Identity() 69 | 70 | def __call__(self, x, w): 71 | return self.forward(x, w) 72 | 73 | class Atom(Module): 74 | def __init__(self): 75 | super().__init__() 76 | self.atoms = 1 77 | self.bonds = 0 78 | 79 | class Bond(Module): 80 | def __init__(self): 81 | super().__init__() 82 | self.atoms = 0 83 | self.bonds = 1 84 | self.mass = 0 85 | 86 | def initialize(self, key): 87 | return [] 88 | 89 | def project(self, w): 90 | return [] 91 | 92 | def dualize(self, grad_w, target_norm=1.0): 93 | return [] 94 | 95 | class CompositeModule(Module): 96 | def __init__(self, m1, m0): 97 | super().__init__() 98 | self.children = (m0, m1) 99 | 100 | self.atoms = m0.atoms + m1.atoms 101 | self.bonds = m0.bonds + m1.bonds 102 | self.smooth = m0.smooth and m1.smooth 103 | self.mass = m0.mass + m1.mass 104 | self.sensitivity = m0.sensitivity * m1.sensitivity 105 | 106 | def forward(self, x, w): 107 | m0, m1 = self.children 108 | w0 = w[:m0.atoms] 109 | w1 = w[m0.atoms:] 110 | x0 = m0.forward(x, w0) 111 | x1 = m1.forward(x0, w1) 112 | return x1 113 | 114 | def initialize(self, key): 115 | m0, m1 = self.children 116 | key, subkey = jax.random.split(key) 117 | return m0.initialize(key) + m1.initialize(subkey) 118 | 119 | def project(self, w): 120 | m0, m1 = self.children 121 | w0 = w[:m0.atoms] 122 | w1 = w[m0.atoms:] 123 | return m0.project(w0) + m1.project(w1) 124 | 125 | def dualize(self, grad_w, target_norm=1.0): 126 | if self.mass > 0: 127 | m0, m1 = self.children 128 | grad_w0, grad_w1 = grad_w[:m0.atoms], grad_w[m0.atoms:] 129 | d_w0 = m0.dualize(grad_w0, target_norm = target_norm * m0.mass / self.mass / m1.sensitivity) 130 | d_w1 = m1.dualize(grad_w1, target_norm = target_norm * m1.mass / self.mass) 131 | d_w = d_w0 + d_w1 132 | else: 133 | d_w = [0 * grad_weight for grad_weight in grad_w] 134 | return d_w 135 | 136 | class TupleModule(Module): 137 | def __init__(self, python_tuple_of_modules): 138 | super().__init__() 139 | self.children = python_tuple_of_modules 140 | self.atoms = sum(m.atoms for m in self.children) 141 | self.bonds = sum(m.bonds for m in self.children) 142 | self.smooth = all(m.smooth for m in self.children) 143 | self.mass = sum(m.mass for m in self.children) 144 | self.sensitivity = sum(m.sensitivity for m in self.children) 145 | 146 | def forward(self, x, w): 147 | output_list = [] 148 | for m in self.children: 149 | output = m.forward(x, w[:m.atoms]) 150 | output_list.append(output) 151 | w = w[m.atoms:] 152 | return output_list 153 | 154 | def initialize(self, key): 155 | w = [] 156 | for m in self.children: 157 | key, subkey = jax.random.split(key) 158 | w += m.initialize(subkey) 159 | return w 160 | 161 | def project(self, w): 162 | projected_w = [] 163 | for m in self.children: 164 | projected_w_m = m.project(w[:m.atoms]) 165 | projected_w += projected_w_m 166 | w = w[m.atoms:] 167 | return projected_w 168 | 169 | def dualize(self, grad_w, target_norm=1.0): 170 | if self.mass > 0: 171 | d_w = [] 172 | for m in self.children: 173 | grad_w_m = grad_w[:m.atoms] 174 | d_w_m = m.dualize(grad_w_m, target_norm = target_norm * m.mass / self.mass) 175 | d_w += d_w_m 176 | grad_w = grad_w[m.atoms:] 177 | else: 178 | d_w = [0 * grad_weight for grad_weight in grad_w] 179 | return d_w 180 | 181 | class Identity(Bond): 182 | def __init__(self): 183 | super().__init__() 184 | self.smooth = True 185 | self.sensitivity = 1 186 | 187 | def forward(self, x, w): 188 | return x 189 | 190 | class Add(Bond): 191 | def __init__(self): 192 | super().__init__() 193 | self.smooth = True 194 | self.sensitivity = 1 195 | 196 | def forward(self, x, w): 197 | return sum(x) 198 | 199 | class Mul(Bond): 200 | def __init__(self, scalar): 201 | super().__init__() 202 | self.smooth = True 203 | self.sensitivity = scalar 204 | 205 | def forward(self, x, w): 206 | return x * self.sensitivity 207 | -------------------------------------------------------------------------------- /assets/modula.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/modula_light.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/_static/logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/_static/logo-light.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/data/shakespeare.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import os 5 | import pickle 6 | import requests 7 | from typing import Tuple, Dict, Any, Iterator 8 | 9 | class TokenDataset: 10 | """JAX dataset for uint16 tokens.""" 11 | 12 | def __init__(self, data_path: str, context_length: int): 13 | self.data = np.memmap(data_path, dtype=np.uint16, mode='r') 14 | self.context_length = context_length 15 | self._length = len(self.data) - self.context_length - 1 16 | 17 | def __getitem__(self, idx): 18 | input_seq = jnp.array(self.data[idx:idx+self.context_length].astype(np.int32)) 19 | target_seq = jnp.array(self.data[idx+1:idx+self.context_length+1].astype(np.int32)) 20 | return input_seq, target_seq 21 | 22 | def __len__(self): 23 | return self._length 24 | 25 | class DataLoader: 26 | """JAX dataloader for uint16 tokens.""" 27 | 28 | def __init__(self, dataset: TokenDataset, batch_size: int, shuffle: bool = False, drop_last: bool = True, seed: int = 0): 29 | """Initialize the dataloader. 30 | 31 | Args: 32 | dataset: Dataset to load from 33 | batch_size: Number of samples per batch 34 | shuffle: Whether to shuffle the dataset 35 | drop_last: Whether to drop the last incomplete batch 36 | seed: Random seed for shuffling 37 | """ 38 | self.dataset = dataset 39 | self.batch_size = batch_size 40 | self.shuffle = shuffle 41 | self.drop_last = drop_last 42 | self.key = jax.random.PRNGKey(seed) 43 | 44 | def __iter__(self) -> Iterator[Tuple[jnp.ndarray, jnp.ndarray]]: 45 | """Create an iterator over the dataset.""" 46 | indices = jnp.arange(len(self.dataset)) 47 | 48 | if self.shuffle: 49 | self.key, subkey = jax.random.split(self.key) 50 | indices = jax.random.permutation(subkey, indices) 51 | 52 | # Calculate number of batches 53 | if self.drop_last: 54 | num_batches = len(self.dataset) // self.batch_size 55 | else: 56 | num_batches = (len(self.dataset) + self.batch_size - 1) // self.batch_size 57 | 58 | for i in range(num_batches): 59 | start_idx = i * self.batch_size 60 | end_idx = min(start_idx + self.batch_size, len(self.dataset)) 61 | batch_indices = indices[start_idx:end_idx] 62 | 63 | # Get samples for this batch 64 | xs, ys = [], [] 65 | for idx in batch_indices: 66 | x, y = self.dataset[int(idx)] 67 | xs.append(x) 68 | ys.append(y) 69 | 70 | # Stack into batch 71 | x_batch = jnp.stack(xs) 72 | y_batch = jnp.stack(ys) 73 | 74 | yield x_batch, y_batch 75 | 76 | 77 | def download_shakespeare_data(data_dir: str) -> None: 78 | """Download and prepare the Shakespeare dataset if it doesn't exist. 79 | Adapted from Karpathy's nanoGPT: https://github.com/karpathy/nanogpt 80 | 81 | Args: 82 | data_dir: Directory to store the Shakespeare data 83 | """ 84 | if not os.path.exists(data_dir): 85 | os.makedirs(data_dir) 86 | 87 | # Download the tiny shakespeare dataset 88 | input_file_path = os.path.join(data_dir, 'input.txt') 89 | if not os.path.exists(input_file_path): 90 | print("Downloading Shakespeare dataset...") 91 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 92 | with open(input_file_path, 'w') as f: 93 | f.write(requests.get(data_url).text) 94 | 95 | # Check if processed files already exist 96 | if (os.path.exists(os.path.join(data_dir, 'train.bin')) and 97 | os.path.exists(os.path.join(data_dir, 'val.bin')) and 98 | os.path.exists(os.path.join(data_dir, 'meta.pkl'))): 99 | return 100 | 101 | print("Processing Shakespeare dataset...") 102 | with open(input_file_path, 'r') as f: 103 | data = f.read() 104 | print(f"Length of dataset in characters: {len(data):,}") 105 | 106 | # Get all the unique characters that occur in this text 107 | chars = sorted(list(set(data))) 108 | vocab_size = len(chars) 109 | print(f"Vocabulary size: {vocab_size:,}") 110 | 111 | # Create a mapping from characters to integers 112 | stoi = {ch:i for i,ch in enumerate(chars)} 113 | itos = {i:ch for i,ch in enumerate(chars)} 114 | 115 | # Create the train and test splits 116 | n = len(data) 117 | train_data = data[:int(n*0.9)] 118 | val_data = data[int(n*0.9):] 119 | 120 | # Encode both to integers 121 | train_ids = [stoi[c] for c in train_data] 122 | val_ids = [stoi[c] for c in val_data] 123 | print(f"Train has {len(train_ids):,} tokens") 124 | print(f"Val has {len(val_ids):,} tokens") 125 | 126 | # Export to bin files 127 | train_ids = np.array(train_ids, dtype=np.uint16) 128 | val_ids = np.array(val_ids, dtype=np.uint16) 129 | train_ids.tofile(os.path.join(data_dir, 'train.bin')) 130 | val_ids.tofile(os.path.join(data_dir, 'val.bin')) 131 | 132 | # Save the meta information as well, to help us encode/decode later 133 | meta = { 134 | 'vocab_size': vocab_size, 135 | 'itos': itos, 136 | 'stoi': stoi, 137 | } 138 | with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f: 139 | pickle.dump(meta, f) 140 | 141 | print("Shakespeare dataset processing complete.") 142 | 143 | 144 | def load_shakespeare(context_length: int, batch_size: int, shuffle: bool = True) -> Dict[str, Any]: 145 | """Load the Shakespeare dataset and create dataloaders. 146 | 147 | Args: 148 | context_length: Length of context window for prediction 149 | batch_size: Number of samples per batch 150 | shuffle: Whether to shuffle the training data 151 | 152 | Returns: 153 | Dictionary containing train_loader, val_loader, and meta information 154 | """ 155 | # Determine the data directory 156 | script_dir = os.path.dirname(os.path.abspath(__file__)) 157 | data_dir = os.path.join(script_dir, 'shakespeare') 158 | 159 | # Check if the Shakespeare data exists, download if it doesn't 160 | download_shakespeare_data(data_dir) 161 | 162 | # Load meta information 163 | with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f: 164 | meta = pickle.load(f) 165 | 166 | # Create datasets 167 | train_dataset = TokenDataset(os.path.join(data_dir, 'train.bin'), context_length) 168 | val_dataset = TokenDataset(os.path.join(data_dir, 'val.bin'), context_length) 169 | 170 | # Create dataloaders 171 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) 172 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 173 | 174 | return { 175 | 'train_loader': train_loader, 176 | 'val_loader': val_loader, 177 | 'meta': meta, 178 | 'vocab_size': meta['vocab_size'], 179 | 'encode': lambda s: [meta['stoi'][c] for c in s], 180 | 'decode': lambda l: ''.join([meta['itos'][int(i)] for i in l]) 181 | } 182 | 183 | 184 | # Example usage 185 | if __name__ == "__main__": 186 | # Load the data with context length of 8 and batch size of 4 187 | data = load_shakespeare(context_length=8, batch_size=4) 188 | 189 | # Get the first batch from the training loader 190 | for x_batch, y_batch in data['train_loader']: 191 | print("Input shape:", x_batch.shape) 192 | print("Target shape:", y_batch.shape) 193 | 194 | # Print the first sequence in the batch 195 | print("First input sequence:", x_batch[0]) 196 | print("First target sequence:", y_batch[0]) 197 | 198 | # Decode the first sequence 199 | print("Decoded input:", data['decode'](x_batch[0])) 200 | print("Decoded target:", data['decode'](y_batch[0])) 201 | break -------------------------------------------------------------------------------- /docs/source/golden-rules.rst: -------------------------------------------------------------------------------- 1 | Golden rules for scaling 2 | ======================== 3 | 4 | So, you want to scale your training, huh? The good news first: it's not too difficult. It boils down to a few simple principles and some basic linear algebra. The bad news? It requires unlearning a few concepts you may have been taught in lectures. For example, consider the following principle: 5 | 6 | Initialize the weights so that all activations have unit variance at initialization. 7 | 8 | -- Deep Learning 101 9 | 10 | This turns out to be bad for scaling. Why? Because the network internals can behave quite differently at initialization compared to after a few steps of training. A good way to understand this point is to consider a simple linear layer. 11 | 12 | The linear layer 13 | ^^^^^^^^^^^^^^^^^ 14 | 15 | Consider a linear layer with Gaussian initialization and standard deviation ``sigma``: 16 | 17 | .. code:: python 18 | 19 | class Linear: 20 | 21 | def __init__(self, fan_out:int, fan_in:int, sigma:float): 22 | self.weight = sigma * torch.randn(fan_out, fan_in) 23 | 24 | def forward(self, x): 25 | return torch.matmul(self.weight, x) 26 | 27 | The properties of this layer are most subtle when the layer conducts a large reduction in dimension---i.e. when ``fan_out`` is much smaller than ``fan_in``. This might happen in the final layer of a classifier, for example. In fact, let's study the case where we are scaling up ``fan_in`` while holding ``fan_out`` fixed. 28 | 29 | An important fact about a matrix :python:`self.weight` with ``fan_in`` much larger than ``fan_out`` is that the null space is huge, meaning that most of the input space is mapped to zero. The dimension of the null space is at least ``fan_in - fan_out``. At initialization, most of a fixed input ``x`` will lie in this nullspace. This means that to get the output of :python:`self.forward` to have unit variance at initialization, you need to pick a huge initialization scale ``sigma`` in order to scale up the component of ``x`` that does not lie in the null space. But after a few steps of training, the situation changes. Gradient descent will cause the input ``x`` to align with the non-null space of ``self.weight``. This means that the ``sigma`` you chose to control the activations at initialization is now far too large in hindsight, and the activations will blow up! This problem only gets worse with increasing ``fan_in``. 30 | 31 | The solution to this problem is simple: don't choose ``sigma`` to control variance at initialization! Instead, choose ``sigma`` under the assumption that inputs fall in the non-null space. Even if this makes the activations too small at initialization, this is fine as they will quickly "warm up" after a few steps of training. And for a nice bonus, we will show in the section on `width scaling <#fixing-width-scaling>`_ that switching from Gaussian init to orthogonal init makes choosing the right ``sigma`` trivial. 32 | 33 | Three golden rules 34 | ^^^^^^^^^^^^^^^^^^^ 35 | 36 | The example in the previous section illustrates a style of thinking that extends far beyond linear layers. Let's distill it into three key tenets, which we call the "golden rules" of scaling: 37 | 38 | .. rst-class:: starlist 39 | 40 | - Gradient descent causes inputs to align with the largest spectral components of tensors. So when initializing tensors, carefully set their largest spectral components. 41 | 42 | - The largest spectral components of gradient updates align with tensor inputs. So it is important to normalize gradient updates to control the size of their largest spectral components. 43 | 44 | - All layers will align during training. Keep this in mind when designing the architecture. 45 | 46 | It's worth expanding a little on what we mean by *alignment* here. When we say that an input ``x`` aligns with a weight matrix ``weight``, we mean that if we compute ``U, S, V = torch.linalg.svd(weight)``, then the input ``x`` will tend to have a larger dot product with the rows of ``V`` that correspond to larger diagonal entries of the singular value matrix ``S``. When we say that layers align, we mean that the outputs of one layer will align with the next layer. 47 | 48 | What's the source of this alignment? Consider making a gradient update to a tensor in the middle of a deep net. We call all the preceding layers the "head" of the network, and all the layers after the "tail": 49 | 50 | .. plot:: figure/alignment.py 51 | 52 | What's important is that the gradient update "knows about" both the head of the network (through the layer inputs) and the tail of the network (through the backpropagated gradient). Applying the update will align the head with the tail [#outerproduct]_. And this kind of alignment happens at all layers at every iteration! 53 | 54 | The rest of this section will show how to apply the golden rules to do `width scaling <#fixing-width-scaling>`_, `depth scaling <#fixing-depth-scaling>`_, and `key-query dot product scaling <#fixing-key-query-dot-product-scaling>`_. This should already be enough to get started scaling a GPT. 55 | 56 | Fixing width scaling 57 | ^^^^^^^^^^^^^^^^^^^^^ 58 | 59 | First, let's do width scaling in a linear layer. When the network has trained for a few steps to reach its fully aligned state, we want the input and output activations to fall roughly in the interval [-1, 1]. Equivalently, we want the inputs to have Euclidean length :python:`math.sqrt(fan_in)` and the outputs to have Euclidean length :python:`math.sqrt(fan_out)`. To achieve this, the first and second golden rules tell us that we need to control the top singular values of the initial weight matrix and the gradient updates. One can check that the right scaling is to set the singular values proportional to :python:`math.sqrt(fan_out / fan_in)`. Intuitively, the factor of :python:`math.sqrt(fan_out / fan_in)` means that the matrix operates as a "dimensional converter": it takes in vectors of length :python:`math.sqrt(fan_in)` and spits out vectors of length :python:`math.sqrt(fan_out)`. 60 | 61 | In fact we can be a little more clever here and reparameterize the linear layer as follows: 62 | 63 | .. code:: python 64 | 65 | class ReparameterizedLinear: 66 | 67 | def __init__(self, fan_out:int, fan_in:int): 68 | self.scale = math.sqrt(fan_out / fan_in) 69 | self.weight = torch.empty(fan_out, fan_in) 70 | torch.nn.init.orthogonal_(self.weight) 71 | 72 | def forward(self, x): 73 | return self.scale * torch.matmul(self.weight, x) 74 | 75 | By including the conversion factor :python:`self.scale = math.sqrt(fan_out / fan_in)` in the forward function, the correct scaling is to make the largest singular values of both :python:`self.weight` and the weight updates order one. Easy, right? For the initialization, we can just use orthogonal init, which sets all the singular values to exactly one. In our experiments, we have found orthogonal init to be a performant, hyperparameter-free initializer. As for weight updates, we can just spectrally normalize them [#spectralnorm]_: 76 | 77 | .. code:: python 78 | 79 | self.weight -= learning_rate * self.weight.grad / spectral_norm(self.weight.grad) 80 | 81 | In practice, you may want to replace :python:`self.weight.grad` with some kind of momentum or Adam expression. And the learning rate can optionally decay through the course of training. Also, you can often replace :python:`spectral_norm(self.weight.grad)` with :python:`frobenius_norm(self.weight.grad)` since gradients tend to be low stable rank. 82 | 83 | Fixing depth scaling 84 | ^^^^^^^^^^^^^^^^^^^^^ 85 | 86 | For depth scaling, we will look at scaling the number of blocks in a residual network [#mlp]_ of the form: 87 | 88 | .. code:: python 89 | 90 | def resnet(x:torch.Tensor, residue_list:list, block_multiplier:float): 91 | 92 | for residue in residue_list: 93 | 94 | x += block_multiplier * residue(x) 95 | 96 | return x 97 | 98 | We call this a residual network because at each iteration of the :python:`for` loop, a :python:`residue` is added to the input, which takes the form of a sub-network applied to the output from the previous step of the loop. The ``block_multiplier`` can be used to ensure that the residual contribution is small, allowing us to make the residual network very, very deep without its output blowing up. The main questions are: 99 | 100 | - What kind of functions are we allowed to include in the ``residue_list``? 101 | - What value should we choose for the ``block_multiplier``? 102 | 103 | The third golden rule makes answering these questions easy. We should set :python:`block_multiplier = 1 / len(residue_list)`. This is because each residue adds one contribution to the output, and there are :python:`len(residue_list)` residues in total. The sum of :python:`len(residue_list)` aligned residues needs to be divided by :python:`len(residue_list)` in order to not blow up. This is similar to an idea you may have seen in math that :math:`(1+\frac{1}{L})^L < \mathrm{e}` for any :math:`L>0`. Even though the product may involve a large number :math:`L` of terms, the residue :math:`1/L` is small enough to prevent the product blowing up. Linking the analogy back to neural nets, :math:`L` plays the role of :python:`len(residue_list)`. 104 | 105 | Since the :python:`1/len(residue_list)` block multiplier prevents both the initialization and the updates to the residues from blowing up, we are safe to set each residue equal to any neural network of our choosing, so long as that network is individually initialized and updated in accordance with the golden rules [#recursive]_. 106 | 107 | Fixing key-query dot product scaling 108 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 109 | 110 | An important operation in transformers is taking the dot product between key and query vectors. Conventionally this is done as follows: 111 | 112 | .. code:: python 113 | 114 | lambda key, query : torch.dot(key, query) / math.sqrt(key.shape[0]) 115 | 116 | The factor of :python:`1 / math.sqrt(key.shape[0])` is included to prevent the dot product from blowing up at initialization, where we assume that ``key`` and ``query`` are uncorrelated random vectors. But by the golden rules, we should expect that the keys and queries become aligned with each other through the course of training. Therefore we should instead normalize the dot product as follows: 117 | 118 | .. code:: python 119 | 120 | lambda key, query : torch.dot(key, query) / key.shape[0] 121 | 122 | To spell this out more clearly, the dot product is the sum of a number :python:`key.shape[0]` of aligned quantities, so we should divide by :python:`key.shape[0]` to prevent the sum blowing up. 123 | 124 | Wrapping up 125 | ^^^^^^^^^^^^ 126 | 127 | On this page, we introduced three "golden rules" for scaling and pointed out how they differ to some conventional wisdom about controlling activation variance at initialization. Something we hope to get across is that the logic associated with the golden rules is not only *more scalable* than standard approaches based on controlling variance, but also *simpler*. You don't need to know anything about how random variables behave in order to get the scaling right---you just need to know how objects add when they point in the same direction. And in the same vein, the use of orthogonal initialization obviates the need to know anything about the spectral properties of Gaussian random matrices. 128 | 129 | In the next section we will look at the history behind these ideas. After that we will move on to explaining how Modula automates the application of the golden rules. 130 | 131 | 132 | .. [#outerproduct] The mathematical analogue of this intuitive statement is to say that the gradient of a linear layer is an outer product of the layer input with the gradient of the loss with respect to the layer output. 133 | 134 | .. [#spectralnorm] The spectral norm of a matrix is the largest singular value. The largest singular value of :python:`matrix / spectral_norm(matrix)` is always one, so long as :python:`matrix != 0`. 135 | 136 | .. [#mlp] We study residual networks over MLPs because MLPs seem to just work bad beyond depth 10 or so. In the Modula paper, we show that the type of residual networks we propose are in fact "smooth" even in the limit of infinitely many blocks. The same property does not hold for MLPs to the best of our knowledge. 137 | 138 | .. [#recursive] The recursive nature of this statement directly inspired the Modula framework. 139 | -------------------------------------------------------------------------------- /docs/source/algorithms/manifold/orthogonal.rst: -------------------------------------------------------------------------------- 1 | Orthogonal manifold 2 | ==================== 3 | 4 | 📚 *This page contains original research. To cite the Modula docs, here's some BibTeX:* 5 | 6 | .. code:: 7 | 8 | @misc{modula-docs, 9 | author = {Jeremy Bernstein}, 10 | title = {The Modula Docs}, 11 | url = {https://docs.modula.systems/}, 12 | year = 2025 13 | } 14 | 15 | On this page, we will work out an algorithm for performing gradient descent on the manifold of orthogonal matrices while taking steps that are steepest under the spectral norm. The algorithm will solve for the matrix of unit spectral norm that maximizes the linearized improvement in loss while lying tangent to the manifold. The "retraction map"---which sends the update from the tangent space back to the manifold---involves a few extra matrix multiplications. 16 | 17 | Steepest descent on the orthogonal manifold 18 | -------------------------------------------- 19 | 20 | Consider a square weight matrix :math:`W\in\mathbb{R}^{n \times n}` that is orthogonal, meaning that :math:`W^\top W = I_n`. Suppose that the "gradient matrix" :math:`G\in\mathbb{R}^{n\times n}` is the derivative of some loss function evaluated at :math:`W`. Given step size :math:`\eta > 0`, we claim that the following weight update is steepest under the spectral norm while staying on the orthogonal manifold. First, we take the matrix sign of the skew part of :math:`W^\top G`: 21 | 22 | .. math:: 23 | X = \operatorname{msign}[W^\top G - G^\top W], 24 | 25 | where the matrix sign :math:`\mathrm{msign}` of a matrix :math:`M` returns the matrix with the same singular vectors as :math:`M` but all positive singular values are set to one. And then we make the update: 26 | 27 | .. math:: 28 | W \mapsto W \cdot (I_n - \eta X) \cdot \left(I_n - X^TX + \frac{X^TX}{\sqrt{1+\eta^2}}\right). 29 | 30 | The final bracket constitutes the "retraction map", which snaps the updated weights back to the manifold. Curiously, the update can be written in a purely multiplicative form. 31 | 32 | 33 | Non-Riemannian manifold methods 34 | -------------------------------- 35 | 36 | One reason this algorithm is interesting is that it is an example of a manifold optimization algorithm that is *non-Riemannian*. A Riemmanian manifold is a manifold equipped with a structure called a *Riemannian metric*, which is an inner product defined at each point on the manifold. The inner product provides a way to measure distance and construct geometry-aware optimization algorithms. There has been a lot of research into Riemannian optimization methods. Some examples in a machine learning context are: 37 | 38 | - `Fast and accurate optimization on the orthogonal manifold without retraction `_; 39 | - `Efficient Riemannian optimization on the Stiefel manifold via the Cayley transform `_. 40 | 41 | However, there has seemingly been much less research into optimization algorithms on manifolds equipped with non-Riemannian structures. For instance, a matrix manifold equipped with the spectral norm at every point is non-Riemannian since the spectral norm does not emerge from an inner product. But we believe these kinds of non-Riemannian geometries are very important in deep learning. 42 | 43 | The structure of the tangent space 44 | ----------------------------------- 45 | 46 | We would like to make a weight update so that the updated weights stay on the orthogonal manifold. First we need to figure out the structure of the "tangent space" at a point on the manifold. Roughly speaking, the tangent space is the set of possible velocities a particle could have as it passes through that particular point. So we need to consider all curves passing through the point on the manifold. 47 | 48 | If we consider a curve :math:`W(t)` on the manifold parameterized by time :math:`t \in \mathbb{R}`, then this curve must satisfy :math:`W(t)^\top W(t) = I_n`. Differentiating with respect to :math:`t`, we find that the velocity must satisfy: 49 | 50 | .. math:: 51 | 52 | \frac{\partial W(t)}{\partial t}^\top W(t) + W(t)^\top \frac{\partial W(t)}{\partial t} = 0. 53 | 54 | So to be in the tangent space of a point :math:`W` on the manifold, a matrix :math:`A` must satisfy :math:`A^\top W + W^\top A = 0`. Conversely, if a matrix :math:`A` satisfies :math:`A^\top W + W^\top A=0`, then it is the velocity of a curve on the manifold that passes through :math:`W`, as evidenced by the curve :math:`W(t) = W \exp(tW^\top A)`. Therefore, the tangent space at :math:`W` is completely characterized by the set: 55 | 56 | .. math:: 57 | 58 | \{A\in \mathbb{R}^{n\times n}:A^\top W + W^\top A = 0\}. 59 | 60 | Finally, if we use the orthogonal matrix :math:`W` to make the change of variables :math:`A = W X`, then we see that :math:`A` belongs to the tangent space at :math:`W` if and only if :math:`X` is skew-symmetric: :math:`X^\top + X = 0`. So the tangent space to the orthogonal manifold can be parameterized by skew-symmetric matrices. 61 | 62 | 63 | Steepest direction in the tangent space 64 | ---------------------------------------- 65 | 66 | We will solve for the matrix :math:`A` that belongs to the tangent space to the orthogonal manifold at matrix :math:`W` and maximizes the linearized improvement in loss :math:`\operatorname{trace}(G^\top A)` under the constraint that :math:`A` has unit spectral norm. Formally, we wish to solve: 67 | 68 | .. math:: 69 | 70 | \operatorname{arg max}_{A\in \mathbb{R}^{n\times n}: \|A\|_*\leq 1 \text{ and } A^\top W + W^\top A = 0}\; \operatorname{trace}(G^\top A). 71 | 72 | To simplify, we make the change of variables :math:`A = W X` so that we now only need to maximize over skew-symmetric matrices :math:`X` of unit spectral norm: 73 | 74 | .. math:: 75 | 76 | \operatorname{arg max}_{X\in \mathbb{R}^{n\times n}:\|X\|_*\leq 1 \text{ and } X^\top + X= 0}\; \operatorname{trace}([W^\top G]^\top X). 77 | 78 | Next, we decompose :math:`W^\top G = \frac{1}{2}[W^\top G + G^\top W] + \frac{1}{2}[W^\top G - G^\top W]` into its symmetric and skew-symmetric components and realize that, because :math:`X` is skew-symmetric, the contribution to the trace from the symmetric part of :math:`W^\top G` vanishes. So the problem becomes: 79 | 80 | .. math:: 81 | \operatorname{arg max}_{X\in \mathbb{R}^{n\times n}:\|X\|_*\leq 1 \text{ and } X^\top + X= 0}\; \operatorname{trace}\left(\left[\frac{W^\top G - G^\top W}{2}\right]^\top X\right). 82 | 83 | If we simply ignore the skew-symmetric constraint, the solution for :math:`X` is given by :math:`X = \operatorname{msign}[W^\top G - G^\top W]`. But this solution for :math:`X` actually satisfies the skew-symmetric constraint! This is because the matrix sign function preserves skew-symmetry. An easy way to see this is that :math:`\operatorname{msign}[W^\top G - G^\top W]` can be computed by running an odd polynomial iteration (see :doc:`Newton-Schulz <../newton-schulz>`) on :math:`W^\top G - G^\top W`, and odd polynomials preserve skew-symmetry. [#youla]_ 84 | 85 | Undoing the change of variables, our tangent vector is given by :math:`A = W \cdot \operatorname{msign}[W^\top G - G^\top W]`. 86 | 87 | 88 | Finding the retraction map 89 | --------------------------- 90 | 91 | The previous section suggests making the weight update :math:`W \mapsto W - \eta W X = W (I_n - \eta X)`. This update takes a step in the tangent space, which diverges slightly from the orthogonal manifold for finite step sizes. A relatively expensive way to fix this issue is to just apply the matrix sign function, i.e. :math:`W \mapsto \operatorname{msign}[W (I_n - \eta X)]`, to project the weights back to the manifold. But we will show in this section that there is actually a shortcut. 92 | 93 | As a warmup, let's first consider the case that :math:`W^\top G - G^\top W` is full rank. Then :math:`X` is an orthogonal matrix and :math:`[W (I_n - \eta X)]^\top [W (I_n - \eta X)] = (1 + \eta^2) I_n`. Therefore, in this case, we can project back to the manifold simply by dividing the updated weights through by the scalar :math:`\sqrt{1+\eta^2}`. 94 | 95 | In the general case where :math:`W^\top G - G^\top W` and therefore :math:`X = \operatorname{msign}[W^\top G - G^\top W]` may not be full rank, let us search for a matrix :math:`C` such that :math:`W \cdot (I_n - \eta X) \cdot C` is orthogonal. Checking the orthogonality condition :math:`(W \cdot (I_n - \eta X) \cdot C)^\top (W \cdot (I_n - \eta X) \cdot C)=I_n` reveals that we need to find a matrix :math:`C` such that: 96 | 97 | .. math:: 98 | 99 | C^\top (I_n + \eta^2 X^\top X) C = I_n. 100 | 101 | The trick is to recognize :math:`X^\top X` as the orthogonal projector on to the row space of :math:`X`. The matrix :math:`I_n + \eta^2 X^\top X` conserves vectors in the null space of :math:`X` but scales up vectors in the row space of :math:`X` by a factor of :math:`1+\eta^2`. It therefore suffices to choose a symmetric matrix :math:`C` that inverts this transformation in two steps. Noting that :math:`I_n - X^\top X` projects on to the null space of :math:`X`, the following choice of :math:`C` is what we need: 102 | 103 | .. math:: 104 | 105 | C = C^\top = I_n - X^\top X + \frac{X^TX}{\sqrt{1+\eta^2}}. 106 | 107 | 108 | Python code 109 | --------------------------- 110 | 111 | Here is a basic JAX implementation for the algorithm: 112 | 113 | .. code-block:: python 114 | 115 | import jax.numpy as jnp 116 | import math 117 | 118 | def orthogonalize(M, steps = 10): 119 | a, b, c = 3, -16/5, 6/5 120 | transpose = M.shape[1] > M.shape[0] 121 | if transpose: 122 | M = M.T 123 | M = M / jnp.linalg.norm(M) 124 | for _ in range(steps): 125 | A = M.T @ M 126 | I = jnp.eye(A.shape[0]) 127 | M = M @ (a * I + b * A + c * A @ A) 128 | if transpose: 129 | M = M.T 130 | return M 131 | 132 | def update(W, G, eta, NS_steps=10): 133 | I = jnp.eye(d) 134 | X = orthogonalize(W.T @ G - G.T @ W, NS_steps) 135 | retraction_factor = I - (1 - math.sqrt(1/(1+eta**2))) * X.T @ X 136 | return W @ (I - eta * X) @ retraction_factor 137 | 138 | 139 | Open problem: Extending to the Stiefel Manifold 140 | ------------------------------------------------ 141 | 142 | I initially thought that this solution easily extended to the *Stiefel manifold*—i.e. the set of :math:`m \times n` semi-orthogonal matrices. But this turns out not to be the case: the algorithm we derived is generally not optimal if :math:`W` is rectangular. To see this, let's consider an :math:`m \times n` matrix :math:`W` with :math:`m > n`, and suppose that it belongs to the Stiefel manifold :math:`W^\top W = I_n`. The problem with our derivation is that the change of variables :math:`A = W X` no longer parameterizes the full set of :math:`m \times n` matrices. Instead, we need to make the change of variable :math:`A = WX + \overline{W}Y` where the columns of :math:`\overline{W}` are the "missing" columns of :math:`W`. In other words, the combined matrix :math:`[W | \overline{W}]` is a square orthogonal matrix. For this parameterization, the tangent space to the Stiefel manifold is obtained by requiring that :math:`X\in\mathbb{R}^{n\times n}` is skew-symmetric while :math:`Y\in\mathbb{R}^{(m-n)\times n}` is completely unconstrained. I do not know how to analytically solve the resulting maximization problem in this parameterization. 143 | 144 | .. [#youla] In fact, any odd function applied entrywise to the singular values of a matrix will preserve skew symmetry. To see this, one needs to understand the spectral structure of skew-symmetric matrices. An :math:`n\times n` matrix :math:`X` is skew symmetric if and only if it can be written :math:`X = \sum_{i=1}^k \sigma_i (u_iv_i^\top - v_i u_i^\top)`, where the :math:`\sigma_i` are non-negative, the :math:`\{u_i\}\cup\{v_i\}` are all orthonormal and :math:`k \leq \lfloor n/2 \rfloor`. In other words, :math:`X` must admit an SVD where the singular values come in pairs with conjugate singular vectors. But applying an odd function :math:`f` to the singular values yields :math:`\sum_{i=1}^k f(\sigma_i) (u_iv_i^\top - v_i u_i^\top)`, which leaves the skew-symmetric structure intact. For more reading on the spectral structure of skew-symmetric matrices, see `(Haber, 2016) `_ or `(Youla, 1961) `_. -------------------------------------------------------------------------------- /docs/source/algorithms/manifold/stiefel.rst: -------------------------------------------------------------------------------- 1 | Stiefel manifold 2 | ================= 3 | 4 | 📚 *This page contains original research. To cite the Modula docs, here's some BibTeX:* 5 | 6 | .. code:: 7 | 8 | @misc{modula-docs, 9 | author = {Jeremy Bernstein}, 10 | title = {The Modula Docs}, 11 | url = {https://docs.modula.systems/}, 12 | year = 2025 13 | } 14 | 15 | On this page we shall consider a problem that I affectionately refer to as *manifold Muon*—or, more formally, the problem of *steepest descent under the spectral norm on the Stiefel manifold*. This problem arises when one is interested in taking the best possible optimization step in a spectral norm geometry (useful for accelerating training) while keeping the size of the weight matrices tightly regulated (potentially helpful for training stability and removing learning rate confounders). This page will generalize the analysis from :doc:`the square case ` to the full Stiefel manifold. 16 | 17 | I posed manifold Muon as an open problem on the `Modula docs `_ earlier this year, and two researchers Franz Louis Cesista (a.k.a. Leloy) and Jianlin Su recently proposed solutions. Leloy proposed a `heuristic solution `_ via alternating projections, and Jianlin `solved the problem `_ by setting up a fixed point iteration. I heard about Leloy's work and an early version of Jianlin's approach (which did not yet work) and managed to solve the problem myself with a slightly different approach based on Lagrangian duality, which I will present in the next section. I also want to acknowledge that `Cédric Simal `_ independently proposed studying the dual problem to me and Leloy, after I had worked out the following analysis. 18 | 19 | Formulating the problem 20 | ------------------------ 21 | 22 | Let's set up the problem mathematically. Say we have a matrix-valued optimization variable :math:`W \in \mathbb{R}^{m \times n}` where, without loss of generality, we take :math:`m\geq n` so that the matrix has more rows than columns. And we have a cost function :math:`\mathcal{C}:\mathbb{R}^{m \times n}\to\mathbb{R}` that we would like to minimize. We would also like to constrain the matrix :math:`W` to the following set: 23 | 24 | .. math:: 25 | 26 | \mathsf{Stiefel}(m,n) := \left\{ W \in \mathbb{R}^{m \times n} \mid W^T W = I_n \right\}. 27 | 28 | This set is known as the *Stiefel manifold*. A matrix :math:`W\in\mathsf{Stiefel}(m,n)` for :math:`m>n` is known as a *semi-orthogonal* matrix—since it has too few columns to form a complete orthonormal basis. There are various alternative ways to characterize the Stiefel manifold. For example, it is equivalently defined as the set of :math:`m \times n` matrices with unit :math:`\ell_2 \to \ell_2` condition number. Suffice to say, the Stiefel manifold is a very well-behaved class of matrices. 29 | 30 | We would like to be able to take optimization steps that lie tangent to this manifold. Just as in :doc:`the square case `, we can show that the tangent space to the Stiefel manifold at semi-orthogonal matrix :math:`W\in\mathsf{Stiefel}(m,n)` is given by the following linear subspace of the ambient matrix space :math:`\mathbb{R}^{m \times n}`: 31 | 32 | .. math:: 33 | 34 | \mathsf{T}_W \mathsf{Stiefel}(m,n) = \left\{ A \in \mathbb{R}^{m \times n} \mid A^\top W + W^\top A = 0 \right\}. 35 | 36 | In the context of Riemannian optimization, there are `established means `_ of projecting the gradient to this linear subspace in order to take steps tangent to the Stiefel manifold. But to make life more interesting, we shall be interested in cost functions with a different sort of structure. In particular, suppose our cost :math:`\mathcal{C}` is Lipschitz-smooth in the *spectral norm*: 37 | 38 | .. math:: 39 | 40 | \mathcal{C}(W + \Delta W) \leq \mathcal{C}(W) + \langle \nabla \mathcal{C}(W), \Delta W\rangle + \tfrac{1}{2} \cdot \| \Delta W \|_{\mathrm{spectral}}^2, 41 | 42 | where :math:`\langle \nabla \mathcal{C}(W), \Delta W\rangle \equiv \operatorname{trace} \nabla \mathcal{C}^\top \Delta W` is the Frobenius inner product between the derivative of the cost and the weight update, measuring the first-order change in cost. To motivate this smoothness structure, observe that matrices in a neural network act as *operators* on vectors, and the spectral norm respects this fact—see our `anthology `_ for more on this. Spectral norm smoothness suggests taking optimization steps of controlled spectral norm. And since the spectral norm does not emerge from an inner product, spectral norm smoothness takes us outside the realm of Riemannian geometry. 43 | 44 | All told, we would like to design a gradient descent algorithm whose updates exploit the spectral norm geometry of the cost function while lying tangent to the Stiefel manifold. Here we focus on the problem of choosing the *direction* of the step given these constraints, and offload the problem of choosing the *magnitude* to the learning rate. We formulate the optimal update direction as the matrix :math:`A` that solves the following minimization problem: 45 | 46 | .. _eq-primal: 47 | 48 | .. math:: 49 | \min_{A \in \mathbb{R}^{m \times n}} \underbrace{\mathstrut \operatorname{trace}(G^\top A)}_{\text{linearization of cost}} \quad \text{subject to} \quad \underbrace{\|A\|_{\mathrm{spectral}} \leq 1}_{\text{spectral constraint}} \quad \text{and} \quad \underbrace{\mathstrut A^\top W + W^\top A = 0}_{\text{tangent space constraint}}. \qquad (1) 50 | 51 | In this expression, :math:`W` is the current point on the manifold, :math:`G := \nabla \mathcal{C}(W)` is shorthand for the derivative of the cost, and :math:`A` is the update direction that we seek. In words, we want to find an update direction that squeezes out the most linear improvement in cost while lying inside the ball of unit spectral norm and also lying tangent to the Stiefel manifold. 52 | 53 | Solving manifold Muon via Lagrangian duality 54 | --------------------------------------------- 55 | 56 | Similar to Jianlin's approach, we introduce a matrix :math:`\Lambda\in\mathbb{R}^{n\times n}` of Lagrange multipliers, and define a Lagrangian function :math:`\mathcal{L}(A, \Lambda)` that incorporates the tangent space constraint: 57 | 58 | .. math:: 59 | 60 | \begin{align*} 61 | \mathcal{L}(A, \Lambda) &:= \operatorname{trace} G^\top A + \operatorname{trace}\Lambda^\top (A^\top W + W^\top A) \\ 62 | &= \operatorname{trace}A^\top (G + 2W(\Lambda+\Lambda^\top)), 63 | \end{align*} 64 | 65 | where the second equality follows by applying the cyclic property of the trace and transposing one term. One can check that our original problem :ref:`(1) ` is equivalent to the saddle point problem :math:`\min_{\|A\|_\mathrm{spectral} \leq 1} \max_{\Lambda} \mathcal{L}(A,\Lambda)` since for any :math:`A` that violates the tangent space constraint, the inner maximization with respect to :math:`\Lambda` would send the Lagrangian to infinity. By Sion's minimax theorem, we can swap the order of the :math:`\min` and :math:`\max` to obtain: 66 | 67 | .. math:: 68 | 69 | \min_{\|A\|_\mathrm{spectral} \leq 1} \max_{\Lambda} \mathcal{L}(A,\Lambda) = \max_{\Lambda} \min_{\|A\|_\mathrm{spectral} \leq 1} \mathcal{L}(A,\Lambda). 70 | 71 | Following `an argument `_ which is now standard in Muon lore, we recognize the optimal value :math:`A_\mathrm{opt}(\Lambda)` of the primal variable :math:`A` for a given dual variable :math:`\Lambda` as: 72 | 73 | .. math:: 74 | 75 | A_{\mathrm{opt}}(\Lambda) := \mathop{\mathrm{arg\,min}}_{\|A\|_\mathrm{spectral} \leq 1} \mathcal{L}(A,\Lambda) = - \operatorname{msign} (G+2W(\Lambda+\Lambda^\top)), 76 | 77 | where :math:`\operatorname{msign}` is the *matrix sign function*, defined as the elementwise sign function applied to the singular values of a matrix, or in PyTorch code: 78 | 79 | .. code-block:: python 80 | 81 | import torch 82 | 83 | def msign(X): 84 | U, S, V = torch.svd(X) 85 | return U @ S.sign().diag() @ V.T 86 | 87 | Note that :math:`\operatorname{msign}` can be computed efficiently on GPUs without taking an SVD via `Newton-Schulz iteration `_ as in the recent `Polar Express `_ algorithm. 88 | 89 | Substituting :math:`A_\mathrm{opt}(\Lambda)` back into the Lagrangian, we uncover the dual problem: 90 | 91 | .. _eq-dual: 92 | 93 | .. math:: 94 | 95 | \max_{\Lambda}\mathcal{L}(A_\mathrm{opt}(\Lambda), \Lambda) = \max_{\Lambda} -\|G + W (\Lambda+\Lambda^\top)\|_\mathrm{nuclear}. 96 | 97 | In contrast to the primal problem :ref:`(1) `, the dual problem is completely unconstrained. We may solve the dual problem by running gradient ascent on the Lagrangian dual function :math:`\mathcal{L}(A_\mathrm{opt}(\Lambda), \Lambda)`—a technique formally known as *dual ascent*. After some work, the gradient of the dual function—or, more precisely, a *subgradient*—is given by the following formula: 98 | 99 | .. math:: 100 | 101 | \begin{align*} 102 | H(\Lambda) &:= - \nabla_\Lambda \|G + W (\Lambda+\Lambda^\top)\|_\mathrm{nuclear} \\ 103 | &= - [W^\top\mathrm{msign}(G + 2W (\Lambda+\Lambda^\top)) + \operatorname{msign}(G + 2W (\Lambda+\Lambda^\top))^\top W]. 104 | \end{align*} 105 | 106 | To obtain this expression, we have applied the chain rule and the fact that :math:`\operatorname{msign}(X)` is in the subdifferential of :math:`\|X\|_\mathrm{nuclear}`. 107 | 108 | This expression for :math:`H(\Lambda)` also has an intuitive interpretation: it measures the deviation of the current setting of :math:`A_\mathrm{opt}(\Lambda)` from satisfying the tangent space condition. `Jianlin's solution `_ can be interpreted as running a fixed point iteration on the first-order optimality condition for the dual problem: :math:`H(\Lambda_\mathrm{opt}) = 0`. Instead of running this fixed point iteration, we propose a different approach known as *dual ascent*. 109 | 110 | The dual ascent algorithm 111 | ------------------------ 112 | 113 | In this section, we write down a gradient ascent algorithm to solve the Lagrangian dual problem. Given a tolerance :math:`\mathtt{tol}>0` and a step size :math:`\alpha>0` for updating the dual variable :math:`\Lambda`, the algorithm is given by: 114 | 115 | 1. Initialize the dual variable: :math:`\Lambda = -\tfrac{1}{4} \times (W^\top G + G^\top W)`. 116 | 2. Compute the candidate update direction: :math:`A = - \operatorname{msign}(G + 2W \Lambda)`. 117 | 3. Measure the deviation of :math:`A` from the tangent space: :math:`H = W^\top A + A^\top W`. 118 | 4. Check the stopping criterion: 119 | 120 | a. If the deviation is small enough, i.e. :math:`\|H\|_\mathrm{F} / \sqrt{mn} < \mathtt{tol}`, then return :math:`A`. 121 | b. Otherwise, update the dual variable: :math:`\Lambda \gets \Lambda + \alpha \times H` and go back to step 2. 122 | 123 | Observe that the dual variable :math:`\Lambda` remains symmetric throughout this procedure, so we can use :math:`2 \Lambda` in place of :math:`\Lambda + \Lambda^\top` at step 2. The motivation for the special initialization of :math:`\Lambda` is that it leads to the algorithm terminating on the first step if :math:`W` is square. This is because step 2 already recovers the optimal value of :math:`A` :doc:`for the square case ` and so :math:`H=0` at step 3. In actual neural network training, where :math:`G` may not change much between steps because of momentum, it might make more sense to warm start :math:`\Lambda` from the previous iteration. 124 | 125 | Once this algorithm terminates, we take the returned value of the primal variable :math:`A` and make the tangent space update :math:`W \gets W + \eta \times A`. The final step is to retract the updated weights back to the manifold. We will work out a retraction map in the next section. 126 | 127 | Working out the retraction map 128 | ----------------------------- 129 | 130 | An update in the tangent space will diverge slightly from the manifold for finite step sizes :math:`\eta`. As such we need to find a retraction map to project the updated weights back to the manifold. It turns out that the retraction map can be implemented in a simple way, by introducing an extra matrix :math:`C` to the update: 131 | 132 | .. math:: 133 | 134 | W \gets (W + \eta \times A)\cdot C. 135 | 136 | We just need to solve for the proper value of :math:`C`. Checking the semi-orthogonality condition and using the fact that :math:`W^\top A + A^\top W = 0` because the update direction :math:`A` belongs to the tangent space, we find that: 137 | 138 | .. math:: 139 | 140 | \begin{align*} 141 | C^\top(W - \eta A)^\top (W - \eta A)C &=C^\top [W^\top W - \eta \times [W^\top A + A^\top W] + \eta^2 A^\top A]C \\ 142 | &= C^\top[I_n - A^\top A + (1+\eta^2) \cdot A^\top A]C. 143 | \end{align*} 144 | 145 | Even though :math:`A` is an output of :math:`\operatorname{msign}`, it may not hold that :math:`A^\top A = I_n` because :math:`A` may be low rank. We need to find a matrix :math:`C` satisfying :math:`C^\top[I_n - A^\top A + (1+\eta^2) \cdot A^\top A]C = I_n`. This task is made substantially easier by observing that :math:`A^\top A` and :math:`I_n - A^\top A` are orthogonal projectors. We can then read off a suitable value for :math:`C` as: 146 | 147 | .. math:: 148 | 149 | C = C^\top = I_n - A^\top A + \frac{A^\top A}{\sqrt{1+\eta^2}}. 150 | 151 | While it is nice to have an analytical expression for the retraction map, in practice it might be numerically advantageous just to use :math:`\operatorname{msign}` to project the updated weights back to the manifold. 152 | 153 | PyTorch implementation 154 | ---------------------- 155 | 156 | Here we give a basic PyTorch implementation for solving manifold Muon via dual ascent. The code re-uses the ``msign`` function defined earlier in the post. 157 | 158 | .. code-block:: python 159 | 160 | import math 161 | 162 | def manifold_muon(W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6): 163 | # Ensure that W and G are both tall matrices 164 | should_tranpose = W.shape[0] < W.shape[1] 165 | if should_tranpose: 166 | W = W.T 167 | G = G.T 168 | # Initialize the dual variable 169 | Lambda = -0.25 * (W.T @ G + G.T @ W) 170 | # Ascend on the dual problem to find the update direction A 171 | for step in range(steps): 172 | # Update the candidate direction A 173 | A = msign(G + 2 * W @ Lambda) 174 | # Measure deviation of A from the tangent space: 175 | H = W.T @ A + A.T @ W 176 | # Check the stopping criterion 177 | if torch.norm(H) / math.sqrt(H.numel()) < tol: 178 | break 179 | # Update the dual variable 180 | Lambda -= alpha * (1 - step / steps) * H 181 | # Descend on the primal problem 182 | new_W = W - eta * A 183 | # Retract to the manifold 184 | new_W += new_W @ A.T @ A * (1/math.sqrt(1 + eta**2) - 1) 185 | # Restore the shape of the solution and return 186 | return new_W.T if should_tranpose else new_W 187 | 188 | Acknowledgments 189 | ---------------- 190 | 191 | I am grateful to `Leloy `_ and `Jianlin Su `_ for sharing their excellent work on this topic. I also want to acknowledge `Cédric Simal `_ who independently proposed studying the dual problem to me, after I had worked out this dual ascent approach. I am incredibly grateful to the team at `Thinking Machines `_ for supporting me to explore this problem. Any mistakes in this writeup are my own responsibility. -------------------------------------------------------------------------------- /docs/source/algorithms/newton-schulz.rst: -------------------------------------------------------------------------------- 1 | Newton-Schulz 2 | ============== 3 | 4 | On this page, we will work out a family of iterative algorithms for "orthogonalizing" a matrix, by which we mean transforming either the rows or the columns of the matrix to form an orthonormal set of vectors. These so-called "Newton-Schulz" iterations are a useful family of algorithms to keep in your toolbox. We proposed using these iterations for neural net optimization in our paper: 5 | 6 | | 📗 `Modular duality in deep learning `_ 7 | | Jeremy Bernstein & Laker Newhouse 8 | | arXiv 2024 9 | 10 | Before that, we included the iteration in an appendix of our `workshop paper `_, and before that I actually worked out `the ideas `_ 11 | `directly `_ 12 | `on `_ `Twitter `_ with my collaborator Tim Large. We used a particular `cursed quintic iteration <#a-cursed-quintic-iteration>`_ in the Muon optimizer, which was used to set speed records for training NanoGPT: 13 | 14 | | 📕 `Muon: An optimizer for hidden layers in neural networks `_ 15 | | Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse & Jeremy Bernstein 16 | | blog post 2024 17 | 18 | Since then, the iteration has been applied in new optimizers such as `Scion `_, `improved-SOAP `_ and `Mango `_. At the bottom of this page, we provide further `historical connections <#id1>`_ on the techniques. 19 | 20 | Problem statement 21 | ----------------- 22 | 23 | We wish to approximate the map that sends a matrix :math:`M\in\mathbb{R}^{m\times n}` with reduced SVD :math:`M = U \Sigma V^\top` to the matrix :math:`U V^\top`. This map can be thought of as "snapping the singular values of :math:`M` to one"---with the exception that the iterations we consider will actually fix zero singular values at zero. But ignoring this detail, the map is given by: 24 | 25 | .. math:: 26 | M = U \Sigma V^\top \mapsto U V^\top. 27 | 28 | This operation is sometimes referred to as `"symmetric orthogonalization" `_ because no row or column of the matrix :math:`M` is treated as special in the procedure. This is in contrast to `Gram-Schmidt orthogonalization `_, which involves first picking out a certain row or column vector as special and then orthogonalizing the remaining vectors against this vector. 29 | 30 | 31 | .. Why care about symmetric orthogonalization? 32 | .. -------------------------------------------- 33 | 34 | .. To train a neural network stably, it is desirable that the outputs of the layers evolve in a controlled fashion during training. We argued in our first paper on `distance measures between neural networks `_ that a good way to achieve this is to control the change in singular values of the weight matrices, and in our paper on `a spectral condition for feature learning `_ we proposed controlling the spectral norms of the weight upates. Following on from this, it makes sense to ask "what is the largest weight update we can make to a layer that has a given spectral norm?" This question is answered by taking the sharp operator of the gradient matrix. Formally, for a matrix :math:`G\in\mathbb{R}^{m\times n}` thought of as the gradient of a loss function, the sharp operator solves the following problem: 35 | 36 | .. .. math:: 37 | .. G^\sharp = \operatorname{arg max}_{T \in \mathbb{R}^{m\times n} \,:\, \|T\|_* \leq 1} \langle G , T \rangle, 38 | 39 | .. where :math:`\langle \cdot, \cdot \rangle` denotes the Frobenius inner product and :math:`\|\cdot\|_*` denotes the spectral norm. In words, the sharp operator tells us the direction :math:`T` in matrix space that squeezes out the most linearized change in loss :math:`\langle G, T \rangle` while keeping the spectral norm under control. Keeping the spectral norm of the weight update under control is important as it allows us to guarantee that the features of the model change by a controlled amount. 40 | 41 | Odd polynomial iterations 42 | ------------------------- 43 | 44 | We will consider iterations based on odd matrix polynomials of the form: 45 | 46 | .. math:: 47 | p(X) = a X + b X X^\top X + c (X X^\top)^2 X + ... 48 | 49 | which acts on a matrix :math:`X \in \mathbb{R}^{m \times n}`. The important property of an odd matrix polynomial of this form is that it *commutes* with the singular value decomposition, in the sense that: 50 | 51 | .. math:: 52 | p(U \Sigma V^\top) = U p(\Sigma) V^\top. 53 | 54 | So, to apply an odd polynomial :math:`p` to the singular values, it is enough to apply it to the overall matrix :math:`X`. Since the matrix of singular values :math:`\Sigma` is diagonal, this reduces to applying the scalar polynomial 55 | 56 | .. math:: 57 | f(x) = a x + bx^3 + cx^5 + ... 58 | 59 | to the diagonal entries of :math:`\Sigma`. In what follows we will simply specify formulae for scalar polynomials :math:`f` with the understanding that they will be extended to matrix polynomials :math:`p` as specified above. Then our task is just to produce odd scalar polynomials :math:`f(x)` that when iterated like :math:`f \circ f \circ f \circ ... \circ f(x)` converge to the sign function :math:`\operatorname{sign}(x)`. 60 | 61 | While the classical way to select these polynomials is by Taylor-approximating the sign function `(e.g. Björck & Bowie, 1971) `_, this only leads to a very restricted set of polynomials. I had the idea of treating the polynomial coefficients as free parameters to be tuned for specific performance characteristics. 62 | 63 | A cubic iteration 64 | ------------------ 65 | 66 | We begin with the simplest Newton-Schulz iteration, based on the cubic polynomial: 67 | 68 | .. math:: 69 | f(x) = \frac{3}{2}x - \frac{1}{2}x^3. 70 | 71 | We plot :math:`f(x)` on the left and on the right we plot :math:`f(x)` iterated five times to yield :math:`f(f(f(f(f(x)))))`. 72 | 73 | .. raw:: html 74 | 75 | 76 | 77 | 78 | As can be seen, by iterating :math:`f` several times, the graph starts to resemble that of the sign function :math:`\operatorname{sign}(x)`, at least on the interval close to the origin. In fact, you can check that if we iterate :math:`f` an infinite number of times, we will obtain precisely the sign function on the interval :math:`[-\sqrt{3},\sqrt{3}]`. As a consequence, if we iterate the corresponding matrix polynomial :math:`p(X) = \frac{3}{2}X - \frac{1}{2}XX^\top X`, we will approximate the sign function element-wise on the singular values of :math:`X`, thereby orthogonalising the matrix. The only caveat is that we need to ensure all singular values of the initial matrix lie in the interval :math:`[-\sqrt{3},\sqrt{3}]`. We can achieve this via a simple pre-processing step, mapping :math:`X \mapsto X / \|X\|_F`. 79 | 80 | A quintic iteration 81 | -------------------- 82 | 83 | Using a higher-order polynomial provides more degrees of freedom in our design space, which we can use to obtain faster convergence. In this section, we consider the quintic iteration given by: 84 | 85 | .. math:: 86 | f(x) = 3x - \frac{16}{5}x^3 + \frac{6}{5}x^5. 87 | 88 | Again, we plot one and five iterations of this polyomial: 89 | 90 | .. raw:: html 91 | 92 | 93 | 94 | 95 | As can be seen, after 5 iterations the quintic iteration has achieved a substantially closer approximation to the sign function than the cubic iteration, at least on the interval :math:`[-3/2,3/2]`. However, the approximation exhibits some oscillatory behaviour close to the origin. 96 | 97 | A cursed quintic iteration 98 | --------------------------- 99 | 100 | We applied a Newton-Schulz iteration in the `Muon optimizer `_ used in the `NanoGPT speedrun `_. Keller experimented with tuning the coefficients in the iteration and found that the most important thing for fast convergence of the optimizer was to inflate the small singular values as fast as possible. To keep the wall-clock time low, we needed to do this in the smallest number of iterations possible. This is achieved by making the first coefficient in the polynomial as large as possible, thereby maximizing the slope of the polynomial at :math:`x=0`. I had the idea for using a non-convergent iteration, and Keller settled on the following: 101 | 102 | .. math:: 103 | f(x) = 3.4445x - 4.7750x^3 + 2.0315x^5. 104 | 105 | Plotting the polynomial after one and five iterations, we see some peculiar behaviour: 106 | 107 | .. raw:: html 108 | 109 | 110 | 111 | 112 | This iteration *oscillates* and in fact *does not converge*! To see why, observe that a convergent iteration must at the very least satisfy :math:`f(1) = 1` so that :math:`x=1` is a fixed point. In turn, this implies that the sum of the coefficients should equal 1. But for Keller's polynomial, the coefficients sum to 113 | 114 | .. math:: 115 | 3.4445 - 4.7750 + 2.0315 = 0.701 \neq 1. 116 | 117 | In short, the cursed quintic iteration sacrifices convergence for speed. 118 | 119 | Jiacheng's six-step 120 | -------------------- 121 | 122 | Another idea is that we do not need to use the same polynomial coefficients at each step of the iteration. Using different coefficients at each step provides even more degrees of freedom to achieve desirable behaviour. For example, `You Jiacheng `_ found the following iteration `by computational search `_: 123 | 124 | .. raw:: html 125 | 126 | 127 | 128 | 129 | The six quintic polynomials on the left are executed in sequence. The last quintic (the orange curve) has a very flat slope close to :math:`x=1`, which has the effect of flattening out the oscillations present after the first five steps. The result is the very good approximation to the sign function on the right. 130 | 131 | I have integrated this iteration as a basic primitive in `Modula `_ using the following JAX implementation: 132 | 133 | .. code-block:: python 134 | 135 | import jax.numpy as jnp 136 | 137 | def orthogonalize(M): 138 | # by @YouJiacheng (with stability loss idea from @leloykun) 139 | # https://twitter.com/YouJiacheng/status/1893704552689303901 140 | # https://gist.github.com/YouJiacheng/393c90cbdc23b09d5688815ba382288b/5bff1f7781cf7d062a155eecd2f13075756482ae 141 | 142 | abc_list = [ 143 | (3955/1024, -8306/1024, 5008/1024), 144 | (3735/1024, -6681/1024, 3463/1024), 145 | (3799/1024, -6499/1024, 3211/1024), 146 | (4019/1024, -6385/1024, 2906/1024), 147 | (2677/1024, -3029/1024, 1162/1024), 148 | (2172/1024, -1833/1024, 682/1024) 149 | ] 150 | 151 | transpose = M.shape[1] > M.shape[0] 152 | if transpose: 153 | M = M.T 154 | M = M / jnp.linalg.norm(M) 155 | for a, b, c in abc_list: 156 | A = M.T @ M 157 | I = jnp.eye(A.shape[0]) 158 | M = M @ (a * I + b * A + c * A @ A) 159 | if transpose: 160 | M = M.T 161 | return M 162 | 163 | I also recommend checking out the `writeup `_ of Franz Cesista (@leloykun) on this topic. 164 | 165 | Designing your own iteration 166 | ----------------------------- 167 | 168 | Designing these polynomial iterations can be a surprisingly fun exercise. If you'd like to explore designing your own iteration, you can start with a polynomial of the form: 169 | 170 | .. math:: 171 | f(x) = a x + b x^3 + c x^5 + d x^7 + e x^9 + ... 172 | 173 | And then choose the coefficients :math:`a,b,c,d,e,...` to achieve your desired behaviour. Three important things to consider are: 174 | 175 | - What order do you want to truncate your polynomial at? A higher-order iteration can converge in fewer steps, but each step is more expensive. There is a trade-off here. 176 | - Do you want the iterations to converge? If so, you at least need to enforce that the coefficients sum to 1 so that :math:`f(1) = 1`. You could consider enforcing additional derivative conditions, such as that :math:`\partial f / \partial x = 0` at :math:`x=1`, to further stabilize the convergence. 177 | - Do you want to use different polynomials on different steps? This adds complexity but provides a larger design space. 178 | 179 | After making these decisions, you may have leftover degrees of freedom. A fun way to fix these degrees of freedom is to open up `Desmos `_ and play around with the coefficients using sliders. 180 | 181 | Historical connections 182 | ---------------------- 183 | 184 | The procedure of symmetric orthogonalization appears in a number of different contexts: 185 | 186 | - it is used in solving the `orthogonal Procrustes problem `_. 187 | - it computes the "orthogonal polar factor" in the `polar decomposition `_ of a matrix. 188 | - it was used by `Per-Olov Löwdin `_ in the 1950s to perform atomic and molecular orbital calculations. 189 | - it is used for doing `Frank-Wolfe optimization `_ over the spectral norm ball. 190 | - it was proposed for deep learning optimization in the paper `"preconditioned spectral descent for deep learning" `_---albeit computed via matrix sketching rather than Newton-Schulz iterations. 191 | - A Newton-Schulz iteration was used to orthogonalize the weight matrices (but not the updates!) in deep learning in the paper `"sorting out Lipschitz function approximation" `_. 192 | 193 | The earliest references on the Newton-Schulz iteration itself seem to be `"some iterative methods for improving orthonormality" `_ (Kovarik, 1970) and `"an iterative algorithm for computing the best estimate of an orthogonal matrix" `_ (Björck & Bowie, 1971). To justify using the name "Newton-Schulz" for these iterations, we note that Higham used it in `these slides `_. The idea of graphically tuning the coefficients of the iteration to obtain certain performance characteristics is, to the best of my knowledge, our own original idea. -------------------------------------------------------------------------------- /assets/gpt-owt-context.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 10 27 | 28 | 1 29 | 30 | 31 | 10 32 | 0 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | learningrate 74 | 75 | 76 | 3 77 | 78 | 79 | 4 80 | 81 | 82 | 5 83 | 84 | 85 | 6 86 | 87 | 88 | 7 89 | 90 | 91 | 8 92 | testloss 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | contextlength 116 | 117 | 32 118 | 119 | 64 120 | 121 | 128 122 | 123 | 256 124 | 125 | 512 126 | 127 | 1024 128 | 129 | -------------------------------------------------------------------------------- /docs/source/faq.rst: -------------------------------------------------------------------------------- 1 | Frequently asked questions 2 | =========================== 3 | 4 | Feel free to reach out or start a `GitHub issue `_ if you have any questions about Modula. We'll post answers to any useful or common questions on this page. 5 | 6 | Conceptual questions 7 | ^^^^^^^^^^^^^^^^^^^^^ 8 | 9 | .. dropdown:: The gradient is a vector: how can a vector have a spectral norm? 10 | :icon: question 11 | 12 | An important mental jump in Modula is to think of the weights of our neural network as a list of tensors :math:`(\mathbf{W}_1, \dots \mathbf{W}_L)` where :math:`\mathbf{W}_k` is the weight tensor of layer :math:`k`. It then makes sense to think of the gradient of the loss :math:`\mathcal{L}` with respect to the :math:`k\text{th}` weight tensor :math:`\mathbf{W}_k` as itself being a tensor :math:`\nabla_{\mathbf{W_k}}\mathcal{L}` with the same shape as :math:`\mathbf{W}_k`. We can then meaningfully ask what is the operator norm of this gradient tensor. 13 | 14 | This contrasts with a common approach to optimization theory where the whole weight space is "flattened" into one big weight vector :math:`\mathbf{w}` with a corresponding gradient vector :math:`\nabla_\mathbf{w} \mathcal{L}`, thus "losing" the operator structure. 15 | 16 | .. dropdown:: Why does Adam beat SGD on transformer, and how does normalization fix SGD? 17 | :icon: question 18 | 19 | While some researchers `have challenged `_ the use of Adam in deep learning, Adam is certainly the optimizer of choice for training large language models, `performing much better `_ than SGD in practice. Still, it is not widely known *why* Adam is better than SGD. Here we aim to provide a mechanistic explanation of one of the main reasons. The basic idea is that there is no reason the raw gradients should have good relative sizes across layers. And a major thing that Adam does is to "rebalance" the update sizes across layers. 20 | 21 | Let's give a concrete example to see what we mean. Consider a machine learning model with a list of weight tensors :math:`\mathbf{w} = (\mathbf{W}_1, \dots \mathbf{W}_L)` and a loss function :math:`\mathcal{L}`. Then a vanilla gradient update is given by :math:`(\mathbf{W}_1, \dots \mathbf{W}_L) - \eta \times (\nabla_{\mathbf{W}_1}\mathcal{L}, \dots \nabla_{\mathbf{W}_L}\mathcal{L})` where :math:`\eta` is the global learning rate. Now, suppose that our neural network is a toy residual network with :math:`L` layers: 22 | 23 | .. math:: 24 | f(\mathbf{w} ;\mathbf{x}) := \mathbf{W}_L \left(1 + \frac{1}{L} \mathbf{W_{L-1}}\right) \dots \left(1 + \frac{1}{L} \mathbf{W_{2}}\right) \mathbf{W_1} \mathbf{x}. 25 | 26 | This toy network consists of "read-in" and "read-out" matrices :math:`\mathbf{W}_0` and :math:`\mathbf{W}_L` along with :math:`L-2` "residual" matrices each depressed by a factor of :math:`1/L`. These depression factors are included to give the model a better large depth limit---in Modula we advocate for :math:`1/L` depression factors, while the the inclusion of :math:`1/\sqrt{L}` depression factors is `standard in large language models `_. We do not include a nonlinearity in this toy model for simplicity. 27 | 28 | The point is that the depression factors---be they :math:`1/L` or :math:`1/\sqrt{L}`---also depress the gradients to the residual blocks by the same factor. So if one takes the depth :math:`L` large and uses vanilla gradient descent or SGD to train a transformer, one is essentially applying the update: 29 | 30 | .. math:: 31 | 32 | (\mathbf{W}_1, \mathbf{W}_2, \mathbf{W}_3, \dots \mathbf{W}_{L-2}, \mathbf{W}_{L-1}, \mathbf{W}_L) - \eta \times (\nabla_{\mathbf{W}_1}\mathcal{L}, 0, 0, \dots, 0, 0, \nabla_{\mathbf{W}_L}\mathcal{L}). 33 | 34 | In words: the inclusion of the depression factors kills the size of the updates to the residual blocks in comparison to the read-in and read-out layers in deep networks. If you use SGD to train such a model, depending on how you set the learning rate :math:`\eta`, you are stuck between severely under-training the middle layers or severely over-training the input and output layers. Adam largely fixes this issue by normalizing each update tensor individually and thus removing the effect of the depression factors. So, Adam is a form of gradient normalization! Modular normalization also automatically fixes this issue by rebalancing the size of the updates for any base optimizer. 35 | 36 | .. dropdown:: Why does modular normalization lead to learning rate transfer across scale? 37 | :icon: question 38 | 39 | By the definition of a "well-normed module" :math:`\mathsf{M}`, when weight updates :math:`\Delta \mathbf{w}` are normalized in the modular norm :math:`\|\cdot\|_\mathsf{M}` then updates :math:`\Delta \mathbf{y}` to the module output are well-behaved in the output norm :math:`\|\cdot\|_\mathcal{Y}`. We set up our actual architectures, including complicated models like GPT, to actually be well-normed independent of the scale of the architecture. A little bit more formally: 40 | 41 | 1. well-normed modules are one-Lipschitz in the modular norm, meaning :math:`\|\Delta \mathbf{y}\|_\mathcal{Y} \leq \|\Delta \mathbf{w}\|_\mathsf{M}`; 42 | 2. this inequality holds tightly when tensors in the network "align" during training, meaning that we may approximate :math:`\|\Delta \mathbf{y}\|_\mathcal{y} \approx \|\Delta \mathbf{w}\|_\mathsf{M}` in a fully aligned network; 43 | 3. therefore normalizing updates in the modular norm provides control on the change in outputs; 44 | 4. these statements are all independent of the size of the architecture. 45 | 46 | Since modular normalization works by recursively normalizing the weight updates to each submodule, these desirable properties extend to all submodules as well as the overall compound. 47 | 48 | .. dropdown:: What do we mean by "tensor alignment" in Modula? 49 | :icon: question 50 | 51 | In the guts of a neural network there can be found lots and lots of tensors. And sometimes these tensors like to multiply each other. For example, there are: 52 | 53 | - vector-vector products :math:`\mathbf{u}^\top\mathbf{v}` 54 | - matrix-vector products :math:`\mathbf{A}\mathbf{v}` 55 | - matrix-matrix products :math:`\mathbf{A}\mathbf{B}` 56 | - and so on... 57 | 58 | An important question is "how big are such tensor products inside a neural network?" In other words, if we know the size of the inputs to the product, can we predict the size of the product itself? 59 | 60 | Let's start with the simplest example of the vector-vector product, otherwise known as a friendly "dot product". Suppose we have two :math:`n` dimensional vectors :math:`\mathbf{u}` and :math:`\mathbf{v}` of known sizes :math:`\|\mathbf{u}\|_2` and :math:`\|\mathbf{v}\|_2`. Here the symbol :math:`\|\mathbf{\cdot}\|_2` denotes the "Euclidean length" or ":math:`\ell_2` norm" of the vectors. How large can the dot product be? 61 | Well, by the Cauchy-Schwarz inequality, we have that: 62 | 63 | .. math:: 64 | 65 | |\mathbf{u}^\top \mathbf{v}| \leq \|\mathbf{u}\|_2 \times \|\mathbf{v}\|_2. 66 | 67 | In words: the size of the dot product is limited by the size of its two inputs. What's more the Cauchy-Schwarz inequality is "tight", meaning that :math:`|\mathbf{u}^\top \mathbf{v}| = \|\mathbf{u}\|_2 \times \|\mathbf{v}\|_2`, when the two vectors :math:`\mathbf{u}` and :math:`\mathbf{v}` point in the same (or opposite) directions---when the two vectors "align". 68 | 69 | This idea of having an inequality that limits the size of a tensor product, which is tight under certain configurations of the input tensors, generalizes to higer-order forms of tensor product. For example, for the matrix-vector product :math:`\mathbf{A}\mathbf{v}` the relevant inequality is given by: 70 | 71 | .. math:: 72 | 73 | \|\mathbf{A} \mathbf{v}\|_2 \leq \|\mathbf{A}\|_* \times \|\mathbf{v}\|_2, 74 | 75 | where :math:`\|\cdot\|_*` is the matrix spectral norm. This inequality is tight when the vector :math:`\mathbf{v}` lies in the top singular subspace of the matrix :math:`\mathbf{A}`---when the matrix and vector "align". 76 | 77 | And for matrix-matrix products, we have the "sub-multiplicativity of the spectral norm": 78 | 79 | .. math:: 80 | 81 | \|\mathbf{A} \mathbf{B}\|_* \leq \|\mathbf{A}\|_* \times \|\mathbf{B}\|_*. 82 | 83 | We will say that this inequality is tight when the two matrices "align"---you get the idea! 84 | 85 | Why does any of this matter? Well for a neural network at initialization, some of these inequalities may be quite slack because tensors in the network are randomly oriented with respect to each other. But it is a central tenet of the Modula framework that after training has sufficiently "warmed up", the network will fall into a fully aligned state where all inequalities of the type mentioned in the section hold reasonably tightly, and may therefore be used to predict the size and scaling of various quantities in the network. 86 | 87 | .. admonition:: Other notions of alignment 88 | :class: seealso 89 | 90 | We have outlined a notion of alignment which captures whether or not a certain inequality governing a tensor product is tight. This is different to the notion of alignment measured in `Scaling Exponents Across Parameterizations and Optimizers `_ which `turns out to be coupled to the matrix stable rank `_. Essentially, the findings on alignment in that paper don't have an obvious bearing on the notion of alignment used in Modula. Large-scale empirical tests of alignment as we have described it are certainly a valuable direction for future work. 91 | 92 | .. dropdown:: Is there a unique and optimal way to parameterize an architecture? 93 | :icon: question 94 | 95 | The short answer is no: if you're careful, there is some freedom in how you can parameterize your architecture. With that said, there are some constraints that you can't really avoid if you want things to work well. And there are some "natural choices" which I think we may as well agree on at least to ease communication between researchers. 96 | 97 | A `LoRA layer `_ provides a really good setting to think about these points. Given a :math:`n \times r` matrix :math:`B` and an :math:`r \times n` matrix :math:`A`, a LoRA layer is just the matrix product :math:`B A`. Now if you're a `spectral-μP `_ afficionado, you'd know that the "right way" to scale these matrices is so that their initialization and updates have spectral norm proportional to :math:`\sqrt{\text{fan-out/fan-in}}`. Written out in full: 98 | 99 | - matrix :math:`B` and update :math:`\Delta B` have spectral norms :math:`\|B\|_*` and :math:`\|\Delta B\|_* \propto \sqrt{n / r}`, 100 | - matrix :math:`A` and update :math:`\Delta A` have spectral norms :math:`\|A\|_*` and :math:`\|\Delta A\|_* \propto \sqrt{r / n}`. 101 | 102 | However, these conditions are more restrictive than necessary. Because matrices are homogeneuous linear maps, in the product :math:`BA` we are free to scale up the matrix :math:`B` by any factor so long as we divide the matrix :math:`A` by the same factor. Nothing changes if we do this. In particular, if we scale :math:`B` by factor :math:`\sqrt{r/n}` and divide :math:`A` by this same factor we obtain new conditions: 103 | 104 | - matrix :math:`B` and update :math:`\Delta B` have spectral norms :math:`\|B\|_*` and :math:`\|\Delta B\|_* \propto 1`, 105 | - matrix :math:`A` and update :math:`\Delta A` have spectral norms :math:`\|A\|_*` and :math:`\|\Delta A\|_* \propto 1`. 106 | 107 | Using these new spectral scaling conditions will have exactly the same training dynamics. 108 | 109 | .. admonition:: Matters of precision 110 | :class: seealso 111 | 112 | When considering representing the weight entries in floating point, a difference may emerge between these two schemes. In particular, one scheme may lead to weight entries more easily representable in a low-precision floating point number format. Charlie Blake et al. consider exploiting this type of "scale symmetry" in `u-μP: The Unit-Scaled Maximal Update Parametrization `_. 113 | 114 | In summary, I hope that this section demonstrates that: 115 | 116 | 1. the conditions in the spectral-μP paper provide a sensible default way of scaling matrices which should work well in generic situations; 117 | 2. however, the conditions are not unique, and in specific cases you can modify the rules---so long as you know what you're doing; 118 | 3. you may want to take advantage of scale symmetries if you are interested in designing low-precision training algorithms. 119 | 120 | Related work 121 | ^^^^^^^^^^^^^ 122 | 123 | .. dropdown:: What is the relationship between Modula and spectral-μP? 124 | :icon: question 125 | 126 | In the `spectral-μP paper `_, we considered the problem of equipping individual layers---such as linear and embedding layers---with their "natural norm". Normalizing updates in this "natural norm" leads to learning rate transfer across the dimensions of that layer. You can see Modula as generalizing this approach to arbitrary compositions and concatenations of individual layers---i.e. neural nets. 127 | 128 | .. dropdown:: What is the relationship between Modula and Tensor Programs? 129 | :icon: question 130 | 131 | We pointed out in the section on `the science of scale <../history>`_ that Modula builds on an approach to learning rate transfer `that we first released `_ almost a year before `the first incarnation of μP `_. So I want to focus here on explaining the technical differences between Modula and `Tensor Programs `_. 132 | 133 | The main advantages of Modula over Tensor Programs are that: 134 | 135 | 1. **Modula is grounded in elementary math.** We show that learning rate transfer is essentially just the question of how to build neural nets with tight and non-dimensional Lipschitz estimates. The main ingredient is just bounding derivatives and tracking how derivative bounds behave under composition and concatenation. We do not employ limiting or probabilistic analyses. 136 | 2. **Modula theory is non-asymptotic.** The unifying thread through the Tensor Programs series of works is the study of neural network computation in limiting cases: infinite width, infinite depth, and so on. This means that the theory is encumbered by significant mathematical overhead, and one is often confronted with thorny technical questions---for example: `do width and depth limits commute? `_ In contrast, Modula is based on a completely non-asymptotic theory. It deals directly with the finite-sized neural networks that we actually use in practice, so you don't have to worry that certain technical details may be "lost in the limit". To show that this is not just talk, in our paper we `built a theory of an actual working transformer `_. 137 | 3. **Modula is more automatic.** In Modula, we automatically build a norm during construction of the computation graph that can be used to explicitly normalize weight updates taken from any base optimizer. The Tensor Programs approach essentially amounts to manually deriving a priori estimates on the size of this norm, and using these estimates to modify the SGD learning rate per layer. However, working out these prior estimates is quite a hairy procedure which seemingly does not always work, hence why later Tensor Programs papers `shift to modifying Adam updates `_. Adam updates are easier to deal with since they already impose a form of normalization on the gradients. Furthermore, the Tensor Programs calculations must be done by hand. The result is large tables of scaling rules, with tables of rules for different base optimizers (Adam versus SGD) and even tables for different matrix shapes (square versus wide rectangular versus skinny rectangular). 138 | 139 | 4. **Modula is easier to extend.** Ultimately, we hope that Modula---and more generally the idea of *metrized deep learning*---will inspire followup work on clean, simple and technically sound approaches to algorithm design in deep learning. We give some directions for future work towards the end of `our paper `_, and we believe it should be relatively easy to extend our approach to handle new modules types and new norms. To give an example, there is a natural extension of the linear module that equips the input and output spaces with the :math:`\ell_\infty` norm instead of the RMS norm, thereby inducing the :math:`\ell_\infty`--:math:`\ell_\infty` operator norm on the matrix space. While from the point of view of infinite width limits, it may not matter whether you are stabilizing infinity norms or RMS norms, we believe this sort of consideration might be quite interesting in practice. 140 | 141 | .. dropdown:: What is the relationship between Modula and AGD? 142 | :icon: question 143 | 144 | In part, Modula builds on the analysis from our previous paper on `automatic gradient descent `_. The AGD paper focused on building a majorize-minimize-style analysis of deep fully-connected networks. The surprising aspect of the AGD algorithm was that it could train various deep learning problems with no learning rate, weight decay, momentum or schedule hyperparameters. However, the training was slower and sometimes not quite as good as conventional training setups. 145 | 146 | The Modula paper, in contrast, shows how to modularize and automate the types of technical calculations done in the AGD paper. In Modula, we conduct these calculations to first and second order, since we came to believe that a full majorization is overly pessimistic, contributing to the slower training of AGD. And ultimately in the Modula experiments, we opted to use a linear decay learning rate schedule for its simplicity and high performance, rather than various automatic learning rate schedules that could be derived from the the Modula theory. 147 | 148 | I (Jeremy) still think an analogue of AGD that is also fast and performant might still be possible. It might involve combining Modula with ideas from people like Konstantin Mishchenko and Aaron Defazio such as `Prodigy `_ or `schedule-free optimizer `_. I think this is a great direction for future work. 149 | 150 | .. dropdown:: What is the relationship between Modula and Shampoo? 151 | :icon: question 152 | 153 | Actually no one asked this one, and I just thought about it for myself. But here goes... Consider a loss function :math:`\mathcal{L} : \mathbb{R}^{m \times n}\to\mathbb{R}`. In other words, we have a machine learning model whose weights are given by an :math:`m \times n` matrix :math:`\mathbf{W}`. Further, suppose that the loss is smooth in the sense that: 154 | 155 | .. math:: 156 | 157 | \mathcal{L}(\mathbf{W} + \mathbf{\Delta W}) \leq \mathcal{L}(\mathbf{W}) +\mathrm{trace}(\mathbf{G}^\top \mathbf{\Delta W}) + \frac{1}{2} \|\mathbf{\Delta W}\|_*^2. 158 | 159 | In words: the loss is "smooth in the spectral norm". This toy problem is interesting for us to think about since the modular norm on linear atomic modules *is* the spectral norm. The second term on the righthand side is the `Frobenius inner product `_ and :math:`\|\cdot\|_*` denotes the `spectral norm `_. We have adopted the shorthand :math:`\mathbf{G}` for the gradient of the loss :math:`\nabla_\mathbf{W} \mathcal{L}(\mathbf{W})`, and we suppose that the gradient admits the singular value decomposition :math:`\mathbf{G} = \mathbf{U}\mathbf{\Sigma}\mathbf{V}^\top`. 160 | 161 | If we minimise the righthand side of this inequality with respect to :math:`\mathbf{\Delta W}`, we find that the optimal step direction is given by :math:`\mathbf{\Delta W} \propto - \mathbf{U}\mathbf{V}^\top`. That is, we take the negative gradient and set all of its singular values to one. This direction "squeezes the most juice" out of the gradient under a spectral norm geometry. This is a somewhat classical observation. For instance, it appears in a 2015 paper on `stochastic spectral descent `_. Tim independently pointed this out to me in the course of the Modula project, and so did Laker Newhouse who is a talented undergrad at MIT. 162 | 163 | What stopped us experimenting further with this idea is that it's not obvious how to compute :math:`\mathbf{\Delta W} \propto - \mathbf{U}\mathbf{V}^\top` without computing SVDs, and SVDs are kind of expensive in PyTorch. But a cool realization I had recently is that there is another way to compute :math:`\mathbf{U}\mathbf{V}^\top`. In fact, it holds that: 164 | 165 | .. math:: 166 | \mathbf{U}\mathbf{V}^\top = (\mathbf{G}\mathbf{G}^\top)^{-\tfrac{1}{4}} \mathbf{G} (\mathbf{G}^\top \mathbf{G})^{-\tfrac{1}{4}}. 167 | 168 | Why is this interesting? Well, for one, that expression on the right-hand side is precisely the `Shampoo `_ preconditioner with the accumulation dropped. It suggests a new perspective on Shampoo as doing "steepest descent under the spectral norm". This is a squarely "first-order" interpretation, as opposed to the predominant way people seem to think of Shampoo as an "approximate second-order method". 169 | 170 | Another reason this could be interesting is that a lot of efficiencies developed for Shampoo could now be applied to `stochastic spectral descent `_ and in turn Modula linear modules. One of the coolest examples is something I found in `Rohan Anil's slides `_ on Shampoo. It's the idea that you can compute expressions like :math:`(\mathbf{G}\mathbf{G}^\top)^{-1/4}` using `Newton-Raphson iterations `_---a very different approach to taking SVDs. A classic paper on this topic is called `On the Computation of the Matrix k-th Root `_ by Slobodan Lakić. `I implemented Algorithm 1 from that paper as a gist `_, finding that it often gives significant speedups over the SVD, provided one is willing to tolerate some error. 171 | 172 | An extremely natural way to combine this idea with Modula is to write a new "ShampooLinear" atomic module which replaces the normalize function of our Linear atom with a zeroth matrix power. Jack Gallagher has started experimenting with this idea in `Modulax `_. 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | Modula package 181 | ^^^^^^^^^^^^^^^ 182 | 183 | .. dropdown:: The modular norm involves a max---why do I not see any maxes in the package? 184 | :icon: question 185 | 186 | Computing the modular norm involves evaluating lots of expressions of the form: 187 | 188 | .. math:: 189 | \| (\mathbf{w}_1, \mathbf{w}_2) \|_{\mathsf{M}} := \max ( p * \|\mathbf{w}_1\|_{\mathsf{M}_1} , q * \|\mathbf{w}_2\|_{\mathsf{M}_2}). 190 | 191 | 192 | So you might be surprised not to see lots of maxes in the package. This is because to normalize a vector :math:`(\mathbf{w}_1, \mathbf{w}_2)` we do not just compute :math:`(\mathbf{w}_1, \mathbf{w}_2) / \|(\mathbf{w}_1, \mathbf{w}_2)\|_\mathsf{M}`. Instead, we separately normalize both sub-vectors in order to "saturate" the max. That is, we send: 193 | 194 | .. math:: 195 | (\mathbf{w}_1, \mathbf{w}_2) \mapsto \left(\frac{\mathbf{w}_1}{p * \|\mathbf{w}_1\|_{\mathsf{M}_1}}, \frac{\mathbf{w}_2}{q * \|\mathbf{w}_2\|_{\mathsf{M}_2}} \right). 196 | 197 | In other words, we maximize the size of each subvector under the constraint that the full vector has unit modular norm. 198 | 199 | .. dropdown:: Is it necessary to use orthogonal intialization in Modula? 200 | :icon: question 201 | 202 | No. You could re-write the atomic modules to use Gaussian initialization if you wanted. The reason we choose to use orthogonal initialization is that it makes it much easier to get scaling right. This is because the spectral norm of any :math:`m \times n` random orthogonal matrix is always one. In contrast, the spectral norm of an :math:`m \times n` random Gaussian matrix depends on the dimensions :math:`m` and :math:`n` and also the entry-wise variance :math:`\sigma^2`, making it more difficult to properly set the initialization scale. In addition, orthogonal matrices have the benign property that all singular values are one. In Gaussian matrices, on the other hand, the average singular value and the max singular value are different, meaning that Gaussian matrices have more subtle numerical properties. 203 | 204 | .. dropdown:: Does Modula support weight sharing? 205 | :icon: question 206 | 207 | Not at present, although it would be possible to support this. 208 | 209 | Research philosophy 210 | ^^^^^^^^^^^^^^^^^^^^ 211 | 212 | .. dropdown:: Do I need to be a mathematical savant to contribute to research of this kind? 213 | :icon: question 214 | 215 | I don't think so. There are a lot of very technical people working in this field bringing with them some quite advanced tools from math and theoretical physics, and this is great. But in my experience it's usually the simpler and more elementary ideas that actually work in practice. I strongly believe that deep learning theory is still at the stage of model building. And I resonate with both Rahimi and Recht's call for `simple theorems and simple experiments `_ and George Dahl's call for `a healthy dose of skepticism `_ when evaluating claims in the literature. --------------------------------------------------------------------------------