├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── loaders.cpython-37.pyc │ └── maskedLayers.cpython-37.pyc ├── loaders.py ├── avg_speed_calc.py ├── maskedLayers.py └── plots.ipynb ├── .gitignore ├── knowledge_distillation ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── student.cpython-37.pyc ├── student.py ├── train_student.py └── distillation.py ├── models ├── exp1.pth ├── exp2.pth ├── exp3.pth ├── best_acc.pth ├── best_acc_v1.4.pth ├── best_acc_student.pth ├── best_acc_student_v1.4.pth ├── best_acc_student_with_distillation.pth └── best_acc_student_with_distillation_v1.4.pth ├── README.md ├── imgs ├── pruningResults.png └── Architecture-of-LeNet-5.png ├── metrics ├── __pycache__ │ ├── flops.cpython-37.pyc │ ├── memory.cpython-37.pyc │ ├── size.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── accuracy.cpython-37.pyc │ ├── modules.cpython-37.pyc │ ├── maskedLayers.cpython-37.pyc │ ├── abstract_flops.cpython-37.pyc │ └── global_sparsity.cpython-37.pyc ├── global_sparsity.py ├── __init__.py ├── size.py ├── flops.py ├── abstract_flops.py └── utils.py ├── .vscode └── settings.json ├── .pre-commit-config.yaml ├── .github └── workflows │ └── ci.yml ├── experiments.py ├── lenet_pytorch.py ├── pruning_loop.py └── Results.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Articles/ 2 | .vscode/ -------------------------------------------------------------------------------- /knowledge_distillation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/exp1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/exp1.pth -------------------------------------------------------------------------------- /models/exp2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/exp2.pth -------------------------------------------------------------------------------- /models/exp3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/exp3.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pruningExperiments 2 | Repository to perform simple pruning experiments on neural networks 3 | -------------------------------------------------------------------------------- /models/best_acc.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc.pth -------------------------------------------------------------------------------- /imgs/pruningResults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/imgs/pruningResults.png -------------------------------------------------------------------------------- /models/best_acc_v1.4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_v1.4.pth -------------------------------------------------------------------------------- /models/best_acc_student.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student.pth -------------------------------------------------------------------------------- /imgs/Architecture-of-LeNet-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/imgs/Architecture-of-LeNet-5.png -------------------------------------------------------------------------------- /models/best_acc_student_v1.4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student_v1.4.pth -------------------------------------------------------------------------------- /metrics/__pycache__/flops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/flops.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/memory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/memory.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/size.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/size.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loaders.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/utils/__pycache__/loaders.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/accuracy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/accuracy.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /models/best_acc_student_with_distillation.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student_with_distillation.pth -------------------------------------------------------------------------------- /utils/__pycache__/maskedLayers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/utils/__pycache__/maskedLayers.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/maskedLayers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/maskedLayers.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/abstract_flops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/abstract_flops.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/global_sparsity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/global_sparsity.cpython-37.pyc -------------------------------------------------------------------------------- /models/best_acc_student_with_distillation_v1.4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student_with_distillation_v1.4.pth -------------------------------------------------------------------------------- /knowledge_distillation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/knowledge_distillation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /knowledge_distillation/__pycache__/student.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/knowledge_distillation/__pycache__/student.cpython-37.pyc -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/userlocal/miniconda3/envs/torch/bin/python", 3 | "python.linting.enabled": true, 4 | "python.linting.flake8Enabled": true, 5 | "python.linting.pycodestyleEnabled": false 6 | } 7 | -------------------------------------------------------------------------------- /metrics/global_sparsity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calculate_global_sparsity(model): 5 | global_sparsity = ( 6 | 100.0 7 | * float( 8 | torch.sum(model.conv1.weight == 0) 9 | + torch.sum(model.conv2.weight == 0) 10 | + torch.sum(model.fc1.weight == 0) 11 | + torch.sum(model.fc2.weight == 0) 12 | ) 13 | / float( 14 | model.conv1.weight.nelement() 15 | + model.conv2.weight.nelement() 16 | + model.fc1.weight.nelement() 17 | + model.fc2.weight.nelement() 18 | ) 19 | ) 20 | 21 | global_compression = 100 / (100 - global_sparsity) 22 | 23 | return global_sparsity, global_compression 24 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def nonzero(tensor): 6 | """Returns absolute number of values different from 0 7 | 8 | Arguments: 9 | tensor {numpy.ndarray} -- Array to compute over 10 | 11 | Returns: 12 | int -- Number of nonzero elements 13 | """ 14 | return np.sum(tensor != 0.0) 15 | 16 | 17 | # https://pytorch.org/docs/stable/tensor_attributes.html 18 | dtype2bits = { 19 | torch.float32: 32, 20 | torch.float: 32, 21 | torch.float64: 64, 22 | torch.double: 64, 23 | torch.float16: 16, 24 | torch.half: 16, 25 | torch.uint8: 8, 26 | torch.int8: 8, 27 | torch.int16: 16, 28 | torch.short: 16, 29 | torch.int32: 32, 30 | torch.int: 32, 31 | torch.int64: 64, 32 | torch.long: 64, 33 | torch.bool: 1, 34 | } 35 | 36 | 37 | from .flops import flops 38 | from .size import model_size 39 | -------------------------------------------------------------------------------- /metrics/size.py: -------------------------------------------------------------------------------- 1 | """Model size metrics 2 | """ 3 | 4 | import numpy as np 5 | 6 | from . import dtype2bits, nonzero 7 | 8 | 9 | def model_size(model, as_bits=False): 10 | """Returns absolute and nonzero model size 11 | 12 | Arguments: 13 | model {torch.nn.Module} -- Network to compute model size over 14 | 15 | Keyword Arguments: 16 | as_bits {bool} -- Whether to account for the size of dtype 17 | 18 | Returns: 19 | int -- Total number of weight & bias params 20 | int -- Out total_params exactly how many are nonzero 21 | """ 22 | 23 | total_params = 0 24 | nonzero_params = 0 25 | for tensor in model.parameters(): 26 | t = np.prod(tensor.shape) 27 | nz = nonzero(tensor.detach().cpu().numpy()) 28 | if as_bits: 29 | bits = dtype2bits[tensor.dtype] 30 | t *= bits 31 | nz *= bits 32 | total_params += t 33 | nonzero_params += nz 34 | return int(total_params), int(nonzero_params) 35 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # I took this configuration from https://github.com/ternaus/iglovikov_helper_functions/blob/master/.pre-commit-config.yaml 2 | 3 | exclude: _pb2\.py$ 4 | repos: 5 | - repo: https://github.com/pre-commit/mirrors-isort 6 | rev: f0001b2 # Use the revision sha / tag you want to point at 7 | hooks: 8 | - id: isort 9 | args: ["--profile", "black"] 10 | - repo: https://github.com/psf/black 11 | rev: 20.8b1 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/asottile/yesqa 15 | rev: v1.1.0 16 | hooks: 17 | - id: yesqa 18 | additional_dependencies: 19 | - flake8-bugbear==20.1.4 20 | - flake8-builtins==1.5.2 21 | - flake8-comprehensions==3.2.2 22 | - flake8-tidy-imports==4.1.0 23 | - flake8==3.7.9 24 | - repo: https://github.com/pre-commit/pre-commit-hooks 25 | rev: v2.3.0 26 | hooks: 27 | - id: check-docstring-first 28 | - id: check-merge-conflict 29 | - id: check-yaml 30 | - id: debug-statements 31 | - id: end-of-file-fixer 32 | - id: trailing-whitespace 33 | -------------------------------------------------------------------------------- /knowledge_distillation/student.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = nn.Dropout2d() 11 | self.fc1 = nn.Linear(320, 50) 12 | self.fc2 = nn.Linear(50, 10) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, training=self.training) 20 | x = self.fc2(x) 21 | return x 22 | 23 | 24 | class LeNetStudent(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | # kernel_size = 2 was too strong, so 4 28 | self.conv1 = nn.Conv2d(1, 3, kernel_size=4) 29 | # Model with batchnorm was too strong reaching 0.973 accuracy 30 | # so I switched it off 31 | # self.bn = nn.BatchNorm2d(6) 32 | self.flatten = nn.Flatten() 33 | self.fc2 = nn.Linear(1875, 10) 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | # x = self.bn(x) 38 | # x = F.relu(x) 39 | # x = F.max_pool2d(x, 2) 40 | x = self.flatten(x) 41 | x = self.fc2(x) 42 | return x 43 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.7] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Cache pip 24 | uses: actions/cache@v1 25 | with: 26 | path: ~/.cache/pip # This path is specific to Ubuntu 27 | # Look to see if there is a cache hit for the corresponding requirements file 28 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 29 | restore-keys: | 30 | ${{ runner.os }}-pip- 31 | ${{ runner.os }}- 32 | # You can test your matrix by printing the current Python version 33 | - name: Display Python version 34 | run: python -c "import sys; print(sys.version)" 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install black flake8 isort pylint 39 | - name: Run black 40 | run: 41 | black --check . 42 | -------------------------------------------------------------------------------- /utils/loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | def get_loaders(batch_size_train, batch_size_test): 6 | """Function to return train and test datasets for MNIST 7 | 8 | :param batch_size_train: Batch size used for train 9 | :param batch_size_test: Batch size used for test 10 | 11 | :return: Data loaders for train and test data 12 | """ 13 | 14 | train_loader = torch.utils.data.DataLoader( 15 | torchvision.datasets.MNIST( 16 | "~/.cache/database/", 17 | train=True, 18 | download=True, 19 | transform=torchvision.transforms.Compose( 20 | [ 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 23 | ] 24 | ), 25 | ), 26 | batch_size=batch_size_train, 27 | shuffle=True, 28 | ) 29 | 30 | test_loader = torch.utils.data.DataLoader( 31 | torchvision.datasets.MNIST( 32 | "~/.cache/database/", 33 | train=False, 34 | download=True, 35 | transform=torchvision.transforms.Compose( 36 | [ 37 | torchvision.transforms.ToTensor(), 38 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 39 | ] 40 | ), 41 | ), 42 | batch_size=batch_size_test, 43 | shuffle=False, 44 | ) 45 | 46 | return train_loader, test_loader 47 | -------------------------------------------------------------------------------- /utils/avg_speed_calc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from loaders import get_loaders 8 | 9 | batch_size_train = 2048 10 | batch_size_test = 2048 11 | 12 | 13 | class LeNet(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 17 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 18 | self.conv2_drop = nn.Dropout2d() 19 | self.fc1 = nn.Linear(320, 50) 20 | self.fc2 = nn.Linear(50, 10) 21 | 22 | def forward(self, x): 23 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 24 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 25 | x = x.view(-1, 320) 26 | x = F.relu(self.fc1(x)) 27 | x = F.dropout(x, training=self.training) 28 | x = self.fc2(x) 29 | return F.log_softmax(x, 1) 30 | 31 | 32 | def go_through_data(net, data_loader, device): 33 | 34 | net.eval() 35 | with torch.no_grad(): 36 | for (idx, (x, t)) in enumerate(data_loader): 37 | x = net.forward(x.to(device)) 38 | t = t.to(device) 39 | return 1 40 | 41 | 42 | device = "cuda" 43 | train_loader, _ = get_loaders(batch_size_train, batch_size_test) 44 | net = LeNet().to(device) 45 | net.load_state_dict(torch.load("models/exp1.pth")) 46 | 47 | t0 = time.time() 48 | for i in range(20): 49 | go_through_data(net, train_loader, device) 50 | 51 | total_time = time.time() - t0 52 | print(total_time, total_time / (i + 1)) 53 | # 46.204389810562134 9.240877962112426 54 | -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch.nn.utils.prune as prune 3 | 4 | from pruning_loop import PruningExperiment 5 | 6 | experiment_number = 3 7 | 8 | # Experiment 1: Random weights pruning 9 | # Change amount = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] 10 | amount = 0.7 11 | kwargs = {} 12 | pruning_strategy_1 = [ 13 | ("fc1", prune.random_unstructured, "weight", amount, kwargs), 14 | ("fc2", prune.random_unstructured, "weight", amount, kwargs), 15 | ("conv1", prune.random_unstructured, "weight", amount, kwargs), 16 | ("conv2", prune.random_unstructured, "weight", amount, kwargs), 17 | ] 18 | 19 | # Experiment 2: Pruning based on norm 20 | amount = 0.7 21 | kwargs = {} 22 | pruning_strategy_2 = [ 23 | ("fc1", prune.l1_unstructured, "weight", amount, kwargs), 24 | ("fc2", prune.l1_unstructured, "weight", amount, kwargs), 25 | ("conv1", prune.l1_unstructured, "weight", amount, kwargs), 26 | ("conv2", prune.l1_unstructured, "weight", amount, kwargs), 27 | ] 28 | 29 | # Experiment 3: Structural pruning with L1 norm 30 | amount = 0.7 31 | kwargs = {"n": 1, "dim": 0} 32 | 33 | pruning_strategy_3 = [ 34 | ("fc1", prune.ln_structured, "weight", amount, kwargs), 35 | ("conv1", prune.ln_structured, "weight", amount, kwargs), 36 | ("conv2", prune.ln_structured, "weight", amount, kwargs), 37 | ] 38 | 39 | 40 | if experiment_number == 1: 41 | pe = PruningExperiment( 42 | pruning_strategy=pruning_strategy_1, 43 | epochs_prune_finetune=3, 44 | epochs_finetune=4, 45 | save_model="exp1", 46 | ) 47 | 48 | if experiment_number == 2: 49 | pe = PruningExperiment( 50 | pruning_strategy=pruning_strategy_2, 51 | epochs_prune_finetune=3, 52 | epochs_finetune=4, 53 | save_model="exp2", 54 | ) 55 | 56 | if experiment_number == 3: 57 | pe = PruningExperiment( 58 | pruning_strategy=pruning_strategy_3, 59 | epochs_prune_finetune=3, 60 | epochs_finetune=4, 61 | save_model="exp3", 62 | ) 63 | 64 | 65 | print(pe.run()) 66 | -------------------------------------------------------------------------------- /knowledge_distillation/train_student.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from student import LeNetStudent 8 | 9 | # To make relative imports work 10 | # See https://stackoverflow.com/questions/16981921/relative-imports-in-python-3 11 | 12 | PACKAGE_PARENT = ".." 13 | SCRIPT_DIR = os.path.dirname( 14 | os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))) 15 | ) 16 | sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) 17 | 18 | from utils.loaders import get_loaders 19 | 20 | 21 | def train(net, loss_fn, optimizer, data_loader, device): 22 | net.train() 23 | for (idx, (x, t)) in enumerate(data_loader): 24 | optimizer.zero_grad() 25 | x = net.forward(x.to(device)) 26 | t = t.to(device) 27 | loss = loss_fn(x, t) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | 32 | def test(net, data_loader, device): 33 | top1 = 0 34 | correct_samples = 0 35 | total_samples = 0 36 | net.eval() 37 | with torch.no_grad(): 38 | for (idx, (x, t)) in enumerate(data_loader): 39 | x = net.forward(x.to(device)) 40 | t = t.to(device) 41 | _, indices = torch.max(x, 1) 42 | correct_samples += torch.sum(indices == t) 43 | total_samples += t.shape[0] 44 | 45 | top1 = float(correct_samples) / total_samples 46 | return top1 47 | 48 | 49 | if __name__ == "__main__": 50 | 51 | device = "cuda" 52 | 53 | net = LeNetStudent().to(device) 54 | 55 | batch_size_train = 512 56 | batch_size_test = 1024 57 | nb_epoch = 60 58 | 59 | train_loader, test_loader = get_loaders(batch_size_train, batch_size_test) 60 | 61 | loss_fn = nn.CrossEntropyLoss() 62 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5) 63 | 64 | best_model = None 65 | best_acc = 0 66 | 67 | for epoch in range(nb_epoch): 68 | train(net, loss_fn, optimizer, train_loader, device) 69 | test_top1 = test(net, test_loader, device) 70 | print(f"Epoch {epoch}. Top1 {test_top1:.4f}") 71 | if test_top1 > best_acc: 72 | best_model = net 73 | best_acc = test_top1 74 | 75 | torch.save(best_model.state_dict(), "models/best_acc_student.pth") 76 | -------------------------------------------------------------------------------- /metrics/flops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | 4 | from utils.maskedLayers import Conv2dMasked, LinearMasked 5 | 6 | from . import nonzero 7 | from .abstract_flops import conv2d_flops, dense_flops 8 | from .utils import get_activations 9 | 10 | 11 | def _conv2d_flops(module, activation): 12 | # Auxiliary func to use abstract flop computation 13 | 14 | # Drop batch & channels. Channels can be dropped since 15 | # unlike shape they have to match to in_channels 16 | input_shape = activation.shape[2:] 17 | # TODO Add support for dilation and padding size 18 | return conv2d_flops( 19 | in_channels=module.in_channels, 20 | out_channels=module.out_channels, 21 | input_shape=input_shape, 22 | kernel_shape=module.kernel_size, 23 | padding=module.padding_mode, 24 | strides=module.stride, 25 | dilation=module.dilation, 26 | ) 27 | 28 | 29 | def _linear_flops(module, activation): 30 | # Auxiliary func to use abstract flop computation 31 | return dense_flops(module.in_features, module.out_features) 32 | 33 | 34 | def flops(model, input): 35 | """Compute Multiply-add FLOPs estimate from model 36 | 37 | Arguments: 38 | model {torch.nn.Module} -- Module to compute flops for 39 | input {torch.Tensor} -- Input tensor needed for activations 40 | 41 | Returns: 42 | tuple: 43 | - int - Number of total FLOPs 44 | - int - Number of FLOPs related to nonzero parameters 45 | """ 46 | FLOP_fn = { 47 | nn.Conv2d: _conv2d_flops, 48 | nn.Linear: _linear_flops, 49 | Conv2dMasked: _conv2d_flops, 50 | LinearMasked: _linear_flops, 51 | } 52 | 53 | total_flops = nonzero_flops = 0 54 | activations = get_activations(model, input) 55 | 56 | # The ones we need for backprop 57 | for m, (act, _) in activations.items(): 58 | if m.__class__ in FLOP_fn: 59 | w = m.weight.detach().cpu().numpy().copy() 60 | module_flops = FLOP_fn[m.__class__](m, act) 61 | total_flops += module_flops 62 | # For our operations, all weights are symmetric so we can just 63 | # do simple rule of three for the estimation 64 | nonzero_flops += module_flops * nonzero(w).sum() / np.prod(w.shape) 65 | 66 | return total_flops, nonzero_flops 67 | -------------------------------------------------------------------------------- /utils/maskedLayers.py: -------------------------------------------------------------------------------- 1 | # Taken from here 2 | # https://github.com/wanglouis49/pytorch-weights_pruning/blob/master/pruning/layers.py 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class LinearMasked(nn.Linear): 10 | def __init__(self, in_features, out_features, bias=True): 11 | super(LinearMasked, self).__init__(in_features, out_features, bias) 12 | self.mask_flag = False 13 | 14 | def set_mask(self, mask): 15 | self.mask = Variable(mask, requires_grad=False, volatile=False) 16 | self.weight.data = self.weight.data * self.mask.data 17 | self.mask_flag = True 18 | 19 | def get_mask(self): 20 | print(self.mask_flag) 21 | return self.mask 22 | 23 | def forward(self, x): 24 | if self.mask_flag: 25 | weight = self.weight * self.mask 26 | return F.linear(x, weight, self.bias) 27 | else: 28 | return F.linear(x, self.weight, self.bias) 29 | 30 | 31 | class Conv2dMasked(nn.Conv2d): 32 | def __init__( 33 | self, 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride=1, 38 | padding=0, 39 | dilation=1, 40 | groups=1, 41 | bias=True, 42 | ): 43 | super(Conv2dMasked, self).__init__( 44 | in_channels, 45 | out_channels, 46 | kernel_size, 47 | stride, 48 | padding, 49 | dilation, 50 | groups, 51 | bias, 52 | ) 53 | self.mask_flag = False 54 | 55 | def set_mask(self, mask): 56 | self.mask = Variable(mask, requires_grad=False, volatile=False) 57 | self.weight.data = self.weight.data * self.mask.data 58 | self.mask_flag = True 59 | 60 | def get_mask(self): 61 | print(self.mask_flag) 62 | return self.mask 63 | 64 | def forward(self, x): 65 | if self.mask_flag: 66 | weight = self.weight * self.mask 67 | return F.conv2d( 68 | x, 69 | weight, 70 | self.bias, 71 | self.stride, 72 | self.padding, 73 | self.dilation, 74 | self.groups, 75 | ) 76 | else: 77 | return F.conv2d( 78 | x, 79 | self.weight, 80 | self.bias, 81 | self.stride, 82 | self.padding, 83 | self.dilation, 84 | self.groups, 85 | ) 86 | -------------------------------------------------------------------------------- /lenet_pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torchvision 8 | 9 | 10 | class LeNet(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 14 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 15 | self.conv2_drop = nn.Dropout2d() 16 | self.fc1 = nn.Linear(320, 50) 17 | self.fc2 = nn.Linear(50, 10) 18 | 19 | def forward(self, x): 20 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 21 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 22 | x = x.view(-1, 320) 23 | x = F.relu(self.fc1(x)) 24 | x = F.dropout(x, training=self.training) 25 | x = self.fc2(x) 26 | return F.log_softmax(x, 1) 27 | 28 | 29 | def train(net, optimizer, data_loader, device): 30 | net.train() 31 | for (idx, (x, t)) in enumerate(data_loader): 32 | optimizer.zero_grad() 33 | x = net.forward(x.to(device)) 34 | t = t.to(device) 35 | loss = F.nll_loss(x, t) 36 | loss.backward() 37 | optimizer.step() 38 | 39 | 40 | def test(net, data_loader, device): 41 | top1 = 0 # TODO compute top1 42 | correct_samples = 0 43 | total_samples = 0 44 | net.eval() 45 | with torch.no_grad(): 46 | for (idx, (x, t)) in enumerate(data_loader): 47 | x = net.forward(x.to(device)) 48 | t = t.to(device) 49 | _, indices = torch.max(x, 1) 50 | correct_samples += torch.sum(indices == t) 51 | total_samples += t.shape[0] 52 | 53 | top1 = float(correct_samples) / total_samples 54 | return top1 55 | 56 | 57 | if __name__ == "__main__": 58 | nb_epoch = 80 59 | batch_size_train = 1024 60 | batch_size_test = 5120 61 | device = "cuda" # change to 'cpu' if needed 62 | 63 | best_model = None 64 | best_acc = 0 65 | 66 | train_loader = torch.utils.data.DataLoader( 67 | torchvision.datasets.MNIST( 68 | "~/.cache/database/", 69 | train=True, 70 | download=True, 71 | transform=torchvision.transforms.Compose( 72 | [ 73 | torchvision.transforms.ToTensor(), 74 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 75 | ] 76 | ), 77 | ), 78 | batch_size=batch_size_train, 79 | shuffle=True, 80 | ) 81 | 82 | test_loader = torch.utils.data.DataLoader( 83 | torchvision.datasets.MNIST( 84 | "~/.cache/database/", 85 | train=False, 86 | download=True, 87 | transform=torchvision.transforms.Compose( 88 | [ 89 | torchvision.transforms.ToTensor(), 90 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 91 | ] 92 | ), 93 | ), 94 | batch_size=batch_size_test, 95 | shuffle=False, 96 | ) 97 | net = LeNet().to(device) 98 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5) 99 | 100 | for epoch in range(nb_epoch): 101 | train(net, optimizer, train_loader, device) 102 | test_top1 = test(net, test_loader, device) 103 | print(f"Epoch {epoch}. Top1 {test_top1:.4f}") 104 | if test_top1 > best_acc: 105 | best_model = net 106 | best_acc = test_top1 107 | 108 | torch.save(best_model.state_dict(), "models/best_acc.pth") 109 | -------------------------------------------------------------------------------- /metrics/abstract_flops.py: -------------------------------------------------------------------------------- 1 | """Module for computing FLOPs from specification, not from Torch objects 2 | """ 3 | import numpy as np 4 | 5 | 6 | def dense_flops(in_neurons, out_neurons): 7 | """Compute the number of multiply-adds used by a Dense (Linear) layer""" 8 | return in_neurons * out_neurons 9 | 10 | 11 | def conv2d_flops( 12 | in_channels, 13 | out_channels, 14 | input_shape, 15 | kernel_shape, 16 | padding="same", 17 | strides=1, 18 | dilation=1, 19 | ): 20 | """Compute the number of multiply-adds used by a Conv2D layer 21 | Args: 22 | in_channels (int): The number of channels in the layer's input 23 | out_channels (int): The number of channels in the layer's output 24 | input_shape (int, int): The spatial shape of the rank-3 input tensor 25 | kernel_shape (int, int): The spatial shape of the rank-4 kernel 26 | padding ({'same', 'valid'}): The padding used by the convolution 27 | strides (int) or (int, int): The spatial stride of the convolution; 28 | two numbers may be specified if it's different for the x and y axes 29 | dilation (int): Must be 1 for now. 30 | Returns: 31 | int: The number of multiply-adds a direct convolution would require 32 | (i.e., no FFT, no Winograd, etc) 33 | >>> c_in, c_out = 10, 10 34 | >>> in_shape = (4, 5) 35 | >>> filt_shape = (3, 2) 36 | >>> # valid padding 37 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, padding='valid') 38 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * (2 * 4)) 39 | True 40 | >>> # same padding, no stride 41 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, padding='same') 42 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * np.prod(in_shape)) 43 | True 44 | >>> # valid padding, stride > 1 45 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, \ 46 | padding='valid', strides=(1, 2)) 47 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * (2 * 2)) 48 | True 49 | >>> # same padding, stride > 1 50 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, \ 51 | padding='same', strides=2) 52 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * (2 * 3)) 53 | True 54 | """ 55 | # validate + sanitize input 56 | assert in_channels > 0 57 | assert out_channels > 0 58 | assert len(input_shape) == 2 59 | assert len(kernel_shape) == 2 60 | padding = padding.lower() 61 | assert padding in ( 62 | "same", 63 | "valid", 64 | "zeros", 65 | ), "Padding must be one of same|valid|zeros" 66 | try: 67 | strides = tuple(strides) 68 | except TypeError: 69 | # if one number provided, make it a 2-tuple 70 | strides = (strides, strides) 71 | assert dilation == 1 or all( 72 | d == 1 for d in dilation 73 | ), "Dilation > 1 is not supported" 74 | 75 | # compute output spatial shape 76 | # based on TF computations https://stackoverflow.com/a/37674568 77 | if padding in ["same", "zeros"]: 78 | out_nrows = np.ceil(float(input_shape[0]) / strides[0]) 79 | out_ncols = np.ceil(float(input_shape[1]) / strides[1]) 80 | else: # padding == 'valid' 81 | out_nrows = np.ceil((input_shape[0] - kernel_shape[0] + 1) / strides[0]) # noqa 82 | out_ncols = np.ceil((input_shape[1] - kernel_shape[1] + 1) / strides[1]) # noqa 83 | output_shape = (int(out_nrows), int(out_ncols)) 84 | 85 | # work to compute one output spatial position 86 | nflops = in_channels * out_channels * int(np.prod(kernel_shape)) 87 | 88 | # total work = work per output position * number of output positions 89 | return nflops * int(np.prod(output_shape)) 90 | 91 | 92 | if __name__ == "__main__": 93 | import doctest 94 | 95 | doctest.testmod() 96 | -------------------------------------------------------------------------------- /knowledge_distillation/distillation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from student import LeNet, LeNetStudent 8 | 9 | PACKAGE_PARENT = ".." 10 | SCRIPT_DIR = os.path.dirname( 11 | os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))) 12 | ) 13 | sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) 14 | 15 | from metrics import flops, model_size 16 | from utils.loaders import get_loaders 17 | 18 | 19 | def train(teacher, student, loss_fn, optimizer, data_loader, device): 20 | teacher.train(False) 21 | student.train() 22 | for (idx, (x, t)) in enumerate(data_loader): 23 | optimizer.zero_grad() 24 | x = x.to(device) 25 | x_teacher = teacher.forward(x) 26 | x_student = student.forward(x) 27 | t = t.to(device) 28 | loss = loss_fn(x_student, x_teacher, t) 29 | loss.backward() 30 | optimizer.step() 31 | 32 | 33 | def test(net, data_loader, device): 34 | top1 = 0 35 | correct_samples = 0 36 | total_samples = 0 37 | net.train(False) 38 | net.eval() 39 | with torch.no_grad(): 40 | for (idx, (x, t)) in enumerate(data_loader): 41 | x = net.forward(x.to(device)) 42 | t = t.to(device) 43 | _, indices = torch.max(x, 1) 44 | correct_samples += torch.sum(indices == t) 45 | total_samples += t.shape[0] 46 | 47 | top1 = float(correct_samples) / total_samples 48 | return top1 49 | 50 | 51 | def calculate_prune_metrics(net, test_loader, device): 52 | 53 | x, _ = next(iter(test_loader)) 54 | x = x.to(device) 55 | 56 | size, size_nz = model_size(net) 57 | 58 | FLOPS = flops(net, x) 59 | return FLOPS, size, size_nz 60 | 61 | 62 | def cross_entropy_with_soft_targets(pred, soft_targets): 63 | logsoftmax = nn.LogSoftmax(dim=1) 64 | return torch.mean(torch.sum(-soft_targets * logsoftmax(pred), 1)) 65 | 66 | 67 | class CrossEntropyLossTemperature_withSoftTargets(torch.nn.Module): 68 | def __init__(self, temperature, reduction="mean"): 69 | super(CrossEntropyLossTemperature_withSoftTargets, self).__init__() 70 | self.T = temperature 71 | self.reduction = reduction 72 | 73 | def forward(self, input, soft_targets, hard_targets): 74 | """ 75 | In the forward function we accept a Tensor of input data and we must 76 | return a Tensor of output data. We can use Modules defined in the 77 | constructor as well as arbitrary operators on Tensors. 78 | """ 79 | z = input / self.T 80 | loss_1 = self.T ** 2 * cross_entropy_with_soft_targets( 81 | z, F.softmax(soft_targets, 1, _stacklevel=5) 82 | ) 83 | 84 | loss_2 = F.cross_entropy( 85 | input, 86 | hard_targets, 87 | weight=None, 88 | ignore_index=-100, 89 | reduction=self.reduction, 90 | ) 91 | 92 | return loss_1 + loss_2 93 | 94 | 95 | if __name__ == "__main__": 96 | 97 | device = "cuda" 98 | teacher = LeNet().to(device) 99 | student = LeNetStudent().to(device) 100 | 101 | teacher.load_state_dict(torch.load("models/best_acc.pth")) 102 | student.load_state_dict(torch.load("models/best_acc_student.pth")) 103 | 104 | batch_size_train = 1024 105 | batch_size_test = 1024 106 | nb_epoch = 40 107 | 108 | train_loader, test_loader = get_loaders(batch_size_train, batch_size_test) 109 | 110 | print(calculate_prune_metrics(teacher, test_loader, device)) 111 | print(calculate_prune_metrics(student, test_loader, device)) 112 | 113 | loss_fn = CrossEntropyLossTemperature_withSoftTargets(1) 114 | # optimizer = torch.optim.Adam(student.parameters(), lr=0.0001, weight_decay=0.00001) 115 | optimizer = torch.optim.SGD(student.parameters(), lr=0.01, momentum=0.2) 116 | 117 | best_model = None 118 | best_acc = 0 119 | 120 | teacher.train(False) 121 | for epoch in range(nb_epoch): 122 | train(teacher, student, loss_fn, optimizer, train_loader, device) 123 | test_top1 = test(student, test_loader, device) 124 | print(f"Epoch {epoch}. Top1 {test_top1:.4f}") 125 | if test_top1 > best_acc: 126 | best_model = student 127 | best_acc = test_top1 128 | 129 | torch.save(best_model.state_dict(), "models/best_acc_student_with_distillation.pth") 130 | -------------------------------------------------------------------------------- /metrics/utils.py: -------------------------------------------------------------------------------- 1 | """Auxiliary utils for implementing pruning strategies 2 | """ 3 | 4 | from collections import OrderedDict, defaultdict 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def hook_applyfn(hook, model, forward=False, backward=False): 11 | """ 12 | 13 | [description] 14 | 15 | Arguments: 16 | hook {[type]} -- [description] 17 | model {[type]} -- [description] 18 | 19 | Keyword Arguments: 20 | forward {bool} -- [description] (default: {False}) 21 | backward {bool} -- [description] (default: {False}) 22 | 23 | Returns: 24 | [type] -- [description] 25 | """ 26 | assert forward ^ backward, "Either forward or backward must be True" 27 | hooks = [] 28 | 29 | def register_hook(module): 30 | if ( 31 | not isinstance(module, nn.Sequential) 32 | and not isinstance(module, nn.ModuleList) 33 | and not isinstance(module, nn.ModuleDict) 34 | and not (module == model) 35 | ): 36 | if forward: 37 | hooks.append(module.register_forward_hook(hook)) 38 | if backward: 39 | hooks.append(module.register_backward_hook(hook)) 40 | 41 | return register_hook, hooks 42 | 43 | 44 | def get_params(model, recurse=False): 45 | """Returns dictionary of paramters 46 | 47 | Arguments: 48 | model {torch.nn.Module} -- Network to extract the parameters from 49 | 50 | Keyword Arguments: 51 | recurse {bool} -- Whether to recurse through children modules 52 | 53 | Returns: 54 | Dict(str:numpy.ndarray) -- Dictionary of named parameters their 55 | associated parameter arrays 56 | """ 57 | params = { 58 | k: v.detach().cpu().numpy().copy() 59 | for k, v in model.named_parameters(recurse=recurse) 60 | } 61 | return params 62 | 63 | 64 | def get_activations(model, input): 65 | 66 | activations = OrderedDict() 67 | 68 | def store_activations(module, input, output): 69 | if isinstance(module, nn.ReLU): 70 | # TODO ResNet18 implementation reuses a 71 | # single ReLU layer? 72 | return 73 | assert module not in activations, f"{module} already in activations" 74 | # TODO [0] means first input, not all models have a single input 75 | activations[module] = ( 76 | input[0].detach().cpu().numpy().copy(), 77 | output.detach().cpu().numpy().copy(), 78 | ) 79 | 80 | fn, hooks = hook_applyfn(store_activations, model, forward=True) 81 | model.apply(fn) 82 | with torch.no_grad(): 83 | model(input) 84 | 85 | for h in hooks: 86 | h.remove() 87 | 88 | return activations 89 | 90 | 91 | def get_gradients(model, inputs, outputs): 92 | # TODO implement using model.register_backward_hook() 93 | # So it is harder than it seems, the grad_input contains also the gradients 94 | # with respect to the weights and so far order seems to be 95 | # (bias, input, weight) which is confusing. 96 | # Moreover, a lot of the time the output activation we are 97 | # looking for is the one after the ReLU and F.ReLU (or any functional call) 98 | # will not be called by the forward or backward hook 99 | # Discussion here 100 | # https://discuss.pytorch.org/t/how-to-register-hook-function-for-functional-form/25775 101 | # Best way seems to be monkey patching F.ReLU & other functional ops 102 | # That'll also help figuring out how to compute a module graph 103 | pass 104 | 105 | 106 | def get_param_gradients(model, inputs, outputs, loss_func=None, by_module=True): 107 | 108 | gradients = OrderedDict() 109 | 110 | if loss_func is None: 111 | loss_func = nn.CrossEntropyLoss() 112 | 113 | training = model.training 114 | model.train() 115 | pred = model(inputs) 116 | loss = loss_func(pred, outputs) 117 | loss.backward() 118 | 119 | if by_module: 120 | gradients = defaultdict(OrderedDict) 121 | for module in model.modules(): 122 | assert module not in gradients 123 | for name, param in module.named_parameters(recurse=False): 124 | if param.requires_grad and param.grad is not None: 125 | gradients[module][name] = param.grad.detach().cpu().numpy().copy() 126 | 127 | else: 128 | gradients = OrderedDict() 129 | for name, param in model.named_parameters(): 130 | assert name not in gradients 131 | if param.requires_grad and param.grad is not None: 132 | gradients[name] = param.grad.detach().cpu().numpy().copy() 133 | 134 | model.zero_grad() 135 | model.train(training) 136 | 137 | return gradients 138 | 139 | 140 | def fraction_to_keep(compression, model, prunable_modules): 141 | """Return fraction of params to keep to achieve desired compression ratio 142 | 143 | Compression = total / ( fraction * prunable + (total-prunable)) 144 | Using algrebra fraction is equal to 145 | fraction = total/prunable * (1/compression - 1) + 1 146 | 147 | Arguments: 148 | compression {float} -- Desired overall compression 149 | model {torch.nn.Module} -- Full model for which to compute the fraction 150 | prunable_modules {List(torch.nn.Module)} -- 151 | Modules that can be pruned in the model. 152 | 153 | Returns: 154 | {float} -- Fraction of prunable parameters to keep 155 | to achieve desired compression 156 | """ 157 | from ..metrics import model_size 158 | 159 | total_size, _ = model_size(model) 160 | prunable_size = sum([model_size(m)[0] for m in prunable_modules]) 161 | nonprunable_size = total_size - prunable_size 162 | fraction = 1 / prunable_size * (total_size / compression - nonprunable_size) 163 | assert 0 < fraction <= 1, ( 164 | f"Cannot compress to {1/compression} model\ 165 | with {nonprunable_size/total_size}" 166 | + "fraction of unprunable parameters" 167 | ) 168 | return fraction 169 | -------------------------------------------------------------------------------- /pruning_loop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.prune as prune 5 | import torch.optim as optim 6 | 7 | from utils.loaders import get_loaders 8 | from utils.maskedLayers import Conv2dMasked, LinearMasked 9 | from metrics import flops, model_size 10 | 11 | 12 | class LeNet(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.conv1 = Conv2dMasked(1, 10, kernel_size=5) 16 | self.conv2 = Conv2dMasked(10, 20, kernel_size=5) 17 | self.conv2_drop = nn.Dropout2d() 18 | self.fc1 = LinearMasked(320, 50) 19 | self.fc2 = LinearMasked(50, 10) 20 | 21 | def forward(self, x): 22 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 23 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 24 | x = x.view(-1, 320) 25 | x = F.relu(self.fc1(x)) 26 | x = F.dropout(x, training=self.training) 27 | x = self.fc2(x) 28 | return F.log_softmax(x, 1) 29 | 30 | 31 | class PruningExperiment: 32 | def __init__( 33 | self, 34 | pruning_strategy=None, 35 | batch_size_train=512, 36 | batch_size_test=1024, 37 | epochs_prune_finetune=3, 38 | epochs_finetune=5, 39 | device="cuda", 40 | save_model=None, 41 | ): 42 | 43 | """Initialize experiment 44 | :param pruning_strategy: Pruning strategy 45 | :param batch_size_train: Batch size for train 46 | :param batch_size_test: Batch size for test 47 | :param epochs_finetune: Number of epochs to finetune pruned model 48 | :param optimizer: Optimizer to perform gradient descent 49 | :param device: Device 'cpu' or 'cuda' where calculations are performed 50 | 51 | :return: Outcome of pruning strategy: Accuracy and pruning metrics 52 | """ 53 | self.pruning_strategy = pruning_strategy 54 | self.device = device 55 | self.batch_size_train = batch_size_train 56 | self.batch_size_test = batch_size_test 57 | self.epochs_prune_finetune = epochs_prune_finetune 58 | self.epochs_finetune = epochs_finetune 59 | self.save_model = save_model 60 | 61 | def load_model(self): 62 | 63 | """Load LeNet model. 64 | All experiments will be performed on a trained model 65 | from the original script. 66 | """ 67 | net = LeNet().to(self.device) 68 | net.load_state_dict(torch.load("models/best_acc.pth")) 69 | return net 70 | 71 | def prune_model(self, net, pruning_strategy): 72 | 73 | for modulename, strategy, name, amount, kwargs in pruning_strategy: 74 | 75 | module = getattr(net, modulename) 76 | if kwargs: 77 | n = kwargs["n"] 78 | dim = kwargs["dim"] 79 | mask = strategy(module, name=name, amount=amount, n=n, dim=dim) 80 | else: 81 | mask = strategy(module, name=name, amount=amount, n=n, dim=dim) 82 | # print(modulename, mask.weight_mask.shape) 83 | module.set_mask(mask.weight_mask) 84 | 85 | return net 86 | 87 | def train(self, net, optimizer, data_loader, device): 88 | 89 | net.train() 90 | for (idx, (x, t)) in enumerate(data_loader): 91 | optimizer.zero_grad() 92 | x = net.forward(x.to(device)) 93 | t = t.to(device) 94 | loss = F.nll_loss(x, t) 95 | loss.backward() 96 | optimizer.step() 97 | 98 | def test(self, net, data_loader, device): 99 | 100 | correct_samples = 0 101 | total_samples = 0 102 | net.eval() 103 | with torch.no_grad(): 104 | for (idx, (x, t)) in enumerate(data_loader): 105 | x = net.forward(x.to(device)) 106 | t = t.to(device) 107 | _, indices = torch.max(x, 1) 108 | correct_samples += torch.sum(indices == t) 109 | total_samples += t.shape[0] 110 | top1 = float(correct_samples) / total_samples 111 | return top1 112 | 113 | def calculate_prune_metrics(self, net, test_loader, device): 114 | 115 | x, _ = next(iter(test_loader)) 116 | x = x.to(device) 117 | 118 | size, size_nz = model_size(net) 119 | 120 | FLOPS = flops(net, x) 121 | compression_ratio = size / size_nz 122 | 123 | return FLOPS, compression_ratio 124 | 125 | def run(self): 126 | """ 127 | Main function to run pruning -> finetuning -> evaluation 128 | """ 129 | pruning_strategy = self.pruning_strategy 130 | batch_size_train = self.batch_size_train 131 | batch_size_test = self.batch_size_test 132 | epochs_prune_finetune = self.epochs_prune_finetune 133 | epochs_finetune = self.epochs_finetune 134 | device = self.device 135 | 136 | net = self.load_model() 137 | 138 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5) 139 | 140 | train_loader, test_loader = get_loaders(batch_size_train, batch_size_test) 141 | 142 | print("Pruning cycle") 143 | print("=======================") 144 | for epoch in range(epochs_prune_finetune): 145 | if pruning_strategy is not None: 146 | net = self.prune_model(net, pruning_strategy) 147 | test_top1 = self.test(net, test_loader, device) 148 | print(f"After pruning: Epoch {epoch}. Top1 {test_top1}.") 149 | for finetune_epoch in range(epochs_finetune): 150 | self.train(net, optimizer, train_loader, device) 151 | test_top1 = self.test(net, test_loader, device) 152 | print( 153 | f"\tAfter finetuning: Epoch {finetune_epoch}.\ 154 | Top1 {test_top1:.4f}" 155 | ) 156 | 157 | for modulename, strategy, name, amount, kwargs in pruning_strategy: 158 | module = getattr(net, modulename) 159 | prune.remove(module, name) 160 | test_top1 = self.test(net, test_loader, device) 161 | 162 | FLOPS, compression_ratio = self.calculate_prune_metrics( 163 | net, test_loader, device 164 | ) 165 | 166 | if self.save_model is not None: 167 | torch.save(net.state_dict(), f"models/{self.save_model}.pth") 168 | 169 | return test_top1, FLOPS[0] / FLOPS[1], compression_ratio 170 | -------------------------------------------------------------------------------- /Results.md: -------------------------------------------------------------------------------- 1 | # Experiments in Neural Network pruning 2 | 3 | ### Prepared by Oleg Polivin, 26 November 2020 4 | --- 5 | 6 | Let's define metrics that we will use to evaluate the effectiveness of pruning. We will look at categorical accuracy to estimate the quality of a neural network.[1](#myfootnote1) Accuracy in the experiments is reported based on the test set, not the one that has been used for training the neural network. 7 | 8 | 9 | Much of this work is based on the paper [What is the State of Neural Network Pruning?](https://arxiv.org/abs/2003.03033) 10 | 11 | In order to estimate the effectiveness at pruning we will take into account: 12 | 13 | 1. Acceleration of inference on the test set. 14 | - Compare the number of multiply-adds operations (FLOPs) to perform inference. 15 | - Additionally, I compute average time of running the original/pruned model on data. 16 | 17 | 2. Model size reduction/ weights compression. 18 | - Here we will compare total number of non-zero parameters. 19 | 20 | ## Experiment setting 21 | 22 | Given the code for LeNet model in PyTorch, let's calculate the metrics defined above. Canonical LeNet-5 architecture is below: 23 | 24 | ![Imagine a LeNet-5 architecture](imgs/Architecture-of-LeNet-5.png "LeNet-5 architecture") 25 | 26 | Architecture in the original paper [Gradient-Based Learning Applied to Document Recognition](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) is a bit different from the code given (for example, there was no ``Dropout``, ``hyperbolic tangent`` was used as an activation, number of filters is different, etc), but the idea is the same. I will organize experiments as follows: 27 | 28 | 1. Train the model using the original script (``lenet_pytorch.py``, although I made some modifications there). 29 | 2. Perform evaluation of the model using the metrics defined above. 30 | 3. Save the trained model. 31 | 32 | I will perform experiments on pruning using the saved model. 33 | 34 | 4. In order to perform pruning experiments I added: 35 | - ``metrics/`` 36 | - ``experiments.py`` (this is the main script that produces results). 37 | - ``pruning_loop.py`` implements the experiment. 38 | - ``utils`` 39 | 40 | - ``avg_speed_clac.py`` to calculate average inference time on train data 41 | - ``loaders.py`` to create train/test loaders 42 | - ``maskedLayers.py`` wrappers for Linear and Conv2d PyTorch modules. 43 | - ``plot.ipynb`` Jupyter notebook to produce the plots below. 44 | 45 | ## Pruning setup 46 | 47 | As suggested in the [What is the State of Neural Network Pruning?](https://arxiv.org/abs/2003.03033) paper many pruning methods are described by the following algorithm: 48 | 49 | 1. A neural network (NN) is trained until convergence (20 epochs now). 50 | 2. ``` 51 | for i in 1 to K do 52 | prune NN 53 | finetune NN 54 | end for 55 | 56 | It means that the neural network is pruned several times. In my version, a weight once set as zero will always stay zero. The weights that were pruned are not retrained. Note also that finetuning means that there are several epochs of training happening. 57 | 58 | In order to fix the pruning setup, in all the experiments number of prune-finetune epochs is equal to 3 (it is ``K`` above), and number of finetuning epochs is equal to 4. The categorical accuracy and model's speed-ups and compression is reported after pruning-finetuning is finished. 59 | 60 | ## Results 61 | 62 | ### Baseline 63 | 64 | LeNet model as defined in the code was trained for ``80`` epochs, and the best model chosen by categorical accuracy was saved. Highest categorical accuracy was reached on epoch ``78`` and equals ``0.9809``. Our objective is to stop when the model converges to be sure that we prune a converged model. There are ``932500`` add-multiply operations (FLOPs), and in 20 runs through train data (``60000`` samples) , average time is given by ``9.1961866`` seconds. 65 | 66 | ### Experiments 67 | 68 | #### Experiment 1: Unstructured pruning of random weights 69 | 70 | **Setting:** Prune fully-connected layers (``fc1``, ``fc2``) and both convolutional layers (``conv1``, ``conv2``). Increase pruning from 10% to 70% (step = 10%). The pruning percentage is given for each layer. Roughly it corresponds to compressing the model up to 36 times. 71 | 72 | #### Experiment 2: Unstructured pruning of the smallest weights (based on the L1 norm) 73 | 74 | **Setting:** Same as in experiment 1. Notice the change that now pruning is not random. Here I assign 0's to the smallest weights. 75 | 76 | #### Experiment 3: Structured pruning (based on the L1 norm) 77 | 78 | **Setting:** Here I use structured pruning. In PyTorch one can use ``prune.ln_structured`` for that. It is possible to pass a dimension (``dim``) to specify which channel should be dropped. For fully-connected layers as ``fc1`` or ``fc2`` -> ``dim=0`` corresponds to "switching off" output neurons (like ``320`` for ``fc1`` and ``10`` for ``fc2``). Therefore, it does not really make sense to switch off neurons in the classification layer ``fc2``. For convolutional layers like ``conv1`` or ``conv2`` -> ``dim=0`` corresponds to removing the output channels of the layers (like ``10`` for ``conv1`` and ``20`` for ``conv2``). That's why I will only prune ``fc1``, ``conv1`` and ``conv2`` layers, again going from pruning 10% of the layers channels up to 70%. For instance, for the fully-connected layers it means zeroing 5 up to 35 neurons out of 50. For ``conv1`` layer it means zeroing out all the connections corresponding to 1 up to 7 channels. 79 | 80 | Below I present results of my pruning experiments: 81 | 82 | ![Pruning results were here](imgs/pruningResults.png "Pruning Results") 83 | 84 | And I confirm that using average time of running a model during inference, there is no real change in terms of time for pruned or non-pruned models. 85 | 86 | ## Conclusions and caveats 87 | 88 | Here are my thoughts on the results above and some caveats. 89 | 90 | If we take the results at face value, we conclude that better results are obtained when we do unstructured pruning of the smallest weights based on L1 norm. In reality however (more on that below) unstructured pruning makes weights sparse, but since sparse operations are not supported in PyTorch yet, it does not bring real gains in terms of model size or speed of inference. However, we can think of such results as some evidence that a smaller architecture with a lower number of weights might be beneficial. 91 | 92 | Below are further caveats: 93 | 94 | ### Unstructured pruning 95 | 1. We are looking at FLOPs to estimate a speed-up of a pruned neural network. We look at the number of non-null parameters to estimate compression. It gives us an impression that by doing pruning we gain a significant speed-up and memory gain. 96 | 97 | 2. However, people report that when looking at actual time that it takes to make a prediction there is no gain in speed-up. I tested it with the model before pruning and after pruning (random weights), and this is true. There is no speedup in terms of average time of running inference. Also, saved PyTorch models (``.pth``) have the same size. 98 | 99 | 3. Additionally, there is no saving in memory, because all those zero elements still have to be saved. 100 | 101 | 4. To my understanding one needs to change the architecture of the neural network according to the zeroed weights in order to really have gains in speed and memory. 102 | 103 | 5. There is a different way which is to use sparse matrices and operations in PyTorch. But this functionality is in beta. See the discussion here [How to improve inference time of pruned model using torch.nn.utils.prune](https://discuss.pytorch.org/t/how-to-improve-inference-time-of-pruned-model-using-torch-nn-utils-prune/78633/4) 104 | 105 | 6. So, if we do unstructured pruning and we want to make use of sparse operations, we will have to write code for inference to take into account sparse matrices. Here is an example of a paper where authors could get large speed-ups but when they introduced operations with sparse matrices on FPGA. [How Can We Be So Dense? The Benefits of Using Highly Sparse Representations](https://arxiv.org/abs/1903.11257) 106 | 107 | What's said above is more relevant to unstructured pruning of weights. 108 | 109 | ### Structured pruning 110 | 111 | One can have speed-ups when using structured pruning, that is, for example, dropping some channels. The price for that would be a drop in accuracy, but at least this really works for better model size and speed-ups. 112 | 113 | ## Additional chapter: Knowledge distillation 114 | 115 | Knowledge distillation is the idea proposed by Geoffrey Hinton, Oriol Vinyals and Jeff Dean to tranfer knowledge from a huge trained model to a simple and light-weighted one. It is not pruning strictly speaking, but has the same objective: simplify the original neural network without sacrifying much of quality. 116 | 117 | It works the following way: 118 | 119 | - Train a comprehensive large network which has a good accuracy [Teacher Network] 120 | - Train a small network until convergence [Student Network]. There will be trade-offs between accuracy that you reach with a simpler model and the level of compression. 121 | - Distill the knowledge from the Teacher Network by training the Student Network using the outputs of the Teacher Network. 122 | - See that original accuracy of the trained and converged student network is increased! 123 | 124 | I provide the code to do it in the ``knowledge_distillation`` folder. Run 125 | 126 | ``` 127 | python knowledge_distillation/train_student.py 128 | ``` 129 | to train the student network. It has a simplified architecture relative to the original ``LeNet`` neural network. For example, when the trained student network is saved it takes 1.16 times less memory on disk (from 90 kBs to 77 kBs, even twice less if saved in ``PyTorch v1.4``). I ran training for 60 epochs and the best accuracy was reached on epoch 47, and it equals ``0.9260``. Thus we can say that the model has converged. 130 | 131 | Run 132 | 133 | ``` 134 | python knowledge_distillation/distillation.py 135 | ``` 136 | to do additional training of the converged student neural network distilling teacher network. 137 | 138 | Here are the results: 139 | - ``FLOPS`` compression coefficient is 42 (the student model is 42 times smaller in terms of FLOPS, down to 21840 multiply-add operations from 932500). 140 | - ``Model size`` compression coefficient is 3 (the student model is 3 times smaller in terms of size) 141 | - ``Accuracy`` of the retrained student model is ``0.9276``, which is a tiny bit better than the original student network. 142 | 143 | I would say that knowledge distillation is definitely worth a try as a method to perform model compression. 144 | 145 | ## Bibliography with comments 146 | 147 | 1. The code to calculate FLOPs is taken from [ShrinkBench repo](https://github.com/JJGO/shrinkbench) written by the authors of the [What is the State of Neural Network Pruning?](https://arxiv.org/abs/2003.03033) paper. The authors are Davis Blalock, Jose Javier Gonzalez Ortiz, Jonathan Frankle and John Guttag. They created this code to allow researchers to compare pruning algorithms: that is, compare compression rates, speed-ups and quality of the model after pruning among others. I copy their way to measure ``FLOPs`` and ``model size`` which is located in the ``metrics`` folder. It is necessary to say that I made some minor modifications to the code, and all errors remain mine and should not be attributed to the author's code. It is also important to add that I also take the logic of evaluating pruned models from this paper. All in all, this is the main source of inspiration for my research. 148 | 149 | 2. The next important source is this [Neural Network Pruning PyTorch Implementation](https://github.com/wanglouis49/pytorch-weights_pruning) by Luyu Wang and Gavin Ding. I copy their code for implementing the high-level idea of doing pruning: 150 | - Write wrappers on PyTorch Linear and Conv2d layers. 151 | - Binary mask is multiplied by actual layer weights 152 | - "Multiplying the mask is a differentiable operation and the backward pass is handed by automatic differentiation" 153 | 154 | 3. Next, I make use of the [PyTorch Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html). It is different from the implementations above. My implementation mixes the code of the above two implementations with PyTorch way. 155 | 156 | Sources on knowledge distillation: 157 | 158 | 4. [Dark knowledge](https://www.ttic.edu/dl/dark14.pdf) 159 | 160 | 5. [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531) 161 | 162 | 6. Open Data Science community (``ods.ai``) is my source of inspiration with brilliant people sharing their ideas on many aspects of Data Science. 163 | 164 | ## Footnotes 165 | 1: Indeed, at the extreme we can just predict a constant. Accuracy will be low, but prunning will be very effective, there will be no parameters at all in the neural network. 166 | -------------------------------------------------------------------------------- /utils/plots.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 2, 4 | "metadata": { 5 | "language_info": { 6 | "name": "python", 7 | "codemirror_mode": { 8 | "name": "ipython", 9 | "version": 3 10 | }, 11 | "version": "3.7.9-final" 12 | }, 13 | "orig_nbformat": 2, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "npconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": 3, 20 | "kernelspec": { 21 | "name": "python37964bitpytorch17conda9210703cbb104b038c93c3861334f328", 22 | "display_name": "Python 3.7.9 64-bit ('pytorch17': conda)" 23 | } 24 | }, 25 | "cells": [ 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from matplotlib import pyplot as plt\n", 33 | "%matplotlib inline" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "experiments_results = {\n", 43 | " \"UnstructRandomPrun\":\n", 44 | " {\n", 45 | " \"accuracy\": [0.9809, 0.9744, 0.9665, 0.9524, 0.9194, 0.8529, 0.4458, 0.1365],\n", 46 | " \"compression_flops\": [1, 1.3706, 1.9531, 2.9137, 4.6296, 8.0134, 15.625, 36.75],\n", 47 | " \"compression_size\": [1, 1.3695, 1.9455, 2.8927, 4.5614, 7.7750, 14.7368, 32.2124]\n", 48 | " },\n", 49 | " \"UnstructPrunL1Norm\":\n", 50 | " {\n", 51 | " \"accuracy\": [0.9809, 0.9824, 0.9818, 0.9781, 0.9698, 0.9543, 0.8555, 0.7565],\n", 52 | " \"compression_flops\": [1, 1.3706, 1.9531, 2.9137, 4.6296, 8.0134, 15.625, 36.75],\n", 53 | " \"compression_size\": [1, 1.3695, 1.9455, 2.8927, 4.5614, 7.7750, 14.7368, 32.2124]\n", 54 | " },\n", 55 | " \"StructuredPrunL1Norm\":\n", 56 | " {\n", 57 | " \"accuracy\": [0.9809, 0.9781, 0.9667, 0.9477, 0.9228, 0.753, 0.5095, 0.1135],\n", 58 | " \"compression_flops\": [1, 1.4268, 1.9976, 2.94, 4.98, 7.15, 16.34, 25.33],\n", 59 | " \"compression_size\": [1, 1.3561, 1.8934, 2.78, 4.23, 6.06, 11.97, 18.83]\n", 60 | " }\n", 61 | "}" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "image/png": "\n", 72 | "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 73 | "text/plain": "
" 74 | }, 75 | "metadata": { 76 | "needs_background": "light" 77 | }, 78 | "output_type": "display_data" 79 | } 80 | ], 81 | "source": [ 82 | "plt.figure(figsize=(15,5))\n", 83 | "plt.subplot(1,2,1)\n", 84 | "marker = ['o', 's', 'd']\n", 85 | "for i, (key, results) in enumerate(experiments_results.items()):\n", 86 | " compression_flops = results[\"compression_flops\"]\n", 87 | " accuracy = results[\"accuracy\"]\n", 88 | " lbl = key\n", 89 | " plt.plot(compression_flops, accuracy, marker=marker[i], label = lbl)\n", 90 | "plt.xlabel('Compression (FLOPS)')\n", 91 | "plt.ylabel('Accuracy')\n", 92 | "plt.title('LeNet-5 on MNIST: Accuracy vs theoretical speedup')\n", 93 | "plt.grid(True)\n", 94 | "plt.legend()\n", 95 | "\n", 96 | "plt.subplot(1,2,2)\n", 97 | "\n", 98 | "for i, (key, results) in enumerate(experiments_results.items()):\n", 99 | " compression_size = results[\"compression_size\"]\n", 100 | " accuracy = results[\"accuracy\"]\n", 101 | " lbl = key\n", 102 | " plt.plot(compression_size, accuracy, marker=marker[i], label = lbl)\n", 103 | "plt.xlabel('Compression (Model size)')\n", 104 | "plt.ylabel('Accuracy')\n", 105 | "plt.title('LeNet-5 on MNIST: Accuracy vs model compression')\n", 106 | "plt.grid(True)\n", 107 | "plt.legend()\n", 108 | "\n", 109 | "plt.show()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | } 119 | ] 120 | } 121 | --------------------------------------------------------------------------------