├── requirements.txt ├── .gitignore ├── torch_ard ├── __init__.py └── torch_ard.py ├── setup.py ├── LICENSE ├── examples ├── boston │ ├── boston_baseline.py │ └── boston_ard.py ├── cifar │ ├── cifar_baseline.py │ └── cifar_ard.py ├── models.py └── mnist │ ├── mnist_baseline.py │ └── mnist_ard.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1.0 2 | torchvision>=0.2.1 3 | scikit-learn>=0.19.2 4 | pandas -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | examples/*/data/** 3 | examples/*/checkpoint/* 4 | __pycache__ 5 | dist 6 | .idea 7 | -------------------------------------------------------------------------------- /torch_ard/__init__.py: -------------------------------------------------------------------------------- 1 | _author__ = 'Artem Ryzhikov' 2 | __version__ = '0.2.4' 3 | __all__ = ['LinearARD', 'Conv2dARD', 'get_ard_reg', 'get_dropped_params_ratio', 'ELBOLoss'] 4 | 5 | from .torch_ard import LinearARD, Conv2dARD, get_ard_reg, get_dropped_params_ratio, ELBOLoss 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Artem Ryzhikov' 2 | 3 | from setuptools import setup 4 | 5 | 6 | 7 | setup( 8 | name="pytorch_ard", 9 | version='0.2.4', 10 | description="Make your PyTorch faster", 11 | long_description=open('README.md', encoding='utf-8').read(), 12 | long_description_content_type='text/markdown', 13 | url='https://github.com/HolyBayes/pytorch_ard', 14 | author='Artem Ryzhikov', 15 | 16 | packages=['torch_ard'], 17 | 18 | classifiers=[ 19 | 'Intended Audience :: Science/Research', 20 | 'Programming Language :: Python :: 3 ', 21 | ], 22 | keywords='pytorch, bayesian neural networks, ard, deep learning, neural networks, machine learning', 23 | install_requires=[ 24 | 'torch>=1.1.0', 25 | 'torchvision>=0.2.1', 26 | 'scikit-learn>=0.19.2', 27 | 'pandas' 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Artem Ryzhikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /examples/boston/boston_baseline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from models import DenseModel 4 | from sklearn.datasets import load_boston 5 | from sklearn.model_selection import train_test_split 6 | import pandas as pd 7 | from torch import nn 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm, trange 11 | 12 | boston = load_boston() 13 | df = pd.DataFrame(boston.data, columns=boston.feature_names) 14 | df['PRICE'] = boston.target 15 | X, y = df.drop('PRICE', 1), df['PRICE'] 16 | 17 | train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.8) 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | train_X, test_X, train_y, test_y = \ 21 | [torch.from_numpy(np.array(x)).float().to(device) 22 | for x in [train_X, test_X, train_y, test_y]] 23 | 24 | model = DenseModel(input_shape=train_X.shape[1], output_shape=1, 25 | activation=nn.functional.relu).to(device) 26 | opt = torch.optim.Adam(model.parameters(), lr=1e-2) 27 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min') 28 | criterion = nn.MSELoss() 29 | 30 | n_epoches = 100000 31 | debug_frequency = 100 32 | 33 | pbar = trange(n_epoches, leave=True, position=0) 34 | for epoch in pbar: 35 | opt.zero_grad() 36 | preds = model(train_X).squeeze() 37 | loss = criterion(preds, train_y) 38 | loss.backward() 39 | # scheduler.step(loss) 40 | opt.step() 41 | loss_train = float(criterion(preds, train_y).detach().cpu().numpy()) 42 | preds = model(test_X).squeeze() 43 | loss_test = float(criterion(preds, test_y).detach().cpu().numpy()) 44 | pbar.set_description('MSE (train): %.3f\tMSE (test): %.3f' % 45 | (loss_train, loss_test)) 46 | pbar.update() 47 | -------------------------------------------------------------------------------- /examples/boston/boston_ard.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from models import DenseModelARD 4 | from sklearn.datasets import load_boston 5 | from sklearn.model_selection import train_test_split 6 | import pandas as pd 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch 10 | import numpy as np 11 | from torch_ard import get_ard_reg, get_dropped_params_ratio, ELBOLoss 12 | from tqdm import trange, tqdm 13 | 14 | boston = load_boston() 15 | df = pd.DataFrame(boston.data, columns=boston.feature_names) 16 | df['PRICE'] = boston.target 17 | X, y = df.drop('PRICE', 1), df['PRICE'] 18 | 19 | train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.8) 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | # device = torch.device('cpu') 23 | train_X, test_X, train_y, test_y = \ 24 | [torch.from_numpy(np.array(x)).float().to(device) 25 | for x in [train_X, test_X, train_y, test_y]] 26 | 27 | model = DenseModelARD(input_shape=train_X.shape[1], output_shape=1, 28 | activation=nn.functional.relu).to(device) 29 | opt = torch.optim.Adam(model.parameters(), lr=1e-3) 30 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min') 31 | criterion = ELBOLoss(model, F.mse_loss).to(device) 32 | 33 | n_epoches = 100000 34 | debug_frequency = 100 35 | def get_kl_weight(epoch): return min(1, 2 * epoch / n_epoches) 36 | 37 | 38 | pbar = trange(n_epoches, leave=True, position=0) 39 | for epoch in pbar: 40 | kl_weight = get_kl_weight(epoch) 41 | opt.zero_grad() 42 | preds = model(train_X).squeeze() 43 | loss = criterion(preds, train_y, 1, kl_weight) 44 | loss.backward() 45 | opt.step() 46 | loss_train = float( 47 | criterion(preds, train_y, 1, 0).detach().cpu().numpy()) 48 | preds = model(test_X).squeeze() 49 | loss_test = float( 50 | criterion(preds, test_y, 1, 0).detach().cpu().numpy()) 51 | pbar.set_description('MSE (train): %.3f\tMSE (test): %.3f\tReg: %.3f\tDropout rate: %f%%' % ( 52 | loss_train, loss_test, get_ard_reg(model).item(), 100 * get_dropped_params_ratio(model))) 53 | pbar.update() 54 | -------------------------------------------------------------------------------- /examples/cifar/cifar_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | import os 12 | import sys 13 | sys.path.append('../') 14 | 15 | from models import LeNet 16 | 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | ckpt_file = 'checkpoint/ckpt_baseline.t7' 21 | 22 | best_acc = 0 # best test accuracy 23 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 24 | 25 | # Data 26 | print('==> Preparing data..') 27 | transform_train = transforms.Compose([ 28 | transforms.RandomCrop(32, padding=4), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 32 | ]) 33 | 34 | transform_test = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 37 | ]) 38 | 39 | trainset = torchvision.datasets.CIFAR10( 40 | root='./data', train=True, download=True, transform=transform_train) 41 | trainloader = torch.utils.data.DataLoader( 42 | trainset, batch_size=128, shuffle=True, num_workers=2) 43 | 44 | testset = torchvision.datasets.CIFAR10( 45 | root='./data', train=False, download=True, transform=transform_test) 46 | testloader = torch.utils.data.DataLoader( 47 | testset, batch_size=100, shuffle=False, num_workers=2) 48 | 49 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 50 | 'dog', 'frog', 'horse', 'ship', 'truck') 51 | 52 | # Model 53 | print('==> Building model..') 54 | model = LeNet(3, len(classes)).to(device) 55 | 56 | if os.path.isfile(ckpt_file): 57 | checkpoint = torch.load(ckpt_file) 58 | model.load_state_dict(checkpoint['net']) 59 | best_acc = checkpoint['acc'] 60 | start_epoch = checkpoint['epoch'] 61 | 62 | criterion = nn.CrossEntropyLoss() 63 | optimizer = optim.SGD(model.parameters(), lr=1e-3, 64 | momentum=0.9) 65 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') 66 | 67 | 68 | # Training 69 | def train(epoch): 70 | print('\nEpoch: %d' % epoch) 71 | model.train() 72 | train_loss = [] 73 | correct = 0 74 | total = 0 75 | for batch_idx, (inputs, targets) in enumerate(trainloader): 76 | inputs, targets = inputs.to(device), targets.to(device) 77 | optimizer.zero_grad() 78 | outputs = model(inputs) 79 | loss = criterion(outputs, targets) 80 | loss.backward() 81 | # scheduler.step(loss) 82 | optimizer.step() 83 | 84 | train_loss.append(loss.item()) 85 | _, predicted = outputs.max(1) 86 | total += targets.size(0) 87 | correct += predicted.eq(targets).sum().item() 88 | print('Train loss: %.2f' % np.mean(train_loss)) 89 | print('Train accuracy: %.2f%%' % (correct * 100.0 / total)) 90 | 91 | 92 | def test(epoch): 93 | global best_acc 94 | model.eval() 95 | test_loss = [] 96 | correct = 0 97 | total = 0 98 | with torch.no_grad(): 99 | for batch_idx, (inputs, targets) in enumerate(testloader): 100 | inputs, targets = inputs.to(device), targets.to(device) 101 | outputs = model(inputs) 102 | loss = criterion(outputs, targets) 103 | 104 | test_loss.append(loss.item()) 105 | _, predicted = outputs.max(1) 106 | total += targets.size(0) 107 | correct += predicted.eq(targets).sum().item() 108 | 109 | # Save checkpoint. 110 | acc = 100. * correct / total 111 | print('Test loss: %.2f' % np.mean(test_loss)) 112 | print('Test accuracy: %.2f%%' % acc) 113 | if acc > best_acc: 114 | print('Saving..') 115 | state = { 116 | 'net': model.state_dict(), 117 | 'acc': acc, 118 | 'epoch': epoch, 119 | } 120 | if not os.path.isdir('checkpoint'): 121 | os.mkdir('checkpoint') 122 | torch.save(state, ckpt_file) 123 | best_acc = acc 124 | 125 | 126 | for epoch in range(start_epoch, start_epoch + 200): 127 | train(epoch) 128 | test(epoch) 129 | -------------------------------------------------------------------------------- /examples/models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../') 3 | import torch_ard as nn_ard 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import torch 7 | 8 | class DenseModelARD(nn.Module): 9 | def __init__(self, input_shape, output_shape, hidden_size=150, activation=None): 10 | super(DenseModelARD, self).__init__() 11 | self.l1 = nn_ard.LinearARD(input_shape, hidden_size) 12 | self.l2 = nn_ard.LinearARD(hidden_size, output_shape) 13 | self.activation = activation 14 | self._init_weights() 15 | 16 | def forward(self, input): 17 | x = input.to(self.device) 18 | x = self.l1(x) 19 | x = nn.functional.tanh(x) 20 | x = self.l2(x) 21 | if self.activation: x = self.activation(x) 22 | return x 23 | 24 | def _init_weights(self): 25 | for layer in self.children(): 26 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu')) 27 | 28 | @property 29 | def device(self): 30 | return next(self.parameters()).device 31 | 32 | 33 | class DenseModel(nn.Module): 34 | def __init__(self, input_shape, output_shape, hidden_size=150, activation=None): 35 | super(DenseModel, self).__init__() 36 | self.l1 = nn.Linear(input_shape, hidden_size) 37 | self.l2 = nn.Linear(hidden_size, output_shape) 38 | self.activation = activation 39 | self._init_weights() 40 | 41 | def forward(self, input): 42 | x = input.to(self.device) 43 | x = self.l1(x) 44 | x = nn.functional.tanh(x) 45 | x = self.l2(x) 46 | if self.activation: x = self.activation(x) 47 | return x 48 | 49 | def _init_weights(self): 50 | for layer in self.children(): 51 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu')) 52 | 53 | @property 54 | def device(self): 55 | return next(self.parameters()).device 56 | 57 | class LeNet(nn.Module): 58 | def __init__(self, input_shape, output_shape): 59 | super(LeNet, self).__init__() 60 | self.conv1 = nn.Conv2d(input_shape, 20, 5) 61 | self.conv2 = nn.Conv2d(20, 50, 5) 62 | self.l1 = nn.Linear(50*5*5, 500) 63 | self.l2 = nn.Linear(500, output_shape) 64 | self._init_weights() 65 | 66 | def forward(self, x): 67 | out = F.relu(self.conv1(x.to(self.device))) 68 | out = F.max_pool2d(out, 2) 69 | out = F.relu(self.conv2(out)) 70 | out = F.max_pool2d(out, 2) 71 | out = out.view(out.shape[0], -1) 72 | out = F.relu(self.l1(out)) 73 | return self.l2(out) 74 | # return F.log_softmax(self.l2(out), dim=1) 75 | 76 | def _init_weights(self): 77 | for layer in self.children(): 78 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu')) 79 | 80 | @property 81 | def device(self): 82 | return next(self.parameters()).device 83 | 84 | class LeNetARD(nn.Module): 85 | def __init__(self, input_shape, output_shape): 86 | super(LeNetARD, self).__init__() 87 | self.conv1 = nn_ard.Conv2dARD(input_shape, 20, 5) 88 | self.conv2 = nn_ard.Conv2dARD(20, 50, 5) 89 | self.l1 = nn_ard.LinearARD(50*5*5, 500) 90 | self.l2 = nn_ard.LinearARD(500, output_shape) 91 | self._init_weights() 92 | 93 | def forward(self, input): 94 | out = F.relu(self.conv1(input.to(self.device))) 95 | out = F.max_pool2d(out, 2) 96 | out = F.relu(self.conv2(out)) 97 | out = F.max_pool2d(out, 2) 98 | out = out.view(out.shape[0], -1) 99 | out = F.relu(self.l1(out)) 100 | return self.l2(out) 101 | # return F.log_softmax(self.l2(out), dim=1) 102 | 103 | def _init_weights(self): 104 | for layer in self.children(): 105 | if hasattr(layer, 'weight'): nn.init.xavier_uniform(layer.weight, gain=nn.init.calculate_gain('relu')) 106 | 107 | @property 108 | def device(self): 109 | return next(self.parameters()).device 110 | 111 | 112 | class LeNet_MNIST(LeNet): 113 | def __init__(self, input_shape, output_shape): 114 | super(LeNet_MNIST, self).__init__(input_shape, output_shape) 115 | self.l1 = nn.Linear(50*4*4, 500) 116 | super(LeNet_MNIST, self)._init_weights() 117 | 118 | class LeNetARD_MNIST(LeNetARD): 119 | def __init__(self, input_shape, output_shape): 120 | super(LeNetARD_MNIST, self).__init__(input_shape, output_shape) 121 | self.l1 = nn_ard.LinearARD(50*4*4, 500) 122 | super(LeNetARD_MNIST, self)._init_weights() 123 | -------------------------------------------------------------------------------- /examples/mnist/mnist_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | from torchvision import datasets 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | import os 12 | import sys 13 | sys.path.append('../') 14 | import time 15 | 16 | from models import LeNet_MNIST 17 | 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | ckpt_file = 'checkpoint/ckpt_baseline.t7' 22 | 23 | best_acc = 0 # best test accuracy 24 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 25 | 26 | # Data 27 | print('==> Preparing data..') 28 | transform_train = transforms.Compose([ 29 | transforms.RandomCrop(32, padding=4), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 33 | ]) 34 | 35 | transform_test = transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 38 | ]) 39 | 40 | trainset = datasets.MNIST('./data', train=True, download=True, 41 | transform=transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.1307,), (0.3081,)) 44 | ])) 45 | trainloader = torch.utils.data.DataLoader( 46 | trainset, batch_size=128, shuffle=True) 47 | 48 | testset = datasets.MNIST('./data', train=False, transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)) 51 | ])) 52 | testloader = torch.utils.data.DataLoader( 53 | testset, batch_size=1000, shuffle=True) 54 | 55 | 56 | n_classes = 10 57 | 58 | # Model 59 | print('==> Building model..') 60 | model = LeNet_MNIST(1, n_classes).to(device) 61 | 62 | if device.type == 'cuda': 63 | model = torch.nn.DataParallel(model) 64 | cudnn.benchmark = True 65 | 66 | if os.path.isfile(ckpt_file): 67 | checkpoint = torch.load(ckpt_file) 68 | model.load_state_dict(checkpoint['net']) 69 | best_acc = checkpoint['acc'] 70 | start_epoch = checkpoint['epoch'] 71 | 72 | criterion = nn.CrossEntropyLoss() 73 | optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9) 74 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') 75 | 76 | # Training 77 | 78 | 79 | def train(epoch): 80 | print('\nEpoch: %d' % epoch) 81 | model.train() 82 | train_loss = [] 83 | correct = 0 84 | total = 0 85 | for batch_idx, (inputs, targets) in enumerate(trainloader): 86 | inputs, targets = inputs.to(device), targets.to(device) 87 | optimizer.zero_grad() 88 | outputs = model(inputs) 89 | loss = criterion(outputs, targets) 90 | loss.backward() 91 | # scheduler.step(loss) 92 | optimizer.step() 93 | 94 | train_loss.append(loss.item()) 95 | _, predicted = outputs.max(1) 96 | total += targets.size(0) 97 | correct += predicted.eq(targets).sum().item() 98 | print('Train loss: %.3f' % np.mean(train_loss)) 99 | print('Train accuracy: %.3f%%' % (correct * 100.0 / total)) 100 | 101 | 102 | def test(epoch): 103 | global best_acc 104 | model.eval() 105 | test_loss = [] 106 | correct = 0 107 | total = 0 108 | inference_time_seconds = 0 109 | with torch.no_grad(): 110 | for batch_idx, (inputs, targets) in enumerate(testloader): 111 | inputs, targets = inputs.to(device), targets.to(device) 112 | start_ts = time.time() 113 | outputs = model(inputs) 114 | inference_time_seconds += time.time() - start_ts 115 | loss = criterion(outputs, targets) 116 | 117 | test_loss.append(loss.item()) 118 | _, predicted = outputs.max(1) 119 | total += targets.size(0) 120 | correct += predicted.eq(targets).sum().item() 121 | 122 | # Save checkpoint. 123 | acc = 100. * correct / total 124 | print('Test loss: %.3f' % np.mean(test_loss)) 125 | print('Test accuracy: %.3f%%' % acc) 126 | print('Inference time: %.2f seconds' % inference_time_seconds) 127 | if acc > best_acc: 128 | print('Saving..') 129 | state = { 130 | 'net': model.state_dict(), 131 | 'acc': acc, 132 | 'epoch': epoch, 133 | } 134 | if not os.path.isdir('checkpoint'): 135 | os.mkdir('checkpoint') 136 | torch.save(state, ckpt_file) 137 | best_acc = acc 138 | 139 | 140 | for epoch in range(start_epoch, start_epoch + 100): 141 | train(epoch) 142 | test(epoch) 143 | -------------------------------------------------------------------------------- /examples/cifar/cifar_ard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | import os 12 | import sys 13 | sys.path.append('../') 14 | 15 | from models import LeNetARD 16 | from torch_ard import get_ard_reg, get_dropped_params_ratio, ELBOLoss 17 | 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | ckpt_baseline_file = 'checkpoint/ckpt_baseline.t7' 22 | ckpt_file = 'checkpoint/ckpt_ard.t7' 23 | 24 | best_acc = 0 # best test accuracy 25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 26 | reg_factor = 1e-5 27 | 28 | # Data 29 | print('==> Preparing data..') 30 | transform_train = transforms.Compose([ 31 | transforms.RandomCrop(32, padding=4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 35 | ]) 36 | 37 | transform_test = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 40 | ]) 41 | 42 | trainset = torchvision.datasets.CIFAR10( 43 | root='./data', train=True, download=True, transform=transform_train) 44 | trainloader = torch.utils.data.DataLoader( 45 | trainset, batch_size=128, shuffle=True, num_workers=2) 46 | 47 | testset = torchvision.datasets.CIFAR10( 48 | root='./data', train=False, download=True, transform=transform_test) 49 | testloader = torch.utils.data.DataLoader( 50 | testset, batch_size=100, shuffle=False, num_workers=2) 51 | 52 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 53 | 'dog', 'frog', 'horse', 'ship', 'truck') 54 | 55 | # Model 56 | print('==> Building model..') 57 | model = LeNetARD(3, len(classes)).to(device) 58 | 59 | 60 | if os.path.isfile(ckpt_file): 61 | state_dict = model.state_dict() 62 | checkpoint = torch.load(ckpt_file) 63 | state_dict.update(checkpoint['net']) 64 | model.load_state_dict(state_dict) 65 | best_acc = checkpoint['acc'] 66 | start_epoch = checkpoint['epoch'] 67 | elif os.path.isfile(ckpt_baseline_file): 68 | state_dict = model.state_dict() 69 | checkpoint = torch.load(ckpt_baseline_file) 70 | state_dict.update(checkpoint['net']) 71 | model.load_state_dict(state_dict, strict=False) 72 | best_acc = checkpoint['acc'] 73 | start_epoch = checkpoint['epoch'] 74 | 75 | criterion = ELBOLoss(model, F.cross_entropy).to(device) 76 | optimizer = optim.SGD(model.parameters(), lr=1e-3, 77 | momentum=0.9) 78 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') 79 | n_epoches = 200 80 | def get_kl_weight(epoch): return min(1, 1e-4 * epoch / n_epoches) 81 | 82 | 83 | # Training 84 | def train(epoch): 85 | print('\nEpoch: %d' % epoch) 86 | kl_weight = get_kl_weight(epoch) 87 | model.train() 88 | train_loss = [] 89 | correct = 0 90 | total = 0 91 | for batch_idx, (inputs, targets) in enumerate(trainloader): 92 | inputs, targets = inputs.to(device), targets.to(device) 93 | optimizer.zero_grad() 94 | outputs = model(inputs) 95 | loss = criterion(outputs, targets, 1, kl_weight) 96 | loss.backward() 97 | 98 | # scheduler.step(loss) 99 | optimizer.step() 100 | 101 | train_loss.append(loss.item()) 102 | _, predicted = outputs.max(1) 103 | total += targets.size(0) 104 | correct += predicted.eq(targets).sum().item() 105 | print('Train loss: %.2f' % np.mean(train_loss)) 106 | print('Train accuracy: %.2f%%' % (correct * 100.0 / total)) 107 | 108 | 109 | def test(epoch): 110 | global best_acc 111 | model.eval() 112 | test_loss = [] 113 | correct = 0 114 | total = 0 115 | with torch.no_grad(): 116 | for batch_idx, (inputs, targets) in enumerate(testloader): 117 | inputs, targets = inputs.to(device), targets.to(device) 118 | outputs = model(inputs) 119 | loss = criterion(outputs, targets, 1, 0) 120 | 121 | test_loss.append(loss.item()) 122 | _, predicted = outputs.max(1) 123 | total += targets.size(0) 124 | correct += predicted.eq(targets).sum().item() 125 | 126 | # Save checkpoint. 127 | acc = 100. * correct / total 128 | print('Test loss: %.2f' % np.mean(test_loss)) 129 | print('Test accuracy: %.2f%%' % acc) 130 | print('Compression: %.2f%%' % (100. * get_dropped_params_ratio(model))) 131 | if acc > best_acc: 132 | print('Saving..') 133 | state = { 134 | 'net': model.state_dict(), 135 | 'acc': acc, 136 | 'epoch': epoch, 137 | } 138 | if not os.path.isdir('checkpoint'): 139 | os.mkdir('checkpoint') 140 | torch.save(state, ckpt_file) 141 | best_acc = acc 142 | 143 | 144 | for epoch in range(start_epoch, start_epoch + n_epoches): 145 | train(epoch) 146 | test(epoch) 147 | -------------------------------------------------------------------------------- /examples/mnist/mnist_ard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | from torchvision import datasets 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | import os 12 | import sys 13 | sys.path.append('../') 14 | import time 15 | 16 | from models import LeNetARD_MNIST 17 | from torch_ard import get_ard_reg, get_dropped_params_ratio, ELBOLoss 18 | 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | ckpt_baseline_file = 'checkpoint/ckpt_baseline.t7' 23 | ckpt_file = 'checkpoint/ckpt_ard.t7' 24 | 25 | best_acc = 0 # best test accuracy 26 | best_compression = 0 27 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 28 | reg_factor = 1e-5 29 | 30 | # Data 31 | print('==> Preparing data..') 32 | transform_train = transforms.Compose([ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 37 | ]) 38 | 39 | transform_test = transforms.Compose([ 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 42 | ]) 43 | 44 | trainset = datasets.MNIST('./data', train=True, download=True, 45 | transform=transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.1307,), (0.3081,)) 48 | ])) 49 | trainloader = torch.utils.data.DataLoader( 50 | trainset, batch_size=128, shuffle=True) 51 | 52 | testset = datasets.MNIST('./data', train=False, transform=transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.1307,), (0.3081,)) 55 | ])) 56 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True) 57 | 58 | n_classes = 10 59 | 60 | # Model 61 | print('==> Building model..') 62 | model = LeNetARD_MNIST(1, n_classes).to(device) 63 | 64 | 65 | if os.path.isfile(ckpt_file): 66 | state_dict = model.state_dict() 67 | checkpoint = torch.load(ckpt_file) 68 | state_dict.update(checkpoint['net']) 69 | model.load_state_dict(state_dict) 70 | best_acc = checkpoint['acc'] 71 | start_epoch = checkpoint['epoch'] 72 | elif os.path.isfile(ckpt_baseline_file): 73 | state_dict = model.state_dict() 74 | checkpoint = torch.load(ckpt_baseline_file) 75 | state_dict.update(checkpoint['net']) 76 | model.load_state_dict(state_dict, strict=False) 77 | best_acc = checkpoint['acc'] 78 | start_epoch = checkpoint['epoch'] 79 | 80 | 81 | criterion = ELBOLoss(model, F.cross_entropy).to(device) 82 | optimizer = optim.SGD(model.parameters(), lr=1e-3, 83 | momentum=0.9) 84 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') 85 | n_epoches = 100 86 | 87 | def get_kl_weight(epoch): return min(1, 1e-2 * epoch / n_epoches) 88 | 89 | # Training 90 | 91 | 92 | def train(epoch): 93 | print('\nEpoch: %d' % epoch) 94 | kl_weight = get_kl_weight(epoch) 95 | model.train() 96 | train_loss = [] 97 | correct = 0 98 | total = 0 99 | for batch_idx, (inputs, targets) in enumerate(trainloader): 100 | inputs, targets = inputs.to(device), targets.to(device) 101 | optimizer.zero_grad() 102 | outputs = model(inputs) 103 | loss = criterion(outputs, targets, 1, kl_weight) 104 | loss.backward() 105 | # scheduler.step(loss) 106 | optimizer.step() 107 | 108 | train_loss.append(loss.item()) 109 | _, predicted = outputs.max(1) 110 | total += targets.size(0) 111 | correct += predicted.eq(targets).sum().item() 112 | print('Train loss: %.3f' % np.mean(train_loss)) 113 | print('Train accuracy: %.3f%%' % (correct * 100.0 / total)) 114 | 115 | 116 | def test(epoch): 117 | global best_acc 118 | global best_compression 119 | model.eval() 120 | test_loss = [] 121 | correct = 0 122 | total = 0 123 | inference_time_seconds = 0 124 | with torch.no_grad(): 125 | for batch_idx, (inputs, targets) in enumerate(testloader): 126 | inputs, targets = inputs.to(device), targets.to(device) 127 | start_ts = time.time() 128 | outputs = model(inputs) 129 | inference_time_seconds += time.time() - start_ts 130 | loss = criterion(outputs, targets) 131 | 132 | test_loss.append(loss.item()) 133 | _, predicted = outputs.max(1) 134 | total += targets.size(0) 135 | correct += predicted.eq(targets).sum().item() 136 | 137 | # Save checkpoint. 138 | acc = 100. * correct / total 139 | compression = 100. * get_dropped_params_ratio(model) 140 | print('Test loss: %.3f' % np.mean(test_loss)) 141 | print('Test accuracy: %.3f%%' % acc) 142 | print('Compression: %.2f%%' % compression) 143 | print('Inference time: %.2f seconds' % inference_time_seconds) 144 | # if acc > best_acc: 145 | if compression > best_compression: 146 | print('Saving..') 147 | state = { 148 | 'net': model.state_dict(), 149 | 'acc': acc, 150 | 'epoch': epoch, 151 | 'compression': compression 152 | } 153 | if not os.path.isdir('checkpoint'): 154 | os.mkdir('checkpoint') 155 | torch.save(state, ckpt_file) 156 | # best_acc = acc 157 | best_compression = compression 158 | 159 | 160 | for epoch in range(start_epoch, start_epoch + n_epoches): 161 | test(epoch) 162 | train(epoch) 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Dropout Sparsifies NN (Pytorch) 2 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](LICENSE) 3 | [![PyPI version](https://badge.fury.io/py/pytorch-ard.svg)](https://badge.fury.io/py/pytorch-ard) 4 | 5 | 6 | Make your neural network 300 times faster! 7 | 8 | Pytorch implementation of Variational Dropout Sparsifies Deep Neural Networks ([arxiv:1701.05369](https://arxiv.org/abs/1701.05369)). 9 | 10 | ## Description 11 | The discovered approach helps to train both convolutional and dense deep sparsified models without significant loss of quality. Additive Noise Reparameterization 12 | and the Local Reparameterization Trick discovered in the paper helps to eliminate weights prior's restrictions () and achieve Automatic Relevance Determination (ARD) effect on (typically most) network's parameters. According to the original paper, authors reduced the number of parameters up to 280 times on LeNet architectures and up to 68 times on VGG-like networks with a negligible decrease of accuracy. Experiments with Boston dataset in this repository proves that: 99% of simple dense model were dropped using paper's ARD-prior without any significant loss of MSE. Moreover, this technique helps to significantly reduce overfitting and helps to not worry about model's complexity - all redundant parameters will be dropped automatically. Moreover, you can achieve any degree of regularization variating regularization factor tradeoff (see ***reg_factor*** variable in [boston_ard.py](examples/boston/boston_ard.py) and [cifar_ard.py](examples/cifar/cifar_ard.py) scripts) 13 | 14 | ## Usage 15 | 16 | ```python 17 | import torch_ard as nn_ard 18 | from torch import nn 19 | import torch.nn.functional as F 20 | 21 | input_size, hidden_size, output_size = 60, 150, 1 22 | 23 | model = nn.Sequential( 24 | nn_ard.LinearARD(input_size, hidden_size), 25 | nn.ReLU(), 26 | nn_ard.LinearARD(hidden_size, output_size) 27 | ) 28 | 29 | 30 | criterion = nn_ard.ELBOLoss(model, F.cross_entropy) 31 | print('Sparsification ratio: %.3f%%' % (100.*nn_ard.get_dropped_params_ratio(model))) 32 | 33 | # test stage 34 | model.eval() # Needed for speed-up 35 | model(input) 36 | ``` 37 | 38 | ## Installation 39 | 40 | ``` 41 | pip install git+https://github.com/HolyBayes/pytorch_ard 42 | ``` 43 | 44 | ## Experiments 45 | 46 | All experiments are placed at [examples](examples/) folder and contains baseline and implemented models comparison. 47 | 48 | ### Boston dataset 49 | 50 | Two scripts were used in the experiment: [boston_baseline.py](examples/boston/boston_baseline.py) and [boston_ard.py](examples/boston/boston_ard.py). Training procedure for each experiment was **100000 epoches, Adam(lr=1e-3)**. Baseline model was dense neural network with single hidden layer with hidden size 150. 51 | 52 | | | Baseline (nn.Linear) | LinearARD, no reg | LinearARD, reg=0.0001 | LinearARD, reg=0.001 | LinearARD, reg=0.1 | LinearARD, reg=1 | 53 | |----------------|----------|-------------|-----------------|----------------|--------------|------------| 54 | | MSE (train) | 1.751 | 1.626 | 1.587 | 1.962 | 17.167 | 33.682 | 55 | | MSE (test) | 22.580 | 16.229 | 15.957 | 8.416 | 25.695 | 30.231 | 56 | | Compression, % | 0 | 0.38 | 52.95 | 64.19 | 97.29 | 99.29 | 57 | 58 | You can see on the table above that variating regularization factor any degree of compression can be achieved (for example, ~99.29% of connections can be dropped if reg_factor=1 will be used). Moreover, you can see that training with LinearARD layers with some regularization parameters (like reg=0.001 in the table above) not only significantly reduces number of model parameters (>64% of parameters can be dropped after training), but also significantly increases quality on test, reducing overfitting. 59 | 60 | ## Tips 61 | 62 | 1. Despite the high performance of implemented layers in "end-to-end" mode, authors recommends to use in fine-tuning pretrained models without ARD prior. In this case the best performance could be achieved. Moreover, it will be faster - despite of comparable convergence speed of this layers optimization, each training epoch takes more time (approx. twice longer - ~2 times more parameters in \*ARD implementations). This fact well describable - using ARD prior in earlier stages can drop useful connections with unobvious dependencies. 63 | 2. Model's sparsification takes almost no any speed-up effects until You convert it to the sparse one! (*TODO*) 64 | 65 | 66 | ## Requirements 67 | * **PyTorch** >= 0.4.0 68 | * **SkLearn** >= 0.19.1 69 | * **Pandas** >= 0.23.3 70 | * **Numpy** >= 1.14.5 71 | 72 | ## Authors 73 | 74 | ``` 75 | @article{molchanov2017variational, 76 | title={Variational Dropout Sparsifies Deep Neural Networks}, 77 | author={Molchanov, Dmitry and Ashukha, Arsenii and Vetrov, Dmitry}, 78 | journal={arXiv preprint arXiv:1701.05369}, 79 | year={2017} 80 | } 81 | ``` 82 | [Original implementation](https://github.com/ars-ashuha/variational-dropout-sparsifies-dnn) (Theano/Lasagne) 83 | 84 | ## Citation 85 | 86 | ``` 87 | @misc{pytorch_ard, 88 | author = {Artem Ryzhikov}, 89 | title = {HolyBayes/pytorch_ard}, 90 | url = {https://github.com/HolyBayes/pytorch_ard}, 91 | year = {2018} 92 | } 93 | ``` 94 | 95 | ## Contacts 96 | 97 | Artem Ryzhikov, LAMBDA laboratory, Higher School of Economics, Yandex School of Data Analysis 98 | 99 | **E-mail:** artemryzhikoff@yandex.ru 100 | 101 | **Linkedin:** https://www.linkedin.com/in/artem-ryzhikov-2b6308103/ 102 | 103 | **Link:** https://www.hse.ru/org/persons/190912317 104 | -------------------------------------------------------------------------------- /torch_ard/torch_ard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | from functools import reduce 6 | import operator 7 | 8 | 9 | class LinearARD(nn.Module): 10 | """ 11 | Dense layer implementation with weights ARD-prior (arxiv:1701.05369) 12 | """ 13 | 14 | def __init__(self, in_features, out_features, bias=True, thresh=3, ard_init=-10): 15 | super(LinearARD, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 19 | self.thresh = thresh 20 | if bias: 21 | self.bias = Parameter(torch.Tensor(out_features)) 22 | else: 23 | self.register_parameter('bias', None) 24 | self.ard_init = ard_init 25 | self.log_sigma2 = Parameter(torch.Tensor(out_features, in_features)) 26 | self.reset_parameters() 27 | 28 | def forward(self, input): 29 | if self.training: 30 | W_mu = F.linear(input, self.weight) 31 | std_w = torch.exp(self.log_alpha).permute(1,0) 32 | W_std = torch.sqrt((input.pow(2)).matmul(std_w*(self.weight.permute(1,0)**2)) + 1e-15) 33 | 34 | epsilon = W_std.new(W_std.shape).normal_() 35 | output = W_mu + W_std * epsilon 36 | if self.bias: output += self.bias 37 | else: 38 | W = self.weights_clipped 39 | output = F.linear(input, W) + self.bias 40 | return output 41 | 42 | @property 43 | def weights_clipped(self): 44 | clip_mask = self.get_clip_mask() 45 | return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight) 46 | 47 | def reset_parameters(self): 48 | self.weight.data.normal_(0, 0.02) 49 | if self.bias is not None: 50 | self.bias.data.zero_() 51 | self.log_sigma2.data.fill_(self.ard_init) 52 | 53 | def get_clip_mask(self): 54 | log_alpha = self.log_alpha 55 | return torch.ge(log_alpha, self.thresh) 56 | 57 | def get_reg(self, **kwargs): 58 | """ 59 | Get weights regularization (KL(q(w)||p(w)) approximation) 60 | """ 61 | k1, k2, k3 = 0.63576, 1.8732, 1.48695 62 | C = -k1 63 | mdkl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - \ 64 | 0.5 * torch.log1p(torch.exp(-self.log_alpha)) + C 65 | return -torch.sum(mdkl) 66 | 67 | def extra_repr(self): 68 | return 'in_features={}, out_features={}, bias={}'.format( 69 | self.in_features, self.out_features, self.bias is not None 70 | ) 71 | 72 | def get_dropped_params_cnt(self): 73 | """ 74 | Get number of dropped weights (with log alpha greater than "thresh" parameter) 75 | 76 | :returns (number of dropped weights, number of all weight) 77 | """ 78 | return self.get_clip_mask().sum().cpu().numpy() 79 | 80 | @property 81 | def log_alpha(self): 82 | log_alpha = self.log_sigma2 - 2 * \ 83 | torch.log(torch.abs(self.weight) + 1e-15) 84 | return torch.clamp(log_alpha, -10, 10) 85 | 86 | 87 | class Conv2dARD(nn.Conv2d): 88 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 89 | padding=0, dilation=1, groups=1, ard_init=-10, thresh=3): 90 | bias = False # Goes to nan if bias = True 91 | super(Conv2dARD, self).__init__(in_channels, out_channels, kernel_size, stride, 92 | padding, dilation, groups, bias) 93 | self.bias = None 94 | self.thresh = thresh 95 | self.in_channels = in_channels 96 | self.out_channels = out_channels 97 | self.ard_init = ard_init 98 | self.log_sigma2 = Parameter(ard_init * torch.ones_like(self.weight)) 99 | # self.log_sigma2 = Parameter(2 * torch.log(torch.abs(self.weight) + eps).clone().detach()+ard_init*torch.ones_like(self.weight)) 100 | 101 | def forward(self, input): 102 | """ 103 | Forward with all regularized connections and random activations (Beyesian mode). Typically used for train 104 | """ 105 | if self.training == False: 106 | return F.conv2d(input, self.weights_clipped, 107 | self.bias, self.stride, 108 | self.padding, self.dilation, self.groups) 109 | W = self.weight 110 | 111 | conved_mu = F.conv2d(input, W, self.bias, self.stride, 112 | self.padding, self.dilation, self.groups) 113 | log_alpha = self.log_alpha 114 | conved_si = torch.sqrt(1e-15 + F.conv2d(input * input, 115 | torch.exp(log_alpha) * W * 116 | W, self.bias, self.stride, 117 | self.padding, self.dilation, self.groups)) 118 | conved = conved_mu + \ 119 | conved_si * \ 120 | torch.normal(torch.zeros_like(conved_mu), 121 | torch.ones_like(conved_mu)) 122 | return conved 123 | 124 | @property 125 | def weights_clipped(self): 126 | clip_mask = self.get_clip_mask() 127 | return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight) 128 | 129 | def get_clip_mask(self): 130 | log_alpha = self.log_alpha 131 | return torch.ge(log_alpha, self.thresh) 132 | 133 | def get_reg(self, **kwargs): 134 | """ 135 | Get weights regularization (KL(q(w)||p(w)) approximation) 136 | """ 137 | k1, k2, k3 = 0.63576, 1.8732, 1.48695 138 | C = -k1 139 | log_alpha = self.log_alpha 140 | mdkl = k1 * torch.sigmoid(k2 + k3 * log_alpha) - \ 141 | 0.5 * torch.log1p(torch.exp(-log_alpha)) + C 142 | return -torch.sum(mdkl) 143 | 144 | def extra_repr(self): 145 | return 'in_features={}, out_features={}, bias={}'.format( 146 | self.in_channels, self.out_channels, self.bias is not None 147 | ) 148 | 149 | def get_dropped_params_cnt(self): 150 | """ 151 | Get number of dropped weights (greater than "thresh" parameter) 152 | 153 | :returns (number of dropped weights, number of all weight) 154 | """ 155 | return self.get_clip_mask().sum().cpu().numpy() 156 | 157 | @property 158 | def log_alpha(self): 159 | log_alpha = self.log_sigma2 - 2 * \ 160 | torch.log(torch.abs(self.weight) + 1e-15) 161 | return torch.clamp(log_alpha, -8, 8) 162 | 163 | 164 | class ELBOLoss(nn.Module): 165 | def __init__(self, net, loss_fn): 166 | super(ELBOLoss, self).__init__() 167 | self.loss_fn = loss_fn 168 | self.net = net 169 | 170 | def forward(self, input, target, loss_weight=1., kl_weight=1.): 171 | assert not target.requires_grad 172 | # Estimate ELBO 173 | return loss_weight * self.loss_fn(input, target) \ 174 | + kl_weight * get_ard_reg(self.net) 175 | 176 | 177 | def get_ard_reg(module): 178 | """ 179 | :param module: model to evaluate ard regularization for 180 | :param reg: auxilary cumulative variable for recursion 181 | :return: total regularization for module 182 | """ 183 | if isinstance(module, LinearARD) or isinstance(module, Conv2dARD): 184 | return module.get_reg() 185 | elif hasattr(module, 'children'): 186 | return sum([get_ard_reg(submodule) for submodule in module.children()]) 187 | return 0 188 | 189 | 190 | def _get_dropped_params_cnt(module): 191 | if hasattr(module, 'get_dropped_params_cnt'): 192 | return module.get_dropped_params_cnt() 193 | elif hasattr(module, 'children'): 194 | return sum([_get_dropped_params_cnt(submodule) for submodule in module.children()]) 195 | return 0 196 | 197 | 198 | def _get_params_cnt(module): 199 | if any([isinstance(module, l) for l in [LinearARD, Conv2dARD]]): 200 | return reduce(operator.mul, module.weight.shape, 1) 201 | elif hasattr(module, 'children'): 202 | return sum( 203 | [_get_params_cnt(submodule) for submodule in module.children()]) 204 | return sum(p.numel() for p in module.parameters()) 205 | 206 | 207 | def get_dropped_params_ratio(model): 208 | return _get_dropped_params_cnt(model) * 1.0 / _get_params_cnt(model) 209 | --------------------------------------------------------------------------------