├── mamba_blocks.py
├── cifar_10.py
├── shakespeare.py
└── README.md
/mamba_blocks.py:
--------------------------------------------------------------------------------
1 | class MambaBlock(nn.Module):
2 | def __init__(self, embed_dim, dropout_level=0):
3 | super().__init__()
4 |
5 | self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=4, expand=2)
6 | self.norm = nn.LayerNorm(embed_dim)
7 | self.dropout = nn.Dropout(dropout_level)
8 |
9 | def forward(self, x):
10 | x = self.norm(self.mamba(x) + x)
11 | return self.dropout(x)
12 |
13 |
14 | class MambaTower(nn.Module):
15 | def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=False, dropout=0):
16 | super().__init__()
17 | self.blocks = nn.Sequential(*[MambaBlock(embed_dim, dropout_level=dropout) for _ in range(n_layers)])
18 | self.global_pool = global_pool #for classification or other supervised learning.
19 |
20 | def forward(self, x):
21 | #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool
22 | out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1)
23 | return out
24 |
--------------------------------------------------------------------------------
/cifar_10.py:
--------------------------------------------------------------------------------
1 |
2 | #!pip install torchview torchmetrics einops wandb causal-conv1d==1.0.2 mamba-ssm
3 |
4 | # !git clone https://github.com/apapiu/generic_transformer.git
5 | # import sys
6 | # sys.path.append('generic_transformer/')
7 | from trainer import Trainer
8 | #from transformer_blocks import EncoderBlock, Tower, MLPSepConv, MLP
9 |
10 |
11 | from mamba_ssm import Mamba
12 | from tqdm import tqdm
13 | import os
14 | import torch
15 | import torch.nn as nn
16 | import wandb
17 | from einops import rearrange
18 | from einops.layers.torch import Rearrange
19 | import torchmetrics
20 | import torchvision
21 | import torchvision.transforms as transforms
22 | from mamba_blocks import MambaTower, MambaBlock
23 |
24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25 |
26 | class ImgClassifier(nn.Module):
27 | def __init__(self, patch_size=4, img_size=32, n_channels=3, embed_dim=256, n_layers=6, dropout=0):
28 | super().__init__()
29 |
30 | self.patch_size = patch_size
31 | self.img_size = img_size
32 | self.n_channels = 3
33 | self.embed_dim = embed_dim
34 | self.dropout = dropout
35 | self.n_layers = n_layers
36 | seq_len = int((self.img_size/self.patch_size)*((self.img_size/self.patch_size)))
37 | patch_dim = self.n_channels*self.patch_size*self.patch_size
38 |
39 | self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)',
40 | p1=self.patch_size, p2=self.patch_size)
41 |
42 | self.func = nn.Sequential(self.rearrange,
43 | nn.LayerNorm(patch_dim),
44 | nn.Linear(patch_dim, embed_dim),
45 | nn.LayerNorm(embed_dim),
46 | MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=True, dropout=dropout),
47 | nn.Linear(embed_dim, 10))
48 |
49 | def forward(self, x):
50 |
51 | return self.func(x)
52 |
53 |
54 | transform = transforms.Compose([
55 | transforms.ToTensor()
56 | ])
57 |
58 | train_aug_transform = transforms.Compose([
59 | transforms.RandomHorizontalFlip(),
60 | transforms.RandomCrop(32, padding=4),
61 | transforms.RandomRotation(15),
62 | transforms.ToTensor(),
63 | transforms.RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
64 | transforms.RandomPerspective(distortion_scale=0.5, p=0.5)
65 | ])
66 |
67 |
68 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
69 | download=True, transform=train_aug_transform)
70 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
71 | shuffle=True, num_workers=4, pin_memory=True)
72 |
73 | testset = torchvision.datasets.CIFAR10(root='./data', train=False,
74 | download=True, transform=transform)
75 | test_loader = torch.utils.data.DataLoader(testset, batch_size=256,
76 | shuffle=False, num_workers=4, pin_memory=True)
77 |
78 | #os.environ["WANDB_API_KEY"]='your_wandb_key'
79 | !wandb login
80 |
81 | block = 'mamba'
82 | n_layers = 6
83 | patch_size = 4
84 | img_size = 32
85 | embed_dim = 256
86 | dropout = 0.1
87 | n_layers = 6
88 | n_channels = 3
89 |
90 | weight_decay = 1e-5
91 | lr = 0.0003
92 | T_max = 30000
93 | n_epochs = 100
94 |
95 | config = {k: v for k, v in locals().items() if k in ['block', 'n_layers', 'patch_size', 'img_size', 'embed_dim',
96 | 'dropout', 'n_layer',
97 | 'lr', 'T_max', 'weight_decay']}
98 |
99 | model = ImgClassifier(patch_size, img_size, n_channels, embed_dim, n_layers, dropout)
100 | model = model.to(device)
101 | loss = nn.CrossEntropyLoss()
102 | val_metric = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
103 | trainer = Trainer(model, loss, lr=lr, T_max=T_max, weight_decay=weight_decay, wandb_log=True)
104 |
105 | wandb.init(
106 | project="cifar10_classification",
107 | config = config)
108 |
109 | #note this function does not work well with the mamba block:
110 | trainer.plot_architecture(train_loader, depth=6)
111 |
112 | wandb.save('model_graph_new.png')
113 | print(f'Num params {sum(p.numel() for p in model.parameters())}')
114 | trainer.train_loop(train_loader, test_loader, n_epochs=n_epochs, val_metric=val_metric)
115 | wandb.finish()
116 |
--------------------------------------------------------------------------------
/shakespeare.py:
--------------------------------------------------------------------------------
1 | from mamba_blocks import MambaTower, MambaBlock
2 | import torch.nn
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader, IterableDataset
7 | from itertools import cycle
8 | import numpy as np
9 | from tqdm import tqdm
10 | import lightning as L
11 | from pytorch_lightning.loggers import WandbLogger
12 | from torch.optim.lr_scheduler import CosineAnnealingLR
13 |
14 | from lightning.pytorch.callbacks import ModelCheckpoint
15 |
16 |
17 | class MambaGPT(nn.Module):
18 | def __init__(self, embed_dim, seq_len, n_layers, dropout):
19 | super().__init__()
20 |
21 | self.embed = nn.Embedding(vocab_size, embed_dim)
22 | self.tower = MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=False)
23 | self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim),
24 | nn.Linear(embed_dim, vocab_size))
25 |
26 | def forward(self, x):
27 | x = self.tower(self.embed(x))
28 | return self.out_proj(x)
29 |
30 | # class GPTmodel(nn.Module):
31 | # def __init__(self, embed_dim, seq_len, n_layers, dropout):
32 | # super().__init__()
33 | # self.embed = nn.Embedding(vocab_size, embed_dim)
34 |
35 | # self.tower = Tower(embed_dim, seq_len, n_layers, use_pos_embeddings=True,
36 | # dropout=dropout,
37 | # n_heads=4, n_class=1, mlp_multiplier=2,
38 | # is_causal=True, global_pool=False,
39 | # block_class=EncoderBlock, mlp_class=MLP)
40 |
41 | # self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim),
42 | # nn.Linear(embed_dim, vocab_size))
43 |
44 | # def forward(self, x):
45 | # x = self.tower(self.embed(x))
46 |
47 | # return self.out_proj(x)
48 |
49 |
50 | class SequenceGenerator(IterableDataset):
51 | def __init__(self, token_ids, seq_length, batch_size):
52 | self.token_ids = torch.tensor(token_ids)
53 | self.seq_length = seq_length
54 | self.batch_size = batch_size
55 | self.n_tokens = len(token_ids)
56 | self.indices = torch.arange(0, self.n_tokens - seq_length)
57 |
58 | def __iter__(self):
59 | self.indices = self.indices[torch.randperm(len(self.indices))]
60 | for i in range(0, len(self.indices), self.batch_size):
61 | batch_indices = self.indices[i:i+self.batch_size]
62 | X_batch = self.token_ids[batch_indices[:, None] + torch.arange(self.seq_length)]
63 | y_batch = self.token_ids[batch_indices[:, None] + torch.arange(1, self.seq_length + 1)]
64 | yield X_batch, y_batch
65 |
66 | class TrainerClass(L.LightningModule):
67 | def __init__(self, vocab_size, embed_dim, seq_length, n_heads, attention_layers, dropout, mlp_multiplier, lr, epsilon, max_steps):
68 | super(Transformer, self).__init__()
69 |
70 | #self.model = GPTmodel(embed_dim, seq_len=seq_length, n_layers=attention_layers, dropout=dropout)
71 | self.model = MambaGPT(embed_dim, seq_len=seq_length, n_layers=attention_layers, dropout=dropout)
72 |
73 | self.max_steps = max_steps
74 | self.loss_fn = nn.CrossEntropyLoss()
75 | self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=epsilon, weight_decay=1e-5)
76 |
77 | self.batch_val_losses = []
78 |
79 | def forward(self, x):
80 | return self.model(x)
81 |
82 | def training_step(self, batch, batch_idx):
83 | x, y = batch
84 | y_pred = self.forward(x)
85 | loss = self.loss_fn(y_pred.view(-1, y_pred.size(-1)), y.view(-1))
86 |
87 | if self.global_step % save_every_n_iterations == 0 and self.global_step>0:
88 | print('saving_model')
89 | checkpoint_path = f"model_checkpoint_{self.global_step}.pth"
90 | torch.save(self.state_dict(), checkpoint_path)
91 | wandb.save(checkpoint_path)
92 |
93 | wandb.log({"train_loss": loss}, step=self.global_step)
94 | return loss
95 |
96 | def validation_step(self, batch, batch_idx):
97 | x, y = batch
98 | y_pred = self.forward(x)
99 | loss = self.loss_fn(y_pred.view(-1, y_pred.size(-1)), y.view(-1))
100 | self.batch_val_losses.append(loss.item())
101 | return loss
102 |
103 | def on_validation_epoch_end(self):
104 |
105 | val_loss = np.array(self.batch_val_losses).mean()
106 | self.batch_val_losses = []
107 |
108 | wandb.log({"val_loss": val_loss}, step=self.global_step)
109 | wandb.log({"learning_rate": self.optimizer.param_groups[0]['lr']}, step=self.global_step)
110 |
111 | query = """A fox was in the forest and"""
112 | example = encode(query)
113 | gen = generate_text(example, model, nchar=196, k=5, one_char_at_a_time=False, end_on_zero=False)
114 | #text_table.add_data(self.global_step, gen)
115 |
116 | def configure_optimizers(self):
117 |
118 | scheduler = {
119 | 'scheduler': CosineAnnealingLR(self.optimizer, T_max=self.max_steps, eta_min=model.optimizer.param_groups[0]['lr']/10),
120 | 'interval': 'step',
121 | 'frequency': 1,
122 | 'strict': True,
123 | }
124 | return {'optimizer': self.optimizer, 'lr_scheduler': scheduler}
125 |
126 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mamba Small Benchmarks:
2 |
3 | Exploring the the [Mamba codebase](https://github.com/state-spaces/mamba) on small example datasets (CIFAR-10, Shakespeare character-level, etc.).
4 |
5 |
6 | Shere the paper below:
7 |
8 | Mamba: Linear-Time Sequence Modeling with Selective State Spaces
9 | Albert Gu*, Tri Dao*
10 | Paper: https://arxiv.org/abs/2312.00752
11 |
12 | **Note**: I am not by any means an expert at any of this. Currently this is just a first pass at getting something up and running. There most likely are ways to improve both the architecture and the speed of the mamba code.
13 |
14 | #### TLDR of first impressions:
15 |
16 | - **CIFAR-10 Classification**: The Mamba-based model slightly outperforms the Transformer ViT-like model (85% vs. 84% accuracy) on CIFAR-10 for models with similar # of params. However, the Mamba model is about 2x slower to train despite faster learning in terms of iterations.
17 |
18 | - **Shakespeare Character-Level Model**: Mamba shows quicker convergence and a slightly better validation loss (1.463 (lower than the example in nano-gpt which gets 1.4697)). However, it's more prone to overfitting, particularly in configurations without dropout.
19 |
20 | ## Stacking Mamba Layers:
21 | The Mamba architecture is a sequence-to-sequence model based on a state space model architecture. Based on my basic understanding of the original paper and the GitHub repository, the code below is a reasonable (although likely not optimal) way to utilize the Mamba architecture. The concept is simple: stack several Mamba layers with normalization and optionally dropout. There's no need to add positional encoding or masking.
22 |
23 |
24 | It's also worth noting that one can incorporate the Mamba layer into other architectures, for example, replacing self-attention or the FFN in a transformer with Mamba (see Mamba Architecture: Interleaving Blocks on [page 31](https://arxiv.org/pdf/2312.00752.pdf).
25 |
26 |
27 |
28 | ```python
29 | import torch.nn as nn
30 | from mamba_ssm import Mamba
31 |
32 | class MambaBlock(nn.Module):
33 | def __init__(self, embed_dim, dropout_level=0):
34 | super().__init__()
35 |
36 | self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=4, expand=2)
37 | self.norm = nn.LayerNorm(embed_dim)
38 | self.dropout = nn.Dropout(dropout_level)
39 |
40 | def forward(self, x):
41 | x = self.norm(self.mamba(x) + x)
42 | return self.dropout(x)
43 |
44 |
45 | class MambaTower(nn.Module):
46 | def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=False):
47 | super().__init__()
48 | self.blocks = nn.Sequential(*[MambaBlock(embed_dim) for _ in range(n_layers)])
49 | self.global_pool = global_pool #for classification or other supervised learning.
50 |
51 | def forward(self, x):
52 | #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool
53 | out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1)
54 | return out
55 | ```
56 |
57 | ## Cifar-10 Classification:
58 |
59 | We'll use the MambaTower class above as the backbone of a vision model on the patchified version of cifar-10 images.
60 |
61 | ### Setup:
62 |
63 | We compare the model above with a Transformer ViT-like model based on the same patches.
64 | Both models have the following config:
65 |
66 | - embed_dim = 256
67 | - 6 layers
68 | - the Transformer model has an FFN dim of 2*embed_dim (512) to maintain similar # of parameters between the two models.
69 | - patch size of 4 (so 64 patches of dimension 48) and various basic augmentation techniques (see the code).
70 |
71 | Here's the code for the setup - it's fairly straightforward (To get a ViT like model I replace the MambaTower with a stack of Transformer Encoders):
72 |
73 | ```python
74 | class ImgClassifier(nn.Module):
75 | def __init__(self, patch_size=4, img_size=32, n_channels=3, embed_dim=256, n_layers=6, dropout=0):
76 | super().__init__()
77 |
78 | seq_len = int((img_size/patch_size)*((img_size/patch_size)))
79 | patch_dim = n_channels*patch_size*patch_size
80 |
81 | self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)',
82 | p1=self.patch_size, p2=self.patch_size)
83 |
84 | self.func = nn.Sequential(self.rearrange,
85 | nn.LayerNorm(patch_dim),
86 | nn.Linear(patch_dim, embed_dim),
87 | nn.LayerNorm(embed_dim),
88 | MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=True),
89 | nn.Linear(embed_dim, 10))
90 |
91 | def forward(self, x):
92 | return self.func(x)
93 | ```
94 |
95 |
96 | ### Results:
97 | The two models perform comparably, with the Mamba-based model having a slight edge (85% accuracy vs. 84% accuracy on the CIFAR-10 test set). While the Mamba model learns "faster" in terms of iterations, it's about twice as slow to train (note that I am using the simple Mamba class - their LLM [example](https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py) looks more optimized but harder to read).
98 |
99 | Either way 85% accuracy on cifar-10 straight out of the box with no convolutions is not bad at all - so I was pretty impressed.
100 |
101 |
102 |
103 | https://api.wandb.ai/links/apapiu/00tsl03a
104 |
105 |
106 |
107 | ## Shakespeare Char Level Language Model:
108 |
109 | The paper has quite a few examples showcasing that mamba is better or equal to the best transformers recepie out there. Still I wanted to try it out on a small dataset so decided to try it out on the shakespeare dataset. I use the split and data setup found in the [nano-gpt)(https://github.com/karpathy/nanoGPT/tree/master/data/shakespeare_char) repo.
110 |
111 | Model setup: Embed dimension is 256 with 256 context window and transformer has a ffn dim of 512. Both models have roughly 2 million parameters. The code is again very simple:
112 |
113 | ```python
114 | class GPMamba(nn.Module):
115 | def __init__(self, embed_dim, seq_len, n_layers, dropout):
116 | super().__init__()
117 |
118 | self.embed = nn.Embedding(vocab_size, embed_dim)
119 | self.tower = MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=False)
120 | self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim),
121 | nn.Linear(embed_dim, vocab_size))
122 |
123 | def forward(self, x):
124 | x = self.tower(self.embed(x))
125 | return self.out_proj(x)
126 | ```
127 |
128 | Results: The mamba model does seems to converge faster (altough it's also more prone to severe overfitting see below). Mamba got a val loss of 1.463 (lower than the example in [nano-gpt](https://github.com/karpathy/nanoGPT/tree/master#:~:text=validation%20loss%20is-,1.4697,-.%20Based%20on%20the) which gets 1.4697).
129 |
130 |
131 |
132 | ### Overfitting:
133 | It looks like the mamba model is more likely to overfit and completely memorize the training data - especially without dropout. See below for a model with embed_dim = 512 and no dropout. Will need to explore this more.. also this is likely not an issue when training on larger datasets.
134 |
135 |
136 |
137 | ### Future ideas:
138 |
139 | - Explore scaling in terms of epoch time vs. sequence length on mamba vs. transformer
140 | - Use it for autoregressice pixel generation
141 | - Use it in a diffusion like model.
142 |
143 |
144 |
--------------------------------------------------------------------------------