├── tests ├── __init__.py ├── test_revnet.py └── common.py ├── .gitattributes ├── setup.py ├── param_count.py ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── LICENSE ├── README.md ├── revnet ├── __init__.py ├── resnet.py └── revnet.py ├── .gitignore └── train_cifar.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.dat filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup( 3 | name="revnet", 4 | version="0.2.0", 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /param_count.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def get_param_size(model): 5 | params = 0 6 | for p in model.parameters(): 7 | tmp = 1 8 | for x in p.size(): 9 | tmp *= x 10 | params += tmp 11 | return params 12 | 13 | 14 | print("revnet38: {}".format(get_param_size(models.revnet38()))) 15 | print("resnet32: {}".format(get_param_size(models.resnet32()))) 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Till Bungert 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # revnet 2 | 3 | [PyTorch](http://pytorch.org/) implementation of [the reversible residual 4 | network](https://arxiv.org/abs/1707.04585). 5 | 6 | 7 | ## Requirements 8 | 9 | The main requirement ist obviously [PyTorch](http://pytorch.org/). CUDA is 10 | strongly recommended. 11 | 12 | The training script requires [tqdm](https://pypi.python.org/pypi/tqdm) for the 13 | progress bar. 14 | 15 | The unittests require the TestCase implemented by the PyTorch project. The 16 | module can be downloaded 17 | [here](https://github.com/pytorch/pytorch/blob/master/test/common.py). 18 | 19 | 20 | ## Note 21 | 22 | The revnet models in this project tend to have exploding gradients. To 23 | counteract this, I used gradient norm clipping. For the experiments below you 24 | would call the following command: 25 | 26 | ``` 27 | python train_cifar.py --model revnet38 --clip 0.25 28 | ``` 29 | 30 | 31 | ## Results 32 | 33 | ### CIFAR-10 34 | 35 | | Model | Accuracy | Memory Usage | Params | 36 | |----------|----------|--------------|--------| 37 | | resnet32 | 92.02% | 1271 MB | 0.47 M | 38 | | revnet38 | 91.98% | 660 MB | 0.47 M | 39 | -------------------------------------------------------------------------------- /revnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet 2 | from .revnet import RevNet, RevBlock, RevBlockFunction, possible_downsample 3 | 4 | 5 | def resnet32(): 6 | model = ResNet( 7 | units=[5, 5, 5], 8 | filters=[16, 16, 32, 64], 9 | strides=[1, 2, 2], 10 | classes=10 11 | ) 12 | model.name = "resnet32" 13 | return model 14 | 15 | 16 | def resnet110(): 17 | model = ResNet( 18 | units=[18, 18, 18], 19 | filters=[16, 16, 32, 64], 20 | strides=[1, 2, 2], 21 | classes=10 22 | ) 23 | model.name = "resnet110" 24 | return model 25 | 26 | 27 | def revnet38(): 28 | model = RevNet( 29 | units=[3, 3, 3], 30 | filters=[32, 32, 64, 112], 31 | strides=[1, 2, 2], 32 | classes=10 33 | ) 34 | model.name = "revnet38" 35 | return model 36 | 37 | 38 | def revnet110(): 39 | model = RevNet( 40 | units=[9, 9, 9], 41 | filters=[32, 32, 64, 112], 42 | strides=[1, 2, 2], 43 | classes=10 44 | ) 45 | model.name = "revnet110" 46 | return model 47 | -------------------------------------------------------------------------------- /tests/test_revnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd 3 | from torch.autograd import Variable 4 | 5 | from revnet import RevBlock, RevBlockFunction 6 | 7 | import unittest 8 | 9 | from .common import TestCase 10 | 11 | 12 | class TestRevNet(TestCase): 13 | def setUp(self): 14 | self.x = torch.rand(4, 4, 4, 4) 15 | self.model = RevBlock(4, 4, []) 16 | parameters = list(self.model._parameters.values()) 17 | buffers = list(self.model._buffers.values()) 18 | # self.f_params = [Variable(x) for x in parameters[:8]] 19 | # self.g_params = [Variable(x) for x in parameters[8:16]] 20 | self.f_params = parameters[:8] 21 | self.g_params = parameters[8:16] 22 | self.f_buffs = buffers[:4] 23 | self.g_buffs = buffers[4:8] 24 | self.in_channels = self.model.in_channels 25 | self.out_channels = self.model.out_channels 26 | self.training = self.model.training 27 | self.stride = self.model.stride 28 | self.no_activation = self.model.no_activation 29 | 30 | def test_grad(self): 31 | pass 32 | 33 | def test_recreation(self): 34 | y = RevBlockFunction._forward( 35 | self.x, 36 | self.in_channels, 37 | self.out_channels, 38 | self.training, 39 | self.stride, 40 | self.f_params, self.f_buffs, 41 | self.g_params, self.g_buffs, 42 | no_activation=self.no_activation 43 | ) 44 | 45 | z = RevBlockFunction._backward( 46 | y.data, 47 | self.in_channels, 48 | self.out_channels, 49 | self.f_params, self.f_buffs, 50 | self.g_params, self.g_buffs, 51 | self.training, 52 | self.no_activation 53 | ) 54 | 55 | self.assertEqual(self.x, z) 56 | 57 | 58 | if __name__ == '__main__': 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /revnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | from .revnet import possible_downsample 8 | 9 | CUDA = torch.cuda.is_available() 10 | 11 | 12 | class Block(nn.Module): 13 | def __init__(self, in_channels, out_channels, stride=1, 14 | no_activation=False): 15 | super(Block, self).__init__() 16 | 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.stride = stride 20 | self.no_activation = no_activation 21 | 22 | self.stride = stride 23 | 24 | self.bn1 = nn.BatchNorm2d(in_channels) 25 | 26 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 27 | padding=1, stride=stride) 28 | 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 32 | padding=1) 33 | 34 | def forward(self, x): 35 | orig_x = x 36 | 37 | out = x 38 | 39 | if not self.no_activation: 40 | out = F.relu(self.bn1(out)) 41 | 42 | out = self.conv1(out) 43 | 44 | out = self.conv2(F.relu(self.bn2(out))) 45 | 46 | out += possible_downsample(orig_x, self.in_channels, 47 | self.out_channels, self.stride) 48 | 49 | return out 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | def __init__(self): 54 | pass 55 | 56 | def forward(self, x): 57 | pass 58 | 59 | 60 | class ResNet(nn.Module): 61 | def __init__(self, 62 | units, 63 | filters, 64 | strides, 65 | classes, 66 | bottleneck=False): 67 | """ 68 | Parameters 69 | ---------- 70 | 71 | units: list-like 72 | Number of residual units in each group 73 | 74 | filters: list-like 75 | Number of filters in each unit including the inputlayer, so it is 76 | one item longer than units 77 | 78 | strides: list-like 79 | Strides to use for the first units in each group, same length as 80 | units 81 | 82 | bottleneck: boolean 83 | Wether to use the bottleneck residual or the basic residual 84 | """ 85 | super(ResNet, self).__init__() 86 | self.name = self.__class__.__name__ 87 | 88 | if bottleneck: 89 | self.Residual = Bottleneck 90 | else: 91 | self.Residual = Block 92 | 93 | self.layers = nn.ModuleList() 94 | 95 | # Input layers 96 | self.layers.append(nn.Conv2d(3, filters[0], 3, padding=1)) 97 | self.layers.append(nn.BatchNorm2d(filters[0])) 98 | self.layers.append(nn.ReLU()) 99 | 100 | for i, group in enumerate(units): 101 | self.layers.append(self.Residual(filters[i], filters[i + 1], 102 | stride=strides[i], 103 | no_activation=True)) 104 | 105 | for unit in range(1, group): 106 | self.layers.append(self.Residual(filters[i + 1], 107 | filters[i + 1])) 108 | 109 | self.bn_last = nn.BatchNorm2d(filters[-1]) 110 | 111 | self.fc = nn.Linear(filters[-1], classes) 112 | 113 | def forward(self, x): 114 | for layer in self.layers: 115 | x = layer(x) 116 | 117 | x = F.relu(self.bn_last(x)) 118 | x = F.avg_pool2d(x, x.size(2)) 119 | x = x.view(x.size(0), -1) 120 | x = self.fc(x) 121 | 122 | return x 123 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /experiments/ 3 | 4 | # Created by https://www.gitignore.io/api/c,vim,git,c++,linux,python,windows 5 | 6 | ### C ### 7 | # Prerequisites 8 | *.d 9 | 10 | # Object files 11 | *.o 12 | *.ko 13 | *.obj 14 | *.elf 15 | 16 | # Linker output 17 | *.ilk 18 | *.map 19 | *.exp 20 | 21 | # Precompiled Headers 22 | *.gch 23 | *.pch 24 | 25 | # Libraries 26 | *.lib 27 | *.a 28 | *.la 29 | *.lo 30 | 31 | # Shared objects (inc. Windows DLLs) 32 | *.dll 33 | *.so 34 | *.so.* 35 | *.dylib 36 | 37 | # Executables 38 | *.exe 39 | *.out 40 | *.app 41 | *.i*86 42 | *.x86_64 43 | *.hex 44 | 45 | # Debug files 46 | *.dSYM/ 47 | *.su 48 | *.idb 49 | *.pdb 50 | 51 | # Kernel Module Compile Results 52 | *.mod* 53 | *.cmd 54 | .tmp_versions/ 55 | modules.order 56 | Module.symvers 57 | Mkfile.old 58 | dkms.conf 59 | 60 | ### C++ ### 61 | # Prerequisites 62 | 63 | # Compiled Object files 64 | *.slo 65 | 66 | # Precompiled Headers 67 | 68 | # Compiled Dynamic libraries 69 | 70 | # Fortran module files 71 | *.mod 72 | *.smod 73 | 74 | # Compiled Static libraries 75 | *.lai 76 | 77 | # Executables 78 | 79 | ### Git ### 80 | *.orig 81 | 82 | ### Linux ### 83 | *~ 84 | 85 | # temporary files which can be created if a process still has a handle open of a deleted file 86 | .fuse_hidden* 87 | 88 | # KDE directory preferences 89 | .directory 90 | 91 | # Linux trash folder which might appear on any partition or disk 92 | .Trash-* 93 | 94 | # .nfs files are created when an open file is removed but is still being accessed 95 | .nfs* 96 | 97 | ### Python ### 98 | # Byte-compiled / optimized / DLL files 99 | __pycache__/ 100 | *.py[cod] 101 | *$py.class 102 | 103 | # C extensions 104 | 105 | # Distribution / packaging 106 | .Python 107 | env/ 108 | build/ 109 | develop-eggs/ 110 | dist/ 111 | downloads/ 112 | eggs/ 113 | .eggs/ 114 | lib/ 115 | lib64/ 116 | parts/ 117 | sdist/ 118 | var/ 119 | wheels/ 120 | *.egg-info/ 121 | .installed.cfg 122 | *.egg 123 | 124 | # PyInstaller 125 | # Usually these files are written by a python script from a template 126 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 127 | *.manifest 128 | *.spec 129 | 130 | # Installer logs 131 | pip-log.txt 132 | pip-delete-this-directory.txt 133 | 134 | # Unit test / coverage reports 135 | htmlcov/ 136 | .tox/ 137 | .coverage 138 | .coverage.* 139 | .cache 140 | nosetests.xml 141 | coverage.xml 142 | *,cover 143 | .hypothesis/ 144 | 145 | # Translations 146 | *.mo 147 | *.pot 148 | 149 | # Django stuff: 150 | *.log 151 | local_settings.py 152 | 153 | # Flask stuff: 154 | instance/ 155 | .webassets-cache 156 | 157 | # Scrapy stuff: 158 | .scrapy 159 | 160 | # Sphinx documentation 161 | docs/_build/ 162 | 163 | # PyBuilder 164 | target/ 165 | 166 | # Jupyter Notebook 167 | .ipynb_checkpoints 168 | 169 | # pyenv 170 | .python-version 171 | 172 | # celery beat schedule file 173 | celerybeat-schedule 174 | 175 | # SageMath parsed files 176 | *.sage.py 177 | 178 | # dotenv 179 | .env 180 | 181 | # virtualenv 182 | .venv 183 | venv/ 184 | ENV/ 185 | 186 | # Spyder project settings 187 | .spyderproject 188 | .spyproject 189 | 190 | # Rope project settings 191 | .ropeproject 192 | 193 | # mkdocs documentation 194 | /site 195 | 196 | ### Vim ### 197 | # swap 198 | [._]*.s[a-v][a-z] 199 | [._]*.sw[a-p] 200 | [._]s[a-v][a-z] 201 | [._]sw[a-p] 202 | # session 203 | Session.vim 204 | # temporary 205 | .netrwhist 206 | # auto-generated tag files 207 | tags 208 | 209 | ### Windows ### 210 | # Windows thumbnail cache files 211 | Thumbs.db 212 | ehthumbs.db 213 | ehthumbs_vista.db 214 | 215 | # Folder config file 216 | Desktop.ini 217 | 218 | # Recycle Bin used on file shares 219 | $RECYCLE.BIN/ 220 | 221 | # Windows Installer files 222 | *.cab 223 | *.msi 224 | *.msm 225 | *.msp 226 | 227 | # Windows shortcuts 228 | *.lnk 229 | 230 | # End of https://www.gitignore.io/api/c,vim,git,c++,linux,python,windows 231 | -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import os 4 | import sys 5 | import argparse 6 | 7 | from tqdm import tqdm 8 | 9 | # import matplotlib.pyplot as plt 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | 17 | from torch.autograd import Variable 18 | from torch.optim.lr_scheduler import StepLR 19 | 20 | import revnet 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model", metavar="NAME", 24 | help="what model to use") 25 | parser.add_argument("--load", metavar="PATH", 26 | help="load a previous model state") 27 | parser.add_argument("-e", "--evaluate", action="store_true", 28 | help="evaluate model on validation set") 29 | parser.add_argument("--batch-size", default=128, type=int, 30 | help="size of the mini-batches") 31 | parser.add_argument("--epochs", default=200, type=int, 32 | help="number of epochs") 33 | parser.add_argument("--lr", default=0.1, type=float, 34 | help="initial learning rate") 35 | parser.add_argument("--clip", default=0, type=float, 36 | help="maximal gradient norm") 37 | parser.add_argument("--weight-decay", default=1e-4, type=float, 38 | help="weight decay factor") 39 | parser.add_argument("--stats", action="store_true", 40 | help="record and plot some stats") 41 | 42 | 43 | # Check if CUDA is avaliable 44 | CUDA = torch.cuda.is_available() 45 | 46 | best_acc = 0 47 | 48 | 49 | def main(): 50 | global best_acc 51 | 52 | args = parser.parse_args() 53 | 54 | model = getattr(revnet, args.model)() 55 | 56 | exp_id = "cifar_{0}_{1:%Y-%m-%d}_{1:%H-%M-%S}".format(model.name, 57 | datetime.now()) 58 | 59 | path = os.path.join("./experiments/", exp_id, "cmd.sh") 60 | if not os.path.exists(os.path.dirname(path)): 61 | os.makedirs(os.path.dirname(path)) 62 | 63 | with open(path, 'w') as f: 64 | f.write(' '.join(sys.argv)) 65 | 66 | if CUDA: 67 | model.cuda() 68 | 69 | if args.load is not None: 70 | load(model, args.load) 71 | 72 | criterion = nn.CrossEntropyLoss() 73 | 74 | optimizer = optim.SGD(model.parameters(), lr=args.lr*10, 75 | momentum=0.9, weight_decay=args.weight_decay) 76 | 77 | scheduler = StepLR(optimizer, step_size=50, gamma=0.1) 78 | 79 | print("Prepairing data...") 80 | 81 | # Load data 82 | transform_train = transforms.Compose([ 83 | transforms.RandomCrop(32, padding=4), 84 | transforms.RandomHorizontalFlip(), 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.4914, 0.4822, 0.4465), 87 | (0.2023, 0.1994, 0.2010)), 88 | ]) 89 | 90 | transform_test = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.4914, 0.4822, 0.4465), 93 | (0.2023, 0.1994, 0.2010)), 94 | ]) 95 | 96 | trainset = torchvision.datasets.CIFAR10( 97 | root='./data', train=True, 98 | download=True, transform=transform_train 99 | ) 100 | 101 | trainloader = torch.utils.data.DataLoader(trainset, 102 | batch_size=args.batch_size, 103 | shuffle=True, num_workers=2) 104 | 105 | testset = torchvision.datasets.CIFAR10( 106 | root='./data', train=False, 107 | download=True, transform=transform_test 108 | ) 109 | 110 | valloader = torch.utils.data.DataLoader(testset, 111 | batch_size=args.batch_size, 112 | shuffle=False, num_workers=2) 113 | 114 | if args.evaluate: 115 | print("\nEvaluating model...") 116 | acc = validate(model, valloader) 117 | print('Accuracy: {}%'.format(acc)) 118 | return 119 | 120 | if args.stats: 121 | losses = [] 122 | taccs = [] 123 | vaccs = [] 124 | 125 | print("\nTraining model...") 126 | for epoch in range(args.epochs): 127 | scheduler.step() 128 | loss, train_acc = train(epoch, model, criterion, optimizer, 129 | trainloader, args.clip) 130 | val_acc = validate(model, valloader) 131 | 132 | if val_acc > best_acc: 133 | best_acc = val_acc 134 | save_checkpoint(model, exp_id) 135 | print('Accuracy: {}%'.format(val_acc)) 136 | 137 | if args.stats: 138 | losses.append(loss) 139 | taccs.append(train_acc) 140 | vaccs.append(val_acc) 141 | 142 | save_checkpoint(model, exp_id) 143 | 144 | if args.stats: 145 | path = os.path.join("./experiments/", exp_id, "stats/{}.dat") 146 | if not os.path.exists(os.path.dirname(path)): 147 | os.makedirs(os.path.dirname(path)) 148 | with open(path.format('loss'), 'w') as f: 149 | for i in losses: 150 | f.write('{}\n'.format(i)) 151 | 152 | with open(path.format('taccs'), 'w') as f: 153 | for i in taccs: 154 | f.write('{}\n'.format(i)) 155 | 156 | with open(path.format('vaccs'), 'w') as f: 157 | for i in vaccs: 158 | f.write('{}\n'.format(i)) 159 | 160 | return model 161 | 162 | 163 | def train(epoch, model, criterion, optimizer, trainloader, clip): 164 | model.train() 165 | train_loss = 0 166 | correct = 0 167 | total = 0 168 | t = tqdm(trainloader, ascii=True, desc='{}'.format(epoch).rjust(3)) 169 | for i, data in enumerate(t): 170 | inputs, labels = data 171 | 172 | if CUDA: 173 | inputs, labels = inputs.cuda(), labels.cuda() 174 | 175 | inputs, labels = Variable(inputs), Variable(labels) 176 | 177 | optimizer.zero_grad() 178 | 179 | outputs = model(inputs) 180 | loss = criterion(outputs, labels) 181 | loss.backward() 182 | 183 | # Free the memory used to store activations 184 | if type(model) is revnet.RevNet: 185 | model.free() 186 | 187 | if clip > 0: 188 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 189 | optimizer.step() 190 | 191 | train_loss += loss.item() 192 | _, predicted = torch.max(outputs.data, 1) 193 | total += labels.size(0) 194 | correct += predicted.eq(labels.data).cpu().sum() 195 | acc = 100 * correct / total 196 | 197 | t.set_postfix(loss='{:.3f}'.format(train_loss/(i+1)).ljust(3), 198 | acc='{:2.1f}%'.format(acc).ljust(6)) 199 | 200 | return train_loss, acc 201 | 202 | 203 | def validate(model, valloader): 204 | correct = 0 205 | total = 0 206 | 207 | model.eval() 208 | 209 | for data in valloader: 210 | images, labels = data 211 | if CUDA: 212 | images, labels = images.cuda(), labels.cuda() 213 | outputs = model(Variable(images)) 214 | 215 | # Free the memory used to store activations 216 | if type(model) is revnet.RevNet: 217 | model.free() 218 | 219 | _, predicted = torch.max(outputs.data, 1) 220 | total += labels.size(0) 221 | correct += (predicted == labels).sum() 222 | 223 | acc = 100 * correct / total 224 | 225 | return acc 226 | 227 | 228 | def load(model, path): 229 | model.load_state_dict(torch.load(path)) 230 | 231 | 232 | def save_checkpoint(model, exp_id): 233 | path = os.path.join( 234 | "experiments", exp_id, "checkpoints", 235 | "cifar_{0}_{1:%Y-%m-%d}_{1:%H-%M-%S}.dat".format(model.name, 236 | datetime.now())) 237 | if not os.path.exists(os.path.dirname(path)): 238 | os.makedirs(os.path.dirname(path)) 239 | torch.save(model.state_dict(), path) 240 | 241 | 242 | if __name__ == "__main__": 243 | main() 244 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import unittest 5 | import warnings 6 | import contextlib 7 | from functools import wraps 8 | from itertools import product 9 | from copy import deepcopy 10 | import __main__ 11 | import errno 12 | 13 | import torch 14 | import torch.cuda 15 | from torch.autograd import Variable 16 | 17 | 18 | torch.set_default_tensor_type('torch.DoubleTensor') 19 | 20 | SEED = 0 21 | SEED_SET = 0 22 | ACCEPT = False 23 | 24 | 25 | # TODO rename me 26 | def parse_set_seed_once(): 27 | global SEED 28 | global SEED_SET 29 | global ACCEPT 30 | parser = argparse.ArgumentParser(add_help=False) 31 | parser.add_argument('--seed', type=int, default=123) 32 | parser.add_argument('--accept', action='store_true') 33 | args, remaining = parser.parse_known_args() 34 | if SEED_SET == 0: 35 | torch.manual_seed(args.seed) 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed_all(args.seed) 38 | SEED = args.seed 39 | SEED_SET = 1 40 | ACCEPT = args.accept 41 | remaining = [sys.argv[0]] + remaining 42 | return remaining 43 | 44 | 45 | def run_tests(): 46 | remaining = parse_set_seed_once() 47 | unittest.main(argv=remaining) 48 | 49 | 50 | TEST_NUMPY = True 51 | try: 52 | import numpy 53 | except ImportError: 54 | TEST_NUMPY = False 55 | 56 | TEST_SCIPY = True 57 | try: 58 | import scipy 59 | except ImportError: 60 | TEST_SCIPY = False 61 | 62 | 63 | def skipIfNoLapack(fn): 64 | @wraps(fn) 65 | def wrapper(*args, **kwargs): 66 | try: 67 | fn(*args, **kwargs) 68 | except Exception as e: 69 | if 'Lapack library not found' in e.args[0]: 70 | raise unittest.SkipTest('Compiled without Lapack') 71 | raise 72 | return wrapper 73 | 74 | 75 | def suppress_warnings(fn): 76 | def wrapper(*args, **kwargs): 77 | with warnings.catch_warnings(): 78 | warnings.simplefilter("ignore") 79 | fn(*args, **kwargs) 80 | return wrapper 81 | 82 | 83 | def get_cpu_type(t): 84 | assert t.__module__ == 'torch.cuda' 85 | return getattr(torch, t.__class__.__name__) 86 | 87 | 88 | def get_gpu_type(t): 89 | assert t.__module__ == 'torch' 90 | return getattr(torch.cuda, t.__name__) 91 | 92 | 93 | def to_gpu(obj, type_map={}): 94 | if torch.is_tensor(obj): 95 | t = type_map.get(type(obj), get_gpu_type(type(obj))) 96 | return obj.clone().type(t) 97 | elif torch.is_storage(obj): 98 | return obj.new().resize_(obj.size()).copy_(obj) 99 | elif isinstance(obj, Variable): 100 | assert obj.is_leaf 101 | t = type_map.get(type(obj.data), get_gpu_type(type(obj.data))) 102 | return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad) 103 | elif isinstance(obj, list): 104 | return [to_gpu(o, type_map) for o in obj] 105 | elif isinstance(obj, tuple): 106 | return tuple(to_gpu(o, type_map) for o in obj) 107 | else: 108 | return deepcopy(obj) 109 | 110 | 111 | @contextlib.contextmanager 112 | def freeze_rng_state(): 113 | rng_state = torch.get_rng_state() 114 | if torch.cuda.is_available(): 115 | cuda_rng_state = torch.cuda.get_rng_state() 116 | yield 117 | if torch.cuda.is_available(): 118 | torch.cuda.set_rng_state(cuda_rng_state) 119 | torch.set_rng_state(rng_state) 120 | 121 | 122 | def iter_indices(tensor): 123 | if tensor.dim() == 0: 124 | return range(0) 125 | if tensor.dim() == 1: 126 | return range(tensor.size(0)) 127 | return product(*(range(s) for s in tensor.size())) 128 | 129 | 130 | def is_iterable(obj): 131 | try: 132 | iter(obj) 133 | return True 134 | except: 135 | return False 136 | 137 | 138 | class TestCase(unittest.TestCase): 139 | precision = 1e-5 140 | 141 | def setUp(self): 142 | torch.manual_seed(SEED) 143 | if torch.cuda.is_available(): 144 | torch.cuda.manual_seed_all(SEED) 145 | 146 | def assertTensorsSlowEqual(self, x, y, prec=None, message=''): 147 | max_err = 0 148 | self.assertEqual(x.size(), y.size()) 149 | for index in iter_indices(x): 150 | max_err = max(max_err, abs(x[index] - y[index])) 151 | self.assertLessEqual(max_err, prec, message) 152 | 153 | def safeCoalesce(self, t): 154 | tc = t.coalesce() 155 | 156 | value_map = {} 157 | for idx, val in zip(t._indices().t(), t._values()): 158 | idx_tup = tuple(idx) 159 | if idx_tup in value_map: 160 | value_map[idx_tup] += val 161 | else: 162 | value_map[idx_tup] = val.clone() if torch.is_tensor(val) else val 163 | 164 | new_indices = sorted(list(value_map.keys())) 165 | new_values = [value_map[idx] for idx in new_indices] 166 | if t._values().ndimension() < 2: 167 | new_values = t._values().new(new_values) 168 | else: 169 | new_values = torch.stack(new_values) 170 | 171 | new_indices = t._indices().new(new_indices).t() 172 | tg = t.new(new_indices, new_values, t.size()) 173 | 174 | self.assertEqual(tc._indices(), tg._indices()) 175 | self.assertEqual(tc._values(), tg._values()) 176 | 177 | return tg 178 | 179 | def unwrapVariables(self, x, y): 180 | if isinstance(x, Variable) and isinstance(y, Variable): 181 | return x.data, y.data 182 | elif isinstance(x, Variable) or isinstance(y, Variable): 183 | raise AssertionError("cannot compare {} and {}".format(type(x), type(y))) 184 | return x, y 185 | 186 | def assertEqual(self, x, y, prec=None, message=''): 187 | if prec is None: 188 | prec = self.precision 189 | 190 | x, y = self.unwrapVariables(x, y) 191 | 192 | if torch.is_tensor(x) and torch.is_tensor(y): 193 | def assertTensorsEqual(a, b): 194 | super(TestCase, self).assertEqual(a.size(), b.size()) 195 | if a.numel() > 0: 196 | b = b.type_as(a) 197 | b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu() 198 | # check that NaNs are in the same locations 199 | nan_mask = a != a 200 | self.assertTrue(torch.equal(nan_mask, b != b)) 201 | diff = a - b 202 | diff[nan_mask] = 0 203 | if diff.is_signed(): 204 | diff = diff.abs() 205 | max_err = diff.max() 206 | self.assertLessEqual(max_err, prec, message) 207 | self.assertEqual(x.is_sparse, y.is_sparse, message) 208 | if x.is_sparse: 209 | x = self.safeCoalesce(x) 210 | y = self.safeCoalesce(y) 211 | assertTensorsEqual(x._indices(), y._indices()) 212 | assertTensorsEqual(x._values(), y._values()) 213 | else: 214 | assertTensorsEqual(x, y) 215 | elif type(x) == str and type(y) == str: 216 | super(TestCase, self).assertEqual(x, y) 217 | elif type(x) == set and type(y) == set: 218 | super(TestCase, self).assertEqual(x, y) 219 | elif is_iterable(x) and is_iterable(y): 220 | super(TestCase, self).assertEqual(len(x), len(y)) 221 | for x_, y_ in zip(x, y): 222 | self.assertEqual(x_, y_, prec, message) 223 | else: 224 | try: 225 | self.assertLessEqual(abs(x - y), prec, message) 226 | return 227 | except: 228 | pass 229 | super(TestCase, self).assertEqual(x, y, message) 230 | 231 | def assertNotEqual(self, x, y, prec=None, message=''): 232 | if prec is None: 233 | prec = self.precision 234 | 235 | x, y = self.unwrapVariables(x, y) 236 | 237 | if torch.is_tensor(x) and torch.is_tensor(y): 238 | if x.size() != y.size(): 239 | super(TestCase, self).assertNotEqual(x.size(), y.size()) 240 | self.assertGreater(x.numel(), 0) 241 | y = y.type_as(x) 242 | y = y.cuda(device=x.get_device()) if x.is_cuda else y.cpu() 243 | nan_mask = x != x 244 | if torch.equal(nan_mask, y != y): 245 | diff = x - y 246 | if diff.is_signed(): 247 | diff = diff.abs() 248 | diff[nan_mask] = 0 249 | max_err = diff.max() 250 | self.assertGreaterEqual(max_err, prec, message) 251 | elif type(x) == str and type(y) == str: 252 | super(TestCase, self).assertNotEqual(x, y) 253 | elif is_iterable(x) and is_iterable(y): 254 | super(TestCase, self).assertNotEqual(x, y) 255 | else: 256 | try: 257 | self.assertGreaterEqual(abs(x - y), prec, message) 258 | return 259 | except: 260 | pass 261 | super(TestCase, self).assertNotEqual(x, y, message) 262 | 263 | def assertObjectIn(self, obj, iterable): 264 | for elem in iterable: 265 | if id(obj) == id(elem): 266 | return 267 | raise AssertionError("object not found in iterable") 268 | 269 | # TODO: Support context manager interface 270 | # NB: The kwargs forwarding to callable robs the 'subname' parameter. 271 | # If you need it, manually apply your callable in a lambda instead. 272 | def assertExpectedRaises(self, exc_type, callable, *args, **kwargs): 273 | subname = None 274 | if 'subname' in kwargs: 275 | subname = kwargs['subname'] 276 | del kwargs['subname'] 277 | try: 278 | callable(*args, **kwargs) 279 | except exc_type as e: 280 | self.assertExpected(str(e), subname) 281 | return 282 | # Don't put this in the try block; the AssertionError will catch it 283 | self.fail(msg="Did not raise when expected to") 284 | 285 | def assertExpected(self, s, subname=None): 286 | """ 287 | Test that a string matches the recorded contents of a file 288 | derived from the name of this test and subname. This file 289 | is placed in the 'expect' directory in the same directory 290 | as the test script. You can automatically update the recorded test 291 | output using --accept. 292 | 293 | If you call this multiple times in a single function, you must 294 | give a unique subname each time. 295 | """ 296 | if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))): 297 | raise TypeError("assertExpected is strings only") 298 | 299 | def remove_prefix(text, prefix): 300 | if text.startswith(prefix): 301 | return text[len(prefix):] 302 | return text 303 | munged_id = remove_prefix(self.id(), "__main__.") 304 | # NB: we take __file__ from __main__, so we place the expect directory 305 | # where the test script lives, NOT where test/common.py lives. This 306 | # doesn't matter in PyTorch where all test scripts are in the same 307 | # directory as test/common.py, but it matters in onnx-pytorch 308 | expected_file = os.path.join(os.path.dirname(os.path.realpath(__main__.__file__)), 309 | "expect", 310 | munged_id) 311 | if subname: 312 | expected_file += "-" + subname 313 | expected_file += ".expect" 314 | expected = None 315 | 316 | def accept_output(update_type): 317 | print("Accepting {} for {}:\n\n{}".format(update_type, munged_id, s)) 318 | with open(expected_file, 'w') as f: 319 | f.write(s) 320 | 321 | try: 322 | with open(expected_file) as f: 323 | expected = f.read() 324 | except IOError as e: 325 | if e.errno != errno.ENOENT: 326 | raise 327 | elif ACCEPT: 328 | return accept_output("output") 329 | else: 330 | raise RuntimeError( 331 | ("I got this output for {}:\n\n{}\n\n" 332 | "No expect file exists; to accept the current output, run:\n" 333 | "python {} {} --accept").format(munged_id, s, __main__.__file__, munged_id)) 334 | if ACCEPT: 335 | if expected != s: 336 | return accept_output("updated output") 337 | else: 338 | if hasattr(self, "assertMultiLineEqual"): 339 | # Python 2.7 only 340 | # NB: Python considers lhs "old" and rhs "new". 341 | self.assertMultiLineEqual(expected, s) 342 | else: 343 | self.assertEqual(s, expected) 344 | 345 | if sys.version_info < (3, 2): 346 | # assertRaisesRegexp renamed assertRaisesRegex in 3.2 347 | assertRaisesRegex = unittest.TestCase.assertRaisesRegexp 348 | 349 | 350 | def download_file(url, binary=True): 351 | if sys.version_info < (3,): 352 | from urlparse import urlsplit 353 | import urllib2 354 | request = urllib2 355 | error = urllib2 356 | else: 357 | from urllib.parse import urlsplit 358 | from urllib import request, error 359 | 360 | filename = os.path.basename(urlsplit(url)[2]) 361 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 362 | path = os.path.join(data_dir, filename) 363 | 364 | if os.path.exists(path): 365 | return path 366 | try: 367 | data = request.urlopen(url, timeout=15).read() 368 | with open(path, 'wb' if binary else 'w') as f: 369 | f.write(data) 370 | return path 371 | except error.URLError: 372 | msg = "could not download test file '{}'".format(url) 373 | warnings.warn(msg, RuntimeWarning) 374 | raise unittest.SkipTest(msg) 375 | -------------------------------------------------------------------------------- /revnet/revnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Function, Variable 8 | 9 | CUDA = torch.cuda.is_available() 10 | 11 | def size_after_residual(size, out_channels, kernel_size, stride, padding, dilation): 12 | """Calculate the size of the output of the residual function 13 | """ 14 | N, C_in, H_in, W_in = size 15 | 16 | H_out = math.floor( 17 | (H_in + 2*padding - dilation*(kernel_size - 1) - 1) / stride + 1 18 | ) 19 | W_out = math.floor( 20 | (W_in + 2*padding - dilation*(kernel_size - 1) - 1) / stride + 1 21 | ) 22 | return N, out_channels, H_out, W_out 23 | 24 | 25 | def possible_downsample(x, in_channels, out_channels, stride=1, padding=1, 26 | dilation=1): 27 | _, _, H_in, W_in = x.size() 28 | 29 | _, _, H_out, W_out = size_after_residual(x.size(), out_channels, 3, stride, padding, dilation) 30 | 31 | # Downsample image 32 | if H_in > H_out or W_in > W_out: 33 | out = F.avg_pool2d(x, 2*dilation+1, stride, padding) 34 | 35 | # Pad with empty channels 36 | if in_channels < out_channels: 37 | 38 | try: out 39 | except: out = x 40 | 41 | pad = Variable(torch.zeros( 42 | out.size(0), 43 | (out_channels - in_channels) // 2, 44 | out.size(2), out.size(3) 45 | ), requires_grad=True) 46 | 47 | if CUDA: 48 | pad = pad.cuda() 49 | 50 | temp = torch.cat([pad, out], dim=1) 51 | out = torch.cat([temp, pad], dim=1) 52 | 53 | # If we did nothing, add zero tensor, so the output of this function 54 | # depends on the input in the graph 55 | try: out 56 | except: 57 | injection = Variable(torch.zeros_like(x.data), requires_grad=True) 58 | 59 | if CUDA: 60 | injection.cuda() 61 | 62 | out = x + injection 63 | 64 | return out 65 | 66 | 67 | class RevBlockFunction(Function): 68 | @staticmethod 69 | def residual(x, in_channels, out_channels, params, buffers, training, 70 | stride=1, padding=1, dilation=1, no_activation=False): 71 | """Compute a pre-activation residual function. 72 | 73 | Args: 74 | x (Variable): The input variable 75 | in_channels (int): Number of channels of x 76 | out_channels (int): Number of channels of the output 77 | 78 | Returns: 79 | out (Variable): The result of the computation 80 | 81 | """ 82 | out = x 83 | 84 | if not no_activation: 85 | out = F.batch_norm(out, buffers[0], buffers[1], params[0], 86 | params[1], training) 87 | out = F.relu(out) 88 | 89 | out = F.conv2d(out, params[-6], params[-5], stride, padding=padding, 90 | dilation=dilation) 91 | 92 | out = F.batch_norm(out, buffers[-2], buffers[-1], params[-4], 93 | params[-3], training) 94 | out = F.relu(out) 95 | out = F.conv2d(out, params[-2], params[-1], stride=1, padding=1, 96 | dilation=1) 97 | 98 | return out 99 | 100 | @staticmethod 101 | def _forward(x, in_channels, out_channels, training, stride, padding, 102 | dilation, f_params, f_buffs, g_params, g_buffs, 103 | no_activation=False): 104 | 105 | x1, x2 = torch.chunk(x, 2, dim=1) 106 | 107 | with torch.no_grad(): 108 | x1 = Variable(x1.contiguous()) 109 | x2 = Variable(x2.contiguous()) 110 | 111 | if CUDA: 112 | x1.cuda() 113 | x2.cuda() 114 | 115 | x1_ = possible_downsample(x1, in_channels, out_channels, stride, 116 | padding, dilation) 117 | x2_ = possible_downsample(x2, in_channels, out_channels, stride, 118 | padding, dilation) 119 | 120 | f_x2 = RevBlockFunction.residual( 121 | x2, 122 | in_channels, 123 | out_channels, 124 | f_params, 125 | f_buffs, training, 126 | stride=stride, 127 | padding=padding, 128 | dilation=dilation, 129 | no_activation=no_activation 130 | ) 131 | 132 | y1 = f_x2 + x1_ 133 | 134 | g_y1 = RevBlockFunction.residual( 135 | y1, 136 | out_channels, 137 | out_channels, 138 | g_params, 139 | g_buffs, 140 | training 141 | ) 142 | 143 | y2 = g_y1 + x2_ 144 | 145 | y = torch.cat([y1, y2], dim=1) 146 | 147 | del y1, y2 148 | del x1, x2 149 | 150 | return y 151 | 152 | @staticmethod 153 | def _backward(output, in_channels, out_channels, f_params, f_buffs, 154 | g_params, g_buffs, training, padding, dilation, no_activation): 155 | 156 | y1, y2 = torch.chunk(output, 2, dim=1) 157 | with torch.no_grad(): 158 | y1 = Variable(y1.contiguous()) 159 | y2 = Variable(y2.contiguous()) 160 | 161 | x2 = y2 - RevBlockFunction.residual( 162 | y1, 163 | out_channels, 164 | out_channels, 165 | g_params, 166 | g_buffs, 167 | training=training 168 | ) 169 | 170 | x1 = y1 - RevBlockFunction.residual( 171 | x2, 172 | in_channels, 173 | out_channels, 174 | f_params, 175 | f_buffs, 176 | training=training, 177 | padding=padding, 178 | dilation=dilation 179 | ) 180 | 181 | del y1, y2 182 | x1, x2 = x1.data, x2.data 183 | 184 | x = torch.cat((x1, x2), 1) 185 | return x 186 | 187 | @staticmethod 188 | def _grad(x, dy, in_channels, out_channels, training, stride, padding, 189 | dilation, activations, f_params, f_buffs, g_params, g_buffs, 190 | no_activation=False, storage_hooks=[]): 191 | dy1, dy2 = torch.chunk(dy, 2, dim=1) 192 | 193 | x1, x2 = torch.chunk(x, 2, dim=1) 194 | 195 | with torch.enable_grad(): 196 | x1 = Variable(x1.contiguous(), requires_grad=True) 197 | x2 = Variable(x2.contiguous(), requires_grad=True) 198 | x1.retain_grad() 199 | x2.retain_grad() 200 | 201 | if CUDA: 202 | x1.cuda() 203 | x2.cuda() 204 | 205 | x1_ = possible_downsample(x1, in_channels, out_channels, stride, 206 | padding, dilation) 207 | x2_ = possible_downsample(x2, in_channels, out_channels, stride, 208 | padding, dilation) 209 | 210 | f_x2 = RevBlockFunction.residual( 211 | x2, 212 | in_channels, 213 | out_channels, 214 | f_params, 215 | f_buffs, 216 | training=training, 217 | stride=stride, 218 | padding=padding, 219 | dilation=dilation, 220 | no_activation=no_activation 221 | ) 222 | 223 | y1_ = f_x2 + x1_ 224 | 225 | g_y1 = RevBlockFunction.residual( 226 | y1_, 227 | out_channels, 228 | out_channels, 229 | g_params, 230 | g_buffs, 231 | training=training 232 | ) 233 | 234 | y2_ = g_y1 + x2_ 235 | 236 | dd1 = torch.autograd.grad(y2_, (y1_,) + tuple(g_params), dy2, 237 | retain_graph=True) 238 | dy2_y1 = dd1[0] 239 | dgw = dd1[1:] 240 | dy1_plus = dy2_y1 + dy1 241 | dd2 = torch.autograd.grad(y1_, (x1, x2) + tuple(f_params), dy1_plus, 242 | retain_graph=True) 243 | dfw = dd2[2:] 244 | 245 | dx2 = dd2[1] 246 | dx2 += torch.autograd.grad(x2_, x2, dy2, retain_graph=True)[0] 247 | dx1 = dd2[0] 248 | 249 | for hook in storage_hooks: 250 | x = hook(x) 251 | 252 | activations.append(x) 253 | 254 | y1_.detach_() 255 | y2_.detach_() 256 | del y1_, y2_ 257 | dx = torch.cat((dx1, dx2), 1) 258 | 259 | return dx, dfw, dgw 260 | 261 | @staticmethod 262 | def forward(ctx, x, in_channels, out_channels, training, stride, padding, 263 | dilation, no_activation, activations, storage_hooks, *args): 264 | """Compute forward pass including boilerplate code. 265 | 266 | This should not be called directly, use the apply method of this class. 267 | 268 | Args: 269 | ctx (Context): Context object, see PyTorch docs 270 | x (Tensor): 4D input tensor 271 | in_channels (int): Number of channels on input 272 | out_channels (int): Number of channels on output 273 | training (bool): Whethere we are training right now 274 | stride (int): Stride to use for convolutions 275 | no_activation (bool): Whether to compute an initial 276 | activation in the residual function 277 | activations (List): Activation stack 278 | storage_hooks (List[Function]): Functions to apply to activations 279 | before storing them 280 | *args: Should contain all the Parameters 281 | of the module 282 | """ 283 | 284 | if not no_activation: 285 | f_params = [Variable(x) for x in args[:8]] 286 | g_params = [Variable(x) for x in args[8:16]] 287 | f_buffs = args[16:20] 288 | g_buffs = args[20:] 289 | else: 290 | f_params = [Variable(x) for x in args[:6]] 291 | g_params = [Variable(x) for x in args[6:14]] 292 | f_buffs = args[14:16] 293 | g_buffs = args[16:] 294 | 295 | if CUDA: 296 | for var in f_params: 297 | var.cuda() 298 | for var in g_params: 299 | var.cuda() 300 | 301 | # if the images get smaller information is lost and we need to save the input 302 | _, _, H_in, W_in = x.size() 303 | _, _, H_out, W_out = size_after_residual(x.size(), out_channels, 3, stride, padding, dilation) 304 | if H_in > H_out or W_in > W_out or no_activation: 305 | activations.append(x) 306 | ctx.load_input = True 307 | else: 308 | ctx.load_input = False 309 | 310 | ctx.save_for_backward(*[x.data for x in f_params], 311 | *[x.data for x in g_params]) 312 | ctx.f_buffs = f_buffs 313 | ctx.g_buffs = g_buffs 314 | ctx.stride = stride 315 | ctx.padding = padding 316 | ctx.dilation = dilation 317 | ctx.training = training 318 | ctx.no_activation = no_activation 319 | ctx.storage_hooks = storage_hooks 320 | ctx.activations = activations 321 | ctx.in_channels = in_channels 322 | ctx.out_channels = out_channels 323 | 324 | y = RevBlockFunction._forward( 325 | x, 326 | in_channels, 327 | out_channels, 328 | training, 329 | stride, 330 | padding, 331 | dilation, 332 | f_params, f_buffs, 333 | g_params, g_buffs, 334 | no_activation=no_activation 335 | ) 336 | 337 | return y.data 338 | 339 | @staticmethod 340 | def backward(ctx, grad_out): 341 | saved_tensors = list(ctx.saved_tensors) 342 | if not ctx.no_activation: 343 | f_params = [Variable(p, requires_grad=True) for p in saved_tensors[:8]] 344 | g_params = [Variable(p, requires_grad=True) for p in saved_tensors[8:16]] 345 | else: 346 | f_params = [Variable(p, requires_grad=True) for p in saved_tensors[:6]] 347 | g_params = [Variable(p, requires_grad=True) for p in saved_tensors[6:14]] 348 | 349 | in_channels = ctx.in_channels 350 | out_channels = ctx.out_channels 351 | 352 | # Load or reconstruct input 353 | if ctx.load_input: 354 | ctx.activations.pop() 355 | x = ctx.activations.pop() 356 | else: 357 | output = ctx.activations.pop() 358 | x = RevBlockFunction._backward( 359 | output, 360 | in_channels, 361 | out_channels, 362 | f_params, ctx.f_buffs, 363 | g_params, ctx.g_buffs, 364 | ctx.training, 365 | ctx.padding, 366 | ctx.dilation, 367 | ctx.no_activation 368 | ) 369 | 370 | dx, dfw, dgw = RevBlockFunction._grad( 371 | x, 372 | grad_out, 373 | in_channels, 374 | out_channels, 375 | ctx.training, 376 | ctx.stride, 377 | ctx.padding, 378 | ctx.dilation, 379 | ctx.activations, 380 | f_params, ctx.f_buffs, 381 | g_params, ctx.g_buffs, 382 | no_activation=ctx.no_activation, 383 | storage_hooks=ctx.storage_hooks 384 | ) 385 | 386 | num_buffs = 2 if ctx.no_activation else 4 387 | 388 | return ((dx, None, None, None, None, None, None, None, None, None) + tuple(dfw) + 389 | tuple(dgw) + tuple([None]*num_buffs) + tuple([None]*4)) 390 | 391 | 392 | class RevBlock(nn.Module): 393 | def __init__(self, in_channels, out_channels, activations, stride=1, 394 | padding=1, dilation=1, no_activation=False, storage_hooks=[]): 395 | super(RevBlock, self).__init__() 396 | 397 | self.in_channels = in_channels // 2 398 | self.out_channels = out_channels // 2 399 | self.stride = stride 400 | self.padding = padding 401 | self.dilation = dilation 402 | self.no_activation = no_activation 403 | self.activations = activations 404 | self.storage_hooks = storage_hooks 405 | 406 | if not no_activation: 407 | self.register_parameter( 408 | 'f_bw1', 409 | nn.Parameter(torch.Tensor(self.in_channels)) 410 | ) 411 | self.register_parameter( 412 | 'f_bb1', 413 | nn.Parameter(torch.Tensor(self.in_channels)) 414 | ) 415 | 416 | self.register_parameter( 417 | 'f_w1', 418 | nn.Parameter(torch.Tensor( 419 | self.out_channels, 420 | self.in_channels, 421 | 3, 3 422 | )) 423 | ) 424 | self.register_parameter( 425 | 'f_b1', 426 | nn.Parameter(torch.Tensor(self.out_channels)) 427 | ) 428 | self.register_parameter( 429 | 'f_bw2', 430 | nn.Parameter(torch.Tensor(self.out_channels)) 431 | ) 432 | self.register_parameter( 433 | 'f_bb2', 434 | nn.Parameter(torch.Tensor(self.out_channels)) 435 | ) 436 | self.register_parameter( 437 | 'f_w2', 438 | nn.Parameter(torch.Tensor( 439 | self.out_channels, 440 | self.out_channels, 441 | 3, 3 442 | )) 443 | ) 444 | self.register_parameter( 445 | 'f_b2', 446 | nn.Parameter(torch.Tensor(self.out_channels)) 447 | ) 448 | 449 | self.register_parameter( 450 | 'g_bw1', 451 | nn.Parameter(torch.Tensor(self.out_channels)) 452 | ) 453 | self.register_parameter( 454 | 'g_bb1', 455 | nn.Parameter(torch.Tensor(self.out_channels)) 456 | ) 457 | self.register_parameter( 458 | 'g_w1', 459 | nn.Parameter(torch.Tensor( 460 | self.out_channels, 461 | self.out_channels, 462 | 3, 3 463 | )) 464 | ) 465 | self.register_parameter( 466 | 'g_b1', 467 | nn.Parameter(torch.Tensor(self.out_channels)) 468 | ) 469 | self.register_parameter( 470 | 'g_bw2', 471 | nn.Parameter(torch.Tensor(self.out_channels)) 472 | ) 473 | self.register_parameter( 474 | 'g_bb2', 475 | nn.Parameter(torch.Tensor(self.out_channels)) 476 | ) 477 | self.register_parameter( 478 | 'g_w2', 479 | nn.Parameter(torch.Tensor( 480 | self.out_channels, 481 | self.out_channels, 482 | 3, 3 483 | )) 484 | ) 485 | self.register_parameter( 486 | 'g_b2', 487 | nn.Parameter(torch.Tensor(self.out_channels)) 488 | ) 489 | 490 | if not no_activation: 491 | self.register_buffer('f_rm1', torch.zeros(self.in_channels)) 492 | self.register_buffer('f_rv1', torch.ones(self.in_channels)) 493 | self.register_buffer('f_rm2', torch.zeros(self.out_channels)) 494 | self.register_buffer('f_rv2', torch.ones(self.out_channels)) 495 | 496 | self.register_buffer('g_rm1', torch.zeros(self.out_channels)) 497 | self.register_buffer('g_rv1', torch.ones(self.out_channels)) 498 | self.register_buffer('g_rm2', torch.zeros(self.out_channels)) 499 | self.register_buffer('g_rv2', torch.ones(self.out_channels)) 500 | 501 | self.reset_parameters() 502 | 503 | def reset_parameters(self): 504 | f_stdv = 1 / math.sqrt(self.in_channels * 3 * 3) 505 | g_stdv = 1 / math.sqrt(self.out_channels * 3 * 3) 506 | 507 | if not self.no_activation: 508 | self._parameters['f_bw1'].data.uniform_() 509 | self._parameters['f_bb1'].data.zero_() 510 | self._parameters['f_w1'].data.uniform_(-f_stdv, f_stdv) 511 | self._parameters['f_b1'].data.uniform_(-f_stdv, f_stdv) 512 | self._parameters['f_w2'].data.uniform_(-g_stdv, g_stdv) 513 | self._parameters['f_b2'].data.uniform_(-g_stdv, g_stdv) 514 | self._parameters['f_bw2'].data.uniform_() 515 | self._parameters['f_bb2'].data.zero_() 516 | 517 | self._parameters['g_w1'].data.uniform_(-g_stdv, g_stdv) 518 | self._parameters['g_b1'].data.uniform_(-g_stdv, g_stdv) 519 | self._parameters['g_w2'].data.uniform_(-g_stdv, g_stdv) 520 | self._parameters['g_b2'].data.uniform_(-g_stdv, g_stdv) 521 | self._parameters['g_bw1'].data.uniform_() 522 | self._parameters['g_bb1'].data.zero_() 523 | self._parameters['g_bw2'].data.uniform_() 524 | self._parameters['g_bb2'].data.zero_() 525 | 526 | if not self.no_activation: 527 | self._buffers['f_rm1'].zero_() 528 | self._buffers['f_rv1'].fill_(1) 529 | self.f_rm2.zero_() 530 | self.f_rv2.fill_(1) 531 | 532 | self.g_rm1.zero_() 533 | self.g_rv1.fill_(1) 534 | self.g_rm2.zero_() 535 | self.g_rv2.fill_(1) 536 | 537 | def forward(self, x): 538 | return RevBlockFunction.apply( 539 | x, 540 | self.in_channels, 541 | self.out_channels, 542 | self.training, 543 | self.stride, 544 | self.padding, 545 | self.dilation, 546 | self.no_activation, 547 | self.activations, 548 | self.storage_hooks, 549 | *self._parameters.values(), 550 | *self._buffers.values(), 551 | ) 552 | 553 | 554 | class RevBottleneck(nn.Module): 555 | # TODO: Implement metaclass and function 556 | pass 557 | 558 | 559 | class RevNet(nn.Module): 560 | def __init__(self, 561 | units, 562 | filters, 563 | strides, 564 | classes, 565 | bottleneck=False): 566 | """ 567 | Args: 568 | units (list-like): Number of residual units in each group 569 | 570 | filters (list-like): Number of filters in each unit including the 571 | inputlayer, so it is one item longer than units 572 | 573 | strides (list-like): Strides to use for the first units in each 574 | group, same length as units 575 | 576 | bottleneck (boolean): Wether to use the bottleneck residual or the 577 | basic residual 578 | """ 579 | super(RevNet, self).__init__() 580 | self.name = self.__class__.__name__ 581 | 582 | self.activations = [] 583 | 584 | if bottleneck: 585 | self.Reversible = RevBottleneck # TODO: Implement RevBottleneck 586 | else: 587 | self.Reversible = RevBlock 588 | 589 | self.layers = nn.ModuleList() 590 | 591 | # Input layer 592 | self.layers.append(nn.Conv2d(3, filters[0], 3, padding=1)) 593 | self.layers.append(nn.BatchNorm2d(filters[0])) 594 | 595 | for i, group_i in enumerate(units): 596 | self.layers.append(self.Reversible( 597 | filters[i], filters[i + 1], 598 | stride=strides[i], 599 | no_activation=True, 600 | activations=self.activations 601 | )) 602 | 603 | for unit in range(1, group_i): 604 | self.layers.append(self.Reversible( 605 | filters[i + 1], 606 | filters[i + 1], 607 | activations=self.activations 608 | )) 609 | 610 | self.fc = nn.Linear(filters[-1], classes) 611 | 612 | def forward(self, x): 613 | for layer in self.layers: 614 | x = layer(x) 615 | 616 | # Save last output for backward 617 | self.activations.append(x.data) 618 | 619 | x = F.avg_pool2d(x, x.size(2)) 620 | x = x.view(x.size(0), -1) 621 | x = self.fc(x) 622 | 623 | return x 624 | 625 | def free(self): 626 | """Clear saved activation residue and thereby free memory.""" 627 | del self.activations[:] 628 | --------------------------------------------------------------------------------