├── .gitignore ├── LICENSE ├── README.md ├── models ├── base.py └── resnet.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yi Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyTorch Implementation of 'Deep Networks with Stochastic Depth' [https://arxiv.org/pdf/1603.09382.pdf](https://arxiv.org/pdf/1603.09382.pdf) 2 | 3 | ## training 4 | ``` 5 | python train.py 6 | ``` 7 | 8 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import torchvision.models as torchmodels 5 | import re 6 | from torch.autograd import Variable 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import random 11 | import torch.nn.utils as torchutils 12 | from torch.nn import init, Parameter 13 | 14 | 15 | class Identity(nn.Module): 16 | def __init__(self): 17 | super(Identity, self).__init__() 18 | def forward(self, x): 19 | return x 20 | 21 | class Flatten(nn.Module): 22 | def __init__(self): 23 | super(Flatten, self).__init__() 24 | def forward(self, x): 25 | return x.view(x.size(0), -1) 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | "3x3 convolution with padding" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes, track_running_stats=False) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes, track_running_stats=False) 41 | 42 | def forward(self, x): 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = F.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | return out 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 63 | self.bn3 = nn.BatchNorm2d(planes * 4) 64 | self.relu = nn.ReLU(inplace=True) 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | return out 81 | 82 | class DownsampleB(nn.Module): 83 | 84 | def __init__(self, nIn, nOut, stride): 85 | super(DownsampleB, self).__init__() 86 | self.avg = nn.AvgPool2d(stride) 87 | self.expand_ratio = nOut // nIn 88 | 89 | def forward(self, x): 90 | x = self.avg(x) 91 | return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1) 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import math 5 | from models import base 6 | 7 | np.random.seed(2 ** 10) 8 | 9 | class FlatResNet(nn.Module): 10 | 11 | def seed(self, x): 12 | # x = self.relu(self.bn1(self.conv1(x))) -- CIFAR 13 | # x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) -- ImageNet 14 | raise NotImplementedError 15 | 16 | def forward(self, x, stochastic=False): 17 | if not stochastic: 18 | x = self.seed(x) 19 | 20 | for segment, num_blocks in enumerate(self.layer_config): 21 | for b in range(num_blocks): 22 | residual = self.ds[segment](x) if b == 0 else x 23 | x = F.relu(residual + self.blocks[segment][b](x)) 24 | 25 | x = self.avgpool(x) 26 | x = x.view(x.size(0), -1) 27 | x = self.fc(x) 28 | return x 29 | else: 30 | x = self.seed(x) 31 | step = 0.5 / sum(self.layer_config) 32 | p = 1.0 33 | for segment, num_blocks in enumerate(self.layer_config): 34 | for b in range(num_blocks): 35 | action = np.random.choice([1, 0], p=[p, 1-p]) 36 | p = p - step 37 | residual = self.ds[segment](x) if b==0 else x 38 | if action == 0: 39 | x = residual 40 | else: 41 | x = F.relu(residual + self.blocks[segment][b](x)) 42 | 43 | x = self.avgpool(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.fc(x) 46 | return x 47 | 48 | 49 | class FlatResNet32(FlatResNet): 50 | 51 | def __init__(self, block, layers, num_classes=10): 52 | super().__init__() 53 | 54 | self.inplanes = 16 55 | self.conv1 = base.conv3x3(3, 16) 56 | self.bn1 = nn.BatchNorm2d(16, track_running_stats=False) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.avgpool = nn.AvgPool2d(8) 59 | 60 | strides = [1, 2, 2] 61 | filt_sizes = [16, 32, 64] 62 | self.blocks, self.ds = [], [] 63 | for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)): 64 | blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride) 65 | self.blocks.append(nn.ModuleList(blocks)) 66 | self.ds.append(ds) 67 | 68 | self.blocks = nn.ModuleList(self.blocks) 69 | self.ds = nn.ModuleList(self.ds) 70 | self.fc = nn.Linear(64 * block.expansion, num_classes) 71 | self.fc_dim = 64 * block.expansion 72 | 73 | self.layer_config = layers 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | 83 | def seed(self, x): 84 | x = self.relu(self.bn1(self.conv1(x))) 85 | return x 86 | 87 | def _make_layer(self, block, planes, blocks, stride=1): 88 | 89 | downsample = nn.Sequential() 90 | if stride != 1 or self.inplanes != planes * block.expansion: 91 | downsample = base.DownsampleB(self.inplanes, planes * block.expansion, stride) 92 | 93 | layers = [block(self.inplanes, planes, stride)] 94 | self.inplanes = planes * block.expansion 95 | for i in range(1, blocks): 96 | layers.append(block(self.inplanes, planes, 1)) 97 | 98 | return layers, downsample 99 | 100 | # TODO: FlatResNet224 for ImageNet 101 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data as D 5 | import tqdm 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | import argparse 9 | from models import resnet, base 10 | import numpy as np 11 | import tensorboard_logger 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | 15 | cudnn.benchmark = True 16 | 17 | parser = argparse.ArgumentParser(description='Dynamic ResNet Training') 18 | parser.add_argument('--lr', type=float, default=.1, help='learning rate') 19 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 20 | parser.add_argument('--max_epochs', type=int, default=500, help='total epochs to run') 21 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 22 | args = parser.parse_args() 23 | 24 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | num_devices = torch.cuda.device_count() 26 | 27 | def train(epoch): 28 | rnet.train() 29 | 30 | total = 0 31 | correct = 0 32 | train_loss = 0 33 | total_batch = 0 34 | 35 | for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)): 36 | inputs, targets = inputs.to(device), targets.to(device) 37 | 38 | probs = rnet(inputs, True) 39 | optimizer.zero_grad() 40 | loss = criterion(probs, targets) 41 | loss.backward() 42 | optimizer.step() 43 | 44 | train_loss += loss.item() 45 | _, predicted = probs.max(1) 46 | total += targets.size(0) 47 | correct += predicted.eq(targets).sum().item() 48 | 49 | total_batch += 1 50 | 51 | print('E:%d Train Loss: %.3f Train Acc: %.3f LR %f' 52 | % (epoch, 53 | train_loss / total_batch, 54 | correct / total, 55 | optimizer.param_groups[0]['lr'])) 56 | 57 | tensorboard_logger.log_value('train_acc', correct/total, epoch) 58 | tensorboard_logger.log_value('train_loss', train_loss / total_batch, epoch) 59 | 60 | 61 | def test(epoch): 62 | global best_test_acc 63 | rnet.eval() 64 | 65 | total = 0 66 | correct = 0 67 | test_loss = 0 68 | total_batch = 0 69 | 70 | for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)): 71 | inputs, targets = inputs.to(device), targets.to(device) 72 | 73 | probs = rnet(inputs) 74 | loss = criterion(probs, targets) 75 | 76 | test_loss += loss.item() 77 | _, predicted = probs.max(1) 78 | total += targets.size(0) 79 | correct += predicted.eq(targets).sum().item() 80 | 81 | total_batch += 1 82 | 83 | print('E:%d Test Loss: %.3f Test Acc: %.3f' 84 | % (epoch, test_loss / total_batch, correct / total)) 85 | 86 | # save best model 87 | acc = 100.*correct/total 88 | 89 | if acc > best_test_acc: 90 | best_test_acc = acc 91 | print('saving best model...') 92 | state = { 93 | 'net': rnet.state_dict(), 94 | 'acc': acc, 95 | 'epoch': epoch, 96 | } 97 | torch.save(state, 'resnet110.t7') 98 | tensorboard_logger.log_value('test_acc', acc, epoch) 99 | tensorboard_logger.log_value('test_loss', test_loss/total_batch, epoch) 100 | 101 | 102 | def adjust_learning_rate(epoch, stage=[250, 375]): 103 | order = np.sum(epoch >= np.array(stage)) 104 | lr = args.lr * (0.1 ** order) 105 | for param_group in optimizer.param_groups: 106 | param_group['lr'] = lr 107 | 108 | 109 | def get_transforms(): 110 | train_tf = transforms.Compose([ 111 | transforms.RandomCrop(32, padding=4), 112 | transforms.RandomHorizontalFlip(), 113 | transforms.ToTensor(), 114 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 115 | ]) 116 | test_tf = transforms.Compose([ 117 | transforms.ToTensor(), 118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 119 | ]) 120 | return train_tf, test_tf 121 | 122 | 123 | 124 | # dataset and dataloader 125 | train_tf, test_tf = get_transforms() 126 | trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_tf) 127 | testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_tf) 128 | trainloader = D.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 129 | testloader = D.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) 130 | best_test_acc = 0.0 131 | 132 | # resnet110 133 | num_layers = 54 134 | rnet = resnet.FlatResNet32(base.BasicBlock, [18, 18, 18], num_classes=10) 135 | rnet.to(device) 136 | if num_devices > 1: 137 | print('paralleling for multiple GPUs...') 138 | rnet = nn.DataParallel(rnet) 139 | 140 | start_epoch = 0 141 | 142 | if args.resume: 143 | assert os.path.isfile('resnet110.t7'), 'Error: no check-point found!' 144 | ckpt = torch.load('resnet110.t7') 145 | rnet.load_state_dict(ckpt['net']) 146 | best_test_acc = ckpt['acc'] 147 | start_epoch = ckpt['epoch'] 148 | else: 149 | # He's init 150 | for module in rnet.modules(): 151 | if isinstance(module, nn.Conv2d): 152 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') 153 | 154 | # Loss Fn and Optimizer 155 | criterion = nn.CrossEntropyLoss() 156 | optimizer = optim.SGD(rnet.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) 157 | 158 | # logger 159 | tensorboard_logger.configure('./log/run1') 160 | 161 | for epoch in range(start_epoch+1, args.max_epochs): 162 | train(epoch) 163 | test(epoch) 164 | adjust_learning_rate(epoch) 165 | --------------------------------------------------------------------------------