├── quant_linear.py ├── requirements.txt ├── toy_linear_nn.py ├── toy_vit_cifar100.py └── toy_vit_f_mnist.py /quant_linear.py: -------------------------------------------------------------------------------- 1 | from hashlib import new 2 | import math 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import Callable 7 | from enum import Enum 8 | 9 | 10 | class QuantizationMode(Enum): 11 | one_bit = 1 12 | two_bit = 2 13 | 14 | 15 | class BitNetLinearLayer(nn.Module): 16 | def __init__( 17 | self, 18 | in_features, 19 | out_features, 20 | bias=False, 21 | quantization_mode: QuantizationMode = QuantizationMode.two_bit, 22 | ): 23 | super(BitNetLinearLayer, self).__init__() 24 | self.binary_layer = True 25 | self.in_features = in_features 26 | self.out_features = out_features 27 | 28 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 29 | self.bias = ( 30 | nn.Parameter(torch.Tensor(out_features)) if bias is not None else None 31 | ) 32 | self.quantization_mode = quantization_mode 33 | 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | stdv = 1.0 / math.sqrt(self.weight.size(1)) 38 | self.weight.data.uniform_(-stdv, stdv) 39 | if self.bias is not None: 40 | self.bias.data.uniform_(-stdv, stdv) 41 | 42 | def compute_adjustment_factor(self, input_tensor: torch.Tensor): 43 | absmean_weight = torch.mean(torch.abs(input_tensor)) 44 | adjustment_factor = 1e-4 + absmean_weight * 2 + 1e-4 45 | return adjustment_factor 46 | 47 | def compute_2bit_quantized_tensor(self, input_tensor: torch.Tensor): 48 | twobit_matrix = torch.clip(input=torch.round(input_tensor), min=-1, max=1) 49 | return twobit_matrix 50 | 51 | def compute_1bit_quantized_tensor(self, input_tensor: torch.Tensor): 52 | return torch.sign(input_tensor) 53 | 54 | def compute_quantized_tensor(self, input_tensor: torch.Tensor): 55 | if self.quantization_mode == QuantizationMode.two_bit: 56 | return self.compute_2bit_quantized_tensor(input_tensor) 57 | else: 58 | return self.compute_1bit_quantized_tensor(input_tensor) 59 | 60 | def compute_commitment_loss( 61 | self, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss 62 | ): 63 | adjustment_factor = self.compute_adjustment_factor(self.weight) 64 | adjusted_weight = self.weight / adjustment_factor 65 | quantized_weight = self.compute_quantized_weight(adjusted_weight) 66 | 67 | return loss_fn(adjusted_weight, quantized_weight.detach()) 68 | 69 | def forward(self, x): 70 | weight_adjustment_factor = self.compute_adjustment_factor(self.weight) 71 | adjusted_weight = self.weight / weight_adjustment_factor 72 | input_adjustment_factor = 127.0 73 | adjusted_input = x / input_adjustment_factor 74 | 75 | quantized_weight = self.compute_quantized_tensor(adjusted_weight) 76 | quantized_input = torch.clip(torch.round(adjusted_input), min=-1, max=1) 77 | 78 | if self.training: 79 | quantized_weight = ( 80 | adjusted_weight + (quantized_weight - adjusted_weight).detach() 81 | ) 82 | 83 | quantized_input = ( 84 | adjusted_input + (quantized_input - adjusted_input).detach() 85 | ) 86 | 87 | output = ( 88 | weight_adjustment_factor 89 | * input_adjustment_factor 90 | * adjusted_input 91 | @ adjusted_weight.t() 92 | ) 93 | 94 | if self.bias is not None: 95 | output += self.bias 96 | return output 97 | 98 | 99 | import copy 100 | 101 | 102 | def create_quantized_copy_of_model( 103 | input_model: nn.Module, quantization_mode: QuantizationMode 104 | ): 105 | model_copy = copy.deepcopy(input_model) 106 | hash_table = {n: m for n, m in model_copy.named_modules()} 107 | 108 | for key in list(hash_table.keys()): 109 | if isinstance(hash_table[key], nn.Linear): 110 | new_module = BitNetLinearLayer( 111 | in_features=hash_table[key].in_features, 112 | out_features=hash_table[key].out_features, 113 | bias=hash_table[key].bias is not None, 114 | quantization_mode=quantization_mode, 115 | ) 116 | name_chain = key.split(".") 117 | parent_module_attr_name = ".".join(name_chain[:-1]) 118 | parent_module = hash_table[parent_module_attr_name] 119 | setattr(parent_module, name_chain[-1], new_module) 120 | for n, m in model_copy.named_modules(): 121 | assert not isinstance(m, nn.Linear) 122 | return model_copy 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.3 2 | aiosignal==1.3.1 3 | appdirs==1.4.4 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | certifi==2024.2.2 7 | charset-normalizer==3.3.2 8 | click==8.1.7 9 | datasets==2.18.0 10 | dill==0.3.8 11 | docker-pycreds==0.4.0 12 | filelock==3.9.0 13 | frozenlist==1.4.1 14 | fsspec==2024.2.0 15 | gitdb==4.0.11 16 | GitPython==3.1.42 17 | huggingface-hub==0.21.4 18 | idna==3.6 19 | Jinja2==3.1.2 20 | lightning==2.2.1 21 | lightning-utilities==0.10.1 22 | MarkupSafe==2.1.3 23 | mpmath==1.3.0 24 | multidict==6.0.5 25 | multiprocess==0.70.16 26 | networkx==3.2.1 27 | numpy==1.26.3 28 | nvidia-cublas-cu11==11.11.3.6 29 | nvidia-cuda-cupti-cu11==11.8.87 30 | nvidia-cuda-nvrtc-cu11==11.8.89 31 | nvidia-cuda-runtime-cu11==11.8.89 32 | nvidia-cudnn-cu11==8.7.0.84 33 | nvidia-cufft-cu11==10.9.0.58 34 | nvidia-curand-cu11==10.3.0.86 35 | nvidia-cusolver-cu11==11.4.1.48 36 | nvidia-cusparse-cu11==11.7.5.86 37 | nvidia-nccl-cu11==2.19.3 38 | nvidia-nvtx-cu11==11.8.86 39 | packaging==23.2 40 | pandas==2.2.1 41 | pillow==10.2.0 42 | protobuf==4.25.3 43 | psutil==5.9.8 44 | pyarrow==15.0.1 45 | pyarrow-hotfix==0.6 46 | python-dateutil==2.9.0.post0 47 | pytorch-lightning==2.2.1 48 | pytz==2024.1 49 | PyYAML==6.0.1 50 | regex==2023.12.25 51 | requests==2.31.0 52 | safetensors==0.4.2 53 | sentry-sdk==1.41.0 54 | setproctitle==1.3.3 55 | six==1.16.0 56 | smmap==5.0.1 57 | sympy==1.12 58 | tokenizers==0.15.2 59 | torch==2.2.1+cu118 60 | torchmetrics==1.3.1 61 | torchvision==0.17.1+cu118 62 | tqdm==4.66.2 63 | transformers==4.38.2 64 | triton==2.2.0 65 | typing_extensions==4.8.0 66 | tzdata==2024.1 67 | urllib3==2.2.1 68 | wandb==0.16.4 69 | xxhash==3.4.1 70 | yarl==1.9.4 71 | -------------------------------------------------------------------------------- /toy_linear_nn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lightning as L 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torchvision.datasets import MNIST 9 | from torchvision.transforms import ToTensor 10 | 11 | from quant_linear import ( 12 | create_quantized_copy_of_model, 13 | QuantizationMode, 14 | ) 15 | 16 | 17 | class MnistLightning(L.LightningModule): 18 | def __init__(self, linear_layer_type: nn.Module = nn.Linear, lr=1e-3): 19 | super().__init__() 20 | self.linear_layer_type = linear_layer_type 21 | self.model = nn.Sequential( 22 | linear_layer_type(28 * 28, 128), 23 | nn.ReLU(), 24 | linear_layer_type(128, 10, bias=False), 25 | nn.ReLU(), 26 | linear_layer_type(10, 10), 27 | nn.ReLU(), 28 | linear_layer_type(10, 10), 29 | ) 30 | self.lr = lr 31 | 32 | def training_step(self, batch, batch_idx): 33 | x, y = batch 34 | x = x.view(x.size(0), -1) 35 | logits = self.model(x) 36 | loss = F.cross_entropy(logits, y) 37 | accuracy = torch.argmax(logits, 1).eq(y).float().mean() 38 | self.log_dict( 39 | {"tl": loss.item(), "ta": accuracy.item()}, 40 | on_step=True, 41 | prog_bar=True, 42 | ) 43 | return loss 44 | 45 | def validation_step(self, batch, batch_idx): 46 | with torch.no_grad(): 47 | x, y = batch 48 | x = x.view(x.size(0), -1) 49 | logits = self.model(x) 50 | loss = F.cross_entropy(logits, y) 51 | accuracy = torch.argmax(logits, 1).eq(y).float().mean() 52 | self.log_dict( 53 | {"vl": loss.item(), "va": accuracy.item()}, on_step=True, prog_bar=True 54 | ) 55 | 56 | def configure_optimizers(self): 57 | optimizer = torch.optim.Adam( 58 | self.parameters(), 59 | lr=self.lr, 60 | ) 61 | return optimizer 62 | 63 | 64 | dataset_folder = os.path.join(os.getcwd(), "data") 65 | # setup data 66 | 67 | train_dataset = MNIST(dataset_folder, train=True, download=True, transform=ToTensor()) 68 | test_dataset = MNIST(dataset_folder, train=False, download=True, transform=ToTensor()) 69 | 70 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128) 71 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128) 72 | 73 | print("NORMAL TRAINING") 74 | 75 | from lightning.pytorch.loggers import WandbLogger 76 | 77 | normal_module = MnistLightning(linear_layer_type=nn.Linear) 78 | one_bit_quantized_module = create_quantized_copy_of_model( 79 | normal_module, quantization_mode=QuantizationMode.one_bit 80 | ) 81 | two_bit_quantized_module = create_quantized_copy_of_model( 82 | normal_module, quantization_mode=QuantizationMode.two_bit 83 | ) 84 | 85 | input_val = input("enter 1,2,3") 86 | if int(input_val) == 1: 87 | normal_logger = WandbLogger(project="BitNet_v2", name="normal_mnist") 88 | normal_trainer = L.Trainer( 89 | max_epochs=10, 90 | logger=normal_logger, 91 | ) 92 | normal_trainer.fit(normal_module, train_loader, test_loader) 93 | normal_logger.finalize(status="success") 94 | 95 | if int(input_val) == 2: 96 | one_bit_logger = WandbLogger(project="BitNet_v2", name="one_bit_mnist") 97 | one_bit_logger.experiment.name = "one_bit_mnist" 98 | one_bit_quantized_module.lr = 1e-4 99 | one_bit_quant_trainer = L.Trainer( 100 | max_epochs=10, 101 | logger=one_bit_logger, 102 | ) 103 | one_bit_quant_trainer.fit(one_bit_quantized_module, train_loader, test_loader) 104 | one_bit_logger.finalize(status="success") 105 | 106 | if int(input_val) == 3: 107 | two_bit_logger = WandbLogger(project="BitNet_v2", name="two_bit_mnist_lr=1e-4") 108 | two_bit_logger.experiment.name = "two_bit_mnist_lr=1e-4" 109 | two_bit_quant_trainer = L.Trainer( 110 | max_epochs=10, 111 | logger=two_bit_logger, 112 | ) 113 | two_bit_quantized_module.lr = 1e-4 114 | two_bit_quant_trainer.fit(two_bit_quantized_module, train_loader, test_loader) 115 | two_bit_logger.finalize(status="success") 116 | -------------------------------------------------------------------------------- /toy_vit_cifar100.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import lightning as L 3 | from transformers.models.vit.configuration_vit import ViTConfig 4 | from transformers.models.vit.modeling_vit import ViTModel, ViTForImageClassification 5 | 6 | 7 | from quant_linear import ( 8 | create_quantized_copy_of_model, 9 | QuantizationMode, 10 | ) 11 | 12 | import torch 13 | from torchvision import transforms 14 | from torch.utils.data import DataLoader 15 | 16 | config = ViTConfig( 17 | hidden_size=128, 18 | num_hidden_layers=8, 19 | num_attention_heads=4, 20 | intermediate_size=256, 21 | hidden_act="gelu", 22 | image_size=32, 23 | patch_size=4, 24 | num_labels=100, 25 | num_channels=3, 26 | ) 27 | 28 | 29 | class ViTImageClassifier(L.LightningModule): 30 | def __init__(self, config: ViTConfig, lr=1e-3): 31 | super().__init__() 32 | self.model = ViTForImageClassification(config) 33 | self.config = config 34 | self.lr = lr 35 | 36 | def forward(self, batch): 37 | return self.model(**batch) 38 | 39 | def training_step(self, batch, batch_idx): 40 | output = self(batch) 41 | loss = output.loss 42 | argmax = output.logits.argmax(dim=1) 43 | accuracy = (argmax == batch["labels"]).float().mean() 44 | self.log_dict( 45 | { 46 | "tl": loss.item(), 47 | "ta": accuracy.item(), 48 | }, 49 | prog_bar=True, 50 | on_step=True, 51 | on_epoch=True, 52 | ) 53 | return loss 54 | 55 | def validation_step(self, batch, batch_idx): 56 | with torch.no_grad(): 57 | output = self(batch) 58 | loss = output.loss 59 | argmax = output.logits.argmax(dim=1) 60 | accuracy = (argmax == batch["labels"]).float().mean() 61 | 62 | self.log_dict( 63 | { 64 | "vl": loss.item(), 65 | "va": accuracy.item(), 66 | }, 67 | prog_bar=True, 68 | on_step=True, 69 | on_epoch=True, 70 | ) 71 | return loss 72 | 73 | def configure_optimizers(self): 74 | return torch.optim.Adam(self.model.parameters(), lr=self.lr) 75 | 76 | 77 | dataset = load_dataset("cifar100") 78 | 79 | image_transforms = transforms.Compose( 80 | [ 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5), (0.5)), 83 | ] 84 | ) 85 | 86 | processed_dataset = dataset.map( 87 | lambda x: {"pixel_values": image_transforms(x["img"]), "labels": x["fine_label"]} 88 | ) 89 | processed_dataset = processed_dataset.remove_columns(["fine_label", "img"]) 90 | processed_dataset.set_format("torch", columns=["pixel_values", "labels"]) 91 | 92 | 93 | train_dataloader = DataLoader(processed_dataset["train"], batch_size=128) 94 | eval_dataloader = DataLoader(processed_dataset["test"], batch_size=128) 95 | 96 | normal_model = ViTImageClassifier(config) 97 | one_bit_quantized_model = create_quantized_copy_of_model( 98 | normal_model, quantization_mode=QuantizationMode.one_bit 99 | ) 100 | two_bit_quantized_model = create_quantized_copy_of_model( 101 | normal_model, quantization_mode=QuantizationMode.two_bit 102 | ) 103 | 104 | from lightning.pytorch.loggers import WandbLogger 105 | 106 | choice = input("Enter 1,2,3:") 107 | if int(choice) == 1: 108 | normal_logger = WandbLogger(project="BitNet_v2", name="normal_cifar100") 109 | normal_trainer = L.Trainer( 110 | max_epochs=10, 111 | logger=normal_logger, 112 | ) 113 | normal_trainer.fit( 114 | normal_model, 115 | train_dataloaders=train_dataloader, 116 | val_dataloaders=eval_dataloader, 117 | ) 118 | if int(choice) == 2: 119 | one_bit_logger = WandbLogger(project="BitNet_v2", name="one_bit_cifar100") 120 | one_bit_trainer = L.Trainer( 121 | max_epochs=10, 122 | logger=one_bit_logger, 123 | ) 124 | one_bit_quantized_model.lr = 1e-4 125 | one_bit_trainer.fit( 126 | one_bit_quantized_model, 127 | train_dataloaders=train_dataloader, 128 | val_dataloaders=eval_dataloader, 129 | ) 130 | 131 | if int(choice) == 3: 132 | two_bit_logger = WandbLogger(project="BitNet_v2", name="two_bit_cifar100") 133 | two_bit_trainer = L.Trainer( 134 | max_epochs=10, 135 | logger=two_bit_logger, 136 | ) 137 | two_bit_quantized_model.lr = 1e-4 138 | two_bit_trainer.fit( 139 | two_bit_quantized_model, 140 | train_dataloaders=train_dataloader, 141 | val_dataloaders=eval_dataloader, 142 | ) 143 | -------------------------------------------------------------------------------- /toy_vit_f_mnist.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import lightning as L 3 | from transformers.models.vit.configuration_vit import ViTConfig 4 | from transformers.models.vit.modeling_vit import ViTModel, ViTForImageClassification 5 | 6 | 7 | from quant_linear import ( 8 | create_quantized_copy_of_model, 9 | QuantizationMode, 10 | ) 11 | 12 | import torch 13 | from torchvision import transforms 14 | from torch.utils.data import DataLoader 15 | 16 | config = ViTConfig( 17 | hidden_size=128, 18 | num_hidden_layers=6, 19 | num_attention_heads=4, 20 | intermediate_size=256, 21 | hidden_act="gelu", 22 | image_size=28, 23 | patch_size=4, 24 | num_labels=10, 25 | num_channels=1, 26 | ) 27 | 28 | 29 | class ViTImageClassifier(L.LightningModule): 30 | def __init__(self, config: ViTConfig, lr=1e-3): 31 | super().__init__() 32 | self.model = ViTForImageClassification(config) 33 | self.config = config 34 | self.lr = lr 35 | 36 | def forward(self, batch): 37 | return self.model(**batch) 38 | 39 | def training_step(self, batch, batch_idx): 40 | output = self(batch) 41 | loss = output.loss 42 | argmax = output.logits.argmax(dim=1) 43 | accuracy = (argmax == batch["labels"]).float().mean() 44 | self.log_dict( 45 | { 46 | "tl": loss.item(), 47 | "ta": accuracy.item(), 48 | }, 49 | prog_bar=True, 50 | on_step=True, 51 | on_epoch=True, 52 | ) 53 | return loss 54 | 55 | def validation_step(self, batch, batch_idx): 56 | with torch.no_grad(): 57 | output = self(batch) 58 | loss = output.loss 59 | argmax = output.logits.argmax(dim=1) 60 | accuracy = (argmax == batch["labels"]).float().mean() 61 | 62 | self.log_dict( 63 | { 64 | "vl": loss.item(), 65 | "va": accuracy.item(), 66 | }, 67 | prog_bar=True, 68 | on_step=True, 69 | on_epoch=True, 70 | ) 71 | return loss 72 | 73 | def configure_optimizers(self): 74 | return torch.optim.Adam(self.model.parameters(), lr=self.lr) 75 | 76 | 77 | dataset = load_dataset("fashion_mnist") 78 | 79 | image_transforms = transforms.Compose( 80 | [ 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5), (0.5)), 83 | ] 84 | ) 85 | 86 | processed_dataset = dataset.map( 87 | lambda x: {"pixel_values": image_transforms(x["image"]), "labels": x["label"]} 88 | ) 89 | processed_dataset = processed_dataset.remove_columns(["label", "image"]) 90 | processed_dataset.set_format("torch", columns=["pixel_values", "labels"]) 91 | 92 | 93 | train_dataloader = DataLoader(processed_dataset["train"], batch_size=128) 94 | eval_dataloader = DataLoader(processed_dataset["test"], batch_size=128) 95 | 96 | normal_model = ViTImageClassifier(config) 97 | one_bit_quantized_model = create_quantized_copy_of_model( 98 | normal_model, quantization_mode=QuantizationMode.one_bit 99 | ) 100 | two_bit_quantized_model = create_quantized_copy_of_model( 101 | normal_model, quantization_mode=QuantizationMode.two_bit 102 | ) 103 | 104 | from lightning.pytorch.loggers import WandbLogger 105 | 106 | choice = input("Enter 1,2,3:") 107 | if int(choice) == 1: 108 | normal_logger = WandbLogger(project="BitNet", name="normal_f_mnist") 109 | normal_trainer = L.Trainer( 110 | max_epochs=10, 111 | logger=normal_logger, 112 | ) 113 | normal_trainer.fit( 114 | normal_model, 115 | train_dataloaders=train_dataloader, 116 | val_dataloaders=eval_dataloader, 117 | ) 118 | if int(choice) == 2: 119 | one_bit_logger = WandbLogger(project="BitNet", name="one_bit_f_mnist") 120 | one_bit_trainer = L.Trainer( 121 | max_epochs=10, 122 | logger=one_bit_logger, 123 | ) 124 | one_bit_quantized_model.lr = 1e-4 125 | one_bit_trainer.fit( 126 | one_bit_quantized_model, 127 | train_dataloaders=train_dataloader, 128 | val_dataloaders=eval_dataloader, 129 | ) 130 | if int(choice) == 3: 131 | two_bit_logger = WandbLogger(project="BitNet", name="two_bit_f_mnist") 132 | two_bit_trainer = L.Trainer( 133 | max_epochs=10, 134 | logger=two_bit_logger, 135 | ) 136 | two_bit_quantized_model.lr = 1e-4 137 | two_bit_trainer.fit( 138 | two_bit_quantized_model, 139 | train_dataloaders=train_dataloader, 140 | val_dataloaders=eval_dataloader, 141 | ) 142 | --------------------------------------------------------------------------------