├── mamba_blocks.py ├── cifar_10.py ├── shakespeare.py └── README.md /mamba_blocks.py: -------------------------------------------------------------------------------- 1 | class MambaBlock(nn.Module): 2 | def __init__(self, embed_dim, dropout_level=0): 3 | super().__init__() 4 | 5 | self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=4, expand=2) 6 | self.norm = nn.LayerNorm(embed_dim) 7 | self.dropout = nn.Dropout(dropout_level) 8 | 9 | def forward(self, x): 10 | x = self.norm(self.mamba(x) + x) 11 | return self.dropout(x) 12 | 13 | 14 | class MambaTower(nn.Module): 15 | def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=False, dropout=0): 16 | super().__init__() 17 | self.blocks = nn.Sequential(*[MambaBlock(embed_dim, dropout_level=dropout) for _ in range(n_layers)]) 18 | self.global_pool = global_pool #for classification or other supervised learning. 19 | 20 | def forward(self, x): 21 | #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool 22 | out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1) 23 | return out 24 | -------------------------------------------------------------------------------- /cifar_10.py: -------------------------------------------------------------------------------- 1 | 2 | #!pip install torchview torchmetrics einops wandb causal-conv1d==1.0.2 mamba-ssm 3 | 4 | # !git clone https://github.com/apapiu/generic_transformer.git 5 | # import sys 6 | # sys.path.append('generic_transformer/') 7 | from trainer import Trainer 8 | #from transformer_blocks import EncoderBlock, Tower, MLPSepConv, MLP 9 | 10 | 11 | from mamba_ssm import Mamba 12 | from tqdm import tqdm 13 | import os 14 | import torch 15 | import torch.nn as nn 16 | import wandb 17 | from einops import rearrange 18 | from einops.layers.torch import Rearrange 19 | import torchmetrics 20 | import torchvision 21 | import torchvision.transforms as transforms 22 | from mamba_blocks import MambaTower, MambaBlock 23 | 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | class ImgClassifier(nn.Module): 27 | def __init__(self, patch_size=4, img_size=32, n_channels=3, embed_dim=256, n_layers=6, dropout=0): 28 | super().__init__() 29 | 30 | self.patch_size = patch_size 31 | self.img_size = img_size 32 | self.n_channels = 3 33 | self.embed_dim = embed_dim 34 | self.dropout = dropout 35 | self.n_layers = n_layers 36 | seq_len = int((self.img_size/self.patch_size)*((self.img_size/self.patch_size))) 37 | patch_dim = self.n_channels*self.patch_size*self.patch_size 38 | 39 | self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', 40 | p1=self.patch_size, p2=self.patch_size) 41 | 42 | self.func = nn.Sequential(self.rearrange, 43 | nn.LayerNorm(patch_dim), 44 | nn.Linear(patch_dim, embed_dim), 45 | nn.LayerNorm(embed_dim), 46 | MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=True, dropout=dropout), 47 | nn.Linear(embed_dim, 10)) 48 | 49 | def forward(self, x): 50 | 51 | return self.func(x) 52 | 53 | 54 | transform = transforms.Compose([ 55 | transforms.ToTensor() 56 | ]) 57 | 58 | train_aug_transform = transforms.Compose([ 59 | transforms.RandomHorizontalFlip(), 60 | transforms.RandomCrop(32, padding=4), 61 | transforms.RandomRotation(15), 62 | transforms.ToTensor(), 63 | transforms.RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3)), 64 | transforms.RandomPerspective(distortion_scale=0.5, p=0.5) 65 | ]) 66 | 67 | 68 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 69 | download=True, transform=train_aug_transform) 70 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, 71 | shuffle=True, num_workers=4, pin_memory=True) 72 | 73 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 74 | download=True, transform=transform) 75 | test_loader = torch.utils.data.DataLoader(testset, batch_size=256, 76 | shuffle=False, num_workers=4, pin_memory=True) 77 | 78 | #os.environ["WANDB_API_KEY"]='your_wandb_key' 79 | !wandb login 80 | 81 | block = 'mamba' 82 | n_layers = 6 83 | patch_size = 4 84 | img_size = 32 85 | embed_dim = 256 86 | dropout = 0.1 87 | n_layers = 6 88 | n_channels = 3 89 | 90 | weight_decay = 1e-5 91 | lr = 0.0003 92 | T_max = 30000 93 | n_epochs = 100 94 | 95 | config = {k: v for k, v in locals().items() if k in ['block', 'n_layers', 'patch_size', 'img_size', 'embed_dim', 96 | 'dropout', 'n_layer', 97 | 'lr', 'T_max', 'weight_decay']} 98 | 99 | model = ImgClassifier(patch_size, img_size, n_channels, embed_dim, n_layers, dropout) 100 | model = model.to(device) 101 | loss = nn.CrossEntropyLoss() 102 | val_metric = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device) 103 | trainer = Trainer(model, loss, lr=lr, T_max=T_max, weight_decay=weight_decay, wandb_log=True) 104 | 105 | wandb.init( 106 | project="cifar10_classification", 107 | config = config) 108 | 109 | #note this function does not work well with the mamba block: 110 | trainer.plot_architecture(train_loader, depth=6) 111 | 112 | wandb.save('model_graph_new.png') 113 | print(f'Num params {sum(p.numel() for p in model.parameters())}') 114 | trainer.train_loop(train_loader, test_loader, n_epochs=n_epochs, val_metric=val_metric) 115 | wandb.finish() 116 | -------------------------------------------------------------------------------- /shakespeare.py: -------------------------------------------------------------------------------- 1 | from mamba_blocks import MambaTower, MambaBlock 2 | import torch.nn 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader, IterableDataset 7 | from itertools import cycle 8 | import numpy as np 9 | from tqdm import tqdm 10 | import lightning as L 11 | from pytorch_lightning.loggers import WandbLogger 12 | from torch.optim.lr_scheduler import CosineAnnealingLR 13 | 14 | from lightning.pytorch.callbacks import ModelCheckpoint 15 | 16 | 17 | class MambaGPT(nn.Module): 18 | def __init__(self, embed_dim, seq_len, n_layers, dropout): 19 | super().__init__() 20 | 21 | self.embed = nn.Embedding(vocab_size, embed_dim) 22 | self.tower = MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=False) 23 | self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim), 24 | nn.Linear(embed_dim, vocab_size)) 25 | 26 | def forward(self, x): 27 | x = self.tower(self.embed(x)) 28 | return self.out_proj(x) 29 | 30 | # class GPTmodel(nn.Module): 31 | # def __init__(self, embed_dim, seq_len, n_layers, dropout): 32 | # super().__init__() 33 | # self.embed = nn.Embedding(vocab_size, embed_dim) 34 | 35 | # self.tower = Tower(embed_dim, seq_len, n_layers, use_pos_embeddings=True, 36 | # dropout=dropout, 37 | # n_heads=4, n_class=1, mlp_multiplier=2, 38 | # is_causal=True, global_pool=False, 39 | # block_class=EncoderBlock, mlp_class=MLP) 40 | 41 | # self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim), 42 | # nn.Linear(embed_dim, vocab_size)) 43 | 44 | # def forward(self, x): 45 | # x = self.tower(self.embed(x)) 46 | 47 | # return self.out_proj(x) 48 | 49 | 50 | class SequenceGenerator(IterableDataset): 51 | def __init__(self, token_ids, seq_length, batch_size): 52 | self.token_ids = torch.tensor(token_ids) 53 | self.seq_length = seq_length 54 | self.batch_size = batch_size 55 | self.n_tokens = len(token_ids) 56 | self.indices = torch.arange(0, self.n_tokens - seq_length) 57 | 58 | def __iter__(self): 59 | self.indices = self.indices[torch.randperm(len(self.indices))] 60 | for i in range(0, len(self.indices), self.batch_size): 61 | batch_indices = self.indices[i:i+self.batch_size] 62 | X_batch = self.token_ids[batch_indices[:, None] + torch.arange(self.seq_length)] 63 | y_batch = self.token_ids[batch_indices[:, None] + torch.arange(1, self.seq_length + 1)] 64 | yield X_batch, y_batch 65 | 66 | class TrainerClass(L.LightningModule): 67 | def __init__(self, vocab_size, embed_dim, seq_length, n_heads, attention_layers, dropout, mlp_multiplier, lr, epsilon, max_steps): 68 | super(Transformer, self).__init__() 69 | 70 | #self.model = GPTmodel(embed_dim, seq_len=seq_length, n_layers=attention_layers, dropout=dropout) 71 | self.model = MambaGPT(embed_dim, seq_len=seq_length, n_layers=attention_layers, dropout=dropout) 72 | 73 | self.max_steps = max_steps 74 | self.loss_fn = nn.CrossEntropyLoss() 75 | self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=epsilon, weight_decay=1e-5) 76 | 77 | self.batch_val_losses = [] 78 | 79 | def forward(self, x): 80 | return self.model(x) 81 | 82 | def training_step(self, batch, batch_idx): 83 | x, y = batch 84 | y_pred = self.forward(x) 85 | loss = self.loss_fn(y_pred.view(-1, y_pred.size(-1)), y.view(-1)) 86 | 87 | if self.global_step % save_every_n_iterations == 0 and self.global_step>0: 88 | print('saving_model') 89 | checkpoint_path = f"model_checkpoint_{self.global_step}.pth" 90 | torch.save(self.state_dict(), checkpoint_path) 91 | wandb.save(checkpoint_path) 92 | 93 | wandb.log({"train_loss": loss}, step=self.global_step) 94 | return loss 95 | 96 | def validation_step(self, batch, batch_idx): 97 | x, y = batch 98 | y_pred = self.forward(x) 99 | loss = self.loss_fn(y_pred.view(-1, y_pred.size(-1)), y.view(-1)) 100 | self.batch_val_losses.append(loss.item()) 101 | return loss 102 | 103 | def on_validation_epoch_end(self): 104 | 105 | val_loss = np.array(self.batch_val_losses).mean() 106 | self.batch_val_losses = [] 107 | 108 | wandb.log({"val_loss": val_loss}, step=self.global_step) 109 | wandb.log({"learning_rate": self.optimizer.param_groups[0]['lr']}, step=self.global_step) 110 | 111 | query = """A fox was in the forest and""" 112 | example = encode(query) 113 | gen = generate_text(example, model, nchar=196, k=5, one_char_at_a_time=False, end_on_zero=False) 114 | #text_table.add_data(self.global_step, gen) 115 | 116 | def configure_optimizers(self): 117 | 118 | scheduler = { 119 | 'scheduler': CosineAnnealingLR(self.optimizer, T_max=self.max_steps, eta_min=model.optimizer.param_groups[0]['lr']/10), 120 | 'interval': 'step', 121 | 'frequency': 1, 122 | 'strict': True, 123 | } 124 | return {'optimizer': self.optimizer, 'lr_scheduler': scheduler} 125 | 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mamba Small Benchmarks: 2 | 3 | Exploring the the [Mamba codebase](https://github.com/state-spaces/mamba) on small example datasets (CIFAR-10, Shakespeare character-level, etc.). 4 | 5 | 6 | Shere the paper below: 7 | 8 | Mamba: Linear-Time Sequence Modeling with Selective State Spaces 9 | Albert Gu*, Tri Dao* 10 | Paper: https://arxiv.org/abs/2312.00752 11 | 12 | **Note**: I am not by any means an expert at any of this. Currently this is just a first pass at getting something up and running. There most likely are ways to improve both the architecture and the speed of the mamba code. 13 | 14 | #### TLDR of first impressions: 15 | 16 | - **CIFAR-10 Classification**: The Mamba-based model slightly outperforms the Transformer ViT-like model (85% vs. 84% accuracy) on CIFAR-10 for models with similar # of params. However, the Mamba model is about 2x slower to train despite faster learning in terms of iterations. 17 | 18 | - **Shakespeare Character-Level Model**: Mamba shows quicker convergence and a slightly better validation loss (1.463 (lower than the example in nano-gpt which gets 1.4697)). However, it's more prone to overfitting, particularly in configurations without dropout. 19 | 20 | ## Stacking Mamba Layers: 21 | The Mamba architecture is a sequence-to-sequence model based on a state space model architecture. Based on my basic understanding of the original paper and the GitHub repository, the code below is a reasonable (although likely not optimal) way to utilize the Mamba architecture. The concept is simple: stack several Mamba layers with normalization and optionally dropout. There's no need to add positional encoding or masking. 22 | 23 | 24 | It's also worth noting that one can incorporate the Mamba layer into other architectures, for example, replacing self-attention or the FFN in a transformer with Mamba (see Mamba Architecture: Interleaving Blocks on [page 31](https://arxiv.org/pdf/2312.00752.pdf). 25 | 26 | 27 | 28 | ```python 29 | import torch.nn as nn 30 | from mamba_ssm import Mamba 31 | 32 | class MambaBlock(nn.Module): 33 | def __init__(self, embed_dim, dropout_level=0): 34 | super().__init__() 35 | 36 | self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=4, expand=2) 37 | self.norm = nn.LayerNorm(embed_dim) 38 | self.dropout = nn.Dropout(dropout_level) 39 | 40 | def forward(self, x): 41 | x = self.norm(self.mamba(x) + x) 42 | return self.dropout(x) 43 | 44 | 45 | class MambaTower(nn.Module): 46 | def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=False): 47 | super().__init__() 48 | self.blocks = nn.Sequential(*[MambaBlock(embed_dim) for _ in range(n_layers)]) 49 | self.global_pool = global_pool #for classification or other supervised learning. 50 | 51 | def forward(self, x): 52 | #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool 53 | out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1) 54 | return out 55 | ``` 56 | 57 | ## Cifar-10 Classification: 58 | 59 | We'll use the MambaTower class above as the backbone of a vision model on the patchified version of cifar-10 images. 60 | 61 | ### Setup: 62 | 63 | We compare the model above with a Transformer ViT-like model based on the same patches. 64 | Both models have the following config: 65 | 66 | - embed_dim = 256 67 | - 6 layers 68 | - the Transformer model has an FFN dim of 2*embed_dim (512) to maintain similar # of parameters between the two models. 69 | - patch size of 4 (so 64 patches of dimension 48) and various basic augmentation techniques (see the code). 70 | 71 | Here's the code for the setup - it's fairly straightforward (To get a ViT like model I replace the MambaTower with a stack of Transformer Encoders): 72 | 73 | ```python 74 | class ImgClassifier(nn.Module): 75 | def __init__(self, patch_size=4, img_size=32, n_channels=3, embed_dim=256, n_layers=6, dropout=0): 76 | super().__init__() 77 | 78 | seq_len = int((img_size/patch_size)*((img_size/patch_size))) 79 | patch_dim = n_channels*patch_size*patch_size 80 | 81 | self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', 82 | p1=self.patch_size, p2=self.patch_size) 83 | 84 | self.func = nn.Sequential(self.rearrange, 85 | nn.LayerNorm(patch_dim), 86 | nn.Linear(patch_dim, embed_dim), 87 | nn.LayerNorm(embed_dim), 88 | MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=True), 89 | nn.Linear(embed_dim, 10)) 90 | 91 | def forward(self, x): 92 | return self.func(x) 93 | ``` 94 | 95 | 96 | ### Results: 97 | The two models perform comparably, with the Mamba-based model having a slight edge (85% accuracy vs. 84% accuracy on the CIFAR-10 test set). While the Mamba model learns "faster" in terms of iterations, it's about twice as slow to train (note that I am using the simple Mamba class - their LLM [example](https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py) looks more optimized but harder to read). 98 | 99 | Either way 85% accuracy on cifar-10 straight out of the box with no convolutions is not bad at all - so I was pretty impressed. 100 | 101 | image 102 | 103 | https://api.wandb.ai/links/apapiu/00tsl03a 104 | 105 | 106 | 107 | ## Shakespeare Char Level Language Model: 108 | 109 | The paper has quite a few examples showcasing that mamba is better or equal to the best transformers recepie out there. Still I wanted to try it out on a small dataset so decided to try it out on the shakespeare dataset. I use the split and data setup found in the [nano-gpt)(https://github.com/karpathy/nanoGPT/tree/master/data/shakespeare_char) repo. 110 | 111 | Model setup: Embed dimension is 256 with 256 context window and transformer has a ffn dim of 512. Both models have roughly 2 million parameters. The code is again very simple: 112 | 113 | ```python 114 | class GPMamba(nn.Module): 115 | def __init__(self, embed_dim, seq_len, n_layers, dropout): 116 | super().__init__() 117 | 118 | self.embed = nn.Embedding(vocab_size, embed_dim) 119 | self.tower = MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=False) 120 | self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim), 121 | nn.Linear(embed_dim, vocab_size)) 122 | 123 | def forward(self, x): 124 | x = self.tower(self.embed(x)) 125 | return self.out_proj(x) 126 | ``` 127 | 128 | Results: The mamba model does seems to converge faster (altough it's also more prone to severe overfitting see below). Mamba got a val loss of 1.463 (lower than the example in [nano-gpt](https://github.com/karpathy/nanoGPT/tree/master#:~:text=validation%20loss%20is-,1.4697,-.%20Based%20on%20the) which gets 1.4697). 129 | 130 | image 131 | 132 | ### Overfitting: 133 | It looks like the mamba model is more likely to overfit and completely memorize the training data - especially without dropout. See below for a model with embed_dim = 512 and no dropout. Will need to explore this more.. also this is likely not an issue when training on larger datasets. 134 | 135 | image 136 | 137 | ### Future ideas: 138 | 139 | - Explore scaling in terms of epoch time vs. sequence length on mamba vs. transformer 140 | - Use it for autoregressice pixel generation 141 | - Use it in a diffusion like model. 142 | 143 | 144 | --------------------------------------------------------------------------------