├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── loaders.cpython-37.pyc
│ └── maskedLayers.cpython-37.pyc
├── loaders.py
├── avg_speed_calc.py
├── maskedLayers.py
└── plots.ipynb
├── .gitignore
├── knowledge_distillation
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── student.cpython-37.pyc
├── student.py
├── train_student.py
└── distillation.py
├── models
├── exp1.pth
├── exp2.pth
├── exp3.pth
├── best_acc.pth
├── best_acc_v1.4.pth
├── best_acc_student.pth
├── best_acc_student_v1.4.pth
├── best_acc_student_with_distillation.pth
└── best_acc_student_with_distillation_v1.4.pth
├── README.md
├── imgs
├── pruningResults.png
└── Architecture-of-LeNet-5.png
├── metrics
├── __pycache__
│ ├── flops.cpython-37.pyc
│ ├── memory.cpython-37.pyc
│ ├── size.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── accuracy.cpython-37.pyc
│ ├── modules.cpython-37.pyc
│ ├── maskedLayers.cpython-37.pyc
│ ├── abstract_flops.cpython-37.pyc
│ └── global_sparsity.cpython-37.pyc
├── global_sparsity.py
├── __init__.py
├── size.py
├── flops.py
├── abstract_flops.py
└── utils.py
├── .vscode
└── settings.json
├── .pre-commit-config.yaml
├── .github
└── workflows
│ └── ci.yml
├── experiments.py
├── lenet_pytorch.py
├── pruning_loop.py
└── Results.md
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | Articles/
2 | .vscode/
--------------------------------------------------------------------------------
/knowledge_distillation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/exp1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/exp1.pth
--------------------------------------------------------------------------------
/models/exp2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/exp2.pth
--------------------------------------------------------------------------------
/models/exp3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/exp3.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pruningExperiments
2 | Repository to perform simple pruning experiments on neural networks
3 |
--------------------------------------------------------------------------------
/models/best_acc.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc.pth
--------------------------------------------------------------------------------
/imgs/pruningResults.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/imgs/pruningResults.png
--------------------------------------------------------------------------------
/models/best_acc_v1.4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_v1.4.pth
--------------------------------------------------------------------------------
/models/best_acc_student.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student.pth
--------------------------------------------------------------------------------
/imgs/Architecture-of-LeNet-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/imgs/Architecture-of-LeNet-5.png
--------------------------------------------------------------------------------
/models/best_acc_student_v1.4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student_v1.4.pth
--------------------------------------------------------------------------------
/metrics/__pycache__/flops.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/flops.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/memory.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/memory.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/size.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/size.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/loaders.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/utils/__pycache__/loaders.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/accuracy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/accuracy.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/modules.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/modules.cpython-37.pyc
--------------------------------------------------------------------------------
/models/best_acc_student_with_distillation.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student_with_distillation.pth
--------------------------------------------------------------------------------
/utils/__pycache__/maskedLayers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/utils/__pycache__/maskedLayers.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/maskedLayers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/maskedLayers.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/abstract_flops.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/abstract_flops.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/global_sparsity.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/metrics/__pycache__/global_sparsity.cpython-37.pyc
--------------------------------------------------------------------------------
/models/best_acc_student_with_distillation_v1.4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/models/best_acc_student_with_distillation_v1.4.pth
--------------------------------------------------------------------------------
/knowledge_distillation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/knowledge_distillation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/knowledge_distillation/__pycache__/student.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/olegpolivin/pruningExperiments/HEAD/knowledge_distillation/__pycache__/student.cpython-37.pyc
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "/home/userlocal/miniconda3/envs/torch/bin/python",
3 | "python.linting.enabled": true,
4 | "python.linting.flake8Enabled": true,
5 | "python.linting.pycodestyleEnabled": false
6 | }
7 |
--------------------------------------------------------------------------------
/metrics/global_sparsity.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def calculate_global_sparsity(model):
5 | global_sparsity = (
6 | 100.0
7 | * float(
8 | torch.sum(model.conv1.weight == 0)
9 | + torch.sum(model.conv2.weight == 0)
10 | + torch.sum(model.fc1.weight == 0)
11 | + torch.sum(model.fc2.weight == 0)
12 | )
13 | / float(
14 | model.conv1.weight.nelement()
15 | + model.conv2.weight.nelement()
16 | + model.fc1.weight.nelement()
17 | + model.fc2.weight.nelement()
18 | )
19 | )
20 |
21 | global_compression = 100 / (100 - global_sparsity)
22 |
23 | return global_sparsity, global_compression
24 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def nonzero(tensor):
6 | """Returns absolute number of values different from 0
7 |
8 | Arguments:
9 | tensor {numpy.ndarray} -- Array to compute over
10 |
11 | Returns:
12 | int -- Number of nonzero elements
13 | """
14 | return np.sum(tensor != 0.0)
15 |
16 |
17 | # https://pytorch.org/docs/stable/tensor_attributes.html
18 | dtype2bits = {
19 | torch.float32: 32,
20 | torch.float: 32,
21 | torch.float64: 64,
22 | torch.double: 64,
23 | torch.float16: 16,
24 | torch.half: 16,
25 | torch.uint8: 8,
26 | torch.int8: 8,
27 | torch.int16: 16,
28 | torch.short: 16,
29 | torch.int32: 32,
30 | torch.int: 32,
31 | torch.int64: 64,
32 | torch.long: 64,
33 | torch.bool: 1,
34 | }
35 |
36 |
37 | from .flops import flops
38 | from .size import model_size
39 |
--------------------------------------------------------------------------------
/metrics/size.py:
--------------------------------------------------------------------------------
1 | """Model size metrics
2 | """
3 |
4 | import numpy as np
5 |
6 | from . import dtype2bits, nonzero
7 |
8 |
9 | def model_size(model, as_bits=False):
10 | """Returns absolute and nonzero model size
11 |
12 | Arguments:
13 | model {torch.nn.Module} -- Network to compute model size over
14 |
15 | Keyword Arguments:
16 | as_bits {bool} -- Whether to account for the size of dtype
17 |
18 | Returns:
19 | int -- Total number of weight & bias params
20 | int -- Out total_params exactly how many are nonzero
21 | """
22 |
23 | total_params = 0
24 | nonzero_params = 0
25 | for tensor in model.parameters():
26 | t = np.prod(tensor.shape)
27 | nz = nonzero(tensor.detach().cpu().numpy())
28 | if as_bits:
29 | bits = dtype2bits[tensor.dtype]
30 | t *= bits
31 | nz *= bits
32 | total_params += t
33 | nonzero_params += nz
34 | return int(total_params), int(nonzero_params)
35 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # I took this configuration from https://github.com/ternaus/iglovikov_helper_functions/blob/master/.pre-commit-config.yaml
2 |
3 | exclude: _pb2\.py$
4 | repos:
5 | - repo: https://github.com/pre-commit/mirrors-isort
6 | rev: f0001b2 # Use the revision sha / tag you want to point at
7 | hooks:
8 | - id: isort
9 | args: ["--profile", "black"]
10 | - repo: https://github.com/psf/black
11 | rev: 20.8b1
12 | hooks:
13 | - id: black
14 | - repo: https://github.com/asottile/yesqa
15 | rev: v1.1.0
16 | hooks:
17 | - id: yesqa
18 | additional_dependencies:
19 | - flake8-bugbear==20.1.4
20 | - flake8-builtins==1.5.2
21 | - flake8-comprehensions==3.2.2
22 | - flake8-tidy-imports==4.1.0
23 | - flake8==3.7.9
24 | - repo: https://github.com/pre-commit/pre-commit-hooks
25 | rev: v2.3.0
26 | hooks:
27 | - id: check-docstring-first
28 | - id: check-merge-conflict
29 | - id: check-yaml
30 | - id: debug-statements
31 | - id: end-of-file-fixer
32 | - id: trailing-whitespace
33 |
--------------------------------------------------------------------------------
/knowledge_distillation/student.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class LeNet(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
9 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
10 | self.conv2_drop = nn.Dropout2d()
11 | self.fc1 = nn.Linear(320, 50)
12 | self.fc2 = nn.Linear(50, 10)
13 |
14 | def forward(self, x):
15 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
17 | x = x.view(-1, 320)
18 | x = F.relu(self.fc1(x))
19 | x = F.dropout(x, training=self.training)
20 | x = self.fc2(x)
21 | return x
22 |
23 |
24 | class LeNetStudent(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 | # kernel_size = 2 was too strong, so 4
28 | self.conv1 = nn.Conv2d(1, 3, kernel_size=4)
29 | # Model with batchnorm was too strong reaching 0.973 accuracy
30 | # so I switched it off
31 | # self.bn = nn.BatchNorm2d(6)
32 | self.flatten = nn.Flatten()
33 | self.fc2 = nn.Linear(1875, 10)
34 |
35 | def forward(self, x):
36 | x = self.conv1(x)
37 | # x = self.bn(x)
38 | # x = F.relu(x)
39 | # x = F.max_pool2d(x, 2)
40 | x = self.flatten(x)
41 | x = self.fc2(x)
42 | return x
43 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | matrix:
16 | python-version: [3.7]
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Cache pip
24 | uses: actions/cache@v1
25 | with:
26 | path: ~/.cache/pip # This path is specific to Ubuntu
27 | # Look to see if there is a cache hit for the corresponding requirements file
28 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
29 | restore-keys: |
30 | ${{ runner.os }}-pip-
31 | ${{ runner.os }}-
32 | # You can test your matrix by printing the current Python version
33 | - name: Display Python version
34 | run: python -c "import sys; print(sys.version)"
35 | - name: Install dependencies
36 | run: |
37 | python -m pip install --upgrade pip
38 | pip install black flake8 isort pylint
39 | - name: Run black
40 | run:
41 | black --check .
42 |
--------------------------------------------------------------------------------
/utils/loaders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 |
5 | def get_loaders(batch_size_train, batch_size_test):
6 | """Function to return train and test datasets for MNIST
7 |
8 | :param batch_size_train: Batch size used for train
9 | :param batch_size_test: Batch size used for test
10 |
11 | :return: Data loaders for train and test data
12 | """
13 |
14 | train_loader = torch.utils.data.DataLoader(
15 | torchvision.datasets.MNIST(
16 | "~/.cache/database/",
17 | train=True,
18 | download=True,
19 | transform=torchvision.transforms.Compose(
20 | [
21 | torchvision.transforms.ToTensor(),
22 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
23 | ]
24 | ),
25 | ),
26 | batch_size=batch_size_train,
27 | shuffle=True,
28 | )
29 |
30 | test_loader = torch.utils.data.DataLoader(
31 | torchvision.datasets.MNIST(
32 | "~/.cache/database/",
33 | train=False,
34 | download=True,
35 | transform=torchvision.transforms.Compose(
36 | [
37 | torchvision.transforms.ToTensor(),
38 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
39 | ]
40 | ),
41 | ),
42 | batch_size=batch_size_test,
43 | shuffle=False,
44 | )
45 |
46 | return train_loader, test_loader
47 |
--------------------------------------------------------------------------------
/utils/avg_speed_calc.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from loaders import get_loaders
8 |
9 | batch_size_train = 2048
10 | batch_size_test = 2048
11 |
12 |
13 | class LeNet(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
17 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
18 | self.conv2_drop = nn.Dropout2d()
19 | self.fc1 = nn.Linear(320, 50)
20 | self.fc2 = nn.Linear(50, 10)
21 |
22 | def forward(self, x):
23 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
24 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
25 | x = x.view(-1, 320)
26 | x = F.relu(self.fc1(x))
27 | x = F.dropout(x, training=self.training)
28 | x = self.fc2(x)
29 | return F.log_softmax(x, 1)
30 |
31 |
32 | def go_through_data(net, data_loader, device):
33 |
34 | net.eval()
35 | with torch.no_grad():
36 | for (idx, (x, t)) in enumerate(data_loader):
37 | x = net.forward(x.to(device))
38 | t = t.to(device)
39 | return 1
40 |
41 |
42 | device = "cuda"
43 | train_loader, _ = get_loaders(batch_size_train, batch_size_test)
44 | net = LeNet().to(device)
45 | net.load_state_dict(torch.load("models/exp1.pth"))
46 |
47 | t0 = time.time()
48 | for i in range(20):
49 | go_through_data(net, train_loader, device)
50 |
51 | total_time = time.time() - t0
52 | print(total_time, total_time / (i + 1))
53 | # 46.204389810562134 9.240877962112426
54 |
--------------------------------------------------------------------------------
/experiments.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import torch.nn.utils.prune as prune
3 |
4 | from pruning_loop import PruningExperiment
5 |
6 | experiment_number = 3
7 |
8 | # Experiment 1: Random weights pruning
9 | # Change amount = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
10 | amount = 0.7
11 | kwargs = {}
12 | pruning_strategy_1 = [
13 | ("fc1", prune.random_unstructured, "weight", amount, kwargs),
14 | ("fc2", prune.random_unstructured, "weight", amount, kwargs),
15 | ("conv1", prune.random_unstructured, "weight", amount, kwargs),
16 | ("conv2", prune.random_unstructured, "weight", amount, kwargs),
17 | ]
18 |
19 | # Experiment 2: Pruning based on norm
20 | amount = 0.7
21 | kwargs = {}
22 | pruning_strategy_2 = [
23 | ("fc1", prune.l1_unstructured, "weight", amount, kwargs),
24 | ("fc2", prune.l1_unstructured, "weight", amount, kwargs),
25 | ("conv1", prune.l1_unstructured, "weight", amount, kwargs),
26 | ("conv2", prune.l1_unstructured, "weight", amount, kwargs),
27 | ]
28 |
29 | # Experiment 3: Structural pruning with L1 norm
30 | amount = 0.7
31 | kwargs = {"n": 1, "dim": 0}
32 |
33 | pruning_strategy_3 = [
34 | ("fc1", prune.ln_structured, "weight", amount, kwargs),
35 | ("conv1", prune.ln_structured, "weight", amount, kwargs),
36 | ("conv2", prune.ln_structured, "weight", amount, kwargs),
37 | ]
38 |
39 |
40 | if experiment_number == 1:
41 | pe = PruningExperiment(
42 | pruning_strategy=pruning_strategy_1,
43 | epochs_prune_finetune=3,
44 | epochs_finetune=4,
45 | save_model="exp1",
46 | )
47 |
48 | if experiment_number == 2:
49 | pe = PruningExperiment(
50 | pruning_strategy=pruning_strategy_2,
51 | epochs_prune_finetune=3,
52 | epochs_finetune=4,
53 | save_model="exp2",
54 | )
55 |
56 | if experiment_number == 3:
57 | pe = PruningExperiment(
58 | pruning_strategy=pruning_strategy_3,
59 | epochs_prune_finetune=3,
60 | epochs_finetune=4,
61 | save_model="exp3",
62 | )
63 |
64 |
65 | print(pe.run())
66 |
--------------------------------------------------------------------------------
/knowledge_distillation/train_student.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from student import LeNetStudent
8 |
9 | # To make relative imports work
10 | # See https://stackoverflow.com/questions/16981921/relative-imports-in-python-3
11 |
12 | PACKAGE_PARENT = ".."
13 | SCRIPT_DIR = os.path.dirname(
14 | os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))
15 | )
16 | sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
17 |
18 | from utils.loaders import get_loaders
19 |
20 |
21 | def train(net, loss_fn, optimizer, data_loader, device):
22 | net.train()
23 | for (idx, (x, t)) in enumerate(data_loader):
24 | optimizer.zero_grad()
25 | x = net.forward(x.to(device))
26 | t = t.to(device)
27 | loss = loss_fn(x, t)
28 | loss.backward()
29 | optimizer.step()
30 |
31 |
32 | def test(net, data_loader, device):
33 | top1 = 0
34 | correct_samples = 0
35 | total_samples = 0
36 | net.eval()
37 | with torch.no_grad():
38 | for (idx, (x, t)) in enumerate(data_loader):
39 | x = net.forward(x.to(device))
40 | t = t.to(device)
41 | _, indices = torch.max(x, 1)
42 | correct_samples += torch.sum(indices == t)
43 | total_samples += t.shape[0]
44 |
45 | top1 = float(correct_samples) / total_samples
46 | return top1
47 |
48 |
49 | if __name__ == "__main__":
50 |
51 | device = "cuda"
52 |
53 | net = LeNetStudent().to(device)
54 |
55 | batch_size_train = 512
56 | batch_size_test = 1024
57 | nb_epoch = 60
58 |
59 | train_loader, test_loader = get_loaders(batch_size_train, batch_size_test)
60 |
61 | loss_fn = nn.CrossEntropyLoss()
62 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
63 |
64 | best_model = None
65 | best_acc = 0
66 |
67 | for epoch in range(nb_epoch):
68 | train(net, loss_fn, optimizer, train_loader, device)
69 | test_top1 = test(net, test_loader, device)
70 | print(f"Epoch {epoch}. Top1 {test_top1:.4f}")
71 | if test_top1 > best_acc:
72 | best_model = net
73 | best_acc = test_top1
74 |
75 | torch.save(best_model.state_dict(), "models/best_acc_student.pth")
76 |
--------------------------------------------------------------------------------
/metrics/flops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import nn
3 |
4 | from utils.maskedLayers import Conv2dMasked, LinearMasked
5 |
6 | from . import nonzero
7 | from .abstract_flops import conv2d_flops, dense_flops
8 | from .utils import get_activations
9 |
10 |
11 | def _conv2d_flops(module, activation):
12 | # Auxiliary func to use abstract flop computation
13 |
14 | # Drop batch & channels. Channels can be dropped since
15 | # unlike shape they have to match to in_channels
16 | input_shape = activation.shape[2:]
17 | # TODO Add support for dilation and padding size
18 | return conv2d_flops(
19 | in_channels=module.in_channels,
20 | out_channels=module.out_channels,
21 | input_shape=input_shape,
22 | kernel_shape=module.kernel_size,
23 | padding=module.padding_mode,
24 | strides=module.stride,
25 | dilation=module.dilation,
26 | )
27 |
28 |
29 | def _linear_flops(module, activation):
30 | # Auxiliary func to use abstract flop computation
31 | return dense_flops(module.in_features, module.out_features)
32 |
33 |
34 | def flops(model, input):
35 | """Compute Multiply-add FLOPs estimate from model
36 |
37 | Arguments:
38 | model {torch.nn.Module} -- Module to compute flops for
39 | input {torch.Tensor} -- Input tensor needed for activations
40 |
41 | Returns:
42 | tuple:
43 | - int - Number of total FLOPs
44 | - int - Number of FLOPs related to nonzero parameters
45 | """
46 | FLOP_fn = {
47 | nn.Conv2d: _conv2d_flops,
48 | nn.Linear: _linear_flops,
49 | Conv2dMasked: _conv2d_flops,
50 | LinearMasked: _linear_flops,
51 | }
52 |
53 | total_flops = nonzero_flops = 0
54 | activations = get_activations(model, input)
55 |
56 | # The ones we need for backprop
57 | for m, (act, _) in activations.items():
58 | if m.__class__ in FLOP_fn:
59 | w = m.weight.detach().cpu().numpy().copy()
60 | module_flops = FLOP_fn[m.__class__](m, act)
61 | total_flops += module_flops
62 | # For our operations, all weights are symmetric so we can just
63 | # do simple rule of three for the estimation
64 | nonzero_flops += module_flops * nonzero(w).sum() / np.prod(w.shape)
65 |
66 | return total_flops, nonzero_flops
67 |
--------------------------------------------------------------------------------
/utils/maskedLayers.py:
--------------------------------------------------------------------------------
1 | # Taken from here
2 | # https://github.com/wanglouis49/pytorch-weights_pruning/blob/master/pruning/layers.py
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 |
9 | class LinearMasked(nn.Linear):
10 | def __init__(self, in_features, out_features, bias=True):
11 | super(LinearMasked, self).__init__(in_features, out_features, bias)
12 | self.mask_flag = False
13 |
14 | def set_mask(self, mask):
15 | self.mask = Variable(mask, requires_grad=False, volatile=False)
16 | self.weight.data = self.weight.data * self.mask.data
17 | self.mask_flag = True
18 |
19 | def get_mask(self):
20 | print(self.mask_flag)
21 | return self.mask
22 |
23 | def forward(self, x):
24 | if self.mask_flag:
25 | weight = self.weight * self.mask
26 | return F.linear(x, weight, self.bias)
27 | else:
28 | return F.linear(x, self.weight, self.bias)
29 |
30 |
31 | class Conv2dMasked(nn.Conv2d):
32 | def __init__(
33 | self,
34 | in_channels,
35 | out_channels,
36 | kernel_size,
37 | stride=1,
38 | padding=0,
39 | dilation=1,
40 | groups=1,
41 | bias=True,
42 | ):
43 | super(Conv2dMasked, self).__init__(
44 | in_channels,
45 | out_channels,
46 | kernel_size,
47 | stride,
48 | padding,
49 | dilation,
50 | groups,
51 | bias,
52 | )
53 | self.mask_flag = False
54 |
55 | def set_mask(self, mask):
56 | self.mask = Variable(mask, requires_grad=False, volatile=False)
57 | self.weight.data = self.weight.data * self.mask.data
58 | self.mask_flag = True
59 |
60 | def get_mask(self):
61 | print(self.mask_flag)
62 | return self.mask
63 |
64 | def forward(self, x):
65 | if self.mask_flag:
66 | weight = self.weight * self.mask
67 | return F.conv2d(
68 | x,
69 | weight,
70 | self.bias,
71 | self.stride,
72 | self.padding,
73 | self.dilation,
74 | self.groups,
75 | )
76 | else:
77 | return F.conv2d(
78 | x,
79 | self.weight,
80 | self.bias,
81 | self.stride,
82 | self.padding,
83 | self.dilation,
84 | self.groups,
85 | )
86 |
--------------------------------------------------------------------------------
/lenet_pytorch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | import torchvision
8 |
9 |
10 | class LeNet(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
14 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
15 | self.conv2_drop = nn.Dropout2d()
16 | self.fc1 = nn.Linear(320, 50)
17 | self.fc2 = nn.Linear(50, 10)
18 |
19 | def forward(self, x):
20 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
21 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
22 | x = x.view(-1, 320)
23 | x = F.relu(self.fc1(x))
24 | x = F.dropout(x, training=self.training)
25 | x = self.fc2(x)
26 | return F.log_softmax(x, 1)
27 |
28 |
29 | def train(net, optimizer, data_loader, device):
30 | net.train()
31 | for (idx, (x, t)) in enumerate(data_loader):
32 | optimizer.zero_grad()
33 | x = net.forward(x.to(device))
34 | t = t.to(device)
35 | loss = F.nll_loss(x, t)
36 | loss.backward()
37 | optimizer.step()
38 |
39 |
40 | def test(net, data_loader, device):
41 | top1 = 0 # TODO compute top1
42 | correct_samples = 0
43 | total_samples = 0
44 | net.eval()
45 | with torch.no_grad():
46 | for (idx, (x, t)) in enumerate(data_loader):
47 | x = net.forward(x.to(device))
48 | t = t.to(device)
49 | _, indices = torch.max(x, 1)
50 | correct_samples += torch.sum(indices == t)
51 | total_samples += t.shape[0]
52 |
53 | top1 = float(correct_samples) / total_samples
54 | return top1
55 |
56 |
57 | if __name__ == "__main__":
58 | nb_epoch = 80
59 | batch_size_train = 1024
60 | batch_size_test = 5120
61 | device = "cuda" # change to 'cpu' if needed
62 |
63 | best_model = None
64 | best_acc = 0
65 |
66 | train_loader = torch.utils.data.DataLoader(
67 | torchvision.datasets.MNIST(
68 | "~/.cache/database/",
69 | train=True,
70 | download=True,
71 | transform=torchvision.transforms.Compose(
72 | [
73 | torchvision.transforms.ToTensor(),
74 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
75 | ]
76 | ),
77 | ),
78 | batch_size=batch_size_train,
79 | shuffle=True,
80 | )
81 |
82 | test_loader = torch.utils.data.DataLoader(
83 | torchvision.datasets.MNIST(
84 | "~/.cache/database/",
85 | train=False,
86 | download=True,
87 | transform=torchvision.transforms.Compose(
88 | [
89 | torchvision.transforms.ToTensor(),
90 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
91 | ]
92 | ),
93 | ),
94 | batch_size=batch_size_test,
95 | shuffle=False,
96 | )
97 | net = LeNet().to(device)
98 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
99 |
100 | for epoch in range(nb_epoch):
101 | train(net, optimizer, train_loader, device)
102 | test_top1 = test(net, test_loader, device)
103 | print(f"Epoch {epoch}. Top1 {test_top1:.4f}")
104 | if test_top1 > best_acc:
105 | best_model = net
106 | best_acc = test_top1
107 |
108 | torch.save(best_model.state_dict(), "models/best_acc.pth")
109 |
--------------------------------------------------------------------------------
/metrics/abstract_flops.py:
--------------------------------------------------------------------------------
1 | """Module for computing FLOPs from specification, not from Torch objects
2 | """
3 | import numpy as np
4 |
5 |
6 | def dense_flops(in_neurons, out_neurons):
7 | """Compute the number of multiply-adds used by a Dense (Linear) layer"""
8 | return in_neurons * out_neurons
9 |
10 |
11 | def conv2d_flops(
12 | in_channels,
13 | out_channels,
14 | input_shape,
15 | kernel_shape,
16 | padding="same",
17 | strides=1,
18 | dilation=1,
19 | ):
20 | """Compute the number of multiply-adds used by a Conv2D layer
21 | Args:
22 | in_channels (int): The number of channels in the layer's input
23 | out_channels (int): The number of channels in the layer's output
24 | input_shape (int, int): The spatial shape of the rank-3 input tensor
25 | kernel_shape (int, int): The spatial shape of the rank-4 kernel
26 | padding ({'same', 'valid'}): The padding used by the convolution
27 | strides (int) or (int, int): The spatial stride of the convolution;
28 | two numbers may be specified if it's different for the x and y axes
29 | dilation (int): Must be 1 for now.
30 | Returns:
31 | int: The number of multiply-adds a direct convolution would require
32 | (i.e., no FFT, no Winograd, etc)
33 | >>> c_in, c_out = 10, 10
34 | >>> in_shape = (4, 5)
35 | >>> filt_shape = (3, 2)
36 | >>> # valid padding
37 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, padding='valid')
38 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * (2 * 4))
39 | True
40 | >>> # same padding, no stride
41 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, padding='same')
42 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * np.prod(in_shape))
43 | True
44 | >>> # valid padding, stride > 1
45 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, \
46 | padding='valid', strides=(1, 2))
47 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * (2 * 2))
48 | True
49 | >>> # same padding, stride > 1
50 | >>> ret = conv2d_flops(c_in, c_out, in_shape, filt_shape, \
51 | padding='same', strides=2)
52 | >>> ret == int(c_in * c_out * np.prod(filt_shape) * (2 * 3))
53 | True
54 | """
55 | # validate + sanitize input
56 | assert in_channels > 0
57 | assert out_channels > 0
58 | assert len(input_shape) == 2
59 | assert len(kernel_shape) == 2
60 | padding = padding.lower()
61 | assert padding in (
62 | "same",
63 | "valid",
64 | "zeros",
65 | ), "Padding must be one of same|valid|zeros"
66 | try:
67 | strides = tuple(strides)
68 | except TypeError:
69 | # if one number provided, make it a 2-tuple
70 | strides = (strides, strides)
71 | assert dilation == 1 or all(
72 | d == 1 for d in dilation
73 | ), "Dilation > 1 is not supported"
74 |
75 | # compute output spatial shape
76 | # based on TF computations https://stackoverflow.com/a/37674568
77 | if padding in ["same", "zeros"]:
78 | out_nrows = np.ceil(float(input_shape[0]) / strides[0])
79 | out_ncols = np.ceil(float(input_shape[1]) / strides[1])
80 | else: # padding == 'valid'
81 | out_nrows = np.ceil((input_shape[0] - kernel_shape[0] + 1) / strides[0]) # noqa
82 | out_ncols = np.ceil((input_shape[1] - kernel_shape[1] + 1) / strides[1]) # noqa
83 | output_shape = (int(out_nrows), int(out_ncols))
84 |
85 | # work to compute one output spatial position
86 | nflops = in_channels * out_channels * int(np.prod(kernel_shape))
87 |
88 | # total work = work per output position * number of output positions
89 | return nflops * int(np.prod(output_shape))
90 |
91 |
92 | if __name__ == "__main__":
93 | import doctest
94 |
95 | doctest.testmod()
96 |
--------------------------------------------------------------------------------
/knowledge_distillation/distillation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from student import LeNet, LeNetStudent
8 |
9 | PACKAGE_PARENT = ".."
10 | SCRIPT_DIR = os.path.dirname(
11 | os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))
12 | )
13 | sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
14 |
15 | from metrics import flops, model_size
16 | from utils.loaders import get_loaders
17 |
18 |
19 | def train(teacher, student, loss_fn, optimizer, data_loader, device):
20 | teacher.train(False)
21 | student.train()
22 | for (idx, (x, t)) in enumerate(data_loader):
23 | optimizer.zero_grad()
24 | x = x.to(device)
25 | x_teacher = teacher.forward(x)
26 | x_student = student.forward(x)
27 | t = t.to(device)
28 | loss = loss_fn(x_student, x_teacher, t)
29 | loss.backward()
30 | optimizer.step()
31 |
32 |
33 | def test(net, data_loader, device):
34 | top1 = 0
35 | correct_samples = 0
36 | total_samples = 0
37 | net.train(False)
38 | net.eval()
39 | with torch.no_grad():
40 | for (idx, (x, t)) in enumerate(data_loader):
41 | x = net.forward(x.to(device))
42 | t = t.to(device)
43 | _, indices = torch.max(x, 1)
44 | correct_samples += torch.sum(indices == t)
45 | total_samples += t.shape[0]
46 |
47 | top1 = float(correct_samples) / total_samples
48 | return top1
49 |
50 |
51 | def calculate_prune_metrics(net, test_loader, device):
52 |
53 | x, _ = next(iter(test_loader))
54 | x = x.to(device)
55 |
56 | size, size_nz = model_size(net)
57 |
58 | FLOPS = flops(net, x)
59 | return FLOPS, size, size_nz
60 |
61 |
62 | def cross_entropy_with_soft_targets(pred, soft_targets):
63 | logsoftmax = nn.LogSoftmax(dim=1)
64 | return torch.mean(torch.sum(-soft_targets * logsoftmax(pred), 1))
65 |
66 |
67 | class CrossEntropyLossTemperature_withSoftTargets(torch.nn.Module):
68 | def __init__(self, temperature, reduction="mean"):
69 | super(CrossEntropyLossTemperature_withSoftTargets, self).__init__()
70 | self.T = temperature
71 | self.reduction = reduction
72 |
73 | def forward(self, input, soft_targets, hard_targets):
74 | """
75 | In the forward function we accept a Tensor of input data and we must
76 | return a Tensor of output data. We can use Modules defined in the
77 | constructor as well as arbitrary operators on Tensors.
78 | """
79 | z = input / self.T
80 | loss_1 = self.T ** 2 * cross_entropy_with_soft_targets(
81 | z, F.softmax(soft_targets, 1, _stacklevel=5)
82 | )
83 |
84 | loss_2 = F.cross_entropy(
85 | input,
86 | hard_targets,
87 | weight=None,
88 | ignore_index=-100,
89 | reduction=self.reduction,
90 | )
91 |
92 | return loss_1 + loss_2
93 |
94 |
95 | if __name__ == "__main__":
96 |
97 | device = "cuda"
98 | teacher = LeNet().to(device)
99 | student = LeNetStudent().to(device)
100 |
101 | teacher.load_state_dict(torch.load("models/best_acc.pth"))
102 | student.load_state_dict(torch.load("models/best_acc_student.pth"))
103 |
104 | batch_size_train = 1024
105 | batch_size_test = 1024
106 | nb_epoch = 40
107 |
108 | train_loader, test_loader = get_loaders(batch_size_train, batch_size_test)
109 |
110 | print(calculate_prune_metrics(teacher, test_loader, device))
111 | print(calculate_prune_metrics(student, test_loader, device))
112 |
113 | loss_fn = CrossEntropyLossTemperature_withSoftTargets(1)
114 | # optimizer = torch.optim.Adam(student.parameters(), lr=0.0001, weight_decay=0.00001)
115 | optimizer = torch.optim.SGD(student.parameters(), lr=0.01, momentum=0.2)
116 |
117 | best_model = None
118 | best_acc = 0
119 |
120 | teacher.train(False)
121 | for epoch in range(nb_epoch):
122 | train(teacher, student, loss_fn, optimizer, train_loader, device)
123 | test_top1 = test(student, test_loader, device)
124 | print(f"Epoch {epoch}. Top1 {test_top1:.4f}")
125 | if test_top1 > best_acc:
126 | best_model = student
127 | best_acc = test_top1
128 |
129 | torch.save(best_model.state_dict(), "models/best_acc_student_with_distillation.pth")
130 |
--------------------------------------------------------------------------------
/metrics/utils.py:
--------------------------------------------------------------------------------
1 | """Auxiliary utils for implementing pruning strategies
2 | """
3 |
4 | from collections import OrderedDict, defaultdict
5 |
6 | import torch
7 | from torch import nn
8 |
9 |
10 | def hook_applyfn(hook, model, forward=False, backward=False):
11 | """
12 |
13 | [description]
14 |
15 | Arguments:
16 | hook {[type]} -- [description]
17 | model {[type]} -- [description]
18 |
19 | Keyword Arguments:
20 | forward {bool} -- [description] (default: {False})
21 | backward {bool} -- [description] (default: {False})
22 |
23 | Returns:
24 | [type] -- [description]
25 | """
26 | assert forward ^ backward, "Either forward or backward must be True"
27 | hooks = []
28 |
29 | def register_hook(module):
30 | if (
31 | not isinstance(module, nn.Sequential)
32 | and not isinstance(module, nn.ModuleList)
33 | and not isinstance(module, nn.ModuleDict)
34 | and not (module == model)
35 | ):
36 | if forward:
37 | hooks.append(module.register_forward_hook(hook))
38 | if backward:
39 | hooks.append(module.register_backward_hook(hook))
40 |
41 | return register_hook, hooks
42 |
43 |
44 | def get_params(model, recurse=False):
45 | """Returns dictionary of paramters
46 |
47 | Arguments:
48 | model {torch.nn.Module} -- Network to extract the parameters from
49 |
50 | Keyword Arguments:
51 | recurse {bool} -- Whether to recurse through children modules
52 |
53 | Returns:
54 | Dict(str:numpy.ndarray) -- Dictionary of named parameters their
55 | associated parameter arrays
56 | """
57 | params = {
58 | k: v.detach().cpu().numpy().copy()
59 | for k, v in model.named_parameters(recurse=recurse)
60 | }
61 | return params
62 |
63 |
64 | def get_activations(model, input):
65 |
66 | activations = OrderedDict()
67 |
68 | def store_activations(module, input, output):
69 | if isinstance(module, nn.ReLU):
70 | # TODO ResNet18 implementation reuses a
71 | # single ReLU layer?
72 | return
73 | assert module not in activations, f"{module} already in activations"
74 | # TODO [0] means first input, not all models have a single input
75 | activations[module] = (
76 | input[0].detach().cpu().numpy().copy(),
77 | output.detach().cpu().numpy().copy(),
78 | )
79 |
80 | fn, hooks = hook_applyfn(store_activations, model, forward=True)
81 | model.apply(fn)
82 | with torch.no_grad():
83 | model(input)
84 |
85 | for h in hooks:
86 | h.remove()
87 |
88 | return activations
89 |
90 |
91 | def get_gradients(model, inputs, outputs):
92 | # TODO implement using model.register_backward_hook()
93 | # So it is harder than it seems, the grad_input contains also the gradients
94 | # with respect to the weights and so far order seems to be
95 | # (bias, input, weight) which is confusing.
96 | # Moreover, a lot of the time the output activation we are
97 | # looking for is the one after the ReLU and F.ReLU (or any functional call)
98 | # will not be called by the forward or backward hook
99 | # Discussion here
100 | # https://discuss.pytorch.org/t/how-to-register-hook-function-for-functional-form/25775
101 | # Best way seems to be monkey patching F.ReLU & other functional ops
102 | # That'll also help figuring out how to compute a module graph
103 | pass
104 |
105 |
106 | def get_param_gradients(model, inputs, outputs, loss_func=None, by_module=True):
107 |
108 | gradients = OrderedDict()
109 |
110 | if loss_func is None:
111 | loss_func = nn.CrossEntropyLoss()
112 |
113 | training = model.training
114 | model.train()
115 | pred = model(inputs)
116 | loss = loss_func(pred, outputs)
117 | loss.backward()
118 |
119 | if by_module:
120 | gradients = defaultdict(OrderedDict)
121 | for module in model.modules():
122 | assert module not in gradients
123 | for name, param in module.named_parameters(recurse=False):
124 | if param.requires_grad and param.grad is not None:
125 | gradients[module][name] = param.grad.detach().cpu().numpy().copy()
126 |
127 | else:
128 | gradients = OrderedDict()
129 | for name, param in model.named_parameters():
130 | assert name not in gradients
131 | if param.requires_grad and param.grad is not None:
132 | gradients[name] = param.grad.detach().cpu().numpy().copy()
133 |
134 | model.zero_grad()
135 | model.train(training)
136 |
137 | return gradients
138 |
139 |
140 | def fraction_to_keep(compression, model, prunable_modules):
141 | """Return fraction of params to keep to achieve desired compression ratio
142 |
143 | Compression = total / ( fraction * prunable + (total-prunable))
144 | Using algrebra fraction is equal to
145 | fraction = total/prunable * (1/compression - 1) + 1
146 |
147 | Arguments:
148 | compression {float} -- Desired overall compression
149 | model {torch.nn.Module} -- Full model for which to compute the fraction
150 | prunable_modules {List(torch.nn.Module)} --
151 | Modules that can be pruned in the model.
152 |
153 | Returns:
154 | {float} -- Fraction of prunable parameters to keep
155 | to achieve desired compression
156 | """
157 | from ..metrics import model_size
158 |
159 | total_size, _ = model_size(model)
160 | prunable_size = sum([model_size(m)[0] for m in prunable_modules])
161 | nonprunable_size = total_size - prunable_size
162 | fraction = 1 / prunable_size * (total_size / compression - nonprunable_size)
163 | assert 0 < fraction <= 1, (
164 | f"Cannot compress to {1/compression} model\
165 | with {nonprunable_size/total_size}"
166 | + "fraction of unprunable parameters"
167 | )
168 | return fraction
169 |
--------------------------------------------------------------------------------
/pruning_loop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.utils.prune as prune
5 | import torch.optim as optim
6 |
7 | from utils.loaders import get_loaders
8 | from utils.maskedLayers import Conv2dMasked, LinearMasked
9 | from metrics import flops, model_size
10 |
11 |
12 | class LeNet(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 | self.conv1 = Conv2dMasked(1, 10, kernel_size=5)
16 | self.conv2 = Conv2dMasked(10, 20, kernel_size=5)
17 | self.conv2_drop = nn.Dropout2d()
18 | self.fc1 = LinearMasked(320, 50)
19 | self.fc2 = LinearMasked(50, 10)
20 |
21 | def forward(self, x):
22 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
23 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
24 | x = x.view(-1, 320)
25 | x = F.relu(self.fc1(x))
26 | x = F.dropout(x, training=self.training)
27 | x = self.fc2(x)
28 | return F.log_softmax(x, 1)
29 |
30 |
31 | class PruningExperiment:
32 | def __init__(
33 | self,
34 | pruning_strategy=None,
35 | batch_size_train=512,
36 | batch_size_test=1024,
37 | epochs_prune_finetune=3,
38 | epochs_finetune=5,
39 | device="cuda",
40 | save_model=None,
41 | ):
42 |
43 | """Initialize experiment
44 | :param pruning_strategy: Pruning strategy
45 | :param batch_size_train: Batch size for train
46 | :param batch_size_test: Batch size for test
47 | :param epochs_finetune: Number of epochs to finetune pruned model
48 | :param optimizer: Optimizer to perform gradient descent
49 | :param device: Device 'cpu' or 'cuda' where calculations are performed
50 |
51 | :return: Outcome of pruning strategy: Accuracy and pruning metrics
52 | """
53 | self.pruning_strategy = pruning_strategy
54 | self.device = device
55 | self.batch_size_train = batch_size_train
56 | self.batch_size_test = batch_size_test
57 | self.epochs_prune_finetune = epochs_prune_finetune
58 | self.epochs_finetune = epochs_finetune
59 | self.save_model = save_model
60 |
61 | def load_model(self):
62 |
63 | """Load LeNet model.
64 | All experiments will be performed on a trained model
65 | from the original script.
66 | """
67 | net = LeNet().to(self.device)
68 | net.load_state_dict(torch.load("models/best_acc.pth"))
69 | return net
70 |
71 | def prune_model(self, net, pruning_strategy):
72 |
73 | for modulename, strategy, name, amount, kwargs in pruning_strategy:
74 |
75 | module = getattr(net, modulename)
76 | if kwargs:
77 | n = kwargs["n"]
78 | dim = kwargs["dim"]
79 | mask = strategy(module, name=name, amount=amount, n=n, dim=dim)
80 | else:
81 | mask = strategy(module, name=name, amount=amount, n=n, dim=dim)
82 | # print(modulename, mask.weight_mask.shape)
83 | module.set_mask(mask.weight_mask)
84 |
85 | return net
86 |
87 | def train(self, net, optimizer, data_loader, device):
88 |
89 | net.train()
90 | for (idx, (x, t)) in enumerate(data_loader):
91 | optimizer.zero_grad()
92 | x = net.forward(x.to(device))
93 | t = t.to(device)
94 | loss = F.nll_loss(x, t)
95 | loss.backward()
96 | optimizer.step()
97 |
98 | def test(self, net, data_loader, device):
99 |
100 | correct_samples = 0
101 | total_samples = 0
102 | net.eval()
103 | with torch.no_grad():
104 | for (idx, (x, t)) in enumerate(data_loader):
105 | x = net.forward(x.to(device))
106 | t = t.to(device)
107 | _, indices = torch.max(x, 1)
108 | correct_samples += torch.sum(indices == t)
109 | total_samples += t.shape[0]
110 | top1 = float(correct_samples) / total_samples
111 | return top1
112 |
113 | def calculate_prune_metrics(self, net, test_loader, device):
114 |
115 | x, _ = next(iter(test_loader))
116 | x = x.to(device)
117 |
118 | size, size_nz = model_size(net)
119 |
120 | FLOPS = flops(net, x)
121 | compression_ratio = size / size_nz
122 |
123 | return FLOPS, compression_ratio
124 |
125 | def run(self):
126 | """
127 | Main function to run pruning -> finetuning -> evaluation
128 | """
129 | pruning_strategy = self.pruning_strategy
130 | batch_size_train = self.batch_size_train
131 | batch_size_test = self.batch_size_test
132 | epochs_prune_finetune = self.epochs_prune_finetune
133 | epochs_finetune = self.epochs_finetune
134 | device = self.device
135 |
136 | net = self.load_model()
137 |
138 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
139 |
140 | train_loader, test_loader = get_loaders(batch_size_train, batch_size_test)
141 |
142 | print("Pruning cycle")
143 | print("=======================")
144 | for epoch in range(epochs_prune_finetune):
145 | if pruning_strategy is not None:
146 | net = self.prune_model(net, pruning_strategy)
147 | test_top1 = self.test(net, test_loader, device)
148 | print(f"After pruning: Epoch {epoch}. Top1 {test_top1}.")
149 | for finetune_epoch in range(epochs_finetune):
150 | self.train(net, optimizer, train_loader, device)
151 | test_top1 = self.test(net, test_loader, device)
152 | print(
153 | f"\tAfter finetuning: Epoch {finetune_epoch}.\
154 | Top1 {test_top1:.4f}"
155 | )
156 |
157 | for modulename, strategy, name, amount, kwargs in pruning_strategy:
158 | module = getattr(net, modulename)
159 | prune.remove(module, name)
160 | test_top1 = self.test(net, test_loader, device)
161 |
162 | FLOPS, compression_ratio = self.calculate_prune_metrics(
163 | net, test_loader, device
164 | )
165 |
166 | if self.save_model is not None:
167 | torch.save(net.state_dict(), f"models/{self.save_model}.pth")
168 |
169 | return test_top1, FLOPS[0] / FLOPS[1], compression_ratio
170 |
--------------------------------------------------------------------------------
/Results.md:
--------------------------------------------------------------------------------
1 | # Experiments in Neural Network pruning
2 |
3 | ### Prepared by Oleg Polivin, 26 November 2020
4 | ---
5 |
6 | Let's define metrics that we will use to evaluate the effectiveness of pruning. We will look at categorical accuracy to estimate the quality of a neural network.[1](#myfootnote1) Accuracy in the experiments is reported based on the test set, not the one that has been used for training the neural network.
7 |
8 |
9 | Much of this work is based on the paper [What is the State of Neural Network Pruning?](https://arxiv.org/abs/2003.03033)
10 |
11 | In order to estimate the effectiveness at pruning we will take into account:
12 |
13 | 1. Acceleration of inference on the test set.
14 | - Compare the number of multiply-adds operations (FLOPs) to perform inference.
15 | - Additionally, I compute average time of running the original/pruned model on data.
16 |
17 | 2. Model size reduction/ weights compression.
18 | - Here we will compare total number of non-zero parameters.
19 |
20 | ## Experiment setting
21 |
22 | Given the code for LeNet model in PyTorch, let's calculate the metrics defined above. Canonical LeNet-5 architecture is below:
23 |
24 | 
25 |
26 | Architecture in the original paper [Gradient-Based Learning Applied to Document Recognition](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) is a bit different from the code given (for example, there was no ``Dropout``, ``hyperbolic tangent`` was used as an activation, number of filters is different, etc), but the idea is the same. I will organize experiments as follows:
27 |
28 | 1. Train the model using the original script (``lenet_pytorch.py``, although I made some modifications there).
29 | 2. Perform evaluation of the model using the metrics defined above.
30 | 3. Save the trained model.
31 |
32 | I will perform experiments on pruning using the saved model.
33 |
34 | 4. In order to perform pruning experiments I added:
35 | - ``metrics/``
36 | - ``experiments.py`` (this is the main script that produces results).
37 | - ``pruning_loop.py`` implements the experiment.
38 | - ``utils``
39 |
40 | - ``avg_speed_clac.py`` to calculate average inference time on train data
41 | - ``loaders.py`` to create train/test loaders
42 | - ``maskedLayers.py`` wrappers for Linear and Conv2d PyTorch modules.
43 | - ``plot.ipynb`` Jupyter notebook to produce the plots below.
44 |
45 | ## Pruning setup
46 |
47 | As suggested in the [What is the State of Neural Network Pruning?](https://arxiv.org/abs/2003.03033) paper many pruning methods are described by the following algorithm:
48 |
49 | 1. A neural network (NN) is trained until convergence (20 epochs now).
50 | 2. ```
51 | for i in 1 to K do
52 | prune NN
53 | finetune NN
54 | end for
55 |
56 | It means that the neural network is pruned several times. In my version, a weight once set as zero will always stay zero. The weights that were pruned are not retrained. Note also that finetuning means that there are several epochs of training happening.
57 |
58 | In order to fix the pruning setup, in all the experiments number of prune-finetune epochs is equal to 3 (it is ``K`` above), and number of finetuning epochs is equal to 4. The categorical accuracy and model's speed-ups and compression is reported after pruning-finetuning is finished.
59 |
60 | ## Results
61 |
62 | ### Baseline
63 |
64 | LeNet model as defined in the code was trained for ``80`` epochs, and the best model chosen by categorical accuracy was saved. Highest categorical accuracy was reached on epoch ``78`` and equals ``0.9809``. Our objective is to stop when the model converges to be sure that we prune a converged model. There are ``932500`` add-multiply operations (FLOPs), and in 20 runs through train data (``60000`` samples) , average time is given by ``9.1961866`` seconds.
65 |
66 | ### Experiments
67 |
68 | #### Experiment 1: Unstructured pruning of random weights
69 |
70 | **Setting:** Prune fully-connected layers (``fc1``, ``fc2``) and both convolutional layers (``conv1``, ``conv2``). Increase pruning from 10% to 70% (step = 10%). The pruning percentage is given for each layer. Roughly it corresponds to compressing the model up to 36 times.
71 |
72 | #### Experiment 2: Unstructured pruning of the smallest weights (based on the L1 norm)
73 |
74 | **Setting:** Same as in experiment 1. Notice the change that now pruning is not random. Here I assign 0's to the smallest weights.
75 |
76 | #### Experiment 3: Structured pruning (based on the L1 norm)
77 |
78 | **Setting:** Here I use structured pruning. In PyTorch one can use ``prune.ln_structured`` for that. It is possible to pass a dimension (``dim``) to specify which channel should be dropped. For fully-connected layers as ``fc1`` or ``fc2`` -> ``dim=0`` corresponds to "switching off" output neurons (like ``320`` for ``fc1`` and ``10`` for ``fc2``). Therefore, it does not really make sense to switch off neurons in the classification layer ``fc2``. For convolutional layers like ``conv1`` or ``conv2`` -> ``dim=0`` corresponds to removing the output channels of the layers (like ``10`` for ``conv1`` and ``20`` for ``conv2``). That's why I will only prune ``fc1``, ``conv1`` and ``conv2`` layers, again going from pruning 10% of the layers channels up to 70%. For instance, for the fully-connected layers it means zeroing 5 up to 35 neurons out of 50. For ``conv1`` layer it means zeroing out all the connections corresponding to 1 up to 7 channels.
79 |
80 | Below I present results of my pruning experiments:
81 |
82 | 
83 |
84 | And I confirm that using average time of running a model during inference, there is no real change in terms of time for pruned or non-pruned models.
85 |
86 | ## Conclusions and caveats
87 |
88 | Here are my thoughts on the results above and some caveats.
89 |
90 | If we take the results at face value, we conclude that better results are obtained when we do unstructured pruning of the smallest weights based on L1 norm. In reality however (more on that below) unstructured pruning makes weights sparse, but since sparse operations are not supported in PyTorch yet, it does not bring real gains in terms of model size or speed of inference. However, we can think of such results as some evidence that a smaller architecture with a lower number of weights might be beneficial.
91 |
92 | Below are further caveats:
93 |
94 | ### Unstructured pruning
95 | 1. We are looking at FLOPs to estimate a speed-up of a pruned neural network. We look at the number of non-null parameters to estimate compression. It gives us an impression that by doing pruning we gain a significant speed-up and memory gain.
96 |
97 | 2. However, people report that when looking at actual time that it takes to make a prediction there is no gain in speed-up. I tested it with the model before pruning and after pruning (random weights), and this is true. There is no speedup in terms of average time of running inference. Also, saved PyTorch models (``.pth``) have the same size.
98 |
99 | 3. Additionally, there is no saving in memory, because all those zero elements still have to be saved.
100 |
101 | 4. To my understanding one needs to change the architecture of the neural network according to the zeroed weights in order to really have gains in speed and memory.
102 |
103 | 5. There is a different way which is to use sparse matrices and operations in PyTorch. But this functionality is in beta. See the discussion here [How to improve inference time of pruned model using torch.nn.utils.prune](https://discuss.pytorch.org/t/how-to-improve-inference-time-of-pruned-model-using-torch-nn-utils-prune/78633/4)
104 |
105 | 6. So, if we do unstructured pruning and we want to make use of sparse operations, we will have to write code for inference to take into account sparse matrices. Here is an example of a paper where authors could get large speed-ups but when they introduced operations with sparse matrices on FPGA. [How Can We Be So Dense? The Benefits of Using Highly Sparse Representations](https://arxiv.org/abs/1903.11257)
106 |
107 | What's said above is more relevant to unstructured pruning of weights.
108 |
109 | ### Structured pruning
110 |
111 | One can have speed-ups when using structured pruning, that is, for example, dropping some channels. The price for that would be a drop in accuracy, but at least this really works for better model size and speed-ups.
112 |
113 | ## Additional chapter: Knowledge distillation
114 |
115 | Knowledge distillation is the idea proposed by Geoffrey Hinton, Oriol Vinyals and Jeff Dean to tranfer knowledge from a huge trained model to a simple and light-weighted one. It is not pruning strictly speaking, but has the same objective: simplify the original neural network without sacrifying much of quality.
116 |
117 | It works the following way:
118 |
119 | - Train a comprehensive large network which has a good accuracy [Teacher Network]
120 | - Train a small network until convergence [Student Network]. There will be trade-offs between accuracy that you reach with a simpler model and the level of compression.
121 | - Distill the knowledge from the Teacher Network by training the Student Network using the outputs of the Teacher Network.
122 | - See that original accuracy of the trained and converged student network is increased!
123 |
124 | I provide the code to do it in the ``knowledge_distillation`` folder. Run
125 |
126 | ```
127 | python knowledge_distillation/train_student.py
128 | ```
129 | to train the student network. It has a simplified architecture relative to the original ``LeNet`` neural network. For example, when the trained student network is saved it takes 1.16 times less memory on disk (from 90 kBs to 77 kBs, even twice less if saved in ``PyTorch v1.4``). I ran training for 60 epochs and the best accuracy was reached on epoch 47, and it equals ``0.9260``. Thus we can say that the model has converged.
130 |
131 | Run
132 |
133 | ```
134 | python knowledge_distillation/distillation.py
135 | ```
136 | to do additional training of the converged student neural network distilling teacher network.
137 |
138 | Here are the results:
139 | - ``FLOPS`` compression coefficient is 42 (the student model is 42 times smaller in terms of FLOPS, down to 21840 multiply-add operations from 932500).
140 | - ``Model size`` compression coefficient is 3 (the student model is 3 times smaller in terms of size)
141 | - ``Accuracy`` of the retrained student model is ``0.9276``, which is a tiny bit better than the original student network.
142 |
143 | I would say that knowledge distillation is definitely worth a try as a method to perform model compression.
144 |
145 | ## Bibliography with comments
146 |
147 | 1. The code to calculate FLOPs is taken from [ShrinkBench repo](https://github.com/JJGO/shrinkbench) written by the authors of the [What is the State of Neural Network Pruning?](https://arxiv.org/abs/2003.03033) paper. The authors are Davis Blalock, Jose Javier Gonzalez Ortiz, Jonathan Frankle and John Guttag. They created this code to allow researchers to compare pruning algorithms: that is, compare compression rates, speed-ups and quality of the model after pruning among others. I copy their way to measure ``FLOPs`` and ``model size`` which is located in the ``metrics`` folder. It is necessary to say that I made some minor modifications to the code, and all errors remain mine and should not be attributed to the author's code. It is also important to add that I also take the logic of evaluating pruned models from this paper. All in all, this is the main source of inspiration for my research.
148 |
149 | 2. The next important source is this [Neural Network Pruning PyTorch Implementation](https://github.com/wanglouis49/pytorch-weights_pruning) by Luyu Wang and Gavin Ding. I copy their code for implementing the high-level idea of doing pruning:
150 | - Write wrappers on PyTorch Linear and Conv2d layers.
151 | - Binary mask is multiplied by actual layer weights
152 | - "Multiplying the mask is a differentiable operation and the backward pass is handed by automatic differentiation"
153 |
154 | 3. Next, I make use of the [PyTorch Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html). It is different from the implementations above. My implementation mixes the code of the above two implementations with PyTorch way.
155 |
156 | Sources on knowledge distillation:
157 |
158 | 4. [Dark knowledge](https://www.ttic.edu/dl/dark14.pdf)
159 |
160 | 5. [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
161 |
162 | 6. Open Data Science community (``ods.ai``) is my source of inspiration with brilliant people sharing their ideas on many aspects of Data Science.
163 |
164 | ## Footnotes
165 | 1: Indeed, at the extreme we can just predict a constant. Accuracy will be low, but prunning will be very effective, there will be no parameters at all in the neural network.
166 |
--------------------------------------------------------------------------------
/utils/plots.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 2,
4 | "metadata": {
5 | "language_info": {
6 | "name": "python",
7 | "codemirror_mode": {
8 | "name": "ipython",
9 | "version": 3
10 | },
11 | "version": "3.7.9-final"
12 | },
13 | "orig_nbformat": 2,
14 | "file_extension": ".py",
15 | "mimetype": "text/x-python",
16 | "name": "python",
17 | "npconvert_exporter": "python",
18 | "pygments_lexer": "ipython3",
19 | "version": 3,
20 | "kernelspec": {
21 | "name": "python37964bitpytorch17conda9210703cbb104b038c93c3861334f328",
22 | "display_name": "Python 3.7.9 64-bit ('pytorch17': conda)"
23 | }
24 | },
25 | "cells": [
26 | {
27 | "cell_type": "code",
28 | "execution_count": 1,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "from matplotlib import pyplot as plt\n",
33 | "%matplotlib inline"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "experiments_results = {\n",
43 | " \"UnstructRandomPrun\":\n",
44 | " {\n",
45 | " \"accuracy\": [0.9809, 0.9744, 0.9665, 0.9524, 0.9194, 0.8529, 0.4458, 0.1365],\n",
46 | " \"compression_flops\": [1, 1.3706, 1.9531, 2.9137, 4.6296, 8.0134, 15.625, 36.75],\n",
47 | " \"compression_size\": [1, 1.3695, 1.9455, 2.8927, 4.5614, 7.7750, 14.7368, 32.2124]\n",
48 | " },\n",
49 | " \"UnstructPrunL1Norm\":\n",
50 | " {\n",
51 | " \"accuracy\": [0.9809, 0.9824, 0.9818, 0.9781, 0.9698, 0.9543, 0.8555, 0.7565],\n",
52 | " \"compression_flops\": [1, 1.3706, 1.9531, 2.9137, 4.6296, 8.0134, 15.625, 36.75],\n",
53 | " \"compression_size\": [1, 1.3695, 1.9455, 2.8927, 4.5614, 7.7750, 14.7368, 32.2124]\n",
54 | " },\n",
55 | " \"StructuredPrunL1Norm\":\n",
56 | " {\n",
57 | " \"accuracy\": [0.9809, 0.9781, 0.9667, 0.9477, 0.9228, 0.753, 0.5095, 0.1135],\n",
58 | " \"compression_flops\": [1, 1.4268, 1.9976, 2.94, 4.98, 7.15, 16.34, 25.33],\n",
59 | " \"compression_size\": [1, 1.3561, 1.8934, 2.78, 4.23, 6.06, 11.97, 18.83]\n",
60 | " }\n",
61 | "}"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "data": {
71 | "image/png": "\n",
72 | "image/svg+xml": "\n\n\n\n",
73 | "text/plain": ""
74 | },
75 | "metadata": {
76 | "needs_background": "light"
77 | },
78 | "output_type": "display_data"
79 | }
80 | ],
81 | "source": [
82 | "plt.figure(figsize=(15,5))\n",
83 | "plt.subplot(1,2,1)\n",
84 | "marker = ['o', 's', 'd']\n",
85 | "for i, (key, results) in enumerate(experiments_results.items()):\n",
86 | " compression_flops = results[\"compression_flops\"]\n",
87 | " accuracy = results[\"accuracy\"]\n",
88 | " lbl = key\n",
89 | " plt.plot(compression_flops, accuracy, marker=marker[i], label = lbl)\n",
90 | "plt.xlabel('Compression (FLOPS)')\n",
91 | "plt.ylabel('Accuracy')\n",
92 | "plt.title('LeNet-5 on MNIST: Accuracy vs theoretical speedup')\n",
93 | "plt.grid(True)\n",
94 | "plt.legend()\n",
95 | "\n",
96 | "plt.subplot(1,2,2)\n",
97 | "\n",
98 | "for i, (key, results) in enumerate(experiments_results.items()):\n",
99 | " compression_size = results[\"compression_size\"]\n",
100 | " accuracy = results[\"accuracy\"]\n",
101 | " lbl = key\n",
102 | " plt.plot(compression_size, accuracy, marker=marker[i], label = lbl)\n",
103 | "plt.xlabel('Compression (Model size)')\n",
104 | "plt.ylabel('Accuracy')\n",
105 | "plt.title('LeNet-5 on MNIST: Accuracy vs model compression')\n",
106 | "plt.grid(True)\n",
107 | "plt.legend()\n",
108 | "\n",
109 | "plt.show()"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": []
118 | }
119 | ]
120 | }
121 |
--------------------------------------------------------------------------------