├── huffman_encode.py ├── net ├── quantization.py ├── models.py ├── prune.py └── huffmancoding.py ├── weight_share.py ├── README.md ├── .gitignore ├── util.py └── pruning.py /huffman_encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from net.huffmancoding import huffman_encode_model 6 | import util 7 | 8 | parser = argparse.ArgumentParser(description='Huffman encode a quantized model') 9 | parser.add_argument('model', type=str, help='saved quantized model') 10 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA') 11 | args = parser.parse_args() 12 | 13 | use_cuda = not args.no_cuda and torch.cuda.is_available() 14 | device = torch.device("cuda" if use_cuda else 'cpu') 15 | 16 | model = torch.load(args.model) 17 | huffman_encode_model(model) 18 | -------------------------------------------------------------------------------- /net/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.cluster import KMeans 4 | from scipy.sparse import csc_matrix, csr_matrix 5 | 6 | 7 | def apply_weight_sharing(model, bits=5): 8 | """ 9 | Applies weight sharing to the given model 10 | """ 11 | for module in model.children(): 12 | dev = module.weight.device 13 | weight = module.weight.data.cpu().numpy() 14 | shape = weight.shape 15 | mat = csr_matrix(weight) if shape[0] < shape[1] else csc_matrix(weight) 16 | min_ = min(mat.data) 17 | max_ = max(mat.data) 18 | space = np.linspace(min_, max_, num=2**bits) 19 | kmeans = KMeans(n_clusters=len(space), init=space.reshape(-1,1), n_init=1, precompute_distances=True, algorithm="full") 20 | kmeans.fit(mat.data.reshape(-1,1)) 21 | new_weight = kmeans.cluster_centers_[kmeans.labels_].reshape(-1) 22 | mat.data = new_weight 23 | module.weight.data = torch.from_numpy(mat.toarray()).to(dev) 24 | 25 | 26 | -------------------------------------------------------------------------------- /weight_share.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from net.models import LeNet 7 | from net.quantization import apply_weight_sharing 8 | import util 9 | 10 | parser = argparse.ArgumentParser(description='This program quantizes weight by using weight sharing') 11 | parser.add_argument('model', type=str, help='path to saved pruned model') 12 | parser.add_argument('--no-cuda', action='store_true', default=False, 13 | help='disables CUDA training') 14 | parser.add_argument('--output', default='saves/model_after_weight_sharing.ptmodel', type=str, 15 | help='path to model output') 16 | args = parser.parse_args() 17 | 18 | use_cuda = not args.no_cuda and torch.cuda.is_available() 19 | 20 | 21 | # Define the model 22 | model = torch.load(args.model) 23 | print('accuracy before weight sharing') 24 | util.test(model, use_cuda) 25 | 26 | # Weight sharing 27 | apply_weight_sharing(model) 28 | print('accuacy after weight sharing') 29 | util.test(model, use_cuda) 30 | 31 | # Save the new model 32 | os.makedirs('saves', exist_ok=True) 33 | torch.save(model, args.output) 34 | -------------------------------------------------------------------------------- /net/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .prune import PruningModule, MaskedLinear 5 | 6 | class LeNet(PruningModule): 7 | def __init__(self, mask=False): 8 | super(LeNet, self).__init__() 9 | linear = MaskedLinear if mask else nn.Linear 10 | self.fc1 = linear(784, 300) 11 | self.fc2 = linear(300, 100) 12 | self.fc3 = linear(100, 10) 13 | 14 | def forward(self, x): 15 | x = x.view(-1, 784) 16 | x = F.relu(self.fc1(x)) 17 | x = F.relu(self.fc2(x)) 18 | x = F.log_softmax(self.fc3(x), dim=1) 19 | return x 20 | 21 | 22 | class LeNet_5(PruningModule): 23 | def __init__(self, mask=False): 24 | super(LeNet_5, self).__init__() 25 | linear = MaskedLinear if mask else Linear 26 | self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5)) 27 | self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5)) 28 | self.conv3 = nn.Conv2d(16, 120, kernel_size=(5,5)) 29 | self.fc1 = linear(120, 84) 30 | self.fc2 = linear(84, 10) 31 | 32 | def forward(self, x): 33 | # Conv1 34 | x = self.conv1(x) 35 | x = F.relu(x) 36 | x = F.max_pool2d(x, kernel_size=(2, 2), stride=2) 37 | 38 | # Conv2 39 | x = self.conv2(x) 40 | x = F.relu(x) 41 | x = F.max_pool2d(x, kernel_size=(2, 2), stride=2) 42 | 43 | # Conv3 44 | x = self.conv3(x) 45 | x = F.relu(x) 46 | 47 | # Fully-connected 48 | x = x.view(-1, 120) 49 | x = self.fc1(x) 50 | x = F.relu(x) 51 | x = self.fc2(x) 52 | x = F.log_softmax(x, dim=1) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Compression-PyTorch 2 | PyTorch implementation of 'Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding' by Song Han, Huizi Mao, William J. Dally 3 | 4 | This implementation implements three core methods in the paper - Deep Compression 5 | - Pruning 6 | - Weight sharing 7 | - Huffman Encoding 8 | 9 | ## Requirements 10 | Following packages are required for this project 11 | - Python3.6+ 12 | - tqdm 13 | - numpy 14 | - pytorch, torchvision 15 | - scipy 16 | - scikit-learn 17 | 18 | or just use docker 19 | ``` bash 20 | $ docker pull tonyapplekim/deepcompressionpytorch 21 | ``` 22 | 23 | ## Usage 24 | ### Pruning 25 | ``` bash 26 | $ python pruning.py 27 | ``` 28 | This command 29 | - trains LeNet-300-100 model with MNIST dataset 30 | - prunes weight values that has low absolute value 31 | - retrains the model with MNIST dataset 32 | - prints out non-zero statistics for each weights in the layer 33 | 34 | You can control other values such as 35 | - random seed 36 | - epochs 37 | - sensitivity 38 | - batch size 39 | - learning rate 40 | - and others 41 | For more, type `python pruning.py --help` 42 | 43 | ### Weight sharing 44 | ``` bash 45 | $ python weight_share.py saves/model_after_retraining.ptmodel 46 | ``` 47 | This command 48 | * Applies K-means clustering algorithm for the data portion of CSC or CSR matrix representation for each weight 49 | * Then, every non-zero weight is now clustered into (2**bits) groups. 50 | (Default is 32 groups - using 5 bits) 51 | - This modified model is saved to 52 | `saves/model_after_weight_sharing.ptmodel` 53 | 54 | ### Huffman coding 55 | ``` bash 56 | $ python huffman_encode.py saves/model_after_weight_sharing.ptmodel 57 | ``` 58 | This command 59 | - Applies Huffman coding algorithm for each of the weights in the network 60 | - Saves each weight to `encodings/` folder 61 | - Prints statistics for improvement 62 | 63 | 64 | 65 | ## Note 66 | Note that I didn’t apply pruning nor weight sharing nor Huffman coding for bias values. Maybe it’s better if I apply those to the biases as well, I haven’t try this out yet. 67 | 68 | Note that this work was done when I was employed at http://nota.ai 69 | 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | data/ 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | ### Python Patch ### 111 | .venv/ 112 | 113 | ### Python.VirtualEnv Stack ### 114 | # Virtualenv 115 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 116 | [Bb]in 117 | [Ii]nclude 118 | [Ll]ib 119 | [Ll]ib64 120 | [Ll]ocal 121 | [Ss]cripts 122 | pyvenv.cfg 123 | pip-selfcheck.json 124 | 125 | 126 | # End of https://www.gitignore.io/api/python 127 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | import numpy as np 5 | from torch.nn import Parameter 6 | from torch.nn.modules.module import Module 7 | import torch.nn.functional as F 8 | from torchvision import datasets, transforms 9 | 10 | def log(filename, content): 11 | with open(filename, 'a') as f: 12 | content += "\n" 13 | f.write(content) 14 | 15 | 16 | def print_model_parameters(model, with_values=False): 17 | print(f"{'Param name':20} {'Shape':30} {'Type':15}") 18 | print('-'*70) 19 | for name, param in model.named_parameters(): 20 | print(f'{name:20} {str(param.shape):30} {str(param.dtype):15}') 21 | if with_values: 22 | print(param) 23 | 24 | 25 | def print_nonzeros(model): 26 | nonzero = total = 0 27 | for name, p in model.named_parameters(): 28 | if 'mask' in name: 29 | continue 30 | tensor = p.data.cpu().numpy() 31 | nz_count = np.count_nonzero(tensor) 32 | total_params = np.prod(tensor.shape) 33 | nonzero += nz_count 34 | total += total_params 35 | print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}') 36 | print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)') 37 | 38 | 39 | def test(model, use_cuda=True): 40 | kwargs = {'num_workers': 5, 'pin_memory': True} if use_cuda else {} 41 | device = torch.device("cuda" if use_cuda else 'cpu') 42 | test_loader = torch.utils.data.DataLoader( 43 | datasets.MNIST('data', train=False, transform=transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.1307,), (0.3081,)) 46 | ])), 47 | batch_size=1000, shuffle=False, **kwargs) 48 | model.eval() 49 | test_loss = 0 50 | correct = 0 51 | with torch.no_grad(): 52 | for data, target in test_loader: 53 | data, target = data.to(device), target.to(device) 54 | output = model(data) 55 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 56 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 57 | correct += pred.eq(target.data.view_as(pred)).sum().item() 58 | 59 | test_loss /= len(test_loader.dataset) 60 | accuracy = 100. * correct / len(test_loader.dataset) 61 | print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)') 62 | return accuracy 63 | -------------------------------------------------------------------------------- /net/prune.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import Parameter 6 | from torch.nn.modules.module import Module 7 | import torch.nn.functional as F 8 | 9 | class PruningModule(Module): 10 | def prune_by_percentile(self, q=5.0, **kwargs): 11 | """ 12 | Note: 13 | The pruning percentile is based on all layer's parameters concatenated 14 | Args: 15 | q (float): percentile in float 16 | **kwargs: may contain `cuda` 17 | """ 18 | # Calculate percentile value 19 | alive_parameters = [] 20 | for name, p in self.named_parameters(): 21 | # We do not prune bias term 22 | if 'bias' in name or 'mask' in name: 23 | continue 24 | tensor = p.data.cpu().numpy() 25 | alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values 26 | alive_parameters.append(alive) 27 | 28 | all_alives = np.concatenate(alive_parameters) 29 | percentile_value = np.percentile(abs(all_alives), q) 30 | print(f'Pruning with threshold : {percentile_value}') 31 | 32 | # Prune the weights and mask 33 | # Note that module here is the layer 34 | # ex) fc1, fc2, fc3 35 | for name, module in self.named_modules(): 36 | if name in ['fc1', 'fc2', 'fc3']: 37 | module.prune(threshold=percentile_value) 38 | 39 | def prune_by_std(self, s=0.25): 40 | """ 41 | Note that `s` is a quality parameter / sensitivity value according to the paper. 42 | According to Song Han's previous paper (Learning both Weights and Connections for Efficient Neural Networks), 43 | 'The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer’s weights' 44 | 45 | I tried multiple values and empirically, 0.25 matches the paper's compression rate and number of parameters. 46 | Note : In the paper, the authors used different sensitivity values for different layers. 47 | """ 48 | for name, module in self.named_modules(): 49 | if name in ['fc1', 'fc2', 'fc3']: 50 | threshold = np.std(module.weight.data.cpu().numpy()) * s 51 | print(f'Pruning with threshold : {threshold} for layer {name}') 52 | module.prune(threshold) 53 | 54 | 55 | class MaskedLinear(Module): 56 | r"""Applies a masked linear transformation to the incoming data: :math:`y = (A * M)x + b` 57 | 58 | Args: 59 | in_features: size of each input sample 60 | out_features: size of each output sample 61 | bias: If set to False, the layer will not learn an additive bias. 62 | Default: ``True`` 63 | 64 | Shape: 65 | - Input: :math:`(N, *, in\_features)` where `*` means any number of 66 | additional dimensions 67 | - Output: :math:`(N, *, out\_features)` where all but the last dimension 68 | are the same shape as the input. 69 | 70 | Attributes: 71 | weight: the learnable weights of the module of shape 72 | (out_features x in_features) 73 | bias: the learnable bias of the module of shape (out_features) 74 | mask: the unlearnable mask for the weight. 75 | It has the same shape as weight (out_features x in_features) 76 | 77 | """ 78 | def __init__(self, in_features, out_features, bias=True): 79 | super(MaskedLinear, self).__init__() 80 | self.in_features = in_features 81 | self.out_features = out_features 82 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 83 | # Initialize the mask with 1 84 | self.mask = Parameter(torch.ones([out_features, in_features]), requires_grad=False) 85 | if bias: 86 | self.bias = Parameter(torch.Tensor(out_features)) 87 | else: 88 | self.register_parameter('bias', None) 89 | self.reset_parameters() 90 | 91 | def reset_parameters(self): 92 | stdv = 1. / math.sqrt(self.weight.size(1)) 93 | self.weight.data.uniform_(-stdv, stdv) 94 | if self.bias is not None: 95 | self.bias.data.uniform_(-stdv, stdv) 96 | 97 | def forward(self, input): 98 | return F.linear(input, self.weight * self.mask, self.bias) 99 | 100 | def __repr__(self): 101 | return self.__class__.__name__ + '(' \ 102 | + 'in_features=' + str(self.in_features) \ 103 | + ', out_features=' + str(self.out_features) \ 104 | + ', bias=' + str(self.bias is not None) + ')' 105 | 106 | def prune(self, threshold): 107 | weight_dev = self.weight.device 108 | mask_dev = self.mask.device 109 | # Convert Tensors to numpy and calculate 110 | tensor = self.weight.data.cpu().numpy() 111 | mask = self.mask.data.cpu().numpy() 112 | new_mask = np.where(abs(tensor) < threshold, 0, mask) 113 | # Apply new weight and mask 114 | self.weight.data = torch.from_numpy(tensor * new_mask).to(weight_dev) 115 | self.mask.data = torch.from_numpy(new_mask).to(mask_dev) 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /pruning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | from tqdm import tqdm 11 | 12 | from net.models import LeNet 13 | from net.quantization import apply_weight_sharing 14 | import util 15 | 16 | os.makedirs('saves', exist_ok=True) 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='PyTorch MNIST pruning from deep compression paper') 20 | parser.add_argument('--batch-size', type=int, default=50, metavar='N', 21 | help='input batch size for training (default: 50)') 22 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 23 | help='input batch size for testing (default: 1000)') 24 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 25 | help='number of epochs to train (default: 100)') 26 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 27 | help='learning rate (default: 0.01)') 28 | parser.add_argument('--no-cuda', action='store_true', default=False, 29 | help='disables CUDA training') 30 | parser.add_argument('--seed', type=int, default=42, metavar='S', 31 | help='random seed (default: 42)') 32 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 33 | help='how many batches to wait before logging training status') 34 | parser.add_argument('--log', type=str, default='log.txt', 35 | help='log file name') 36 | parser.add_argument('--sensitivity', type=float, default=2, 37 | help="sensitivity value that is multiplied to layer's std in order to get threshold value") 38 | args = parser.parse_args() 39 | 40 | # Control Seed 41 | torch.manual_seed(args.seed) 42 | 43 | # Select Device 44 | use_cuda = not args.no_cuda and torch.cuda.is_available() 45 | device = torch.device("cuda" if use_cuda else 'cpu') 46 | if use_cuda: 47 | print("Using CUDA!") 48 | torch.cuda.manual_seed(args.seed) 49 | else: 50 | print('Not using CUDA!!!') 51 | 52 | # Loader 53 | kwargs = {'num_workers': 5, 'pin_memory': True} if use_cuda else {} 54 | train_loader = torch.utils.data.DataLoader( 55 | datasets.MNIST('data', train=True, download=True, 56 | transform=transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.1307,), (0.3081,)) 59 | ])), 60 | batch_size=args.batch_size, shuffle=True, **kwargs) 61 | test_loader = torch.utils.data.DataLoader( 62 | datasets.MNIST('data', train=False, transform=transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.1307,), (0.3081,)) 65 | ])), 66 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 67 | 68 | 69 | # Define which model to use 70 | model = LeNet(mask=True).to(device) 71 | 72 | print(model) 73 | util.print_model_parameters(model) 74 | 75 | # NOTE : `weight_decay` term denotes L2 regularization loss term 76 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001) 77 | initial_optimizer_state_dict = optimizer.state_dict() 78 | 79 | def train(epochs): 80 | model.train() 81 | for epoch in range(epochs): 82 | pbar = tqdm(enumerate(train_loader), total=len(train_loader)) 83 | for batch_idx, (data, target) in pbar: 84 | data, target = data.to(device), target.to(device) 85 | optimizer.zero_grad() 86 | output = model(data) 87 | loss = F.nll_loss(output, target) 88 | loss.backward() 89 | 90 | # zero-out all the gradients corresponding to the pruned connections 91 | for name, p in model.named_parameters(): 92 | if 'mask' in name: 93 | continue 94 | tensor = p.data.cpu().numpy() 95 | grad_tensor = p.grad.data.cpu().numpy() 96 | grad_tensor = np.where(tensor==0, 0, grad_tensor) 97 | p.grad.data = torch.from_numpy(grad_tensor).to(device) 98 | 99 | optimizer.step() 100 | if batch_idx % args.log_interval == 0: 101 | done = batch_idx * len(data) 102 | percentage = 100. * batch_idx / len(train_loader) 103 | pbar.set_description(f'Train Epoch: {epoch} [{done:5}/{len(train_loader.dataset)} ({percentage:3.0f}%)] Loss: {loss.item():.6f}') 104 | 105 | 106 | def test(): 107 | model.eval() 108 | test_loss = 0 109 | correct = 0 110 | with torch.no_grad(): 111 | for data, target in test_loader: 112 | data, target = data.to(device), target.to(device) 113 | output = model(data) 114 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 115 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 116 | correct += pred.eq(target.data.view_as(pred)).sum().item() 117 | 118 | test_loss /= len(test_loader.dataset) 119 | accuracy = 100. * correct / len(test_loader.dataset) 120 | print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)') 121 | return accuracy 122 | 123 | 124 | # Initial training 125 | print("--- Initial training ---") 126 | train(args.epochs) 127 | accuracy = test() 128 | util.log(args.log, f"initial_accuracy {accuracy}") 129 | torch.save(model, f"saves/initial_model.ptmodel") 130 | print("--- Before pruning ---") 131 | util.print_nonzeros(model) 132 | 133 | # Pruning 134 | model.prune_by_std(args.sensitivity) 135 | accuracy = test() 136 | util.log(args.log, f"accuracy_after_pruning {accuracy}") 137 | print("--- After pruning ---") 138 | util.print_nonzeros(model) 139 | 140 | # Retrain 141 | print("--- Retraining ---") 142 | optimizer.load_state_dict(initial_optimizer_state_dict) # Reset the optimizer 143 | train(args.epochs) 144 | torch.save(model, f"saves/model_after_retraining.ptmodel") 145 | accuracy = test() 146 | util.log(args.log, f"accuracy_after_retraining {accuracy}") 147 | 148 | print("--- After Retraining ---") 149 | util.print_nonzeros(model) 150 | -------------------------------------------------------------------------------- /net/huffmancoding.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict, namedtuple 3 | from heapq import heappush, heappop, heapify 4 | import struct 5 | from pathlib import Path 6 | 7 | import torch 8 | import numpy as np 9 | from scipy.sparse import csr_matrix, csc_matrix 10 | 11 | Node = namedtuple('Node', 'freq value left right') 12 | Node.__lt__ = lambda x, y: x.freq < y.freq 13 | 14 | def huffman_encode(arr, prefix, save_dir='./'): 15 | """ 16 | Encodes numpy array 'arr' and saves to `save_dir` 17 | The names of binary files are prefixed with `prefix` 18 | returns the number of bytes for the tree and the data after the compression 19 | """ 20 | # Infer dtype 21 | dtype = str(arr.dtype) 22 | 23 | # Calculate frequency in arr 24 | freq_map = defaultdict(int) 25 | convert_map = {'float32':float, 'int32':int} 26 | for value in np.nditer(arr): 27 | value = convert_map[dtype](value) 28 | freq_map[value] += 1 29 | 30 | # Make heap 31 | heap = [Node(frequency, value, None, None) for value, frequency in freq_map.items()] 32 | heapify(heap) 33 | 34 | # Merge nodes 35 | while(len(heap) > 1): 36 | node1 = heappop(heap) 37 | node2 = heappop(heap) 38 | merged = Node(node1.freq + node2.freq, None, node1, node2) 39 | heappush(heap, merged) 40 | 41 | # Generate code value mapping 42 | value2code = {} 43 | 44 | def generate_code(node, code): 45 | if node is None: 46 | return 47 | if node.value is not None: 48 | value2code[node.value] = code 49 | return 50 | generate_code(node.left, code + '0') 51 | generate_code(node.right, code + '1') 52 | 53 | root = heappop(heap) 54 | generate_code(root, '') 55 | 56 | # Path to save location 57 | directory = Path(save_dir) 58 | 59 | # Dump data 60 | data_encoding = ''.join(value2code[convert_map[dtype](value)] for value in np.nditer(arr)) 61 | datasize = dump(data_encoding, directory/f'{prefix}.bin') 62 | 63 | # Dump codebook (huffman tree) 64 | codebook_encoding = encode_huffman_tree(root, dtype) 65 | treesize = dump(codebook_encoding, directory/f'{prefix}_codebook.bin') 66 | 67 | return treesize, datasize 68 | 69 | 70 | def huffman_decode(directory, prefix, dtype): 71 | """ 72 | Decodes binary files from directory 73 | """ 74 | directory = Path(directory) 75 | 76 | # Read the codebook 77 | codebook_encoding = load(directory/f'{prefix}_codebook.bin') 78 | root = decode_huffman_tree(codebook_encoding, dtype) 79 | 80 | # Read the data 81 | data_encoding = load(directory/f'{prefix}.bin') 82 | 83 | # Decode 84 | data = [] 85 | ptr = root 86 | for bit in data_encoding: 87 | ptr = ptr.left if bit == '0' else ptr.right 88 | if ptr.value is not None: # Leaf node 89 | data.append(ptr.value) 90 | ptr = root 91 | 92 | return np.array(data, dtype=dtype) 93 | 94 | 95 | # Logics to encode / decode huffman tree 96 | # Referenced the idea from https://stackoverflow.com/questions/759707/efficient-way-of-storing-huffman-tree 97 | def encode_huffman_tree(root, dtype): 98 | """ 99 | Encodes a huffman tree to string of '0's and '1's 100 | """ 101 | converter = {'float32':float2bitstr, 'int32':int2bitstr} 102 | code_list = [] 103 | def encode_node(node): 104 | if node.value is not None: # node is leaf node 105 | code_list.append('1') 106 | lst = list(converter[dtype](node.value)) 107 | code_list.extend(lst) 108 | else: 109 | code_list.append('0') 110 | encode_node(node.left) 111 | encode_node(node.right) 112 | encode_node(root) 113 | return ''.join(code_list) 114 | 115 | 116 | def decode_huffman_tree(code_str, dtype): 117 | """ 118 | Decodes a string of '0's and '1's and costructs a huffman tree 119 | """ 120 | converter = {'float32':bitstr2float, 'int32':bitstr2int} 121 | idx = 0 122 | def decode_node(): 123 | nonlocal idx 124 | info = code_str[idx] 125 | idx += 1 126 | if info == '1': # Leaf node 127 | value = converter[dtype](code_str[idx:idx+32]) 128 | idx += 32 129 | return Node(0, value, None, None) 130 | else: 131 | left = decode_node() 132 | right = decode_node() 133 | return Node(0, None, left, right) 134 | 135 | return decode_node() 136 | 137 | 138 | 139 | # My own dump / load logics 140 | def dump(code_str, filename): 141 | """ 142 | code_str : string of either '0' and '1' characters 143 | this function dumps to a file 144 | returns how many bytes are written 145 | """ 146 | # Make header (1 byte) and add padding to the end 147 | # Files need to be byte aligned. 148 | # Therefore we add 1 byte as a header which indicates how many bits are padded to the end 149 | # This introduces minimum of 8 bits, maximum of 15 bits overhead 150 | num_of_padding = -len(code_str) % 8 151 | header = f"{num_of_padding:08b}" 152 | code_str = header + code_str + '0' * num_of_padding 153 | 154 | # Convert string to integers and to real bytes 155 | byte_arr = bytearray(int(code_str[i:i+8], 2) for i in range(0, len(code_str), 8)) 156 | 157 | # Dump to a file 158 | with open(filename, 'wb') as f: 159 | f.write(byte_arr) 160 | return len(byte_arr) 161 | 162 | 163 | def load(filename): 164 | """ 165 | This function reads a file and makes a string of '0's and '1's 166 | """ 167 | with open(filename, 'rb') as f: 168 | header = f.read(1) 169 | rest = f.read() # bytes 170 | code_str = ''.join(f'{byte:08b}' for byte in rest) 171 | offset = ord(header) 172 | if offset != 0: 173 | code_str = code_str[:-offset] # string of '0's and '1's 174 | return code_str 175 | 176 | 177 | # Helper functions for converting between bit string and (float or int) 178 | def float2bitstr(f): 179 | four_bytes = struct.pack('>f', f) # bytes 180 | return ''.join(f'{byte:08b}' for byte in four_bytes) # string of '0's and '1's 181 | 182 | def bitstr2float(bitstr): 183 | byte_arr = bytearray(int(bitstr[i:i+8], 2) for i in range(0, len(bitstr), 8)) 184 | return struct.unpack('>f', byte_arr)[0] 185 | 186 | def int2bitstr(integer): 187 | four_bytes = struct.pack('>I', integer) # bytes 188 | return ''.join(f'{byte:08b}' for byte in four_bytes) # string of '0's and '1's 189 | 190 | def bitstr2int(bitstr): 191 | byte_arr = bytearray(int(bitstr[i:i+8], 2) for i in range(0, len(bitstr), 8)) 192 | return struct.unpack('>I', byte_arr)[0] 193 | 194 | 195 | # Functions for calculating / reconstructing index diff 196 | def calc_index_diff(indptr): 197 | return indptr[1:] - indptr[:-1] 198 | 199 | def reconstruct_indptr(diff): 200 | return np.concatenate([[0], np.cumsum(diff)]) 201 | 202 | 203 | # Encode / Decode models 204 | def huffman_encode_model(model, directory='encodings/'): 205 | os.makedirs(directory, exist_ok=True) 206 | original_total = 0 207 | compressed_total = 0 208 | print(f"{'Layer':<15} | {'original':>10} {'compressed':>10} {'improvement':>11} {'percent':>7}") 209 | print('-'*70) 210 | for name, param in model.named_parameters(): 211 | if 'mask' in name: 212 | continue 213 | if 'weight' in name: 214 | weight = param.data.cpu().numpy() 215 | shape = weight.shape 216 | form = 'csr' if shape[0] < shape[1] else 'csc' 217 | mat = csr_matrix(weight) if shape[0] < shape[1] else csc_matrix(weight) 218 | 219 | # Encode 220 | t0, d0 = huffman_encode(mat.data, name+f'_{form}_data', directory) 221 | t1, d1 = huffman_encode(mat.indices, name+f'_{form}_indices', directory) 222 | t2, d2 = huffman_encode(calc_index_diff(mat.indptr), name+f'_{form}_indptr', directory) 223 | 224 | # Print statistics 225 | original = mat.data.nbytes + mat.indices.nbytes + mat.indptr.nbytes 226 | compressed = t0 + t1 + t2 + d0 + d1 + d2 227 | 228 | print(f"{name:<15} | {original:10} {compressed:10} {original / compressed:>10.2f}x {100 * compressed / original:>6.2f}%") 229 | else: # bias 230 | # Note that we do not huffman encode bias 231 | bias = param.data.cpu().numpy() 232 | bias.dump(f'{directory}/{name}') 233 | 234 | # Print statistics 235 | original = bias.nbytes 236 | compressed = original 237 | 238 | print(f"{name:<15} | {original:10} {compressed:10} {original / compressed:>10.2f}x {100 * compressed / original:>6.2f}%") 239 | original_total += original 240 | compressed_total += compressed 241 | 242 | print('-'*70) 243 | print(f"{'total':15} | {original_total:>10} {compressed_total:>10} {original_total / compressed_total:>10.2f}x {100 * compressed_total / original_total:>6.2f}%") 244 | 245 | 246 | def huffman_decode_model(model, directory='encodings/'): 247 | for name, param in model.named_parameters(): 248 | if 'mask' in name: 249 | continue 250 | if 'weight' in name: 251 | dev = param.device 252 | weight = param.data.cpu().numpy() 253 | shape = weight.shape 254 | form = 'csr' if shape[0] < shape[1] else 'csc' 255 | matrix = csr_matrix if shape[0] < shape[1] else csc_matrix 256 | 257 | # Decode data 258 | data = huffman_decode(directory, name+f'_{form}_data', dtype='float32') 259 | indices = huffman_decode(directory, name+f'_{form}_indices', dtype='int32') 260 | indptr = reconstruct_indptr(huffman_decode(directory, name+f'_{form}_indptr', dtype='int32')) 261 | 262 | # Construct matrix 263 | mat = matrix((data, indices, indptr), shape) 264 | 265 | # Insert to model 266 | param.data = torch.from_numpy(mat.toarray()).to(dev) 267 | else: 268 | dev = param.device 269 | bias = np.load(directory+'/'+name) 270 | param.data = torch.from_numpy(bias).to(dev) 271 | --------------------------------------------------------------------------------