├── src ├── hyperspherical_descent.py ├── manifold_muon.py ├── msign.py └── main.py ├── README.md ├── LICENSE └── cover.svg /src/hyperspherical_descent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | @torch.no_grad() 4 | def hyperspherical_descent(W, G, eta=0.1): 5 | w = W.flatten() 6 | g = G.flatten() 7 | # Compute update direction 8 | a = g - w * torch.dot(w, g) 9 | a /= (a.norm() + 1e-12) 10 | # Apply update 11 | new_w = w - eta * a 12 | # Retract to the manifold 13 | new_w /= new_w.norm() 14 | # Restore the shape of the solution and return 15 | return new_w.reshape(W.shape) 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modular Manifolds 2 | 3 | ![](cover.svg) 4 | 5 | Supporting code for the blog post on modular manifolds. 6 | 7 | Find the blog post at https://thinkingmachines.ai/blog/modular-manifolds. 8 | 9 | ### Code structure 10 | 11 | ```text 12 | src/ 13 | ├── main.py # Entry point: training loop and CLI 14 | ├── msign.py # Matrix sign function via Polar-Express 15 | ├── manifold_muon.py # Manifold Muon update rule 16 | └── hyperspherical_descent.py # Hyperspherical descent update rule 17 | ``` 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Thinking Machines Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/manifold_muon.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from msign import msign 4 | 5 | @torch.no_grad() 6 | def manifold_muon(W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6): 7 | # Ensure that W and G are both tall matrices 8 | should_tranpose = W.shape[0] < W.shape[1] 9 | if should_tranpose: 10 | W = W.T 11 | G = G.T 12 | # Initialize the dual variable 13 | Lambda = -0.25 * (W.T @ G + G.T @ W) 14 | # Ascend on the dual problem to find the update direction A 15 | for step in range(steps): 16 | # Update the candidate direction A 17 | A = msign(G + 2 * W @ Lambda) 18 | # Measure deviation of A from the tangent space: 19 | H = W.T @ A + A.T @ W 20 | # Check the stopping criterion 21 | if torch.norm(H) / math.sqrt(H.numel()) < tol: 22 | break 23 | # Update the dual variable 24 | Lambda -= alpha * (1 - step / steps) * H 25 | # Descend on the primal problem 26 | new_W = W - eta * A 27 | # Retract to the manifold 28 | new_W = msign(new_W) 29 | # Restore the shape of the solution and return 30 | return new_W.T if should_tranpose else new_W 31 | -------------------------------------------------------------------------------- /src/msign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ABC_LIST: list[tuple[float, float, float]] = [ 4 | (8.28721201814563, -23.595886519098837, 17.300387312530933), 5 | (4.107059111542203, -2.9478499167379106, 0.5448431082926601), 6 | (3.9486908534822946, -2.908902115962949, 0.5518191394370137), 7 | (3.3184196573706015, -2.488488024314874, 0.51004894012372), 8 | (2.300652019954817, -1.6689039845747493, 0.4188073119525673), 9 | (1.891301407787398, -1.2679958271945868, 0.37680408948524835), 10 | (1.8750014808534479, -1.2500016453999487, 0.3750001645474248), 11 | (1.875, -1.25, 0.375), 12 | ] 13 | 14 | # safety factor for numerical stability (but exclude last polynomial) 15 | ABC_LIST_STABLE: list[tuple[float, float, float]] = [ 16 | (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1] 17 | ] + [ABC_LIST[-1]] 18 | 19 | 20 | @torch.no_grad() 21 | def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor: 22 | """ 23 | Polar Express algorithm for the matrix sign function: 24 | https://arxiv.org/abs/2505.16932 25 | """ 26 | assert G.ndim >= 2 27 | should_transpose: bool = G.size(-2) > G.size(-1) 28 | 29 | x = G.bfloat16() 30 | if should_transpose: 31 | x = x.mT 32 | 33 | x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01 34 | for step in range(steps): 35 | a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1] 36 | s = x @ x.mT 37 | # goal is to compute x = a x + b S x + c S^2 x 38 | # we can break this up into: x = (a I + (b I + c S) S) x 39 | y = c * s 40 | y.diagonal(dim1=-2, dim2=-1).add_(b) 41 | y = y @ s 42 | y.diagonal(dim1=-2, dim2=-1).add_(a) 43 | x = y @ x 44 | 45 | if should_transpose: 46 | x = x.mT 47 | x = torch.nan_to_num(x) 48 | return x.float() -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from hyperspherical_descent import hyperspherical_descent 11 | from manifold_muon import manifold_muon 12 | from torch.optim import AdamW 13 | from torch.utils.data import DataLoader 14 | 15 | transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)) 18 | ]) 19 | 20 | train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True) 21 | test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform, download=True) 22 | 23 | train_loader = DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True) 24 | test_loader = DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False) 25 | 26 | 27 | class MLP(nn.Module): 28 | def __init__(self): 29 | super(MLP, self).__init__() 30 | self.fc1 = nn.Linear(32 * 32 * 3, 128, bias=False) 31 | self.fc2 = nn.Linear(128, 64, bias=False) 32 | self.fc3 = nn.Linear(64, 10, bias=False) 33 | 34 | def forward(self, x): 35 | x = x.view(-1, 32 * 32 * 3) 36 | x = torch.relu(self.fc1(x)) 37 | x = torch.relu(self.fc2(x)) 38 | x = self.fc3(x) 39 | return x 40 | 41 | 42 | def train(epochs, initial_lr, update, wd): 43 | model = MLP().cuda() 44 | criterion = nn.CrossEntropyLoss() 45 | 46 | if update == AdamW: 47 | optimizer = AdamW(model.parameters(), lr=initial_lr, weight_decay=wd) 48 | else: 49 | assert update in [manifold_muon, hyperspherical_descent] 50 | optimizer = None 51 | 52 | steps = epochs * len(train_loader) 53 | step = 0 54 | 55 | if optimizer is None: 56 | # Project the weights to the manifold 57 | for p in model.parameters(): 58 | p.data = update(p.data, torch.zeros_like(p.data), eta=0) 59 | 60 | epoch_losses = [] 61 | epoch_times = [] 62 | 63 | for epoch in range(epochs): 64 | start_time = time.time() 65 | running_loss = 0.0 66 | for i, (images, labels) in enumerate(train_loader): 67 | images = images.cuda() 68 | labels = labels.cuda() 69 | 70 | # Forward pass 71 | outputs = model(images) 72 | loss = criterion(outputs, labels) 73 | 74 | # Backward and optimize 75 | model.zero_grad() 76 | loss.backward() 77 | lr = initial_lr * (1 - step / steps) 78 | with torch.no_grad(): 79 | if optimizer is None: 80 | for p in model.parameters(): 81 | p.data = update(p, p.grad, eta=lr) 82 | else: 83 | for param_group in optimizer.param_groups: 84 | param_group["lr"] = lr 85 | optimizer.step() 86 | step += 1 87 | 88 | running_loss += loss.item() 89 | if (i+1) % 100 == 0: 90 | print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}") 91 | 92 | end_time = time.time() 93 | epoch_loss = running_loss / len(train_loader) 94 | epoch_time = end_time - start_time 95 | epoch_losses.append(epoch_loss) 96 | epoch_times.append(epoch_time) 97 | print(f"Epoch {epoch+1}, Loss: {epoch_loss}, Time: {epoch_time:.4f} seconds") 98 | return model, epoch_losses, epoch_times 99 | 100 | 101 | def eval(model): 102 | # Test the model 103 | model.eval() 104 | with torch.no_grad(): 105 | accs = [] 106 | for dataloader in [test_loader, train_loader]: 107 | correct = 0 108 | total = 0 109 | for images, labels in dataloader: 110 | images = images.cuda() 111 | labels = labels.cuda() 112 | outputs = model(images) 113 | _, predicted = torch.max(outputs.data, 1) 114 | total += labels.size(0) 115 | correct += (predicted == labels).sum().item() 116 | accs.append(100 * correct / total) 117 | 118 | print(f"Accuracy of the network on the {len(test_loader.dataset)} test images: {accs[0]} %") 119 | print(f"Accuracy of the network on the {len(train_loader.dataset)} train images: {accs[1]} %") 120 | return accs 121 | 122 | def weight_stats(model): 123 | singular_values = [] 124 | norms = [] 125 | for p in model.parameters(): 126 | u,s,v = torch.svd(p) 127 | singular_values.append(s) 128 | norms.append(p.norm()) 129 | return singular_values, norms 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser(description="Train a model on CIFAR-10.") 134 | parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train for.") 135 | parser.add_argument("--lr", type=float, default=0.1, help="Initial learning rate.") 136 | parser.add_argument("--update", type=str, default="manifold_muon", choices=["manifold_muon", "hyperspherical_descent", "adam"], help="Update rule to use.") 137 | parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator.") 138 | parser.add_argument("--wd", type=float, default=0.0, help="Weight decay for AdamW.") 139 | args = parser.parse_args() 140 | 141 | # determinism flags 142 | torch.manual_seed(args.seed) 143 | torch.cuda.manual_seed_all(args.seed) 144 | torch.backends.cudnn.deterministic = True 145 | torch.backends.cudnn.benchmark = False 146 | 147 | update_rules = { 148 | "manifold_muon": manifold_muon, 149 | "hyperspherical_descent": hyperspherical_descent, 150 | "adam": AdamW 151 | } 152 | 153 | update = update_rules[args.update] 154 | 155 | print(f"Training with: {args.update}") 156 | print(f"Epochs: {args.epochs} --- LR: {args.lr}", f"--- WD: {args.wd}" if args.update == "adam" else "") 157 | 158 | model, epoch_losses, epoch_times = train( 159 | epochs=args.epochs, 160 | initial_lr=args.lr, 161 | update=update, 162 | wd=args.wd 163 | ) 164 | test_acc, train_acc = eval(model) 165 | singular_values, norms = weight_stats(model) 166 | 167 | results = { 168 | "epochs": args.epochs, 169 | "lr": args.lr, 170 | "seed": args.seed, 171 | "wd": args.wd, 172 | "update": args.update, 173 | "epoch_losses": epoch_losses, 174 | "epoch_times": epoch_times, 175 | "test_acc": test_acc, 176 | "train_acc": train_acc, 177 | "singular_values": singular_values, 178 | "norms": norms 179 | } 180 | 181 | filename = f"update-{args.update}-lr-{args.lr}-wd-{args.wd}-seed-{args.seed}.pkl" 182 | os.makedirs("results", exist_ok=True) 183 | 184 | print(f"Saving results to {os.path.join("results", filename)}") 185 | with open(os.path.join("results", filename), "wb") as f: 186 | pickle.dump(results, f) 187 | print(f"Results saved to {os.path.join("results", filename)}") 188 | -------------------------------------------------------------------------------- /cover.svg: -------------------------------------------------------------------------------- 1 | 77 2 | 777777 3 | 777777777777 777777777777777777 7777777777777777777777777 777777777777777777777777777777 77777777777777777777777777777777777777 77777777777777777777777777777777777777777777 177777777777777777777777777777777777777777777772 445377777777777777777777777777777777777777775800 444445377777777777777777777777777777777715000000 444444445277777777777777777777777777716800000000 444444444444217777777777777777777736000000000000 444444444444444517777777777777739000000000000000 444444444444444444517777777728000000000000000000 444444444444444444444537748000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 44444444444444444444460000000000000000000000 44444444444444444460000000000000000000 444444444444446000000000000000 44444444444460000000000000 4 | 44444444446000000000 5 | 444446000000 6 | 446800 7 | 40444477 8 | 777777 9 | 777777777777 777777777777777777 7777777777777777777777777 777777777777777777777777777777 77777777777777777777777777777777777777 77777777777777777777777777777777777777777777 177777777777777777777777777777777777777777777772 445377777777777777777777777777777777777777775800 444445377777777777777777777777777777777715000000 444444445277777777777777777777777777716800000000 444444444444217777777777777777777736000000000000 444444444444444517777777777777739000000000000000 444444444444444444517777777728000000000000000000 444444444444444444444537748000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 44444444444444444444460000000000000000000000 44444444444444444460000000000000000000 444444444444446000000000000000 44444444444460000000000000 10 | 44444444446000000000 11 | 444446000000 12 | 446800 13 | 40444477 14 | 777777 15 | 777777777777 777777777777777777 7777777777777777777777777 777777777777777777777777777777 77777777777777777777777777777777777777 77777777777777777777777777777777777777777777 177777777777777777777777777777777777777777777772 445377777777777777777777777777777777777777775800 444445377777777777777777777777777777777715000000 444444445277777777777777777777777777716800000000 444444444444217777777777777777777736000000000000 444444444444444517777777777777739000000000000000 444444444444444444517777777728000000000000000000 444444444444444444444537748000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 44444444444444444444460000000000000000000000 44444444444444444460000000000000000000 444444444444446000000000000000 44444444444460000000000000 16 | 44444444446000000000 17 | 444446000000 18 | 446800 19 | 404444 20 | --------------------------------------------------------------------------------