├── requirements.txt
├── .gitignore
├── torch_ard
├── __init__.py
└── torch_ard.py
├── setup.py
├── LICENSE
├── examples
├── boston
│ ├── boston_baseline.py
│ └── boston_ard.py
├── cifar
│ ├── cifar_baseline.py
│ └── cifar_ard.py
├── models.py
└── mnist
│ ├── mnist_baseline.py
│ └── mnist_ard.py
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.1.0
2 | torchvision>=0.2.1
3 | scikit-learn>=0.19.2
4 | pandas
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | examples/*/data/**
3 | examples/*/checkpoint/*
4 | __pycache__
5 | dist
6 | .idea
7 |
--------------------------------------------------------------------------------
/torch_ard/__init__.py:
--------------------------------------------------------------------------------
1 | _author__ = 'Artem Ryzhikov'
2 | __version__ = '0.2.4'
3 | __all__ = ['LinearARD', 'Conv2dARD', 'get_ard_reg', 'get_dropped_params_ratio', 'ELBOLoss']
4 |
5 | from .torch_ard import LinearARD, Conv2dARD, get_ard_reg, get_dropped_params_ratio, ELBOLoss
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | __author__ = 'Artem Ryzhikov'
2 |
3 | from setuptools import setup
4 |
5 |
6 |
7 | setup(
8 | name="pytorch_ard",
9 | version='0.2.4',
10 | description="Make your PyTorch faster",
11 | long_description=open('README.md', encoding='utf-8').read(),
12 | long_description_content_type='text/markdown',
13 | url='https://github.com/HolyBayes/pytorch_ard',
14 | author='Artem Ryzhikov',
15 |
16 | packages=['torch_ard'],
17 |
18 | classifiers=[
19 | 'Intended Audience :: Science/Research',
20 | 'Programming Language :: Python :: 3 ',
21 | ],
22 | keywords='pytorch, bayesian neural networks, ard, deep learning, neural networks, machine learning',
23 | install_requires=[
24 | 'torch>=1.1.0',
25 | 'torchvision>=0.2.1',
26 | 'scikit-learn>=0.19.2',
27 | 'pandas'
28 | ]
29 | )
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Artem Ryzhikov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/examples/boston/boston_baseline.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | from models import DenseModel
4 | from sklearn.datasets import load_boston
5 | from sklearn.model_selection import train_test_split
6 | import pandas as pd
7 | from torch import nn
8 | import torch
9 | import numpy as np
10 | from tqdm import tqdm, trange
11 |
12 | boston = load_boston()
13 | df = pd.DataFrame(boston.data, columns=boston.feature_names)
14 | df['PRICE'] = boston.target
15 | X, y = df.drop('PRICE', 1), df['PRICE']
16 |
17 | train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.8)
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 | train_X, test_X, train_y, test_y = \
21 | [torch.from_numpy(np.array(x)).float().to(device)
22 | for x in [train_X, test_X, train_y, test_y]]
23 |
24 | model = DenseModel(input_shape=train_X.shape[1], output_shape=1,
25 | activation=nn.functional.relu).to(device)
26 | opt = torch.optim.Adam(model.parameters(), lr=1e-2)
27 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')
28 | criterion = nn.MSELoss()
29 |
30 | n_epoches = 100000
31 | debug_frequency = 100
32 |
33 | pbar = trange(n_epoches, leave=True, position=0)
34 | for epoch in pbar:
35 | opt.zero_grad()
36 | preds = model(train_X).squeeze()
37 | loss = criterion(preds, train_y)
38 | loss.backward()
39 | # scheduler.step(loss)
40 | opt.step()
41 | loss_train = float(criterion(preds, train_y).detach().cpu().numpy())
42 | preds = model(test_X).squeeze()
43 | loss_test = float(criterion(preds, test_y).detach().cpu().numpy())
44 | pbar.set_description('MSE (train): %.3f\tMSE (test): %.3f' %
45 | (loss_train, loss_test))
46 | pbar.update()
47 |
--------------------------------------------------------------------------------
/examples/boston/boston_ard.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | from models import DenseModelARD
4 | from sklearn.datasets import load_boston
5 | from sklearn.model_selection import train_test_split
6 | import pandas as pd
7 | from torch import nn
8 | import torch.nn.functional as F
9 | import torch
10 | import numpy as np
11 | from torch_ard import get_ard_reg, get_dropped_params_ratio, ELBOLoss
12 | from tqdm import trange, tqdm
13 |
14 | boston = load_boston()
15 | df = pd.DataFrame(boston.data, columns=boston.feature_names)
16 | df['PRICE'] = boston.target
17 | X, y = df.drop('PRICE', 1), df['PRICE']
18 |
19 | train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.8)
20 |
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 | # device = torch.device('cpu')
23 | train_X, test_X, train_y, test_y = \
24 | [torch.from_numpy(np.array(x)).float().to(device)
25 | for x in [train_X, test_X, train_y, test_y]]
26 |
27 | model = DenseModelARD(input_shape=train_X.shape[1], output_shape=1,
28 | activation=nn.functional.relu).to(device)
29 | opt = torch.optim.Adam(model.parameters(), lr=1e-3)
30 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')
31 | criterion = ELBOLoss(model, F.mse_loss).to(device)
32 |
33 | n_epoches = 100000
34 | debug_frequency = 100
35 | def get_kl_weight(epoch): return min(1, 2 * epoch / n_epoches)
36 |
37 |
38 | pbar = trange(n_epoches, leave=True, position=0)
39 | for epoch in pbar:
40 | kl_weight = get_kl_weight(epoch)
41 | opt.zero_grad()
42 | preds = model(train_X).squeeze()
43 | loss = criterion(preds, train_y, 1, kl_weight)
44 | loss.backward()
45 | opt.step()
46 | loss_train = float(
47 | criterion(preds, train_y, 1, 0).detach().cpu().numpy())
48 | preds = model(test_X).squeeze()
49 | loss_test = float(
50 | criterion(preds, test_y, 1, 0).detach().cpu().numpy())
51 | pbar.set_description('MSE (train): %.3f\tMSE (test): %.3f\tReg: %.3f\tDropout rate: %f%%' % (
52 | loss_train, loss_test, get_ard_reg(model).item(), 100 * get_dropped_params_ratio(model)))
53 | pbar.update()
54 |
--------------------------------------------------------------------------------
/examples/cifar/cifar_baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 |
11 | import os
12 | import sys
13 | sys.path.append('../')
14 |
15 | from models import LeNet
16 |
17 |
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | ckpt_file = 'checkpoint/ckpt_baseline.t7'
21 |
22 | best_acc = 0 # best test accuracy
23 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
24 |
25 | # Data
26 | print('==> Preparing data..')
27 | transform_train = transforms.Compose([
28 | transforms.RandomCrop(32, padding=4),
29 | transforms.RandomHorizontalFlip(),
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
32 | ])
33 |
34 | transform_test = transforms.Compose([
35 | transforms.ToTensor(),
36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
37 | ])
38 |
39 | trainset = torchvision.datasets.CIFAR10(
40 | root='./data', train=True, download=True, transform=transform_train)
41 | trainloader = torch.utils.data.DataLoader(
42 | trainset, batch_size=128, shuffle=True, num_workers=2)
43 |
44 | testset = torchvision.datasets.CIFAR10(
45 | root='./data', train=False, download=True, transform=transform_test)
46 | testloader = torch.utils.data.DataLoader(
47 | testset, batch_size=100, shuffle=False, num_workers=2)
48 |
49 | classes = ('plane', 'car', 'bird', 'cat', 'deer',
50 | 'dog', 'frog', 'horse', 'ship', 'truck')
51 |
52 | # Model
53 | print('==> Building model..')
54 | model = LeNet(3, len(classes)).to(device)
55 |
56 | if os.path.isfile(ckpt_file):
57 | checkpoint = torch.load(ckpt_file)
58 | model.load_state_dict(checkpoint['net'])
59 | best_acc = checkpoint['acc']
60 | start_epoch = checkpoint['epoch']
61 |
62 | criterion = nn.CrossEntropyLoss()
63 | optimizer = optim.SGD(model.parameters(), lr=1e-3,
64 | momentum=0.9)
65 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
66 |
67 |
68 | # Training
69 | def train(epoch):
70 | print('\nEpoch: %d' % epoch)
71 | model.train()
72 | train_loss = []
73 | correct = 0
74 | total = 0
75 | for batch_idx, (inputs, targets) in enumerate(trainloader):
76 | inputs, targets = inputs.to(device), targets.to(device)
77 | optimizer.zero_grad()
78 | outputs = model(inputs)
79 | loss = criterion(outputs, targets)
80 | loss.backward()
81 | # scheduler.step(loss)
82 | optimizer.step()
83 |
84 | train_loss.append(loss.item())
85 | _, predicted = outputs.max(1)
86 | total += targets.size(0)
87 | correct += predicted.eq(targets).sum().item()
88 | print('Train loss: %.2f' % np.mean(train_loss))
89 | print('Train accuracy: %.2f%%' % (correct * 100.0 / total))
90 |
91 |
92 | def test(epoch):
93 | global best_acc
94 | model.eval()
95 | test_loss = []
96 | correct = 0
97 | total = 0
98 | with torch.no_grad():
99 | for batch_idx, (inputs, targets) in enumerate(testloader):
100 | inputs, targets = inputs.to(device), targets.to(device)
101 | outputs = model(inputs)
102 | loss = criterion(outputs, targets)
103 |
104 | test_loss.append(loss.item())
105 | _, predicted = outputs.max(1)
106 | total += targets.size(0)
107 | correct += predicted.eq(targets).sum().item()
108 |
109 | # Save checkpoint.
110 | acc = 100. * correct / total
111 | print('Test loss: %.2f' % np.mean(test_loss))
112 | print('Test accuracy: %.2f%%' % acc)
113 | if acc > best_acc:
114 | print('Saving..')
115 | state = {
116 | 'net': model.state_dict(),
117 | 'acc': acc,
118 | 'epoch': epoch,
119 | }
120 | if not os.path.isdir('checkpoint'):
121 | os.mkdir('checkpoint')
122 | torch.save(state, ckpt_file)
123 | best_acc = acc
124 |
125 |
126 | for epoch in range(start_epoch, start_epoch + 200):
127 | train(epoch)
128 | test(epoch)
129 |
--------------------------------------------------------------------------------
/examples/models.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../')
3 | import torch_ard as nn_ard
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import torch
7 |
8 | class DenseModelARD(nn.Module):
9 | def __init__(self, input_shape, output_shape, hidden_size=150, activation=None):
10 | super(DenseModelARD, self).__init__()
11 | self.l1 = nn_ard.LinearARD(input_shape, hidden_size)
12 | self.l2 = nn_ard.LinearARD(hidden_size, output_shape)
13 | self.activation = activation
14 | self._init_weights()
15 |
16 | def forward(self, input):
17 | x = input.to(self.device)
18 | x = self.l1(x)
19 | x = nn.functional.tanh(x)
20 | x = self.l2(x)
21 | if self.activation: x = self.activation(x)
22 | return x
23 |
24 | def _init_weights(self):
25 | for layer in self.children():
26 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu'))
27 |
28 | @property
29 | def device(self):
30 | return next(self.parameters()).device
31 |
32 |
33 | class DenseModel(nn.Module):
34 | def __init__(self, input_shape, output_shape, hidden_size=150, activation=None):
35 | super(DenseModel, self).__init__()
36 | self.l1 = nn.Linear(input_shape, hidden_size)
37 | self.l2 = nn.Linear(hidden_size, output_shape)
38 | self.activation = activation
39 | self._init_weights()
40 |
41 | def forward(self, input):
42 | x = input.to(self.device)
43 | x = self.l1(x)
44 | x = nn.functional.tanh(x)
45 | x = self.l2(x)
46 | if self.activation: x = self.activation(x)
47 | return x
48 |
49 | def _init_weights(self):
50 | for layer in self.children():
51 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu'))
52 |
53 | @property
54 | def device(self):
55 | return next(self.parameters()).device
56 |
57 | class LeNet(nn.Module):
58 | def __init__(self, input_shape, output_shape):
59 | super(LeNet, self).__init__()
60 | self.conv1 = nn.Conv2d(input_shape, 20, 5)
61 | self.conv2 = nn.Conv2d(20, 50, 5)
62 | self.l1 = nn.Linear(50*5*5, 500)
63 | self.l2 = nn.Linear(500, output_shape)
64 | self._init_weights()
65 |
66 | def forward(self, x):
67 | out = F.relu(self.conv1(x.to(self.device)))
68 | out = F.max_pool2d(out, 2)
69 | out = F.relu(self.conv2(out))
70 | out = F.max_pool2d(out, 2)
71 | out = out.view(out.shape[0], -1)
72 | out = F.relu(self.l1(out))
73 | return self.l2(out)
74 | # return F.log_softmax(self.l2(out), dim=1)
75 |
76 | def _init_weights(self):
77 | for layer in self.children():
78 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu'))
79 |
80 | @property
81 | def device(self):
82 | return next(self.parameters()).device
83 |
84 | class LeNetARD(nn.Module):
85 | def __init__(self, input_shape, output_shape):
86 | super(LeNetARD, self).__init__()
87 | self.conv1 = nn_ard.Conv2dARD(input_shape, 20, 5)
88 | self.conv2 = nn_ard.Conv2dARD(20, 50, 5)
89 | self.l1 = nn_ard.LinearARD(50*5*5, 500)
90 | self.l2 = nn_ard.LinearARD(500, output_shape)
91 | self._init_weights()
92 |
93 | def forward(self, input):
94 | out = F.relu(self.conv1(input.to(self.device)))
95 | out = F.max_pool2d(out, 2)
96 | out = F.relu(self.conv2(out))
97 | out = F.max_pool2d(out, 2)
98 | out = out.view(out.shape[0], -1)
99 | out = F.relu(self.l1(out))
100 | return self.l2(out)
101 | # return F.log_softmax(self.l2(out), dim=1)
102 |
103 | def _init_weights(self):
104 | for layer in self.children():
105 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu'))
106 |
107 | @property
108 | def device(self):
109 | return next(self.parameters()).device
110 |
111 |
112 | class LeNet_MNIST(LeNet):
113 | def __init__(self, input_shape, output_shape):
114 | super(LeNet_MNIST, self).__init__(input_shape, output_shape)
115 | self.l1 = nn.Linear(50*4*4, 500)
116 | super(LeNet_MNIST, self)._init_weights()
117 |
118 | class LeNetARD_MNIST(LeNetARD):
119 | def __init__(self, input_shape, output_shape):
120 | super(LeNetARD_MNIST, self).__init__(input_shape, output_shape)
121 | self.l1 = nn_ard.LinearARD(50*4*4, 500)
122 | super(LeNetARD_MNIST, self)._init_weights()
123 |
--------------------------------------------------------------------------------
/examples/mnist/mnist_baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | from torchvision import datasets
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 |
11 | import os
12 | import sys
13 | sys.path.append('../')
14 | import time
15 |
16 | from models import LeNet_MNIST
17 |
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | ckpt_file = 'checkpoint/ckpt_baseline.t7'
22 |
23 | best_acc = 0 # best test accuracy
24 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
25 |
26 | # Data
27 | print('==> Preparing data..')
28 | transform_train = transforms.Compose([
29 | transforms.RandomCrop(32, padding=4),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.ToTensor(),
32 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
33 | ])
34 |
35 | transform_test = transforms.Compose([
36 | transforms.ToTensor(),
37 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
38 | ])
39 |
40 | trainset = datasets.MNIST('./data', train=True, download=True,
41 | transform=transforms.Compose([
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.1307,), (0.3081,))
44 | ]))
45 | trainloader = torch.utils.data.DataLoader(
46 | trainset, batch_size=128, shuffle=True)
47 |
48 | testset = datasets.MNIST('./data', train=False, transform=transforms.Compose([
49 | transforms.ToTensor(),
50 | transforms.Normalize((0.1307,), (0.3081,))
51 | ]))
52 | testloader = torch.utils.data.DataLoader(
53 | testset, batch_size=1000, shuffle=True)
54 |
55 |
56 | n_classes = 10
57 |
58 | # Model
59 | print('==> Building model..')
60 | model = LeNet_MNIST(1, n_classes).to(device)
61 |
62 | if device.type == 'cuda':
63 | model = torch.nn.DataParallel(model)
64 | cudnn.benchmark = True
65 |
66 | if os.path.isfile(ckpt_file):
67 | checkpoint = torch.load(ckpt_file)
68 | model.load_state_dict(checkpoint['net'])
69 | best_acc = checkpoint['acc']
70 | start_epoch = checkpoint['epoch']
71 |
72 | criterion = nn.CrossEntropyLoss()
73 | optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
74 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
75 |
76 | # Training
77 |
78 |
79 | def train(epoch):
80 | print('\nEpoch: %d' % epoch)
81 | model.train()
82 | train_loss = []
83 | correct = 0
84 | total = 0
85 | for batch_idx, (inputs, targets) in enumerate(trainloader):
86 | inputs, targets = inputs.to(device), targets.to(device)
87 | optimizer.zero_grad()
88 | outputs = model(inputs)
89 | loss = criterion(outputs, targets)
90 | loss.backward()
91 | # scheduler.step(loss)
92 | optimizer.step()
93 |
94 | train_loss.append(loss.item())
95 | _, predicted = outputs.max(1)
96 | total += targets.size(0)
97 | correct += predicted.eq(targets).sum().item()
98 | print('Train loss: %.3f' % np.mean(train_loss))
99 | print('Train accuracy: %.3f%%' % (correct * 100.0 / total))
100 |
101 |
102 | def test(epoch):
103 | global best_acc
104 | model.eval()
105 | test_loss = []
106 | correct = 0
107 | total = 0
108 | inference_time_seconds = 0
109 | with torch.no_grad():
110 | for batch_idx, (inputs, targets) in enumerate(testloader):
111 | inputs, targets = inputs.to(device), targets.to(device)
112 | start_ts = time.time()
113 | outputs = model(inputs)
114 | inference_time_seconds += time.time() - start_ts
115 | loss = criterion(outputs, targets)
116 |
117 | test_loss.append(loss.item())
118 | _, predicted = outputs.max(1)
119 | total += targets.size(0)
120 | correct += predicted.eq(targets).sum().item()
121 |
122 | # Save checkpoint.
123 | acc = 100. * correct / total
124 | print('Test loss: %.3f' % np.mean(test_loss))
125 | print('Test accuracy: %.3f%%' % acc)
126 | print('Inference time: %.2f seconds' % inference_time_seconds)
127 | if acc > best_acc:
128 | print('Saving..')
129 | state = {
130 | 'net': model.state_dict(),
131 | 'acc': acc,
132 | 'epoch': epoch,
133 | }
134 | if not os.path.isdir('checkpoint'):
135 | os.mkdir('checkpoint')
136 | torch.save(state, ckpt_file)
137 | best_acc = acc
138 |
139 |
140 | for epoch in range(start_epoch, start_epoch + 100):
141 | train(epoch)
142 | test(epoch)
143 |
--------------------------------------------------------------------------------
/examples/cifar/cifar_ard.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 |
11 | import os
12 | import sys
13 | sys.path.append('../')
14 |
15 | from models import LeNetARD
16 | from torch_ard import get_ard_reg, get_dropped_params_ratio, ELBOLoss
17 |
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | ckpt_baseline_file = 'checkpoint/ckpt_baseline.t7'
22 | ckpt_file = 'checkpoint/ckpt_ard.t7'
23 |
24 | best_acc = 0 # best test accuracy
25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
26 | reg_factor = 1e-5
27 |
28 | # Data
29 | print('==> Preparing data..')
30 | transform_train = transforms.Compose([
31 | transforms.RandomCrop(32, padding=4),
32 | transforms.RandomHorizontalFlip(),
33 | transforms.ToTensor(),
34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
35 | ])
36 |
37 | transform_test = transforms.Compose([
38 | transforms.ToTensor(),
39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
40 | ])
41 |
42 | trainset = torchvision.datasets.CIFAR10(
43 | root='./data', train=True, download=True, transform=transform_train)
44 | trainloader = torch.utils.data.DataLoader(
45 | trainset, batch_size=128, shuffle=True, num_workers=2)
46 |
47 | testset = torchvision.datasets.CIFAR10(
48 | root='./data', train=False, download=True, transform=transform_test)
49 | testloader = torch.utils.data.DataLoader(
50 | testset, batch_size=100, shuffle=False, num_workers=2)
51 |
52 | classes = ('plane', 'car', 'bird', 'cat', 'deer',
53 | 'dog', 'frog', 'horse', 'ship', 'truck')
54 |
55 | # Model
56 | print('==> Building model..')
57 | model = LeNetARD(3, len(classes)).to(device)
58 |
59 |
60 | if os.path.isfile(ckpt_file):
61 | state_dict = model.state_dict()
62 | checkpoint = torch.load(ckpt_file)
63 | state_dict.update(checkpoint['net'])
64 | model.load_state_dict(state_dict)
65 | best_acc = checkpoint['acc']
66 | start_epoch = checkpoint['epoch']
67 | elif os.path.isfile(ckpt_baseline_file):
68 | state_dict = model.state_dict()
69 | checkpoint = torch.load(ckpt_baseline_file)
70 | state_dict.update(checkpoint['net'])
71 | model.load_state_dict(state_dict, strict=False)
72 | best_acc = checkpoint['acc']
73 | start_epoch = checkpoint['epoch']
74 |
75 | criterion = ELBOLoss(model, F.cross_entropy).to(device)
76 | optimizer = optim.SGD(model.parameters(), lr=1e-3,
77 | momentum=0.9)
78 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
79 | n_epoches = 200
80 | def get_kl_weight(epoch): return min(1, 1e-4 * epoch / n_epoches)
81 |
82 |
83 | # Training
84 | def train(epoch):
85 | print('\nEpoch: %d' % epoch)
86 | kl_weight = get_kl_weight(epoch)
87 | model.train()
88 | train_loss = []
89 | correct = 0
90 | total = 0
91 | for batch_idx, (inputs, targets) in enumerate(trainloader):
92 | inputs, targets = inputs.to(device), targets.to(device)
93 | optimizer.zero_grad()
94 | outputs = model(inputs)
95 | loss = criterion(outputs, targets, 1, kl_weight)
96 | loss.backward()
97 |
98 | # scheduler.step(loss)
99 | optimizer.step()
100 |
101 | train_loss.append(loss.item())
102 | _, predicted = outputs.max(1)
103 | total += targets.size(0)
104 | correct += predicted.eq(targets).sum().item()
105 | print('Train loss: %.2f' % np.mean(train_loss))
106 | print('Train accuracy: %.2f%%' % (correct * 100.0 / total))
107 |
108 |
109 | def test(epoch):
110 | global best_acc
111 | model.eval()
112 | test_loss = []
113 | correct = 0
114 | total = 0
115 | with torch.no_grad():
116 | for batch_idx, (inputs, targets) in enumerate(testloader):
117 | inputs, targets = inputs.to(device), targets.to(device)
118 | outputs = model(inputs)
119 | loss = criterion(outputs, targets, 1, 0)
120 |
121 | test_loss.append(loss.item())
122 | _, predicted = outputs.max(1)
123 | total += targets.size(0)
124 | correct += predicted.eq(targets).sum().item()
125 |
126 | # Save checkpoint.
127 | acc = 100. * correct / total
128 | print('Test loss: %.2f' % np.mean(test_loss))
129 | print('Test accuracy: %.2f%%' % acc)
130 | print('Compression: %.2f%%' % (100. * get_dropped_params_ratio(model)))
131 | if acc > best_acc:
132 | print('Saving..')
133 | state = {
134 | 'net': model.state_dict(),
135 | 'acc': acc,
136 | 'epoch': epoch,
137 | }
138 | if not os.path.isdir('checkpoint'):
139 | os.mkdir('checkpoint')
140 | torch.save(state, ckpt_file)
141 | best_acc = acc
142 |
143 |
144 | for epoch in range(start_epoch, start_epoch + n_epoches):
145 | train(epoch)
146 | test(epoch)
147 |
--------------------------------------------------------------------------------
/examples/mnist/mnist_ard.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | from torchvision import datasets
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 |
11 | import os
12 | import sys
13 | sys.path.append('../')
14 | import time
15 |
16 | from models import LeNetARD_MNIST
17 | from torch_ard import get_ard_reg, get_dropped_params_ratio, ELBOLoss
18 |
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 | ckpt_baseline_file = 'checkpoint/ckpt_baseline.t7'
23 | ckpt_file = 'checkpoint/ckpt_ard.t7'
24 |
25 | best_acc = 0 # best test accuracy
26 | best_compression = 0
27 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
28 | reg_factor = 1e-5
29 |
30 | # Data
31 | print('==> Preparing data..')
32 | transform_train = transforms.Compose([
33 | transforms.RandomCrop(32, padding=4),
34 | transforms.RandomHorizontalFlip(),
35 | transforms.ToTensor(),
36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
37 | ])
38 |
39 | transform_test = transforms.Compose([
40 | transforms.ToTensor(),
41 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
42 | ])
43 |
44 | trainset = datasets.MNIST('./data', train=True, download=True,
45 | transform=transforms.Compose([
46 | transforms.ToTensor(),
47 | transforms.Normalize((0.1307,), (0.3081,))
48 | ]))
49 | trainloader = torch.utils.data.DataLoader(
50 | trainset, batch_size=128, shuffle=True)
51 |
52 | testset = datasets.MNIST('./data', train=False, transform=transforms.Compose([
53 | transforms.ToTensor(),
54 | transforms.Normalize((0.1307,), (0.3081,))
55 | ]))
56 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True)
57 |
58 | n_classes = 10
59 |
60 | # Model
61 | print('==> Building model..')
62 | model = LeNetARD_MNIST(1, n_classes).to(device)
63 |
64 |
65 | if os.path.isfile(ckpt_file):
66 | state_dict = model.state_dict()
67 | checkpoint = torch.load(ckpt_file)
68 | state_dict.update(checkpoint['net'])
69 | model.load_state_dict(state_dict)
70 | best_acc = checkpoint['acc']
71 | start_epoch = checkpoint['epoch']
72 | elif os.path.isfile(ckpt_baseline_file):
73 | state_dict = model.state_dict()
74 | checkpoint = torch.load(ckpt_baseline_file)
75 | state_dict.update(checkpoint['net'])
76 | model.load_state_dict(state_dict, strict=False)
77 | best_acc = checkpoint['acc']
78 | start_epoch = checkpoint['epoch']
79 |
80 |
81 | criterion = ELBOLoss(model, F.cross_entropy).to(device)
82 | optimizer = optim.SGD(model.parameters(), lr=1e-3,
83 | momentum=0.9)
84 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
85 | n_epoches = 100
86 |
87 | def get_kl_weight(epoch): return min(1, 1e-2 * epoch / n_epoches)
88 |
89 | # Training
90 |
91 |
92 | def train(epoch):
93 | print('\nEpoch: %d' % epoch)
94 | kl_weight = get_kl_weight(epoch)
95 | model.train()
96 | train_loss = []
97 | correct = 0
98 | total = 0
99 | for batch_idx, (inputs, targets) in enumerate(trainloader):
100 | inputs, targets = inputs.to(device), targets.to(device)
101 | optimizer.zero_grad()
102 | outputs = model(inputs)
103 | loss = criterion(outputs, targets, 1, kl_weight)
104 | loss.backward()
105 | # scheduler.step(loss)
106 | optimizer.step()
107 |
108 | train_loss.append(loss.item())
109 | _, predicted = outputs.max(1)
110 | total += targets.size(0)
111 | correct += predicted.eq(targets).sum().item()
112 | print('Train loss: %.3f' % np.mean(train_loss))
113 | print('Train accuracy: %.3f%%' % (correct * 100.0 / total))
114 |
115 |
116 | def test(epoch):
117 | global best_acc
118 | global best_compression
119 | model.eval()
120 | test_loss = []
121 | correct = 0
122 | total = 0
123 | inference_time_seconds = 0
124 | with torch.no_grad():
125 | for batch_idx, (inputs, targets) in enumerate(testloader):
126 | inputs, targets = inputs.to(device), targets.to(device)
127 | start_ts = time.time()
128 | outputs = model(inputs)
129 | inference_time_seconds += time.time() - start_ts
130 | loss = criterion(outputs, targets)
131 |
132 | test_loss.append(loss.item())
133 | _, predicted = outputs.max(1)
134 | total += targets.size(0)
135 | correct += predicted.eq(targets).sum().item()
136 |
137 | # Save checkpoint.
138 | acc = 100. * correct / total
139 | compression = 100. * get_dropped_params_ratio(model)
140 | print('Test loss: %.3f' % np.mean(test_loss))
141 | print('Test accuracy: %.3f%%' % acc)
142 | print('Compression: %.2f%%' % compression)
143 | print('Inference time: %.2f seconds' % inference_time_seconds)
144 | # if acc > best_acc:
145 | if compression > best_compression:
146 | print('Saving..')
147 | state = {
148 | 'net': model.state_dict(),
149 | 'acc': acc,
150 | 'epoch': epoch,
151 | 'compression': compression
152 | }
153 | if not os.path.isdir('checkpoint'):
154 | os.mkdir('checkpoint')
155 | torch.save(state, ckpt_file)
156 | # best_acc = acc
157 | best_compression = compression
158 |
159 |
160 | for epoch in range(start_epoch, start_epoch + n_epoches):
161 | test(epoch)
162 | train(epoch)
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Dropout Sparsifies NN (Pytorch)
2 | [](LICENSE)
3 | [](https://badge.fury.io/py/pytorch-ard)
4 |
5 |
6 | Make your neural network 300 times faster!
7 |
8 | Pytorch implementation of Variational Dropout Sparsifies Deep Neural Networks ([arxiv:1701.05369](https://arxiv.org/abs/1701.05369)).
9 |
10 | ## Description
11 | The discovered approach helps to train both convolutional and dense deep sparsified models without significant loss of quality. Additive Noise Reparameterization
12 | and the Local Reparameterization Trick discovered in the paper helps to eliminate weights prior's restrictions (
) and achieve Automatic Relevance Determination (ARD) effect on (typically most) network's parameters. According to the original paper, authors reduced the number of parameters up to 280 times on LeNet architectures and up to 68 times on VGG-like networks with a negligible decrease of accuracy. Experiments with Boston dataset in this repository proves that: 99% of simple dense model were dropped using paper's ARD-prior without any significant loss of MSE. Moreover, this technique helps to significantly reduce overfitting and helps to not worry about model's complexity - all redundant parameters will be dropped automatically. Moreover, you can achieve any degree of regularization variating regularization factor tradeoff (see ***reg_factor*** variable in [boston_ard.py](examples/boston/boston_ard.py) and [cifar_ard.py](examples/cifar/cifar_ard.py) scripts)
13 |
14 | ## Usage
15 |
16 | ```python
17 | import torch_ard as nn_ard
18 | from torch import nn
19 | import torch.nn.functional as F
20 |
21 | input_size, hidden_size, output_size = 60, 150, 1
22 |
23 | model = nn.Sequential(
24 | nn_ard.LinearARD(input_size, hidden_size),
25 | nn.ReLU(),
26 | nn_ard.LinearARD(hidden_size, output_size)
27 | )
28 |
29 |
30 | criterion = nn_ard.ELBOLoss(model, F.cross_entropy)
31 | print('Sparsification ratio: %.3f%%' % (100.*nn_ard.get_dropped_params_ratio(model)))
32 |
33 | # test stage
34 | model.eval() # Needed for speed-up
35 | model(input)
36 | ```
37 |
38 | ## Installation
39 |
40 | ```
41 | pip install git+https://github.com/HolyBayes/pytorch_ard
42 | ```
43 |
44 | ## Experiments
45 |
46 | All experiments are placed at [examples](examples/) folder and contains baseline and implemented models comparison.
47 |
48 | ### Boston dataset
49 |
50 | Two scripts were used in the experiment: [boston_baseline.py](examples/boston/boston_baseline.py) and [boston_ard.py](examples/boston/boston_ard.py). Training procedure for each experiment was **100000 epoches, Adam(lr=1e-3)**. Baseline model was dense neural network with single hidden layer with hidden size 150.
51 |
52 | | | Baseline (nn.Linear) | LinearARD, no reg | LinearARD, reg=0.0001 | LinearARD, reg=0.001 | LinearARD, reg=0.1 | LinearARD, reg=1 |
53 | |----------------|----------|-------------|-----------------|----------------|--------------|------------|
54 | | MSE (train) | 1.751 | 1.626 | 1.587 | 1.962 | 17.167 | 33.682 |
55 | | MSE (test) | 22.580 | 16.229 | 15.957 | 8.416 | 25.695 | 30.231 |
56 | | Compression, % | 0 | 0.38 | 52.95 | 64.19 | 97.29 | 99.29 |
57 |
58 | You can see on the table above that variating regularization factor any degree of compression can be achieved (for example, ~99.29% of connections can be dropped if reg_factor=1 will be used). Moreover, you can see that training with LinearARD layers with some regularization parameters (like reg=0.001 in the table above) not only significantly reduces number of model parameters (>64% of parameters can be dropped after training), but also significantly increases quality on test, reducing overfitting.
59 |
60 | ## Tips
61 |
62 | 1. Despite the high performance of implemented layers in "end-to-end" mode, authors recommends to use in fine-tuning pretrained models without ARD prior. In this case the best performance could be achieved. Moreover, it will be faster - despite of comparable convergence speed of this layers optimization, each training epoch takes more time (approx. twice longer - ~2 times more parameters in \*ARD implementations). This fact well describable - using ARD prior in earlier stages can drop useful connections with unobvious dependencies.
63 | 2. Model's sparsification takes almost no any speed-up effects until You convert it to the sparse one! (*TODO*)
64 |
65 |
66 | ## Requirements
67 | * **PyTorch** >= 0.4.0
68 | * **SkLearn** >= 0.19.1
69 | * **Pandas** >= 0.23.3
70 | * **Numpy** >= 1.14.5
71 |
72 | ## Authors
73 |
74 | ```
75 | @article{molchanov2017variational,
76 | title={Variational Dropout Sparsifies Deep Neural Networks},
77 | author={Molchanov, Dmitry and Ashukha, Arsenii and Vetrov, Dmitry},
78 | journal={arXiv preprint arXiv:1701.05369},
79 | year={2017}
80 | }
81 | ```
82 | [Original implementation](https://github.com/ars-ashuha/variational-dropout-sparsifies-dnn) (Theano/Lasagne)
83 |
84 | ## Citation
85 |
86 | ```
87 | @misc{pytorch_ard,
88 | author = {Artem Ryzhikov},
89 | title = {HolyBayes/pytorch_ard},
90 | url = {https://github.com/HolyBayes/pytorch_ard},
91 | year = {2018}
92 | }
93 | ```
94 |
95 | ## Contacts
96 |
97 | Artem Ryzhikov, LAMBDA laboratory, Higher School of Economics, Yandex School of Data Analysis
98 |
99 | **E-mail:** artemryzhikoff@yandex.ru
100 |
101 | **Linkedin:** https://www.linkedin.com/in/artem-ryzhikov-2b6308103/
102 |
103 | **Link:** https://www.hse.ru/org/persons/190912317
104 |
--------------------------------------------------------------------------------
/torch_ard/torch_ard.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 | import torch.nn.functional as F
5 | from functools import reduce
6 | import operator
7 |
8 |
9 | class LinearARD(nn.Module):
10 | """
11 | Dense layer implementation with weights ARD-prior (arxiv:1701.05369)
12 | """
13 |
14 | def __init__(self, in_features, out_features, bias=True, thresh=3, ard_init=-10):
15 | super(LinearARD, self).__init__()
16 | self.in_features = in_features
17 | self.out_features = out_features
18 | self.weight = Parameter(torch.Tensor(out_features, in_features))
19 | self.thresh = thresh
20 | if bias:
21 | self.bias = Parameter(torch.Tensor(out_features))
22 | else:
23 | self.register_parameter('bias', None)
24 | self.ard_init = ard_init
25 | self.log_sigma2 = Parameter(torch.Tensor(out_features, in_features))
26 | self.reset_parameters()
27 |
28 | def forward(self, input):
29 | if self.training:
30 | W_mu = F.linear(input, self.weight)
31 | std_w = torch.exp(self.log_alpha).permute(1,0)
32 | W_std = torch.sqrt((input.pow(2)).matmul(std_w*(self.weight.permute(1,0)**2)) + 1e-15)
33 |
34 | epsilon = W_std.new(W_std.shape).normal_()
35 | output = W_mu + W_std * epsilon
36 | if self.bias: output += self.bias
37 | else:
38 | W = self.weights_clipped
39 | output = F.linear(input, W) + self.bias
40 | return output
41 |
42 | @property
43 | def weights_clipped(self):
44 | clip_mask = self.get_clip_mask()
45 | return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight)
46 |
47 | def reset_parameters(self):
48 | self.weight.data.normal_(0, 0.02)
49 | if self.bias is not None:
50 | self.bias.data.zero_()
51 | self.log_sigma2.data.fill_(self.ard_init)
52 |
53 | def get_clip_mask(self):
54 | log_alpha = self.log_alpha
55 | return torch.ge(log_alpha, self.thresh)
56 |
57 | def get_reg(self, **kwargs):
58 | """
59 | Get weights regularization (KL(q(w)||p(w)) approximation)
60 | """
61 | k1, k2, k3 = 0.63576, 1.8732, 1.48695
62 | C = -k1
63 | mdkl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - \
64 | 0.5 * torch.log1p(torch.exp(-self.log_alpha)) + C
65 | return -torch.sum(mdkl)
66 |
67 | def extra_repr(self):
68 | return 'in_features={}, out_features={}, bias={}'.format(
69 | self.in_features, self.out_features, self.bias is not None
70 | )
71 |
72 | def get_dropped_params_cnt(self):
73 | """
74 | Get number of dropped weights (with log alpha greater than "thresh" parameter)
75 |
76 | :returns (number of dropped weights, number of all weight)
77 | """
78 | return self.get_clip_mask().sum().cpu().numpy()
79 |
80 | @property
81 | def log_alpha(self):
82 | log_alpha = self.log_sigma2 - 2 * \
83 | torch.log(torch.abs(self.weight) + 1e-15)
84 | return torch.clamp(log_alpha, -10, 10)
85 |
86 |
87 | class Conv2dARD(nn.Conv2d):
88 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
89 | padding=0, dilation=1, groups=1, ard_init=-10, thresh=3):
90 | bias = False # Goes to nan if bias = True
91 | super(Conv2dARD, self).__init__(in_channels, out_channels, kernel_size, stride,
92 | padding, dilation, groups, bias)
93 | self.bias = None
94 | self.thresh = thresh
95 | self.in_channels = in_channels
96 | self.out_channels = out_channels
97 | self.ard_init = ard_init
98 | self.log_sigma2 = Parameter(ard_init * torch.ones_like(self.weight))
99 | # self.log_sigma2 = Parameter(2 * torch.log(torch.abs(self.weight) + eps).clone().detach()+ard_init*torch.ones_like(self.weight))
100 |
101 | def forward(self, input):
102 | """
103 | Forward with all regularized connections and random activations (Beyesian mode). Typically used for train
104 | """
105 | if self.training == False:
106 | return F.conv2d(input, self.weights_clipped,
107 | self.bias, self.stride,
108 | self.padding, self.dilation, self.groups)
109 | W = self.weight
110 |
111 | conved_mu = F.conv2d(input, W, self.bias, self.stride,
112 | self.padding, self.dilation, self.groups)
113 | log_alpha = self.log_alpha
114 | conved_si = torch.sqrt(1e-15 + F.conv2d(input * input,
115 | torch.exp(log_alpha) * W *
116 | W, self.bias, self.stride,
117 | self.padding, self.dilation, self.groups))
118 | conved = conved_mu + \
119 | conved_si * \
120 | torch.normal(torch.zeros_like(conved_mu),
121 | torch.ones_like(conved_mu))
122 | return conved
123 |
124 | @property
125 | def weights_clipped(self):
126 | clip_mask = self.get_clip_mask()
127 | return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight)
128 |
129 | def get_clip_mask(self):
130 | log_alpha = self.log_alpha
131 | return torch.ge(log_alpha, self.thresh)
132 |
133 | def get_reg(self, **kwargs):
134 | """
135 | Get weights regularization (KL(q(w)||p(w)) approximation)
136 | """
137 | k1, k2, k3 = 0.63576, 1.8732, 1.48695
138 | C = -k1
139 | log_alpha = self.log_alpha
140 | mdkl = k1 * torch.sigmoid(k2 + k3 * log_alpha) - \
141 | 0.5 * torch.log1p(torch.exp(-log_alpha)) + C
142 | return -torch.sum(mdkl)
143 |
144 | def extra_repr(self):
145 | return 'in_features={}, out_features={}, bias={}'.format(
146 | self.in_channels, self.out_channels, self.bias is not None
147 | )
148 |
149 | def get_dropped_params_cnt(self):
150 | """
151 | Get number of dropped weights (greater than "thresh" parameter)
152 |
153 | :returns (number of dropped weights, number of all weight)
154 | """
155 | return self.get_clip_mask().sum().cpu().numpy()
156 |
157 | @property
158 | def log_alpha(self):
159 | log_alpha = self.log_sigma2 - 2 * \
160 | torch.log(torch.abs(self.weight) + 1e-15)
161 | return torch.clamp(log_alpha, -8, 8)
162 |
163 |
164 | class ELBOLoss(nn.Module):
165 | def __init__(self, net, loss_fn):
166 | super(ELBOLoss, self).__init__()
167 | self.loss_fn = loss_fn
168 | self.net = net
169 |
170 | def forward(self, input, target, loss_weight=1., kl_weight=1.):
171 | assert not target.requires_grad
172 | # Estimate ELBO
173 | return loss_weight * self.loss_fn(input, target) \
174 | + kl_weight * get_ard_reg(self.net)
175 |
176 |
177 | def get_ard_reg(module):
178 | """
179 | :param module: model to evaluate ard regularization for
180 | :param reg: auxilary cumulative variable for recursion
181 | :return: total regularization for module
182 | """
183 | if isinstance(module, LinearARD) or isinstance(module, Conv2dARD):
184 | return module.get_reg()
185 | elif hasattr(module, 'children'):
186 | return sum([get_ard_reg(submodule) for submodule in module.children()])
187 | return 0
188 |
189 |
190 | def _get_dropped_params_cnt(module):
191 | if hasattr(module, 'get_dropped_params_cnt'):
192 | return module.get_dropped_params_cnt()
193 | elif hasattr(module, 'children'):
194 | return sum([_get_dropped_params_cnt(submodule) for submodule in module.children()])
195 | return 0
196 |
197 |
198 | def _get_params_cnt(module):
199 | if any([isinstance(module, l) for l in [LinearARD, Conv2dARD]]):
200 | return reduce(operator.mul, module.weight.shape, 1)
201 | elif hasattr(module, 'children'):
202 | return sum(
203 | [_get_params_cnt(submodule) for submodule in module.children()])
204 | return sum(p.numel() for p in module.parameters())
205 |
206 |
207 | def get_dropped_params_ratio(model):
208 | return _get_dropped_params_cnt(model) * 1.0 / _get_params_cnt(model)
209 |
--------------------------------------------------------------------------------