├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── images ├── forward.png └── memvslayers.png ├── models ├── __init__.py └── densenet.py └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # vscode project settings 104 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Geoff Pleiss 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # efficient_densenet_pytorch 2 | A PyTorch >=1.0 implementation of DenseNets, optimized to save GPU memory. 3 | 4 | ## Recent updates 5 | 1. **Now works on PyTorch 1.0!** It uses the checkpointing feature, which makes this code WAY more efficient!!! 6 | 7 | ## Motivation 8 | While DenseNets are fairly easy to implement in deep learning frameworks, most 9 | implmementations (such as the [original](https://github.com/liuzhuang13/DenseNet)) tend to be memory-hungry. 10 | In particular, the number of intermediate feature maps generated by batch normalization and concatenation operations 11 | grows quadratically with network depth. 12 | *It is worth emphasizing that this is not a property inherent to DenseNets, but rather to the implementation.* 13 | 14 | This implementation uses a new strategy to reduce the memory consumption of DenseNets. 15 | We use [checkpointing](https://pytorch.org/docs/stable/checkpoint.html?highlight=checkpointing) to compute the Batch Norm and concatenation feature maps. 16 | These intermediate feature maps are discarded during the forward pass and recomputed for the backward pass. 17 | This adds 15-20% of time overhead for training, but **reduces feature map consumption from quadratic to linear.** 18 | 19 | This implementation is inspired by this [technical report](https://arxiv.org/pdf/1707.06990.pdf), which outlines a strategy for efficient DenseNets via memory sharing. 20 | 21 | ## Requirements 22 | - PyTorch >=1.0.0 23 | - CUDA 24 | 25 | ## Usage 26 | 27 | **In your existing project:** 28 | There is one file in the `models` folder. 29 | - `models/densenet.py` is an implementation based off the [torchvision](https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py) and 30 | [project killer](https://github.com/felixgwu/img_classification_pk_pytorch/blob/master/models/densenet.py) implementations. 31 | 32 | If you care about speed, and memory is not an option, pass the `efficient=False` argument into the `DenseNet` constructor. 33 | Otherwise, pass in `efficient=True`. 34 | 35 | **Options:** 36 | - All options are described in [the docstrings of the model files](https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet_efficient.py#L189) 37 | - The depth is controlled by `block_config` option 38 | - `efficient=True` uses the memory-efficient version 39 | - If you want to use the model for ImageNet, set `small_inputs=False`. For CIFAR or SVHN, set `small_inputs=True`. 40 | 41 | **Running the demo:** 42 | 43 | The only extra package you need to install is [python-fire](https://github.com/google/python-fire): 44 | ```sh 45 | pip install fire 46 | ``` 47 | 48 | - Single GPU: 49 | 50 | ```sh 51 | CUDA_VISIBLE_DEVICES=0 python demo.py --efficient True --data --save 52 | ``` 53 | 54 | - Multiple GPU: 55 | 56 | ```sh 57 | CUDA_VISIBLE_DEVICES=0,1,2 python demo.py --efficient True --data --save 58 | ``` 59 | 60 | Options: 61 | - `--depth` (int) - depth of the network (number of convolution layers) (default 40) 62 | - `--growth_rate` (int) - number of features added per DenseNet layer (default 12) 63 | - `--n_epochs` (int) - number of epochs for training (default 300) 64 | - `--batch_size` (int) - size of minibatch (default 256) 65 | - `--seed` (int) - manually set the random seed (default None) 66 | 67 | ## Performance 68 | 69 | A comparison of the two implementations (each is a DenseNet-BC with 100 layers, batch size 64, tested on a NVIDIA Pascal Titan-X): 70 | 71 | | Implementation | Memory cosumption (GB/GPU) | Speed (sec/mini batch) | 72 | |----------------|------------------------|------------------------| 73 | | Naive | 2.863 | 0.165 | 74 | | Efficient | 1.605 | 0.207 | 75 | | Efficient (multi-GPU) | 0.985 | - | 76 | 77 | 78 | ## Other efficient implementations 79 | - [LuaTorch](https://github.com/liuzhuang13/DenseNet/tree/master/models) (by Gao Huang) 80 | - [Tensorflow](https://github.com/joeyearsley/efficient_densenet_tensorflow) (by Joe Yearsley) 81 | - [Caffe](https://github.com/Tongcheng/DN_CaffeScript) (by Tongcheng Li) 82 | 83 | ## Reference 84 | 85 | ``` 86 | @article{pleiss2017memory, 87 | title={Memory-Efficient Implementation of DenseNets}, 88 | author={Pleiss, Geoff and Chen, Danlu and Huang, Gao and Li, Tongcheng and van der Maaten, Laurens and Weinberger, Kilian Q}, 89 | journal={arXiv preprint arXiv:1707.06990}, 90 | year={2017} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import os 3 | import time 4 | import torch 5 | from torchvision import datasets, transforms 6 | from models import DenseNet 7 | 8 | 9 | class AverageMeter(object): 10 | """ 11 | Computes and stores the average and current value 12 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 13 | """ 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def train_epoch(model, loader, optimizer, epoch, n_epochs, print_freq=1): 31 | batch_time = AverageMeter() 32 | losses = AverageMeter() 33 | error = AverageMeter() 34 | 35 | # Model on train mode 36 | model.train() 37 | 38 | end = time.time() 39 | for batch_idx, (input, target) in enumerate(loader): 40 | # Create vaiables 41 | if torch.cuda.is_available(): 42 | input = input.cuda() 43 | target = target.cuda() 44 | 45 | # compute output 46 | output = model(input) 47 | loss = torch.nn.functional.cross_entropy(output, target) 48 | 49 | # measure accuracy and record loss 50 | batch_size = target.size(0) 51 | _, pred = output.data.cpu().topk(1, dim=1) 52 | error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size, batch_size) 53 | losses.update(loss.item(), batch_size) 54 | 55 | # compute gradient and do SGD step 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | 60 | # measure elapsed time 61 | batch_time.update(time.time() - end) 62 | end = time.time() 63 | 64 | # print stats 65 | if batch_idx % print_freq == 0: 66 | res = '\t'.join([ 67 | 'Epoch: [%d/%d]' % (epoch + 1, n_epochs), 68 | 'Iter: [%d/%d]' % (batch_idx + 1, len(loader)), 69 | 'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg), 70 | 'Loss %.4f (%.4f)' % (losses.val, losses.avg), 71 | 'Error %.4f (%.4f)' % (error.val, error.avg), 72 | ]) 73 | print(res) 74 | 75 | # Return summary statistics 76 | return batch_time.avg, losses.avg, error.avg 77 | 78 | 79 | def test_epoch(model, loader, print_freq=1, is_test=True): 80 | batch_time = AverageMeter() 81 | losses = AverageMeter() 82 | error = AverageMeter() 83 | 84 | # Model on eval mode 85 | model.eval() 86 | 87 | end = time.time() 88 | with torch.no_grad(): 89 | for batch_idx, (input, target) in enumerate(loader): 90 | # Create vaiables 91 | if torch.cuda.is_available(): 92 | input = input.cuda() 93 | target = target.cuda() 94 | 95 | # compute output 96 | output = model(input) 97 | loss = torch.nn.functional.cross_entropy(output, target) 98 | 99 | # measure accuracy and record loss 100 | batch_size = target.size(0) 101 | _, pred = output.data.cpu().topk(1, dim=1) 102 | error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size, batch_size) 103 | losses.update(loss.item(), batch_size) 104 | 105 | # measure elapsed time 106 | batch_time.update(time.time() - end) 107 | end = time.time() 108 | 109 | # print stats 110 | if batch_idx % print_freq == 0: 111 | res = '\t'.join([ 112 | 'Test' if is_test else 'Valid', 113 | 'Iter: [%d/%d]' % (batch_idx + 1, len(loader)), 114 | 'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg), 115 | 'Loss %.4f (%.4f)' % (losses.val, losses.avg), 116 | 'Error %.4f (%.4f)' % (error.val, error.avg), 117 | ]) 118 | print(res) 119 | 120 | # Return summary statistics 121 | return batch_time.avg, losses.avg, error.avg 122 | 123 | 124 | def train(model, train_set, valid_set, test_set, save, n_epochs=300, 125 | batch_size=64, lr=0.1, wd=0.0001, momentum=0.9, seed=None): 126 | if seed is not None: 127 | torch.manual_seed(seed) 128 | 129 | # Data loaders 130 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, 131 | pin_memory=(torch.cuda.is_available()), num_workers=0) 132 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, 133 | pin_memory=(torch.cuda.is_available()), num_workers=0) 134 | if valid_set is None: 135 | valid_loader = None 136 | else: 137 | valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, 138 | pin_memory=(torch.cuda.is_available()), num_workers=0) 139 | # Model on cuda 140 | if torch.cuda.is_available(): 141 | model = model.cuda() 142 | 143 | # Wrap model for multi-GPUs, if necessary 144 | model_wrapper = model 145 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 146 | model_wrapper = torch.nn.DataParallel(model).cuda() 147 | 148 | # Optimizer 149 | optimizer = torch.optim.SGD(model_wrapper.parameters(), lr=lr, momentum=momentum, nesterov=True, weight_decay=wd) 150 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], 151 | gamma=0.1) 152 | 153 | # Start log 154 | with open(os.path.join(save, 'results.csv'), 'w') as f: 155 | f.write('epoch,train_loss,train_error,valid_loss,valid_error,test_error\n') 156 | 157 | # Train model 158 | best_error = 1 159 | for epoch in range(n_epochs): 160 | _, train_loss, train_error = train_epoch( 161 | model=model_wrapper, 162 | loader=train_loader, 163 | optimizer=optimizer, 164 | epoch=epoch, 165 | n_epochs=n_epochs, 166 | ) 167 | scheduler.step() 168 | _, valid_loss, valid_error = test_epoch( 169 | model=model_wrapper, 170 | loader=valid_loader if valid_loader else test_loader, 171 | is_test=(not valid_loader) 172 | ) 173 | 174 | # Determine if model is the best 175 | if valid_loader: 176 | if valid_error < best_error: 177 | best_error = valid_error 178 | print('New best error: %.4f' % best_error) 179 | torch.save(model.state_dict(), os.path.join(save, 'model.dat')) 180 | else: 181 | torch.save(model.state_dict(), os.path.join(save, 'model.dat')) 182 | 183 | # Log results 184 | with open(os.path.join(save, 'results.csv'), 'a') as f: 185 | f.write('%03d,%0.6f,%0.6f,%0.5f,%0.5f,\n' % ( 186 | (epoch + 1), 187 | train_loss, 188 | train_error, 189 | valid_loss, 190 | valid_error, 191 | )) 192 | 193 | # Final test of model on test set 194 | model.load_state_dict(torch.load(os.path.join(save, 'model.dat'))) 195 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 196 | model = torch.nn.DataParallel(model).cuda() 197 | test_results = test_epoch( 198 | model=model, 199 | loader=test_loader, 200 | is_test=True 201 | ) 202 | _, _, test_error = test_results 203 | with open(os.path.join(save, 'results.csv'), 'a') as f: 204 | f.write(',,,,,%0.5f\n' % (test_error)) 205 | print('Final test error: %.4f' % test_error) 206 | 207 | 208 | def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000, 209 | n_epochs=300, batch_size=64, seed=None): 210 | """ 211 | A demo to show off training of efficient DenseNets. 212 | Trains and evaluates a DenseNet-BC on CIFAR-10. 213 | 214 | Args: 215 | data (str) - path to directory where data should be loaded from/downloaded 216 | (default $DATA_DIR) 217 | save (str) - path to save the model to (default /tmp) 218 | 219 | depth (int) - depth of the network (number of convolution layers) (default 40) 220 | growth_rate (int) - number of features added per DenseNet layer (default 12) 221 | efficient (bool) - use the memory efficient implementation? (default True) 222 | 223 | valid_size (int) - size of validation set 224 | n_epochs (int) - number of epochs for training (default 300) 225 | batch_size (int) - size of minibatch (default 256) 226 | seed (int) - manually set the random seed (default None) 227 | """ 228 | 229 | # Get densenet configuration 230 | if (depth - 4) % 3: 231 | raise Exception('Invalid depth') 232 | block_config = [(depth - 4) // 6 for _ in range(3)] 233 | 234 | # Data transforms 235 | mean = [0.49139968, 0.48215841, 0.44653091] 236 | stdv = [0.24703223, 0.24348513, 0.26158784] 237 | train_transforms = transforms.Compose([ 238 | transforms.RandomCrop(32, padding=4), 239 | transforms.RandomHorizontalFlip(), 240 | transforms.ToTensor(), 241 | transforms.Normalize(mean=mean, std=stdv), 242 | ]) 243 | test_transforms = transforms.Compose([ 244 | transforms.ToTensor(), 245 | transforms.Normalize(mean=mean, std=stdv), 246 | ]) 247 | 248 | # Datasets 249 | train_set = datasets.CIFAR10(data, train=True, transform=train_transforms, download=True) 250 | test_set = datasets.CIFAR10(data, train=False, transform=test_transforms, download=False) 251 | 252 | if valid_size: 253 | valid_set = datasets.CIFAR10(data, train=True, transform=test_transforms) 254 | indices = torch.randperm(len(train_set)) 255 | train_indices = indices[:len(indices) - valid_size] 256 | valid_indices = indices[len(indices) - valid_size:] 257 | train_set = torch.utils.data.Subset(train_set, train_indices) 258 | valid_set = torch.utils.data.Subset(valid_set, valid_indices) 259 | else: 260 | valid_set = None 261 | 262 | # Models 263 | model = DenseNet( 264 | growth_rate=growth_rate, 265 | block_config=block_config, 266 | num_init_features=growth_rate*2, 267 | num_classes=10, 268 | small_inputs=True, 269 | efficient=efficient, 270 | ) 271 | print(model) 272 | 273 | # Print number of parameters 274 | num_params = sum(p.numel() for p in model.parameters()) 275 | print("Total parameters: ", num_params) 276 | 277 | # Make save directory 278 | if not os.path.exists(save): 279 | os.makedirs(save) 280 | if not os.path.isdir(save): 281 | raise Exception('%s is not a dir' % save) 282 | 283 | # Train the model 284 | train(model=model, train_set=train_set, valid_set=valid_set, test_set=test_set, save=save, 285 | n_epochs=n_epochs, batch_size=batch_size, seed=seed) 286 | print('Done!') 287 | 288 | 289 | """ 290 | A demo to show off training of efficient DenseNets. 291 | Trains and evaluates a DenseNet-BC on CIFAR-10. 292 | 293 | Try out the efficient DenseNet implementation: 294 | python demo.py --efficient True --data --save 295 | 296 | Try out the naive DenseNet implementation: 297 | python demo.py --efficient False --data --save 298 | 299 | Other args: 300 | --depth (int) - depth of the network (number of convolution layers) (default 40) 301 | --growth_rate (int) - number of features added per DenseNet layer (default 12) 302 | --n_epochs (int) - number of epochs for training (default 300) 303 | --batch_size (int) - size of minibatch (default 256) 304 | --seed (int) - manually set the random seed (default None) 305 | """ 306 | if __name__ == '__main__': 307 | fire.Fire(demo) 308 | -------------------------------------------------------------------------------- /images/forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpleiss/efficient_densenet_pytorch/d2d0629dae2912ec4f0444fab5e3df4c375c44fd/images/forward.png -------------------------------------------------------------------------------- /images/memvslayers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpleiss/efficient_densenet_pytorch/d2d0629dae2912ec4f0444fab5e3df4c375c44fd/images/memvslayers.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import DenseNet 2 | 3 | 4 | __all__ = [ 5 | DenseNet, 6 | ] 7 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | # This implementation is based on the DenseNet-BC implementation in torchvision 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint as cp 9 | from collections import OrderedDict 10 | 11 | 12 | def _bn_function_factory(norm, relu, conv): 13 | def bn_function(*inputs): 14 | concated_features = torch.cat(inputs, 1) 15 | bottleneck_output = conv(relu(norm(concated_features))) 16 | return bottleneck_output 17 | 18 | return bn_function 19 | 20 | 21 | class _DenseLayer(nn.Module): 22 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): 23 | super(_DenseLayer, self).__init__() 24 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 25 | self.add_module('relu1', nn.ReLU(inplace=True)), 26 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, 27 | kernel_size=1, stride=1, bias=False)), 28 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 29 | self.add_module('relu2', nn.ReLU(inplace=True)), 30 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 31 | kernel_size=3, stride=1, padding=1, bias=False)), 32 | self.drop_rate = drop_rate 33 | self.efficient = efficient 34 | 35 | def forward(self, *prev_features): 36 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 37 | if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 38 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 39 | else: 40 | bottleneck_output = bn_function(*prev_features) 41 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 42 | if self.drop_rate > 0: 43 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 44 | return new_features 45 | 46 | 47 | class _Transition(nn.Sequential): 48 | def __init__(self, num_input_features, num_output_features): 49 | super(_Transition, self).__init__() 50 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 51 | self.add_module('relu', nn.ReLU(inplace=True)) 52 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 53 | kernel_size=1, stride=1, bias=False)) 54 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 55 | 56 | 57 | class _DenseBlock(nn.Module): 58 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False): 59 | super(_DenseBlock, self).__init__() 60 | for i in range(num_layers): 61 | layer = _DenseLayer( 62 | num_input_features + i * growth_rate, 63 | growth_rate=growth_rate, 64 | bn_size=bn_size, 65 | drop_rate=drop_rate, 66 | efficient=efficient, 67 | ) 68 | self.add_module('denselayer%d' % (i + 1), layer) 69 | 70 | def forward(self, init_features): 71 | features = [init_features] 72 | for name, layer in self.named_children(): 73 | new_features = layer(*features) 74 | features.append(new_features) 75 | return torch.cat(features, 1) 76 | 77 | 78 | class DenseNet(nn.Module): 79 | r"""Densenet-BC model class, based on 80 | `"Densely Connected Convolutional Networks" ` 81 | Args: 82 | growth_rate (int) - how many filters to add each layer (`k` in paper) 83 | block_config (list of 3 or 4 ints) - how many layers in each pooling block 84 | num_init_features (int) - the number of filters to learn in the first convolution layer 85 | bn_size (int) - multiplicative factor for number of bottle neck layers 86 | (i.e. bn_size * k features in the bottleneck layer) 87 | drop_rate (float) - dropout rate after each dense layer 88 | num_classes (int) - number of classification classes 89 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 90 | efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. 91 | """ 92 | def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5, 93 | num_init_features=24, bn_size=4, drop_rate=0, 94 | num_classes=10, small_inputs=True, efficient=False): 95 | 96 | super(DenseNet, self).__init__() 97 | assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' 98 | 99 | # First convolution 100 | if small_inputs: 101 | self.features = nn.Sequential(OrderedDict([ 102 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), 103 | ])) 104 | else: 105 | self.features = nn.Sequential(OrderedDict([ 106 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 107 | ])) 108 | self.features.add_module('norm0', nn.BatchNorm2d(num_init_features)) 109 | self.features.add_module('relu0', nn.ReLU(inplace=True)) 110 | self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, 111 | ceil_mode=False)) 112 | 113 | # Each denseblock 114 | num_features = num_init_features 115 | for i, num_layers in enumerate(block_config): 116 | block = _DenseBlock( 117 | num_layers=num_layers, 118 | num_input_features=num_features, 119 | bn_size=bn_size, 120 | growth_rate=growth_rate, 121 | drop_rate=drop_rate, 122 | efficient=efficient, 123 | ) 124 | self.features.add_module('denseblock%d' % (i + 1), block) 125 | num_features = num_features + num_layers * growth_rate 126 | if i != len(block_config) - 1: 127 | trans = _Transition(num_input_features=num_features, 128 | num_output_features=int(num_features * compression)) 129 | self.features.add_module('transition%d' % (i + 1), trans) 130 | num_features = int(num_features * compression) 131 | 132 | # Final batch norm 133 | self.features.add_module('norm_final', nn.BatchNorm2d(num_features)) 134 | 135 | # Linear layer 136 | self.classifier = nn.Linear(num_features, num_classes) 137 | 138 | # Initialization 139 | for name, param in self.named_parameters(): 140 | if 'conv' in name and 'weight' in name: 141 | n = param.size(0) * param.size(2) * param.size(3) 142 | param.data.normal_().mul_(math.sqrt(2. / n)) 143 | elif 'norm' in name and 'weight' in name: 144 | param.data.fill_(1) 145 | elif 'norm' in name and 'bias' in name: 146 | param.data.fill_(0) 147 | elif 'classifier' in name and 'bias' in name: 148 | param.data.fill_(0) 149 | 150 | def forward(self, x): 151 | features = self.features(x) 152 | out = F.relu(features, inplace=True) 153 | out = F.adaptive_avg_pool2d(out, (1, 1)) 154 | out = torch.flatten(out, 1) 155 | out = self.classifier(out) 156 | return out 157 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [pep8] 5 | max-line-length = 120 6 | 7 | [flake8] 8 | max-line-length = 120 9 | ignore = F403, F405, E128 10 | --------------------------------------------------------------------------------