├── .gitignore ├── LICENSE ├── README.md ├── bllama ├── __init__.py ├── bitlinear.py ├── bllama.py ├── config.py ├── quantization.py ├── transformer.py └── utils.py ├── notes.md └── utils ├── callbacks └── weight_distribution_callback.py ├── datamodules └── text_data_module.py ├── datasets ├── gepeto.py └── shakespeare.py └── images └── bllama_quantization.png /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | test.ipynb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Rafael Celente 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## bLLaMa 2 | 3 | bLLaMa is a [b1.58](https://arxiv.org/pdf/2402.17764v1.pdf) LLaMa model. 4 | 5 | ### Set up 6 | 7 | Both the module configuration dataclass and the module itself are contained on `bllama`. By default, the configuration is a 1.7B model, which can be found on `config.py`. 8 | 9 | ```python 10 | from bllama import bLlamaConfig, bLlama 11 | 12 | config = bLlamaConfig() 13 | bllm = bLlama(config) 14 | ``` 15 | 16 | #### Training 17 | 18 | bLLaMa is built as a Lightning module, so you may pass `pl.Trainer`s and `pl.LightningDataModules` for training tasks. To faciliate, some examples of datasets the corresponding datamodules are given on `utils`. 19 | 20 | ```python 21 | from transformers import LlamaTokenizer 22 | from utils import ShakespeareDataset, TextDataModule, GepetoDataset 23 | 24 | tokenizer = LlamaTokenizer.from_pretrained("fxmarty/tiny-llama-fast-tokenizer") 25 | dataset = ShakespeareDataset.from_file("/path/to/shakespeare/input.txt", tokenizer=tokenizer, max_length=1024) 26 | dataloader = TextDataModule(dataset, batch_size=config.batch_size, train_test_split=0.9) 27 | ``` 28 | 29 | To setup a trainer, you may pass a `pl.Trainer` or a manually setup a training run. 30 | 31 | ``` 32 | import pytorch_lightning as pl 33 | 34 | bllm_trainer = pl.Trainer( 35 | accelerator="gpu", 36 | max_epochs=1, 37 | ) 38 | 39 | bllm_trainer.fit(bllm, dataloader) 40 | ``` 41 | 42 | #### Inference 43 | 44 | The BitLinear layers of bLLaMa have 2 modes, one for training (`fp32`) and one for quantized inference (`int8`). To perform quantized inference, the weights have to be offline-quantized. bLLaMa has a built-in method to quantize the BitLinear modules for inference: 45 | 46 | ![bLLaMa quantization](utils/images/bllama_quantization.png) 47 | 48 | After quantization, the model can then generate with the `generate` method. 49 | 50 | ```python 51 | bllm.generate(prompt="In a shocking turn of events,", tokenizer=tokenizer, max_len=200, do_sample=False, top_k=3, repetition_penalty=2) 52 | ``` 53 | 54 | Full precision inference is also allowed, but the model will promptly caution all the BitLinear layers that are not quantized. 55 | 56 | ### TODOS 57 | 58 | - [x] Inference with `int8` BitLinear quantization 59 | - [ ] Custom GEMM for lowbit inference 60 | - [ ] KV Cache 61 | - [ ] Model sharding / Parallel training optimizations 62 | - [ ] Custom model saving for quantized tensors 63 | - [ ] Full 1.7B model training 64 | 65 | ### Resources and inspiration 66 | 67 | - [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/pdf/2402.17764v1.pdf) 68 | - [The Era of 1-bit LLMs: Training tips and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf) 69 | - [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/pdf/2302.13971v1.pdf) 70 | - [Official LLaMa implementation](https://github.com/meta-llama/llama/blob/main/llama/model.py#L80) 71 | - [joey00072's BitNet implementation](https://github.com/joey00072/ohara/blob/master/experiments/bitnet/bitnet.py) 72 | 73 | ### Notes on training 74 | 75 | This repo contains only the implementation and training code for bLLaMa. No (relevant) model checkpoints or model weights have been yet produced as it requires significantly more compute than I have at my disposal at the moment. 76 | 77 | Nonetheless, small training runs using a 1.7B model were done to assess training performance. The training runs were done with ~15M tokens from [wikitext](https://huggingface.co/datasets/wikitext) and [minipile](https://huggingface.co/datasets/JeanKaddour/minipile). 78 | 79 | Using a single NVIDIA A6000, due to the VRAM bottleneck, the batch size used was 1. This may indicate some issues with memory usage and/or optimization opportunities, as the full `fp32` model alone uses ~41GB for training. 80 | 81 | The total test training time was 5 hours. Based on this, we can extrapolate that, with the current configuration, to achieve a Chinchilla-optimal 1.7B model, it would take ~472 hours of training on a A6000. -------------------------------------------------------------------------------- /bllama/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import bLlamaConfig 2 | from .bllama import bLlama 3 | from .transformer import Transformer -------------------------------------------------------------------------------- /bllama/bitlinear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .utils import RMSNorm 4 | from .quantization import weight_quant, activation_quant, activation_post_quant 5 | from typing import Optional 6 | 7 | class BitLinear(nn.Linear): 8 | def __init__(self, in_features, out_features, bias=False): 9 | super(BitLinear, self).__init__(in_features, out_features, bias) 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.rms_norm = RMSNorm(in_features) 13 | self.weight_scale = None 14 | nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') 15 | 16 | def forward(self, x, inference: Optional[bool]=False): 17 | w = self.weight 18 | if not inference: 19 | x_norm = self.rms_norm(x) 20 | x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() 21 | w_quant = w + (weight_quant(w) - w).detach() 22 | return F.linear(x_quant, w_quant, self.bias) 23 | else: 24 | # in case of inference, the weights are offline quantized to int8, so we assume w = w_quant 25 | x_norm = self.rms_norm(x) 26 | x_quant, x_scale = activation_post_quant(x_norm) 27 | w_scale = self.weight_scale 28 | # according to the paper, this linear layer may have to be replaced by a gemm_lowbit_kernel, 29 | # but no such kernel is available, nor any directions on how to implement it, so we'll just use linear 30 | return F.linear(x_quant, w.float(), self.bias) / (x_scale * w_scale) -------------------------------------------------------------------------------- /bllama/bllama.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn.functional as F 4 | from .config import bLlamaConfig, trainerConfig 5 | from .transformer import Transformer 6 | from .quantization import quantize_weights_to_int8 7 | from typing import Tuple, Optional 8 | from .bitlinear import BitLinear 9 | from transformers import get_cosine_schedule_with_warmup 10 | from torchmetrics import Metric 11 | from torchmetrics.aggregation import SumMetric 12 | 13 | class bLlama(pl.LightningModule): 14 | def __init__( 15 | self, 16 | config: bLlamaConfig, 17 | trainer_config: Optional[trainerConfig] = trainerConfig(), 18 | metric: Optional[Metric] = None, 19 | log_tokens_progression: Optional[bool] = False, 20 | ): 21 | super().__init__() 22 | self.config = config 23 | self.trainer_config = trainer_config 24 | self.model = Transformer(config) 25 | self.metric = metric 26 | self.metric_name = self.metric.__class__.__name__ if self.metric is not None else None 27 | self.token_aggregator = None 28 | if log_tokens_progression: 29 | self.log_tokens_progression = log_tokens_progression 30 | self.token_aggregator = SumMetric() 31 | 32 | 33 | def quantize_weights_to_ternary(self, verbose=True): 34 | """ 35 | Quantize all BitLinear layers to ternary in int8. 36 | """ 37 | if verbose: 38 | size_before_quant = self.get_model_size_in_bytes() 39 | print(f'Mode size before quantization: {size_before_quant} MB') 40 | for name, layer in self.model.named_modules(): 41 | if isinstance(layer, BitLinear): 42 | for k, v in layer.state_dict().items(): 43 | if 'weight' in k and 'norm' not in k: 44 | w_quant, scale = quantize_weights_to_int8(v) 45 | layer.weight.requires_grad = False 46 | layer.weight.data = w_quant 47 | layer.weight_scale = scale 48 | if verbose: 49 | size_after_quant = self.get_model_size_in_bytes() 50 | print(f'Mode size after quantization: {size_after_quant} MB') 51 | print(f'Quantization ratio: {size_after_quant / size_before_quant}') 52 | 53 | def get_model_size_in_bytes(self): 54 | param_size = 0 55 | for param in self.model.parameters(): 56 | param_size += param.nelement() * param.element_size() 57 | 58 | size_all_mb = (param_size) / 1024**2 59 | return size_all_mb 60 | 61 | # FIXME: This is overly complicated, but works for now. 62 | def generate( 63 | self, 64 | prompt: str, 65 | tokenizer: callable, 66 | max_len: int = 50, 67 | do_sample: bool = True, 68 | temperature: float = 0.1, 69 | top_k: int = 0, 70 | repetition_penalty : float = 1.0, 71 | num_return_sequences : int = 1, 72 | device : str = "cuda", 73 | ): 74 | assert hasattr(tokenizer, 'encode'), f"Tokenizer {tokenizer.__name__} must have an encode method" 75 | assert hasattr(tokenizer, 'decode'), f"Tokenizer {tokenizer.__name__} must have an decode method" 76 | prompt_tokens = tokenizer.encode(prompt) 77 | self.model = self.model.to(device) 78 | self.model.eval() 79 | 80 | # for comparison, allow inference with unquantized weights 81 | inference_with_quantized = True 82 | unquantized_layers = [name for name, layer in self.model.named_modules() if isinstance(layer, BitLinear) and layer.weight_scale is None] 83 | if len(unquantized_layers): 84 | print(f'WARNING: The following bitlinear layers are not quantized: {unquantized_layers}. Inference will be done with unquantized weights.') 85 | inference_with_quantized = False 86 | 87 | def top_k_filtering(logits, top_k=0, filter_value=-float('Inf')): 88 | top_k = min(top_k, logits.size(-1)) 89 | if top_k > 0: 90 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 91 | logits[indices_to_remove] = filter_value 92 | 93 | return logits 94 | 95 | for _ in range(num_return_sequences): 96 | generated = torch.tensor([prompt_tokens]) 97 | generated = generated.to(device) 98 | 99 | for _ in range(max_len): 100 | with torch.no_grad(): 101 | outputs = self.model(generated, inference=inference_with_quantized) 102 | next_token_logits = outputs[:, -1, :] 103 | for token in set(generated[0].tolist()): 104 | next_token_logits[:, token] /= repetition_penalty 105 | next_token_logits = next_token_logits / temperature 106 | filtered_logits = top_k_filtering(next_token_logits, top_k=top_k) 107 | if do_sample: 108 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 109 | else: 110 | next_token = torch.argmax(F.softmax(filtered_logits, dim=-1), dim=-1, keepdims=True) 111 | generated = torch.cat((generated, next_token), dim=-1) 112 | 113 | result = generated[0].tolist() 114 | text = tokenizer.decode(result) 115 | return text 116 | 117 | def training_step( 118 | self, 119 | batch: Tuple[torch.Tensor], 120 | batch_idx: int, 121 | ): 122 | x, y = batch 123 | logits = self.model(x) 124 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) 125 | self.log("train_loss", loss, prog_bar=True) 126 | if self.metric is not None: 127 | self.metric(logits, y) 128 | self.log(f"train_{self.metric_name}", self.metric, prog_bar=True) 129 | if self.log_tokens_progression: 130 | self.token_aggregator.update(x.shape[0] * x.shape[1]) 131 | self.log("tokens", self.token_aggregator.compute(), prog_bar=True) 132 | return loss 133 | 134 | def validation_step( 135 | self, 136 | batch: Tuple[torch.Tensor], 137 | batch_idx: int): 138 | x, y = batch 139 | logits = self.model(x) 140 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) 141 | self.log("val_loss", loss, prog_bar=True) 142 | if self.metric is not None: 143 | self.metric(logits, y) 144 | self.log(f"val_{self.metric_name}", self.metric, prog_bar=True) 145 | return loss 146 | 147 | def test_step( 148 | self, 149 | batch: Tuple[torch.Tensor], 150 | batch_idx: int 151 | ): 152 | x, y = batch 153 | logits = self.model(x) 154 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) 155 | self.log("test_loss", loss, prog_bar=True) 156 | if self.metric is not None: 157 | self.metric(logits, y) 158 | self.log(f"test_{self.metric_name}", self.metric, prog_bar=True) 159 | return loss 160 | 161 | def configure_optimizers(self): 162 | optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0.01) 163 | cosine_scheduler = get_cosine_schedule_with_warmup( 164 | optimizer, 165 | num_warmup_steps=self.trainer_config.num_warmup_steps, 166 | num_training_steps=self.trainer_config.max_steps 167 | ) 168 | lr_scheduler_config = {'scheduler': cosine_scheduler, 'interval': 'step'} 169 | return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_config} -------------------------------------------------------------------------------- /bllama/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class bLlamaConfig: 5 | vocab_size: int = 32000 6 | seq_len: int = 1024 7 | hidden_size: int = 2048 8 | num_heads: int = 16 9 | num_layers: int = 24 10 | dropout: float = 0.0 11 | bias: bool = False 12 | 13 | @dataclass 14 | class trainerConfig: 15 | max_steps: int = 10000 16 | gpus: int = 1 17 | precision: int = 16 18 | gradient_clip_val: float = 1.0 19 | check_val_every_n_epoch: int = 1 20 | log_every_n_steps: int = 1 21 | limit_val_batches: float = 1.0 22 | limit_train_batches: float = 1.0 23 | num_warmup_steps: int = 1000 -------------------------------------------------------------------------------- /bllama/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def weight_quant(w: torch.Tensor): 4 | """ 5 | Quantize a set of weights to ternary based on the mean of the absolute values described by 6 | https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf 7 | 8 | This doesn't change the type of the weights, but rather the values themselves. 9 | Args: 10 | w (torch.Tensor): weights 11 | Returns: 12 | u (torch.Tensor): quantized weights 13 | """ 14 | scale = 1.0 / w.abs().mean().clamp_(min=1e-5) 15 | u = (w * scale).round().clamp_(-1,1) / scale 16 | return u 17 | 18 | def activation_quant(x: torch.Tensor): 19 | """ 20 | Quantize the activation tensor to scaled 8 bit based on the mean of the absolute values described by 21 | https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf 22 | 23 | This doesn't change the type of the activations, but rather the values themselves. 24 | Args: 25 | x (torch.Tensor): activations 26 | Returns: 27 | u (torch.Tensor): quantized weights 28 | """ 29 | scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) 30 | u = (x * scale).round().clamp_(-128,127) / scale 31 | return u 32 | 33 | def activation_post_quant(x_norm: torch.Tensor): 34 | """ 35 | Quantize the layer-normalized activations to 8 bit based on the mean of the absolute values described by 36 | https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf 37 | 38 | In constrast with the add-on of the paper, this function does not do its own RMSNorm. In turn, the 39 | RMSNorm will be done directly on the BitLinear layer. 40 | 41 | This doesn't change the type of the activations, but rather the values themselves. 42 | Args: 43 | w (torch.Tensor): weights 44 | Returns: 45 | y (torch.Tensor): quantized activations 46 | scale (torch.Tensor): scale factor 47 | """ 48 | scale = 127.0 / x_norm.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) 49 | y = (x_norm * scale).round().clamp_(-128,127) 50 | return y, scale 51 | 52 | def quantize_weights_to_int8(w: torch.Tensor): 53 | """ 54 | Offline quantization of a set of weights to int8 based on the mean of the absolute values. 55 | 56 | This operation casts the weights to int8. 57 | Args: 58 | w (torch.Tensor): weights 59 | Returns: 60 | w_quant (torch.Tensor): quantized weights 61 | scale (torch.Tensor): scale factor 62 | """ 63 | scale = 1.0 / w.abs().mean().clamp_(min=1e-5) 64 | w_quant = (w * scale).round().clamp_(-1,1).to(torch.int8) 65 | w_quant.requires_grad = False 66 | return w_quant, scale -------------------------------------------------------------------------------- /bllama/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .bitlinear import BitLinear 5 | from .utils import apply_rotary_emb, precompute_freqs_cis, RMSNorm 6 | from .config import bLlamaConfig 7 | 8 | 9 | class Attention(nn.Module): 10 | def __init__(self, config: bLlamaConfig): 11 | super().__init__() 12 | self.config = config 13 | self.head_dim = config.hidden_size // config.num_heads 14 | self.n_heads = config.num_heads 15 | self.hidden_size = config.hidden_size 16 | self.attn_dropout = config.dropout 17 | self.proj_dropout = config.dropout 18 | self.flash_attention = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 19 | # consider self.n_heads == self.n_kv_heads 20 | 21 | 22 | self.wq = BitLinear(self.hidden_size, self.hidden_size, bias=False) 23 | self.wk = BitLinear(self.hidden_size, self.hidden_size, bias=False) 24 | self.wv = BitLinear(self.hidden_size, self.hidden_size, bias=False) 25 | self.attn_drop = nn.Dropout(self.attn_dropout) 26 | self.wo = BitLinear(self.hidden_size, self.hidden_size, bias=False) 27 | self.proj_drop = nn.Dropout(self.proj_dropout) 28 | 29 | def forward( 30 | self, 31 | x: torch.Tensor, 32 | freq_cis: torch.Tensor, 33 | mask: torch.Tensor = None, 34 | inference: bool = False, 35 | ): 36 | 37 | bsz, seqlen, _ = x.shape 38 | xq, xk, xv = self.wq(x, inference), self.wk(x, inference), self.wv(x, inference) 39 | xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) 40 | xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) 41 | xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) 42 | 43 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freq_cis) 44 | 45 | keys = xk.transpose(1,2) 46 | values = xv.transpose(1,2) 47 | queries = xq.transpose(1,2) 48 | 49 | if self.flash_attention: 50 | attn_output = F.scaled_dot_product_attention( 51 | queries, 52 | keys, 53 | values, 54 | attn_mask=None, 55 | dropout_p=self.attn_dropout if self.training else 0.0, 56 | is_causal=True, 57 | ) 58 | else: 59 | attn = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5) 60 | attn = attn + mask[:,:, :seqlen, :seqlen] 61 | attn = F.softmax(attn.float(), dim=-1).type_as(queries) 62 | attn_output = torch.matmul(attn, values) 63 | 64 | attn_output = self.attn_drop(attn_output) 65 | attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 66 | return self.wo(attn_output, inference) 67 | 68 | class FeedForward(nn.Module): 69 | def __init__( 70 | self, 71 | dim: int, 72 | hidden_dim: int, 73 | ): 74 | super().__init__() 75 | self.w1 = BitLinear(dim, hidden_dim) 76 | self.w2 = BitLinear(hidden_dim, dim) 77 | self.w3 = BitLinear(dim, hidden_dim) 78 | 79 | def forward(self, x, inference: bool = False): 80 | return self.w2(F.silu(self.w1(x, inference)) * self.w3(x, inference), inference=inference) 81 | 82 | class TransformerBlock(nn.Module): 83 | def __init__( 84 | self, 85 | config: bLlamaConfig, 86 | ): 87 | super().__init__() 88 | self.attn = Attention(config) 89 | self.ff = FeedForward(config.hidden_size, config.hidden_size * 4) 90 | # self.attn_norm = RMSNorm(config.hidden_size) 91 | # self.ff_norm = RMSNorm(config.hidden_size) 92 | # BitLinear has built-in RMSNorm, but I'll keep as Identity in case 93 | # we want to make it a choice to use BitLinear or Linear later 94 | self.attn_norm = nn.Identity() 95 | self.ff_norm = nn.Identity() 96 | 97 | def forward( 98 | self, 99 | x: torch.Tensor, 100 | freq_cis: torch.Tensor, 101 | inference: bool = False, 102 | ): 103 | h = x + self.attn(self.attn_norm(x), freq_cis, mask=None, inference=inference) 104 | return h + self.ff(self.ff_norm(h), inference=inference) 105 | 106 | class Transformer(nn.Module): 107 | """ 108 | Transformer module based on the LLaMa architecture. 109 | 110 | Args: 111 | config (bLlamaConfig): Configuration object containing model hyperparameters. 112 | 113 | Attributes: 114 | config (bLlamaConfig): Configuration object containing model hyperparameters. 115 | embed (nn.Embedding): Embedding layer for input tokens. 116 | blocks (nn.ModuleList): List of TransformerBlock instances. 117 | norm (RMSNorm): Normalization layer. 118 | freq_cis (torch.Tensor): Precomputed frequency tensor. 119 | vocab_proj (nn.Linear): Linear projection layer for output vocabulary. 120 | mask (torch.Tensor): Mask for attention mechanism. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | config: bLlamaConfig, 126 | ): 127 | super().__init__() 128 | self.config = config 129 | self.embed = nn.Embedding(config.vocab_size, config.hidden_size) 130 | self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) 131 | self.norm = RMSNorm(config.hidden_size) 132 | self.freq_cis = precompute_freqs_cis(config.hidden_size // config.num_heads, config.seq_len * 2) 133 | self.vocab_proj = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 134 | self.embed.weight = self.vocab_proj.weight # tie weights 135 | 136 | if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 137 | mask = torch.full((1,1,config.seq_len,config.seq_len), float('-inf'), device=self.freq_cis.device, dtype=self.freq_cis.dtype) 138 | mask = torch.triu(mask, diagonal=1) 139 | self.register_buffer('mask', mask) 140 | else: 141 | mask = None 142 | 143 | 144 | def forward( 145 | self, 146 | x: torch.Tensor, 147 | inference: bool = False, 148 | ): 149 | """ 150 | Forward pass of the Transformer module. 151 | 152 | Args: 153 | x (torch.Tensor): Input tensor of shape (batch_size, sequence_length). 154 | inference (bool, optional): Flag indicating whether the forward pass is for inference or training. This is needed because the BitLinear layers have different behavior during inference. 155 | 156 | Returns: 157 | torch.Tensor: Output tensor of shape (batch_size, sequence_length, vocab_size). 158 | 159 | """ 160 | bsz, seqlen = x.shape 161 | x = self.embed(x) 162 | freq_cis = self.freq_cis[:seqlen].to(x.device) 163 | for i, blk in enumerate(self.blocks): 164 | x = blk(x, freq_cis, inference=inference) 165 | x = self.norm(x) 166 | return self.vocab_proj(x) 167 | 168 | -------------------------------------------------------------------------------- /bllama/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 6 | """ 7 | From: https://github.com/meta-llama/llama/blob/main/llama/model.py 8 | Reshape frequency tensor for broadcasting it with another tensor. 9 | 10 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x' 11 | for the purpose of broadcasting the frequency tensor during element-wise operations. 12 | 13 | Args: 14 | freqs_cis (torch.Tensor): Frequency tensor to be reshaped. 15 | x (torch.Tensor): Target tensor for broadcasting compatibility. 16 | 17 | Returns: 18 | torch.Tensor: Reshaped frequency tensor. 19 | 20 | Raises: 21 | AssertionError: If the frequency tensor doesn't match the expected shape. 22 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. 23 | """ 24 | ndim = x.ndim 25 | assert 0 <= 1 < ndim 26 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 27 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 28 | return freqs_cis.view(*shape) 29 | 30 | def apply_rotary_emb( 31 | xq: torch.Tensor, 32 | xk: torch.Tensor, 33 | freqs_cis: torch.Tensor, 34 | ) -> Tuple[torch.Tensor, torch.Tensor]: 35 | """ 36 | From: https://github.com/meta-llama/llama/blob/main/llama/model.py 37 | Apply rotary embeddings to input tensors using the given frequency tensor. 38 | 39 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided 40 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor 41 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are 42 | returned as real tensors. 43 | 44 | Args: 45 | xq (torch.Tensor): Query tensor to apply rotary embeddings. 46 | xk (torch.Tensor): Key tensor to apply rotary embeddings. 47 | freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. 48 | 49 | Returns: 50 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 51 | """ 52 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 53 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 54 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 55 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 56 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 57 | return xq_out.type_as(xq), xk_out.type_as(xk) 58 | 59 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 60 | """ 61 | From: https://github.com/meta-llama/llama/blob/main/llama/model.py 62 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 63 | 64 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 65 | and the end index 'end'. The 'theta' parameter scales the frequencies. 66 | The returned tensor contains complex values in complex64 data type. 67 | 68 | Args: 69 | dim (int): Dimension of the frequency tensor. 70 | end (int): End index for precomputing frequencies. 71 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 72 | 73 | Returns: 74 | torch.Tensor: Precomputed frequency tensor with complex exponentials. 75 | """ 76 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 77 | t = torch.arange(end, device=freqs.device) # type: ignore 78 | freqs = torch.outer(t, freqs).float() # type: ignore 79 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 80 | return freqs_cis 81 | 82 | class RMSNorm(nn.Module): 83 | def __init__(self, hidden_size, eps=1e-6): 84 | super().__init__() 85 | self.weight = nn.Parameter(torch.ones(hidden_size)) 86 | self.variance_epsilon = eps 87 | 88 | def forward(self, x): 89 | variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) 90 | x = x * torch.rsqrt(variance + self.variance_epsilon) 91 | return self.weight * x 92 | -------------------------------------------------------------------------------- /notes.md: -------------------------------------------------------------------------------- 1 | ## BitNet b1.58 2 | 3 | [BitNet b1.58](https://arxiv.org/pdf/2402.17764v1.pdf) is a novel 1-bit LLM variant that introduces ternary {-1,0,1} parameter weights. Although only having 3 possible values for weights, it matches full-precision Transformers of the same size in performance. 4 | 5 | What is incredibily interesting is that this concept allows models to be both performant and cost-effective. 6 | 7 | ### Tips on training BitNet 8 | 9 | - Replace all nn.Linear in attention and SwiGLU with BitLinear 10 | - Remove RMSNorm before attention and SwiGLU because BitLinear has built-in RMSNorm 11 | 12 | 13 | ## A review of LLaMa 14 | 15 | ### Datasets used 16 | 17 | - 67% English CommonCrawl (CCNet pipeline with filtering of low quality content with n-gram model) 18 | - 15% C4 with filtering 19 | - 4.5% GitHub with removed boilerplates 20 | - 4.5% Wikipedia with preprocess to remove hyperlinks, comments and formatting boilerplates. 21 | - 4.5% Gutenberg and Books3 22 | - 2.5% ArXiv with removed first sections and bibliography 23 | - 2% Stack Exchange with removed HTML tags 24 | 25 | ### Notable differences from past architectures 26 | 27 | - RMSNorm instead of LayerNorm and normalization of each sub-layer's input instead of output (already done on GPT-2) 28 | - SwiGLU instead of ReLU (GPT-2 uses GeLU) 29 | - Rotary encodings instead of absolute positional embeddings 30 | 31 | ### Optimizer 32 | 33 | - Trained with AdamW with $\beta_1 = 0.9$, $\beta_2 = 0.95$ 34 | - Cosine learning rate schedule such that the final learning rate is equal to 10% of the maximal learning rate. 35 | - Weight decay of 0.1 and gradient clipping of 1.0. 36 | - 2000 warm up steps. 37 | 38 | 39 | ### Optimizations and efficient implementation 40 | 41 | - Efficient implementation of multi-head attention using the xformers library. -> Doesn't store attention weights and doesn't compute key/query scores that are masked. 42 | - Manual iomplementation of the backward function of the transformer layers. 43 | - Training efficiency of the 65B-parameter model: 380 tokens/sec/GPU on 2048 A100-80GB. -> 21 days to train 1.4T tokens. 44 | - -------------------------------------------------------------------------------- /utils/callbacks/weight_distribution_callback.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from typing import Optional 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import wandb 7 | 8 | class CalculateDistributionCallback(pl.Callback): 9 | def __init__( 10 | self, 11 | every_n_step: Optional[int] = 1000, 12 | to_wandb: Optional[bool] = True, 13 | save_path: Optional[str] = "./callback.png", 14 | ): 15 | self.every_n_step = every_n_step 16 | self.to_wandb = to_wandb 17 | self.save_path = save_path 18 | 19 | def on_train_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx): 20 | if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0: 21 | self._produce_distribution(trainer) 22 | 23 | @torch.no_grad() 24 | def _quantize_weights_to_int8_with_no_clamp(self, w: torch.Tensor): 25 | scale = 1.0 / w.abs().mean().clamp_(min=1e-5) 26 | w_quant = (w * scale).round() 27 | w_quant.requires_grad = False 28 | return w_quant, scale 29 | 30 | @torch.no_grad() 31 | def _produce_distribution( 32 | self, 33 | trainer: pl.Trainer, 34 | ): 35 | model = trainer.model.model 36 | fig, axs = plt.subplots(4,4, figsize=(10,10)) 37 | 38 | i = 0 39 | for k,v in model.state_dict().items(): 40 | if "attn.wq" in k and "weight" in k and "rms_norm" not in k: 41 | w_quant, _ = self._quantize_weights_to_int8_with_no_clamp(v) 42 | line = int(i / 4) 43 | column = i % 4 44 | counts, bins = np.histogram(w_quant.view(1,-1).cpu().numpy(), bins=20) 45 | axs[line, column].stairs(counts, bins, fill=True) 46 | axs[line, column].axvline(x=1, color='r', ls='--') 47 | axs[line, column].axvline(x=-1, color='r', ls='--') 48 | axs[line, column].set_title(f"Quantized wq layer {i}") 49 | i += 1 50 | plt.suptitle(f"Weight distribution of quantized wq layers at step {trainer.global_step}") 51 | if self.to_wandb: 52 | wandb.log({"weight_distribution": wandb.Image(fig)}) 53 | if self.save_path is not None: 54 | plt.savefig(self.save_path) -------------------------------------------------------------------------------- /utils/datamodules/text_data_module.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from torch.utils.data import DataLoader, random_split 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def collate_batch(batch, max_length=1024): 9 | input_ids, labels = zip(*batch) 10 | 11 | max_len = min(max(len(x) for x in input_ids), max_length) 12 | padded_inputs = pad_sequence(input_ids, batch_first=True, padding_value=0) 13 | padded_labels = pad_sequence(labels, batch_first=True, padding_value=-100) 14 | 15 | return padded_inputs[:max_len], padded_labels[:max_len] 16 | 17 | class TextDataModule(LightningDataModule): 18 | def __init__(self, 19 | dataset: Dataset, 20 | batch_size:int=8, 21 | train_test_split:float=0.8, 22 | seed:int=42 23 | ): 24 | super().__init__() 25 | self.dataset = dataset 26 | self.batch_size = batch_size 27 | self.train_test_split = train_test_split 28 | self.seed = seed 29 | self.dataset_train = None 30 | self.dataset_val = None 31 | self.dataset_predict = None 32 | 33 | def setup(self, stage=None): 34 | if stage == "fit": 35 | train_size = int(len(self.dataset) * self.train_test_split) 36 | val_size = len(self.dataset) - train_size 37 | self.dataset_train, self.dataset_val = random_split( 38 | self.dataset, 39 | [train_size, val_size], 40 | generator=torch.Generator().manual_seed(self.seed) 41 | ) 42 | if stage == "predict": 43 | self.dataset_predict = self.dataset 44 | 45 | def train_dataloader(self): 46 | return DataLoader( 47 | self.dataset_train, 48 | batch_size=self.batch_size, 49 | shuffle=True, num_workers=7, 50 | #pin_memory=True, 51 | collate_fn=lambda x: collate_batch(x, max_length=self.dataset.max_length) 52 | ) 53 | 54 | def val_dataloader(self): 55 | return DataLoader( 56 | self.dataset_val, 57 | batch_size=self.batch_size, 58 | shuffle=False, num_workers=7, 59 | #pin_memory=True, 60 | collate_fn=lambda x: collate_batch(x, max_length=self.dataset.max_length) 61 | ) 62 | 63 | def predict_dataloader(self): 64 | return DataLoader( 65 | self.dataset_predict, 66 | batch_size=self.batch_size, 67 | shuffle=False, num_workers=0, 68 | collate_fn=lambda x: collate_batch(x, max_length=self.dataset.max_length) 69 | ) -------------------------------------------------------------------------------- /utils/datasets/gepeto.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from typing import Optional, Union, List 4 | import pandas as pd 5 | 6 | 7 | class GepetoDataset(Dataset): 8 | """ 9 | Generic dataset class that can be used to load any parquet file or list of parquet files. 10 | """ 11 | def __init__(self, 12 | file_paths: Union[str, List[str]], 13 | tokenizer: object, 14 | max_length: Optional[str]=1024): 15 | self.df = None 16 | self.file_paths = file_paths 17 | self.max_length = max_length 18 | self.tokenizer = tokenizer 19 | 20 | self._setup() 21 | self._tokenize(prune_dataset=True) 22 | 23 | def _setup(self): 24 | for file_path in self.file_paths: 25 | assert file_path.endswith(".parquet"), "File path must be a parquet file" 26 | df = pd.read_parquet(file_path) 27 | assert "text" in df.columns, "Column 'text' not found in dataframe" 28 | df = df[["text"]] 29 | self.df = df if self.df is None else pd.concat([self.df, df]) 30 | self.df.reset_index(drop=True, inplace=True) 31 | 32 | def _tokenize(self, prune_dataset=True): 33 | print('Tokenizing text...') 34 | self.df["tokens"] = self.df["text"].apply(lambda x: self.tokenizer.encode(x, allowed_special={'<|endoftext|>'})) 35 | self.df["len_tokens"] = self.df["tokens"].apply(lambda x: len(x)) 36 | if prune_dataset: 37 | self.df = self.df[self.df["len_tokens"] > 50] 38 | 39 | def __len__(self): 40 | return len(self.df) 41 | 42 | def __getitem__(self, idx): 43 | tokens = self.df.iloc[idx]["tokens"] 44 | tokens = tokens[:self.max_length] 45 | inputs = torch.tensor(tokens[:-1], dtype=torch.long) 46 | labels = torch.tensor(tokens[1:], dtype=torch.long) 47 | 48 | return inputs, labels -------------------------------------------------------------------------------- /utils/datasets/shakespeare.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from typing import Optional 4 | 5 | class ShakespeareDataset(Dataset): 6 | def __init__(self, entry: str, tokenizer: object, max_length: Optional[str]=1024): 7 | self.entry = entry 8 | self.inputs = [] 9 | self.labels = [] 10 | self.max_length = max_length 11 | self.tokenizer = tokenizer 12 | 13 | @classmethod 14 | def from_file(cls, file_path: str, tokenizer: object, max_length: Optional[int] =1024): 15 | with open(file_path, "r") as f: 16 | entry = f.read() 17 | dataset = cls(entry, tokenizer, max_length) 18 | dataset._prepare_text() 19 | return dataset 20 | 21 | def _prepare_text(self): 22 | tokenized_text = self.tokenizer.encode(self.entry) 23 | 24 | for i in range(0, len(tokenized_text), self.max_length): 25 | a = tokenized_text[i:i+self.max_length] 26 | b = tokenized_text[i+1:i+self.max_length+1] 27 | if len(a) != self.max_length: 28 | print(f"{i} this has a different length: {len(a)}, padding") 29 | a = a + [0 for _ in range(self.max_length - len(a))] 30 | b = b + [0 for _ in range(self.max_length - len(b))] 31 | self.inputs.append(a) 32 | self.labels.append(b) 33 | 34 | def __len__(self): 35 | return len(self.inputs) 36 | 37 | def __getitem__(self, idx): 38 | return torch.tensor(self.inputs[idx]), torch.tensor(self.labels[idx]) -------------------------------------------------------------------------------- /utils/images/bllama_quantization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rafacelente/bllama/ebe528a107c1563d716afdb1fb881092fcc4c42b/utils/images/bllama_quantization.png --------------------------------------------------------------------------------