├── .gitignore ├── README.md ├── __pycache__ ├── __init__.cpython-36.pyc ├── net2net.cpython-36.pyc └── tests.cpython-36.pyc ├── examples ├── __init__.py ├── data │ ├── processed │ │ ├── test.pt │ │ └── training.pt │ └── raw │ │ ├── t10k-images-idx3-ubyte │ │ ├── t10k-labels-idx1-ubyte │ │ ├── train-images-idx3-ubyte │ │ └── train-labels-idx1-ubyte ├── plots │ └── cifar │ │ ├── Teacher_accu_plot.png │ │ ├── Teacher_loss_plot.png │ │ ├── Teacher_loss_plot_NormWider_NormDeeper.png.png │ │ ├── Teacher_loss_plot_noisyAndNormWider_NormDeeper.png │ │ ├── Teacher_loss_plot_noisyAndNormWider_noisyAndNormDeeper.png │ │ ├── Wider_Deeper_teacher_accu_plot.png │ │ ├── Wider_Deeper_teacher_loss_plot.png │ │ ├── Wider_teacher_accu_plot.png │ │ └── Wider_teacher_loss_plot.png ├── train_cifar10.py ├── train_mnist.py └── utils.py ├── net2net.py └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Net2Net 2 | Net2Net implementation on PyTorch for any possible vision layers (nn.Linear, nn.Conv2d, nn.Conv3D, even wider operator btw nn.ConvX to nn.Linear). 3 | Checkout the [paper](https://arxiv.org/abs/1511.05641:) for more detail 4 | 5 | ## Observations: 6 | 7 | - Using BatchNorm between layers, improves the competence of Net2Net. Otherwise, Net2Net approach is not able to get 8 | comparable results to a network trained from scratch. 9 | 10 | - Inducing noise to new units and connections prelude to better networks. The effect is more evident without BathNorm layer. 11 | 12 | - Normalizing layer weights before any Net2Net operation increases the speed of learning and gives better convergence. Even so, it worths to investgate better normalization methods than L2 norm. 13 | -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/net2net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/__pycache__/net2net.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/tests.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/__pycache__/tests.cpython-36.pyc -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/__init__.py -------------------------------------------------------------------------------- /examples/data/processed/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/data/processed/test.pt -------------------------------------------------------------------------------- /examples/data/processed/training.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/data/processed/training.pt -------------------------------------------------------------------------------- /examples/data/raw/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/data/raw/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /examples/data/raw/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /examples/data/raw/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/data/raw/train-images-idx3-ubyte -------------------------------------------------------------------------------- /examples/data/raw/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/data/raw/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /examples/plots/cifar/Teacher_accu_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Teacher_accu_plot.png -------------------------------------------------------------------------------- /examples/plots/cifar/Teacher_loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Teacher_loss_plot.png -------------------------------------------------------------------------------- /examples/plots/cifar/Teacher_loss_plot_NormWider_NormDeeper.png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Teacher_loss_plot_NormWider_NormDeeper.png.png -------------------------------------------------------------------------------- /examples/plots/cifar/Teacher_loss_plot_noisyAndNormWider_NormDeeper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Teacher_loss_plot_noisyAndNormWider_NormDeeper.png -------------------------------------------------------------------------------- /examples/plots/cifar/Teacher_loss_plot_noisyAndNormWider_noisyAndNormDeeper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Teacher_loss_plot_noisyAndNormWider_noisyAndNormDeeper.png -------------------------------------------------------------------------------- /examples/plots/cifar/Wider_Deeper_teacher_accu_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Wider_Deeper_teacher_accu_plot.png -------------------------------------------------------------------------------- /examples/plots/cifar/Wider_Deeper_teacher_loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Wider_Deeper_teacher_loss_plot.png -------------------------------------------------------------------------------- /examples/plots/cifar/Wider_teacher_accu_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Wider_teacher_accu_plot.png -------------------------------------------------------------------------------- /examples/plots/cifar/Wider_teacher_loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erogol/Net2Net/fffc2b66df8a11577518f7f01287abe264ce30de/examples/plots/cifar/Wider_teacher_loss_plot.png -------------------------------------------------------------------------------- /examples/train_cifar10.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import sys 10 | sys.path.append('../') 11 | from net2net import wider, deeper 12 | import copy 13 | import numpy as np 14 | 15 | from utils import NLL_loss_instance 16 | from utils import PlotLearning 17 | 18 | 19 | # Training settings 20 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 21 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 22 | help='input batch size for training (default: 64)') 23 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 24 | help='input batch size for testing (default: 1000)') 25 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 26 | help='number of epochs to train (default: 10)') 27 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 28 | help='learning rate (default: 0.01)') 29 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 30 | help='SGD momentum (default: 0.9)') 31 | parser.add_argument('--no-cuda', action='store_true', default=False, 32 | help='disables CUDA training') 33 | parser.add_argument('--seed', type=int, default=1, metavar='S', 34 | help='random seed (default: 1)') 35 | parser.add_argument('--log-interval', type=int, default=1000, metavar='N', 36 | help='how many batches to wait before logging status') 37 | parser.add_argument('--noise', type=int, default=1, 38 | help='noise or no noise 0-1') 39 | parser.add_argument('--weight_norm', type=int, default=1, 40 | help='norm or no weight norm 0-1') 41 | args = parser.parse_args() 42 | args.cuda = not args.no_cuda and torch.cuda.is_available() 43 | 44 | torch.manual_seed(args.seed) 45 | if args.cuda: 46 | torch.cuda.manual_seed(args.seed) 47 | 48 | train_transform = transforms.Compose( 49 | [ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 52 | 53 | test_transform = transforms.Compose( 54 | [transforms.ToTensor(), 55 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 56 | 57 | 58 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} 59 | train_loader = torch.utils.data.DataLoader( 60 | datasets.CIFAR10('./data', train=True, download=True, transform=train_transform), 61 | batch_size=args.batch_size, shuffle=True, **kwargs) 62 | test_loader = torch.utils.data.DataLoader( 63 | datasets.CIFAR10('./data', train=False, transform=test_transform), 64 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 65 | 66 | 67 | class Net(nn.Module): 68 | def __init__(self): 69 | super(Net, self).__init__() 70 | self.conv1 = nn.Conv2d(3, 8, 3, padding=1) 71 | self.bn1 = nn.BatchNorm2d(8) 72 | self.pool1 = nn.MaxPool2d(3, 2) 73 | self.conv2 = nn.Conv2d(8, 16, 3, padding=1) 74 | self.bn2 = nn.BatchNorm2d(16) 75 | self.pool2 = nn.MaxPool2d(3, 2) 76 | self.conv3 = nn.Conv2d(16, 32, 3, padding=1) 77 | self.bn3 = nn.BatchNorm2d(32) 78 | self.pool3 = nn.AvgPool2d(5, 1) 79 | self.fc1 = nn.Linear(32 * 3 * 3, 10) 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | m.weight.data.normal_(0, np.sqrt(2. / n)) 84 | m.bias.data.fill_(0.0) 85 | if isinstance(m, nn.Linear): 86 | m.bias.data.fill_(0.0) 87 | 88 | def forward(self, x): 89 | try: 90 | x = self.conv1(x) 91 | x = self.bn1(x) 92 | x = F.relu(x) 93 | x = self.pool1(x) 94 | x = self.conv2(x) 95 | x = self.bn2(x) 96 | x = F.relu(x) 97 | x = self.pool2(x) 98 | x = self.conv3(x) 99 | x = self.bn3(x) 100 | x = F.relu(x) 101 | x = self.pool3(x) 102 | x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) 103 | x = self.fc1(x) 104 | return F.log_softmax(x) 105 | except RuntimeError: 106 | print(x.size()) 107 | 108 | def net2net_wider(self): 109 | self.conv1, self.conv2, _ = wider(self.conv1, self.conv2, 12, 110 | self.bn1, noise=args.noise) 111 | self.conv2, self.conv3, _ = wider(self.conv2, self.conv3, 24, 112 | self.bn2, noise=args.noise) 113 | self.conv3, self.fc1, _ = wider(self.conv3, self.fc1, 48, 114 | self.bn3, noise=args.noise) 115 | print(self) 116 | 117 | def net2net_deeper(self): 118 | s = deeper(self.conv1, nn.ReLU, bnorm_flag=True, weight_norm=args.weight_norm, noise=args.noise) 119 | self.conv1 = s 120 | s = deeper(self.conv2, nn.ReLU, bnorm_flag=True, weight_norm=args.weight_norm, noise=args.noise) 121 | self.conv2 = s 122 | s = deeper(self.conv3, nn.ReLU, bnorm_flag=True, weight_norm=args.weight_norm, noise=args.noise) 123 | self.conv3 = s 124 | print(self) 125 | 126 | def net2net_deeper_nononline(self): 127 | s = deeper(self.conv1, None, bnorm_flag=True, weight_norm=args.weight_norm, noise=args.noise) 128 | self.conv1 = s 129 | s = deeper(self.conv2, None, bnorm_flag=True, weight_norm=args.weight_norm, noise=args.noise) 130 | self.conv2 = s 131 | s = deeper(self.conv3, None, bnorm_flag=True, weight_norm=args.weight_norm, noise=args.noise) 132 | self.conv3 = s 133 | print(self) 134 | 135 | def define_wider(self): 136 | self.conv1 = nn.Conv2d(3, 12, kernel_size=3, padding=1) 137 | self.bn1 = nn.BatchNorm2d(12) 138 | self.conv2 = nn.Conv2d(12, 24, kernel_size=3, padding=1) 139 | self.bn2 = nn.BatchNorm2d(24) 140 | self.conv3 = nn.Conv2d(24, 48, kernel_size=3, padding=1) 141 | self.bn3 = nn.BatchNorm2d(48) 142 | self.fc1 = nn.Linear(48*3*3, 10) 143 | 144 | def define_wider_deeper(self): 145 | self.conv1 = nn.Sequential(nn.Conv2d(3, 12, kernel_size=3, padding=1), 146 | nn.BatchNorm2d(12), 147 | nn.ReLU(), 148 | nn.Conv2d(12, 12, kernel_size=3, padding=1)) 149 | self.bn1 = nn.BatchNorm2d(12) 150 | self.conv2 = nn.Sequential(nn.Conv2d(12, 24, kernel_size=3, padding=1), 151 | nn.BatchNorm2d(24), 152 | nn.ReLU(), 153 | nn.Conv2d(24, 24, kernel_size=3, padding=1)) 154 | self.bn2 = nn.BatchNorm2d(24) 155 | self.conv3 = nn.Sequential(nn.Conv2d(24, 48, kernel_size=3, padding=1), 156 | nn.BatchNorm2d(48), 157 | nn.ReLU(), 158 | nn.Conv2d(48, 48, kernel_size=3, padding=1)) 159 | self.bn3 = nn.BatchNorm2d(48) 160 | self.fc1 = nn.Linear(48*3*3, 10) 161 | print(self) 162 | 163 | 164 | def net2net_deeper_recursive(model): 165 | """ 166 | Apply deeper operator recursively any conv layer. 167 | """ 168 | for name, module in model._modules.items(): 169 | if isinstance(module, nn.Conv2d): 170 | s = deeper(module, nn.ReLU, bnorm_flag=False) 171 | model._modules[name] = s 172 | elif isinstance(module, nn.Sequential): 173 | module = net2net_deeper_recursive(module) 174 | model._modules[name] = module 175 | return model 176 | 177 | 178 | def train(epoch): 179 | model.train() 180 | avg_loss = 0 181 | avg_accu = 0 182 | for batch_idx, (data, target) in enumerate(train_loader): 183 | if args.cuda: 184 | data, target = data.cuda(), target.cuda() 185 | data, target = Variable(data), Variable(target) 186 | optimizer.zero_grad() 187 | output = model(data) 188 | loss = criterion(output, target) 189 | loss.backward() 190 | optimizer.step() 191 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 192 | avg_accu += pred.eq(target.data.view_as(pred)).cpu().sum() 193 | avg_loss += loss.data[0] 194 | if batch_idx % args.log_interval == 0: 195 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 196 | epoch, batch_idx * len(data), len(train_loader.dataset), 197 | 100. * batch_idx / len(train_loader), loss.data[0])) 198 | avg_loss /= batch_idx + 1 199 | avg_accu = avg_accu / len(train_loader.dataset) 200 | return avg_accu, avg_loss 201 | 202 | 203 | def test(): 204 | model.eval() 205 | test_loss = 0 206 | correct = 0 207 | for data, target in test_loader: 208 | if args.cuda: 209 | data, target = data.cuda(), target.cuda() 210 | data, target = Variable(data, volatile=True), Variable(target) 211 | output = model(data) 212 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 213 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 214 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 215 | 216 | test_loss /= len(test_loader.dataset) 217 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 218 | test_loss, correct, len(test_loader.dataset), 219 | 100. * correct / len(test_loader.dataset))) 220 | return correct / len(test_loader.dataset), test_loss 221 | 222 | 223 | def run_training(model, run_name, epochs, plot=None): 224 | global optimizer 225 | model.cuda() 226 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 227 | if plot is None: 228 | plot = PlotLearning('./plots/cifar/', 10, prefix=run_name) 229 | for epoch in range(1, epochs + 1): 230 | accu_train, loss_train = train(epoch) 231 | accu_test, loss_test = test() 232 | logs = {} 233 | logs['acc'] = accu_train 234 | logs['val_acc'] = accu_test 235 | logs['loss'] = loss_train 236 | logs['val_loss'] = loss_test 237 | plot.plot(logs) 238 | return plot 239 | 240 | 241 | if __name__ == "__main__": 242 | start_t = time.time() 243 | print("\n\n > Teacher training ... ") 244 | model = Net() 245 | model.cuda() 246 | criterion = nn.NLLLoss() 247 | plot = run_training(model, 'Teacher_', (args.epochs + 1) // 3) 248 | 249 | # wider student training 250 | print("\n\n > Wider Student training ... ") 251 | model_ = Net() 252 | model_ = copy.deepcopy(model) 253 | 254 | del model 255 | model = model_ 256 | model.net2net_wider() 257 | plot = run_training(model, 'Wider_student_', (args.epochs + 1) // 3, plot) 258 | 259 | # wider + deeper student training 260 | print("\n\n > Wider+Deeper Student training ... ") 261 | model_ = Net() 262 | model_.net2net_wider() 263 | model_ = copy.deepcopy(model) 264 | 265 | del model 266 | model = model_ 267 | model.net2net_deeper_nononline() 268 | run_training(model, 'WiderDeeper_student_', (args.epochs + 1) // 3, plot) 269 | print(" >> Time tkaen by whole net2net training {}".format(time.time() - start_t)) 270 | 271 | # wider teacher training 272 | start_t = time.time() 273 | print("\n\n > Wider teacher training ... ") 274 | model_ = Net() 275 | 276 | del model 277 | model = model_ 278 | model.define_wider() 279 | model.cuda() 280 | run_training(model, 'Wider_teacher_', args.epochs + 1) 281 | print(" >> Time taken {}".format(time.time() - start_t)) 282 | 283 | # wider deeper teacher training 284 | print("\n\n > Wider+Deeper teacher training ... ") 285 | start_t = time.time() 286 | model_ = Net() 287 | 288 | del model 289 | model = model_ 290 | model.define_wider_deeper() 291 | run_training(model, 'Wider_Deeper_teacher_', args.epochs + 1) 292 | print(" >> Time taken {}".format(time.time() - start_t)) 293 | -------------------------------------------------------------------------------- /examples/train_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import sys 10 | sys.path.append('../') 11 | from net2net import * 12 | import copy 13 | 14 | 15 | # Training settings 16 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 17 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 18 | help='input batch size for training (default: 64)') 19 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 20 | help='input batch size for testing (default: 1000)') 21 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 22 | help='number of epochs to train (default: 10)') 23 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 24 | help='learning rate (default: 0.01)') 25 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 26 | help='SGD momentum (default: 0.5)') 27 | parser.add_argument('--no-cuda', action='store_true', default=False, 28 | help='disables CUDA training') 29 | parser.add_argument('--seed', type=int, default=1, metavar='S', 30 | help='random seed (default: 1)') 31 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 32 | help='how many batches to wait before logging status') 33 | args = parser.parse_args() 34 | args.cuda = not args.no_cuda and torch.cuda.is_available() 35 | 36 | torch.manual_seed(args.seed) 37 | if args.cuda: 38 | torch.cuda.manual_seed(args.seed) 39 | 40 | 41 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 42 | train_loader = torch.utils.data.DataLoader( 43 | datasets.MNIST('./data', train=True, download=True, 44 | transform=transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.1307,), (0.3081,)) 47 | ])), 48 | batch_size=args.batch_size, shuffle=True, **kwargs) 49 | test_loader = torch.utils.data.DataLoader( 50 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.1307,), (0.3081,)) 53 | ])), 54 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 55 | 56 | 57 | class Net(nn.Module): 58 | def __init__(self): 59 | super(Net, self).__init__() 60 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 61 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 62 | self.conv2_drop = nn.Dropout2d() 63 | self.fc1 = nn.Linear(320, 50) 64 | self.fc2 = nn.Linear(50, 10) 65 | 66 | def forward(self, x): 67 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 68 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 69 | x = x.view(-1, x.size(1)*x.size(2)*x.size(3)) 70 | x = F.relu(self.fc1(x)) 71 | x = F.dropout(x, training=self.training) 72 | x = self.fc2(x) 73 | return F.log_softmax(x) 74 | 75 | def net2net_wider(self): 76 | self.conv1, self.conv2, _ = wider(self.conv1, self.conv2, 15, noise_var=0.01) 77 | self.conv2, self.fc1, _ = wider(self.conv2, self.fc1, 30, noise_var=0.01) 78 | print(self) 79 | 80 | def net2net_deeper(self): 81 | s = deeper(self.conv1, nn.ReLU, bnorm_flag=False) 82 | self.conv1 = s 83 | s = deeper(self.conv2, nn.ReLU, bnorm_flag=False) 84 | self.conv2 = s 85 | print(self) 86 | 87 | def define_wider(self): 88 | self.conv1 = nn.Conv2d(1, 15, kernel_size=5) 89 | self.conv2 = nn.Conv2d(15, 30, kernel_size=5) 90 | self.fc1 = nn.Linear(480, 50) 91 | 92 | def define_wider_deeper(self): 93 | self.conv1 = nn.Sequential(nn.Conv2d(1, 15, kernel_size=5), 94 | nn.ReLU(), 95 | nn.Conv2d(15, 15, kernel_size=5, padding=2)) 96 | self.conv2 = nn.Sequential(nn.Conv2d(15, 30, kernel_size=5), 97 | nn.ReLU(), 98 | nn.Conv2d(30, 30, kernel_size=5, padding=2)) 99 | self.fc1 = nn.Linear(480, 50) 100 | 101 | 102 | model = Net() 103 | if args.cuda: 104 | model.cuda() 105 | 106 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 107 | 108 | 109 | def train(epoch): 110 | model.train() 111 | for batch_idx, (data, target) in enumerate(train_loader): 112 | if args.cuda: 113 | data, target = data.cuda(), target.cuda() 114 | data, target = Variable(data), Variable(target) 115 | optimizer.zero_grad() 116 | output = model(data) 117 | loss = F.nll_loss(output, target) 118 | loss.backward() 119 | optimizer.step() 120 | if batch_idx % args.log_interval == 0: 121 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 122 | epoch, batch_idx * len(data), len(train_loader.dataset), 123 | 100. * batch_idx / len(train_loader), loss.data[0])) 124 | 125 | 126 | def test(): 127 | model.eval() 128 | test_loss = 0 129 | correct = 0 130 | for data, target in test_loader: 131 | if args.cuda: 132 | data, target = data.cuda(), target.cuda() 133 | data, target = Variable(data, volatile=True), Variable(target) 134 | output = model(data) 135 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 136 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 137 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 138 | 139 | test_loss /= len(test_loader.dataset) 140 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 141 | test_loss, correct, len(test_loader.dataset), 142 | 100. * correct / len(test_loader.dataset))) 143 | return 100. * correct / len(test_loader.dataset) 144 | 145 | print("\n\n > Teacher training ... ") 146 | # treacher training 147 | for epoch in range(1, args.epochs + 1): 148 | train(epoch) 149 | teacher_accu = test() 150 | 151 | 152 | # wider student training 153 | print("\n\n > Wider Student training ... ") 154 | model_ = Net() 155 | model_ = copy.deepcopy(model) 156 | 157 | del model 158 | model = model_ 159 | model.net2net_wider() 160 | model.cuda() 161 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 162 | for epoch in range(1, args.epochs + 1): 163 | train(epoch) 164 | wider_accu = test() 165 | 166 | 167 | # wider + deeper student training 168 | print("\n\n > Wider+Deeper Student training ... ") 169 | model_ = Net() 170 | model_ = copy.deepcopy(model) 171 | 172 | del model 173 | model = model_ 174 | model.net2net_deeper() 175 | model.cuda() 176 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 177 | for epoch in range(1, args.epochs + 1): 178 | train(epoch) 179 | deeper_accu = test() 180 | 181 | 182 | # wider teacher training 183 | print("\n\n > Wider teacher training ... ") 184 | model_ = Net() 185 | 186 | del model 187 | model = model_ 188 | model.define_wider() 189 | model.cuda() 190 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 191 | for epoch in range(1, 2*(args.epochs) + 1): 192 | train(epoch) 193 | wider_teacher_accu = test() 194 | 195 | 196 | # wider deeper teacher training 197 | print("\n\n > Wider+Deeper teacher training ... ") 198 | model_ = Net() 199 | 200 | del model 201 | model = model_ 202 | model.define_wider_deeper() 203 | model.cuda() 204 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 205 | for epoch in range(1, 3*(args.epochs) + 1): 206 | train(epoch) 207 | wider_deeper_teacher_accu = test() 208 | 209 | 210 | print(" -> Teacher:\t{}".format(teacher_accu)) 211 | print(" -> Wider model:\t{}".format(wider_accu)) 212 | print(" -> Deeper-Wider model:\t{}".format(deeper_accu)) 213 | print(" -> Wider teacher:\t{}".format(wider_teacher_accu)) 214 | print(" -> Deeper-Wider teacher:\t{}".format(wider_deeper_teacher_accu)) 215 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as th 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pylab as plt 7 | 8 | 9 | class NLL_loss_instance(th.nn.NLLLoss): 10 | 11 | def __init__(self, ratio): 12 | super(NLL_loss_instance, self).__init__(None, True) 13 | self.ratio = ratio 14 | 15 | def forward(self, x, y, ratio=None): 16 | if ratio is not None: 17 | self.ratio = ratio 18 | num_inst = x.size(0) 19 | num_hns = int(self.ratio * num_inst) 20 | x_ = x.clone() 21 | for idx, label in enumerate(y.data): 22 | x_.data[idx, label] = 0.0 23 | loss_incs = -x_.sum(1) 24 | _, idxs = loss_incs.topk(num_hns) 25 | x_hn = x.index_select(0, idxs) 26 | y_hn = y.index_select(0, idxs) 27 | return th.nn.functional.nll_loss(x_hn, y_hn) 28 | 29 | 30 | class PlotLearning(object): 31 | def __init__(self, save_path, num_classes, prefix=''): 32 | self.accuracy = [] 33 | self.val_accuracy = [] 34 | self.losses = [] 35 | self.val_losses = [] 36 | self.save_path_loss = os.path.join(save_path, prefix+'loss_plot.png') 37 | self.save_path_accu = os.path.join(save_path, prefix+'accu_plot.png') 38 | self.init_loss = -np.log(1.0 / num_classes) 39 | 40 | def plot(self, logs): 41 | self.accuracy.append(logs.get('acc')) 42 | self.val_accuracy.append(logs.get('val_acc')) 43 | self.losses.append(logs.get('loss')) 44 | self.val_losses.append(logs.get('val_loss')) 45 | 46 | best_val_acc = max(self.val_accuracy) 47 | best_train_acc = max(self.accuracy) 48 | best_val_epoch = self.val_accuracy.index(best_val_acc) 49 | best_train_epoch = self.accuracy.index(best_train_acc) 50 | 51 | plt.figure(1) 52 | plt.gca().cla() 53 | plt.ylim(0, 1) 54 | plt.plot(self.accuracy, label='train') 55 | plt.plot(self.val_accuracy, label='valid') 56 | plt.title("best_val@{0:}-{1:.2f}, best_train@{2:}-{3:.2f}".format( 57 | best_val_epoch, best_val_acc, best_train_epoch, best_train_acc)) 58 | plt.legend() 59 | plt.savefig(self.save_path_accu) 60 | 61 | best_val_loss = min(self.val_losses) 62 | best_train_loss = min(self.losses) 63 | best_val_epoch = self.val_losses.index(best_val_loss) 64 | best_train_epoch = self.losses.index(best_train_loss) 65 | 66 | plt.figure(2) 67 | plt.gca().cla() 68 | plt.ylim(0, self.init_loss) 69 | plt.plot(self.losses, label='train') 70 | plt.plot(self.val_losses, label='valid') 71 | plt.title("best_val@{0:}-{1:.2f}, best_train@{2:}-{3:.2f}".format( 72 | best_val_epoch, best_val_loss, best_train_epoch, best_train_loss)) 73 | plt.legend() 74 | plt.savefig(self.save_path_loss) 75 | -------------------------------------------------------------------------------- /net2net.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from collections import Counter 4 | 5 | 6 | def wider(m1, m2, new_width, bnorm=None, out_size=None, noise=True, 7 | random_init=True, weight_norm=True): 8 | """ 9 | Convert m1 layer to its wider version by adapthing next weight layer and 10 | possible batch norm layer in btw. 11 | Args: 12 | m1 - module to be wider 13 | m2 - follwing module to be adapted to m1 14 | new_width - new width for m1. 15 | bn (optional) - batch norm layer, if there is btw m1 and m2 16 | out_size (list, optional) - necessary for m1 == conv3d and m2 == linear. It 17 | is 3rd dim size of the output feature map of m1. Used to compute 18 | the matching Linear layer size 19 | noise (bool, True) - add a slight noise to break symmetry btw weights. 20 | random_init (optional, True) - if True, new weights are initialized 21 | randomly. 22 | weight_norm (optional, True) - If True, weights are normalized before 23 | transfering. 24 | """ 25 | 26 | w1 = m1.weight.data 27 | w2 = m2.weight.data 28 | b1 = m1.bias.data 29 | 30 | if "Conv" in m1.__class__.__name__ or "Linear" in m1.__class__.__name__: 31 | # Convert Linear layers to Conv if linear layer follows target layer 32 | if "Conv" in m1.__class__.__name__ and "Linear" in m2.__class__.__name__: 33 | assert w2.size(1) % w1.size(0) == 0, "Linear units need to be multiple" 34 | if w1.dim() == 4: 35 | factor = int(np.sqrt(w2.size(1) // w1.size(0))) 36 | w2 = w2.view(w2.size(0), w2.size(1)//factor**2, factor, factor) 37 | elif w1.dim() == 5: 38 | assert out_size is not None,\ 39 | "For conv3d -> linear out_size is necessary" 40 | factor = out_size[0] * out_size[1] * out_size[2] 41 | w2 = w2.view(w2.size(0), w2.size(1)//factor, out_size[0], 42 | out_size[1], out_size[2]) 43 | else: 44 | assert w1.size(0) == w2.size(1), "Module weights are not compatible" 45 | assert new_width > w1.size(0), "New size should be larger" 46 | 47 | old_width = w1.size(0) 48 | nw1 = m1.weight.data.clone() 49 | nw2 = w2.clone() 50 | 51 | if nw1.dim() == 4: 52 | nw1.resize_(new_width, nw1.size(1), nw1.size(2), nw1.size(3)) 53 | nw2.resize_(nw2.size(0), new_width, nw2.size(2), nw2.size(3)) 54 | elif nw1.dim() == 5: 55 | nw1.resize_(new_width, nw1.size(1), nw1.size(2), nw1.size(3), nw1.size(4)) 56 | nw2.resize_(nw2.size(0), new_width, nw2.size(2), nw2.size(3), nw2.size(4)) 57 | else: 58 | nw1.resize_(new_width, nw1.size(1)) 59 | nw2.resize_(nw2.size(0), new_width) 60 | 61 | if b1 is not None: 62 | nb1 = m1.bias.data.clone() 63 | nb1.resize_(new_width) 64 | 65 | if bnorm is not None: 66 | nrunning_mean = bnorm.running_mean.clone().resize_(new_width) 67 | nrunning_var = bnorm.running_var.clone().resize_(new_width) 68 | if bnorm.affine: 69 | nweight = bnorm.weight.data.clone().resize_(new_width) 70 | nbias = bnorm.bias.data.clone().resize_(new_width) 71 | 72 | w2 = w2.transpose(0, 1) 73 | nw2 = nw2.transpose(0, 1) 74 | 75 | nw1.narrow(0, 0, old_width).copy_(w1) 76 | nw2.narrow(0, 0, old_width).copy_(w2) 77 | nb1.narrow(0, 0, old_width).copy_(b1) 78 | 79 | if bnorm is not None: 80 | nrunning_var.narrow(0, 0, old_width).copy_(bnorm.running_var) 81 | nrunning_mean.narrow(0, 0, old_width).copy_(bnorm.running_mean) 82 | if bnorm.affine: 83 | nweight.narrow(0, 0, old_width).copy_(bnorm.weight.data) 84 | nbias.narrow(0, 0, old_width).copy_(bnorm.bias.data) 85 | 86 | # TEST:normalize weights 87 | if weight_norm: 88 | for i in range(old_width): 89 | norm = w1.select(0, i).norm() 90 | w1.select(0, i).div_(norm) 91 | 92 | # select weights randomly 93 | tracking = dict() 94 | for i in range(old_width, new_width): 95 | idx = np.random.randint(0, old_width) 96 | try: 97 | tracking[idx].append(i) 98 | except: 99 | tracking[idx] = [idx] 100 | tracking[idx].append(i) 101 | 102 | # TEST:random init for new units 103 | if random_init: 104 | n = m1.kernel_size[0] * m1.kernel_size[1] * m1.out_channels 105 | if m2.weight.dim() == 4: 106 | n2 = m2.kernel_size[0] * m2.kernel_size[1] * m2.out_channels 107 | elif m2.weight.dim() == 5: 108 | n2 = m2.kernel_size[0] * m2.kernel_size[1] * m2.kernel_size[2] * m2.out_channels 109 | elif m2.weight.dim() == 2: 110 | n2 = m2.out_features * m2.in_features 111 | nw1.select(0, i).normal_(0, np.sqrt(2./n)) 112 | nw2.select(0, i).normal_(0, np.sqrt(2./n2)) 113 | else: 114 | nw1.select(0, i).copy_(w1.select(0, idx).clone()) 115 | nw2.select(0, i).copy_(w2.select(0, idx).clone()) 116 | nb1[i] = b1[idx] 117 | 118 | if bnorm is not None: 119 | nrunning_mean[i] = bnorm.running_mean[idx] 120 | nrunning_var[i] = bnorm.running_var[idx] 121 | if bnorm.affine: 122 | nweight[i] = bnorm.weight.data[idx] 123 | nbias[i] = bnorm.bias.data[idx] 124 | bnorm.num_features = new_width 125 | 126 | if not random_init: 127 | for idx, d in tracking.items(): 128 | for item in d: 129 | nw2[item].div_(len(d)) 130 | 131 | w2.transpose_(0, 1) 132 | nw2.transpose_(0, 1) 133 | 134 | m1.out_channels = new_width 135 | m2.in_channels = new_width 136 | 137 | if noise: 138 | noise = np.random.normal(scale=5e-2 * nw1.std(), 139 | size=list(nw1.size())) 140 | nw1 += th.FloatTensor(noise).type_as(nw1) 141 | 142 | m1.weight.data = nw1 143 | 144 | if "Conv" in m1.__class__.__name__ and "Linear" in m2.__class__.__name__: 145 | if w1.dim() == 4: 146 | m2.weight.data = nw2.view(m2.weight.size(0), new_width*factor**2) 147 | m2.in_features = new_width*factor**2 148 | elif w2.dim() == 5: 149 | m2.weight.data = nw2.view(m2.weight.size(0), new_width*factor) 150 | m2.in_features = new_width*factor 151 | else: 152 | m2.weight.data = nw2 153 | 154 | m1.bias.data = nb1 155 | 156 | if bnorm is not None: 157 | bnorm.running_var = nrunning_var 158 | bnorm.running_mean = nrunning_mean 159 | if bnorm.affine: 160 | bnorm.weight.data = nweight 161 | bnorm.bias.data = nbias 162 | return m1, m2, bnorm 163 | 164 | 165 | # TODO: Consider adding noise to new layer as wider operator. 166 | def deeper(m, nonlin, bnorm_flag=False, weight_norm=True, noise=True): 167 | """ 168 | Deeper operator adding a new layer on topf of the given layer. 169 | Args: 170 | m (module) - module to add a new layer onto. 171 | nonlin (module) - non-linearity to be used for the new layer. 172 | bnorm_flag (bool, False) - whether add a batch normalization btw. 173 | weight_norm (bool, True) - if True, normalize weights of m before 174 | adding a new layer. 175 | noise (bool, True) - if True, add noise to the new layer weights. 176 | """ 177 | 178 | if "Linear" in m.__class__.__name__: 179 | m2 = th.nn.Linear(m.out_features, m.out_features) 180 | m2.weight.data.copy_(th.eye(m.out_features)) 181 | m2.bias.data.zero_() 182 | 183 | if bnorm_flag: 184 | bnorm = th.nn.BatchNorm1d(m2.weight.size(1)) 185 | bnorm.weight.data.fill_(1) 186 | bnorm.bias.data.fill_(0) 187 | bnorm.running_mean.fill_(0) 188 | bnorm.running_var.fill_(1) 189 | 190 | elif "Conv" in m.__class__.__name__: 191 | assert m.kernel_size[0] % 2 == 1, "Kernel size needs to be odd" 192 | 193 | if m.weight.dim() == 4: 194 | pad_h = int((m.kernel_size[0] - 1) / 2) 195 | # pad_w = pad_h 196 | m2 = th.nn.Conv2d(m.out_channels, m.out_channels, 197 | kernel_size=m.kernel_size, padding=pad_h) 198 | m2.weight.data.zero_() 199 | c = m.kernel_size[0] // 2 + 1 200 | 201 | elif m.weight.dim() == 5: 202 | pad_hw = int((m.kernel_size[1] - 1) / 2) # pad height and width 203 | pad_d = int((m.kernel_size[0] - 1) / 2) # pad depth 204 | m2 = th.nn.Conv3d(m.out_channels, 205 | m.out_channels, 206 | kernel_size=m.kernel_size, 207 | padding=(pad_d, pad_hw, pad_hw)) 208 | c_wh = m.kernel_size[1] // 2 + 1 209 | c_d = m.kernel_size[0] // 2 + 1 210 | 211 | restore = False 212 | if m2.weight.dim() == 2: 213 | restore = True 214 | m2.weight.data = m2.weight.data.view(m2.weight.size(0), 215 | m2.in_channels, 216 | m2.kernel_size[0], 217 | m2.kernel_size[0]) 218 | 219 | if weight_norm: 220 | for i in range(m.out_channels): 221 | weight = m.weight.data 222 | norm = weight.select(0, i).norm() 223 | weight.div_(norm) 224 | m.weight.data = weight 225 | 226 | for i in range(0, m.out_channels): 227 | if m.weight.dim() == 4: 228 | m2.weight.data.narrow(0, i, 1).narrow(1, i, 1).narrow(2, c, 1).narrow(3, c, 1).fill_(1) 229 | elif m.weight.dim() == 5: 230 | m2.weight.data.narrow(0, i, 1).narrow(1, i, 1).narrow(2, c_d, 1).narrow(3, c_wh, 1).narrow(4, c_wh, 1).fill_(1) 231 | 232 | if noise: 233 | noise = np.random.normal(scale=5e-2 * m2.weight.data.std(), 234 | size=list(m2.weight.size())) 235 | m2.weight.data += th.FloatTensor(noise).type_as(m2.weight.data) 236 | 237 | if restore: 238 | m2.weight.data = m2.weight.data.view(m2.weight.size(0), 239 | m2.in_channels, 240 | m2.kernel_size[0], 241 | m2.kernel_size[0]) 242 | 243 | m2.bias.data.zero_() 244 | 245 | if bnorm_flag: 246 | if m.weight.dim() == 4: 247 | bnorm = th.nn.BatchNorm2d(m2.out_channels) 248 | elif m.weight.dim() == 5: 249 | bnorm = th.nn.BatchNorm3d(m2.out_channels) 250 | bnorm.weight.data.fill_(1) 251 | bnorm.bias.data.fill_(0) 252 | bnorm.running_mean.fill_(0) 253 | bnorm.running_var.fill_(1) 254 | 255 | else: 256 | raise RuntimeError("{} Module not supported".format(m.__class__.__name__)) 257 | 258 | s = th.nn.Sequential() 259 | s.add_module('conv', m) 260 | if bnorm_flag: 261 | s.add_module('bnorm', bnorm) 262 | if nonlin is not None: 263 | s.add_module('nonlin', nonlin()) 264 | s.add_module('conv_new', m2) 265 | 266 | return s 267 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from net2net import wider, deeper 6 | 7 | 8 | class Net(nn.Module): 9 | 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 13 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 14 | self.conv2_drop = nn.Dropout2d() 15 | self.fc1 = nn.Linear(320, 50) 16 | self.fc2 = nn.Linear(50, 10) 17 | 18 | 19 | def forward(self, x): 20 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 21 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 22 | x = x.view(-1, x.size(1)*x.size(2)*x.size(3)) 23 | x = F.relu(self.fc1(x)) 24 | x = F.dropout(x, training=self.training) 25 | x = self.fc2(x) 26 | return F.log_softmax(x) 27 | 28 | 29 | 30 | class Net3D(nn.Module): 31 | 32 | def __init__(self): 33 | super(Net3D, self).__init__() 34 | self.conv1 = nn.Conv3d(1, 10, kernel_size=5) 35 | self.conv2 = nn.Conv3d(10, 20, kernel_size=5) 36 | # self.conv2_drop = nn.Dropout2d() 37 | self.fc1 = nn.Linear(320, 50) 38 | self.fc2 = nn.Linear(50, 10) 39 | 40 | 41 | def forward(self, x): 42 | x = F.relu(F.max_pool3d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool3d(self.conv2(x), 2)) 44 | x = x.view(-1, x.size(1)*x.size(2)*x.size(3)*x.size(4)) 45 | x = F.relu(self.fc1(x)) 46 | # x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return F.log_softmax(x) 49 | 50 | 51 | 52 | class TestOperators(unittest.TestCase): 53 | 54 | 55 | def _create_net(self): 56 | return Net() 57 | 58 | 59 | def test_wider(self): 60 | net = self._create_net() 61 | inp = th.autograd.Variable(th.rand(32, 1, 28, 28)) 62 | 63 | net.eval() 64 | out = net(inp) 65 | 66 | conv1, conv2, _ = wider(net._modules['conv1'], 67 | net._modules['conv2'], 68 | 20, 69 | noise=False, 70 | random_init=False, 71 | weight_norm = False) 72 | 73 | net._modules['conv1'] = conv1 74 | net._modules['conv2'] = conv2 75 | 76 | conv2, fc1, _ = wider(net._modules['conv2'], 77 | net._modules['fc1'], 78 | 60, 79 | noise=False, 80 | random_init=False, 81 | weight_norm=False) 82 | net._modules['conv2'] = conv2 83 | net._modules['fc1'] = fc1 84 | 85 | net.eval() 86 | nout = net(inp) 87 | assert th.abs((out - nout).sum().data)[0] < 1e-1 88 | assert nout.size(0) == 32 and nout.size(1) == 10 89 | 90 | # Testing 3D layers 91 | net = Net3D() 92 | inp = th.autograd.Variable(th.rand(32, 1, 16, 28, 28)) 93 | 94 | net.eval() 95 | out = net(inp) 96 | 97 | conv1, conv2, _ = wider(net._modules['conv1'], 98 | net._modules['conv2'], 99 | 20, 100 | noise=False, 101 | random_init=False, 102 | weight_norm=False) 103 | 104 | net._modules['conv1'] = conv1 105 | net._modules['conv2'] = conv2 106 | 107 | conv2, fc1, _ = wider(net._modules['conv2'], 108 | net._modules['fc1'], 109 | 60, 110 | out_size=[1, 4, 4], 111 | noise=False, 112 | random_init=False, 113 | weight_norm=False) 114 | net._modules['conv2'] = conv2 115 | net._modules['fc1'] = fc1 116 | 117 | net.eval() 118 | nout = net(inp) 119 | assert th.abs((out - nout).sum().data)[0] < 1e-1 120 | assert nout.size(0) == 32 and nout.size(1) == 10 121 | 122 | # testing noise 123 | net = self._create_net() 124 | inp = th.autograd.Variable(th.rand(32, 1, 28, 28)) 125 | 126 | net.eval() 127 | out = net(inp) 128 | 129 | conv1, conv2, _ = wider(net._modules['conv1'], 130 | net._modules['conv2'], 131 | 20, 132 | noise=1) 133 | 134 | net._modules['conv1'] = conv1 135 | net._modules['conv2'] = conv2 136 | 137 | conv2, fc1, _ = wider(net._modules['conv2'], 138 | net._modules['fc1'], 139 | 60, 140 | noise=1) 141 | net._modules['conv2'] = conv2 142 | net._modules['fc1'] = fc1 143 | 144 | net.eval() 145 | nout = net(inp) 146 | assert th.abs((out - nout).sum().data)[0] > 1e-1 147 | assert nout.size(0) == 32 and nout.size(1) == 10 148 | 149 | 150 | def test_deeper(self): 151 | net = self._create_net() 152 | inp = th.autograd.Variable(th.rand(32, 1, 28, 28)) 153 | 154 | net.eval() 155 | out = net(inp) 156 | 157 | s = deeper(net._modules['conv1'], nn.ReLU, bnorm_flag=True, weight_norm=False, noise=False) 158 | net._modules['conv1'] = s 159 | 160 | s2 = deeper(net._modules['conv2'], nn.ReLU, bnorm_flag=True, weight_norm=False, noise=False) 161 | net._modules['conv2'] = s2 162 | 163 | s3 = deeper(net._modules['fc1'], nn.ReLU, bnorm_flag=True, weight_norm=False, noise=False) 164 | net._modules['fc1'] = s3 165 | 166 | net.eval() 167 | nout = net(inp) 168 | 169 | assert th.abs((out - nout).sum().data)[0] < 1e-1 170 | 171 | # test for 3D net 172 | net = Net3D() 173 | inp = th.autograd.Variable(th.rand(32, 1, 16, 28, 28)) 174 | 175 | net.eval() 176 | out = net(inp) 177 | 178 | s = deeper(net._modules['conv1'], nn.ReLU, bnorm_flag=False, weight_norm=False, noise=False) 179 | net._modules['conv1'] = s 180 | 181 | # s2 = deeper(net._modules['conv2'], nn.ReLU, bnorm_flag=False, weight_norm=False, noise=False) 182 | # net._modules['conv2'] = s2 183 | 184 | net.eval() 185 | nout = net(inp) 186 | 187 | assert th.abs((out - nout).sum().data)[0] < 1e-1, "New layer changes values by {}".format(th.abs(out - nout).sum().data[0]) 188 | --------------------------------------------------------------------------------