├── .flake8 ├── .gitignore ├── .travis.yml ├── BENCHMARK.md ├── LICENSE ├── README.md ├── dropblock ├── __init__.py ├── dropblock.py └── scheduler.py ├── examples ├── README.md ├── config.yml ├── requirements.txt └── resnet-cifar10.py ├── requirements.txt ├── setup.py └── tests ├── test_dropblock2d.py └── test_dropblock3d.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | stages: 3 | - quality-assurance 4 | - deploy 5 | cache: pip 6 | jobs: 7 | include: 8 | - stage: quality-assurance 9 | python: '3.6' 10 | install: pip install flake8 11 | script: flake8 12 | 13 | - stage: quality-assurance 14 | python: '3.6' 15 | before_install: 16 | - pip install pytest==4.3.0 pytest-cov 17 | install: 18 | - python setup.py install 19 | script: 20 | - pytest --cov dropblock --cov-report term-missing -v tests/ 21 | -------------------------------------------------------------------------------- /BENCHMARK.md: -------------------------------------------------------------------------------- 1 | # ResNet-9 CIFAR-10 Benchmark 2 | 3 | 4 | Results for ResNet9 on CIFAR10, trained on 1 x NVidia V100 GPU, average over 3 runs: 5 | 6 | | Model | Accuracy (%) | Time (s) | 7 | |----------------------|--------------|----------| 8 | | ResNet9 | 81.46 | 271 | 9 | | ResNet9 + DropBlock* | 81.65 | 288 | 10 | 11 | `* scheduled dropblock with block_size=5 and increasing drop_prob 12 | from 0.0 to 0.25 over 5000 iterations` 13 | 14 | Example available [here](examples/resnet-cifar10.py) 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Miguel Varela Ramos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DropBlock 2 | 3 | ![build](https://travis-ci.org/miguelvr/dropblock.png?branch=master) 4 | [![Downloads](https://pepy.tech/badge/dropblock)](https://pepy.tech/project/dropblock) 5 | 6 | 7 | Implementation of [DropBlock: A regularization method for convolutional networks](https://arxiv.org/pdf/1810.12890.pdf) 8 | in PyTorch. 9 | 10 | ## Abstract 11 | 12 | Deep neural networks often work well when they are over-parameterized 13 | and trained with a massive amount of noise and regularization, such as 14 | weight decay and dropout. Although dropout is widely used as a regularization 15 | technique for fully connected layers, it is often less effective for convolutional layers. 16 | This lack of success of dropout for convolutional layers is perhaps due to the fact 17 | that activation units in convolutional layers are spatially correlated so 18 | information can still flow through convolutional networks despite dropout. 19 | Thus a structured form of dropout is needed to regularize convolutional networks. 20 | In this paper, we introduce DropBlock, a form of structured dropout, where units in a 21 | contiguous region of a feature map are dropped together. 22 | We found that applying DropBlock in skip connections in addition to the 23 | convolution layers increases the accuracy. Also, gradually increasing number 24 | of dropped units during training leads to better accuracy and more robust to hyperparameter choices. 25 | Extensive experiments show that DropBlock works better than dropout in regularizing 26 | convolutional networks. On ImageNet classification, ResNet-50 architecture with 27 | DropBlock achieves 78.13% accuracy, which is more than 1.6% improvement on the baseline. 28 | On COCO detection, DropBlock improves Average Precision of RetinaNet from 36.8% to 38.4%. 29 | 30 | 31 | ## Installation 32 | 33 | Install directly from PyPI: 34 | 35 | pip install dropblock 36 | 37 | or the bleeding edge version from github: 38 | 39 | pip install git+https://github.com/miguelvr/dropblock.git#egg=dropblock 40 | 41 | **NOTE**: Implementation and tests were done in Python 3.6, if you have problems with other versions of python please open an issue. 42 | 43 | ## Usage 44 | 45 | 46 | For 2D inputs (DropBlock2D): 47 | 48 | ```python 49 | import torch 50 | from dropblock import DropBlock2D 51 | 52 | # (bsize, n_feats, height, width) 53 | x = torch.rand(100, 10, 16, 16) 54 | 55 | drop_block = DropBlock2D(block_size=3, drop_prob=0.3) 56 | regularized_x = drop_block(x) 57 | ``` 58 | 59 | For 3D inputs (DropBlock3D): 60 | 61 | ```python 62 | import torch 63 | from dropblock import DropBlock3D 64 | 65 | # (bsize, n_feats, depth, height, width) 66 | x = torch.rand(100, 10, 16, 16, 16) 67 | 68 | drop_block = DropBlock3D(block_size=3, drop_prob=0.3) 69 | regularized_x = drop_block(x) 70 | ``` 71 | 72 | Scheduled Dropblock: 73 | 74 | ```python 75 | import torch 76 | from dropblock import DropBlock2D, LinearScheduler 77 | 78 | # (bsize, n_feats, depth, height, width) 79 | loader = [torch.rand(20, 10, 16, 16) for _ in range(10)] 80 | 81 | drop_block = LinearScheduler( 82 | DropBlock2D(block_size=3, drop_prob=0.), 83 | start_value=0., 84 | stop_value=0.25, 85 | nr_steps=5 86 | ) 87 | 88 | probs = [] 89 | for x in loader: 90 | drop_block.step() 91 | regularized_x = drop_block(x) 92 | probs.append(drop_block.dropblock.drop_prob) 93 | 94 | print(probs) 95 | ``` 96 | 97 | The drop probabilities will be: 98 | ``` 99 | >>> [0. , 0.0625, 0.125 , 0.1875, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25] 100 | ``` 101 | 102 | The user should include the `step()` call at the start of the batch loop, 103 | or at the the start of a model's `forward` call. 104 | 105 | Check [examples/resnet-cifar10.py](examples/resnet-cifar10.py) to 106 | see an implementation example. 107 | 108 | ## Implementation details 109 | 110 | We use `drop_prob` instead of `keep_prob` as a matter of preference, 111 | and to keep the argument consistent with pytorch's dropout. 112 | Regardless, everything else should work similarly to what is described in the paper. 113 | 114 | ## Benchmark 115 | 116 | Refer to [BENCHMARK.md](BENCHMARK.md) 117 | 118 | ## Reference 119 | [Ghiasi et al., 2018] DropBlock: A regularization method for convolutional networks 120 | 121 | ## TODO 122 | - [x] Scheduled DropBlock 123 | - [x] Get benchmark numbers 124 | - [x] Extend the concept for 3D images 125 | -------------------------------------------------------------------------------- /dropblock/__init__.py: -------------------------------------------------------------------------------- 1 | from .dropblock import DropBlock2D, DropBlock3D 2 | from .scheduler import LinearScheduler 3 | 4 | __all__ = ['DropBlock2D', 'DropBlock3D', 'LinearScheduler'] 5 | -------------------------------------------------------------------------------- /dropblock/dropblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class DropBlock2D(nn.Module): 7 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 8 | 9 | As described in the paper 10 | `DropBlock: A regularization method for convolutional networks`_ , 11 | dropping whole blocks of feature map allows to remove semantic 12 | information as compared to regular dropout. 13 | 14 | Args: 15 | drop_prob (float): probability of an element to be dropped. 16 | block_size (int): size of the block to drop 17 | 18 | Shape: 19 | - Input: `(N, C, H, W)` 20 | - Output: `(N, C, H, W)` 21 | 22 | .. _DropBlock: A regularization method for convolutional networks: 23 | https://arxiv.org/abs/1810.12890 24 | 25 | """ 26 | 27 | def __init__(self, drop_prob, block_size): 28 | super(DropBlock2D, self).__init__() 29 | 30 | self.drop_prob = drop_prob 31 | self.block_size = block_size 32 | 33 | def forward(self, x): 34 | # shape: (bsize, channels, height, width) 35 | 36 | assert x.dim() == 4, \ 37 | "Expected input with 4 dimensions (bsize, channels, height, width)" 38 | 39 | if not self.training or self.drop_prob == 0.: 40 | return x 41 | else: 42 | # get gamma value 43 | gamma = self._compute_gamma(x) 44 | 45 | # sample mask 46 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 47 | 48 | # place mask on input device 49 | mask = mask.to(x.device) 50 | 51 | # compute block mask 52 | block_mask = self._compute_block_mask(mask) 53 | 54 | # apply block mask 55 | out = x * block_mask[:, None, :, :] 56 | 57 | # scale output 58 | out = out * block_mask.numel() / block_mask.sum() 59 | 60 | return out 61 | 62 | def _compute_block_mask(self, mask): 63 | block_mask = F.max_pool2d(input=mask[:, None, :, :], 64 | kernel_size=(self.block_size, self.block_size), 65 | stride=(1, 1), 66 | padding=self.block_size // 2) 67 | 68 | if self.block_size % 2 == 0: 69 | block_mask = block_mask[:, :, :-1, :-1] 70 | 71 | block_mask = 1 - block_mask.squeeze(1) 72 | 73 | return block_mask 74 | 75 | def _compute_gamma(self, x): 76 | return self.drop_prob / (self.block_size ** 2) 77 | 78 | 79 | class DropBlock3D(DropBlock2D): 80 | r"""Randomly zeroes 3D spatial blocks of the input tensor. 81 | 82 | An extension to the concept described in the paper 83 | `DropBlock: A regularization method for convolutional networks`_ , 84 | dropping whole blocks of feature map allows to remove semantic 85 | information as compared to regular dropout. 86 | 87 | Args: 88 | drop_prob (float): probability of an element to be dropped. 89 | block_size (int): size of the block to drop 90 | 91 | Shape: 92 | - Input: `(N, C, D, H, W)` 93 | - Output: `(N, C, D, H, W)` 94 | 95 | .. _DropBlock: A regularization method for convolutional networks: 96 | https://arxiv.org/abs/1810.12890 97 | 98 | """ 99 | 100 | def __init__(self, drop_prob, block_size): 101 | super(DropBlock3D, self).__init__(drop_prob, block_size) 102 | 103 | def forward(self, x): 104 | # shape: (bsize, channels, depth, height, width) 105 | 106 | assert x.dim() == 5, \ 107 | "Expected input with 5 dimensions (bsize, channels, depth, height, width)" 108 | 109 | if not self.training or self.drop_prob == 0.: 110 | return x 111 | else: 112 | # get gamma value 113 | gamma = self._compute_gamma(x) 114 | 115 | # sample mask 116 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 117 | 118 | # place mask on input device 119 | mask = mask.to(x.device) 120 | 121 | # compute block mask 122 | block_mask = self._compute_block_mask(mask) 123 | 124 | # apply block mask 125 | out = x * block_mask[:, None, :, :, :] 126 | 127 | # scale output 128 | out = out * block_mask.numel() / block_mask.sum() 129 | 130 | return out 131 | 132 | def _compute_block_mask(self, mask): 133 | block_mask = F.max_pool3d(input=mask[:, None, :, :, :], 134 | kernel_size=(self.block_size, self.block_size, self.block_size), 135 | stride=(1, 1, 1), 136 | padding=self.block_size // 2) 137 | 138 | if self.block_size % 2 == 0: 139 | block_mask = block_mask[:, :, :-1, :-1, :-1] 140 | 141 | block_mask = 1 - block_mask.squeeze(1) 142 | 143 | return block_mask 144 | 145 | def _compute_gamma(self, x): 146 | return self.drop_prob / (self.block_size ** 3) 147 | -------------------------------------------------------------------------------- /dropblock/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | 4 | 5 | class LinearScheduler(nn.Module): 6 | def __init__(self, dropblock, start_value, stop_value, nr_steps): 7 | super(LinearScheduler, self).__init__() 8 | self.dropblock = dropblock 9 | self.i = 0 10 | self.drop_values = np.linspace(start=start_value, stop=stop_value, num=int(nr_steps)) 11 | 12 | def forward(self, x): 13 | return self.dropblock(x) 14 | 15 | def step(self): 16 | if self.i < len(self.drop_values): 17 | self.dropblock.drop_prob = self.drop_values[self.i] 18 | 19 | self.i += 1 20 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # ResNet9 on CIFAR-10 2 | 3 | 4 | ## Requirements 5 | 6 | ```bash 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Usage 11 | 12 | Run the example on the CPU: 13 | 14 | ```bash 15 | python resnet-cifar10.py -c config.yml 16 | ``` 17 | 18 | Run the example on the GPU (device 0): 19 | 20 | ```bash 21 | python resnet-cifar10.py -c config.yml --device 0 22 | ``` -------------------------------------------------------------------------------- /examples/config.yml: -------------------------------------------------------------------------------- 1 | root: ./data 2 | workers: 4 3 | bsize: 256 4 | epochs: 50 5 | lr: 0.001 6 | drop_prob: 0. 7 | block_size: 5 -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-ignite==0.1.1 2 | tqdm 3 | torchvision 4 | configargparse -------------------------------------------------------------------------------- /examples/resnet-cifar10.py: -------------------------------------------------------------------------------- 1 | import time 2 | import configargparse 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.models.resnet import BasicBlock, ResNet 8 | from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events 9 | from ignite.metrics import Accuracy 10 | from ignite.metrics import RunningAverage 11 | from ignite.contrib.handlers import ProgressBar 12 | 13 | from dropblock import DropBlock2D, LinearScheduler 14 | 15 | results = [] 16 | 17 | 18 | class ResNetCustom(ResNet): 19 | 20 | def __init__(self, block, layers, num_classes=1000, drop_prob=0., block_size=5): 21 | super(ResNet, self).__init__() 22 | self.inplanes = 64 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.dropblock = LinearScheduler( 28 | DropBlock2D(drop_prob=drop_prob, block_size=block_size), 29 | start_value=0., 30 | stop_value=drop_prob, 31 | nr_steps=5e3 32 | ) 33 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 34 | self.layer1 = self._make_layer(block, 64, layers[0]) 35 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 36 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 37 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 38 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 39 | self.fc = nn.Linear(512 * block.expansion, num_classes) 40 | 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 44 | elif isinstance(m, nn.BatchNorm2d): 45 | nn.init.constant_(m.weight, 1) 46 | nn.init.constant_(m.bias, 0) 47 | 48 | def forward(self, x): 49 | self.dropblock.step() # increment number of iterations 50 | 51 | x = self.conv1(x) 52 | x = self.bn1(x) 53 | x = self.relu(x) 54 | x = self.maxpool(x) 55 | 56 | x = self.dropblock(self.layer1(x)) 57 | x = self.dropblock(self.layer2(x)) 58 | x = self.layer3(x) 59 | x = self.layer4(x) 60 | 61 | x = self.avgpool(x) 62 | x = x.view(x.shape[0], -1) 63 | x = self.fc(x) 64 | 65 | return x 66 | 67 | 68 | def resnet9(**kwargs): 69 | return ResNetCustom(BasicBlock, [1, 1, 1, 1], **kwargs) 70 | 71 | 72 | def logger(engine, model, evaluator, loader, pbar): 73 | evaluator.run(loader) 74 | metrics = evaluator.state.metrics 75 | avg_accuracy = metrics['accuracy'] 76 | pbar.log_message( 77 | "Test Results - Avg accuracy: {:.2f}, drop_prob: {:.2f}".format(avg_accuracy, 78 | model.dropblock.dropblock.drop_prob) 79 | ) 80 | results.append(avg_accuracy) 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = configargparse.ArgumentParser() 85 | 86 | parser.add_argument('-c', '--config', required=False, 87 | is_config_file=True, help='config file') 88 | parser.add_argument('--root', required=False, type=str, default='./data', 89 | help='data root path') 90 | parser.add_argument('--workers', required=False, type=int, default=4, 91 | help='number of workers for data loader') 92 | parser.add_argument('--bsize', required=False, type=int, default=256, 93 | help='batch size') 94 | parser.add_argument('--epochs', required=False, type=int, default=50, 95 | help='number of epochs') 96 | parser.add_argument('--lr', required=False, type=float, default=0.001, 97 | help='learning rate') 98 | parser.add_argument('--drop_prob', required=False, type=float, default=0., 99 | help='dropblock dropout probability') 100 | parser.add_argument('--block_size', required=False, type=int, default=5, 101 | help='dropblock block size') 102 | parser.add_argument('--device', required=False, default=None, type=int, 103 | help='CUDA device id for GPU training') 104 | options = parser.parse_args() 105 | 106 | root = options.root 107 | bsize = options.bsize 108 | workers = options.workers 109 | epochs = options.epochs 110 | lr = options.lr 111 | drop_prob = options.drop_prob 112 | block_size = options.block_size 113 | device = 'cpu' if options.device is None \ 114 | else torch.device('cuda:{}'.format(options.device)) 115 | 116 | transform = transforms.Compose([ 117 | transforms.RandomHorizontalFlip(), 118 | transforms.RandomCrop(32, padding=4), 119 | transforms.ToTensor(), 120 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 121 | ]) 122 | 123 | train_set = torchvision.datasets.CIFAR10(root=root, train=True, 124 | download=True, transform=transform) 125 | 126 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=bsize, 127 | shuffle=True, num_workers=workers) 128 | 129 | test_set = torchvision.datasets.CIFAR10(root=root, train=False, 130 | download=True, transform=transform) 131 | 132 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=bsize, 133 | shuffle=False, num_workers=workers) 134 | 135 | classes = ('plane', 'car', 'bird', 'cat', 136 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 137 | 138 | # define model 139 | model = resnet9(num_classes=len(classes), drop_prob=drop_prob, block_size=block_size) 140 | 141 | # define loss and optimizer 142 | criterion = nn.CrossEntropyLoss() 143 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 144 | 145 | # create ignite engines 146 | trainer = create_supervised_trainer(model=model, 147 | optimizer=optimizer, 148 | loss_fn=criterion, 149 | device=device) 150 | 151 | evaluator = create_supervised_evaluator(model, 152 | metrics={'accuracy': Accuracy()}, 153 | device=device) 154 | 155 | # ignite handlers 156 | RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') 157 | 158 | pbar = ProgressBar() 159 | pbar.attach(trainer, ['loss']) 160 | 161 | trainer.add_event_handler(Events.EPOCH_COMPLETED, logger, model, evaluator, test_loader, pbar) 162 | 163 | # start training 164 | t0 = time.time() 165 | trainer.run(train_loader, max_epochs=epochs) 166 | t1 = time.time() 167 | print('Best Accuracy:', max(results)) 168 | print('Total time:', t1 - t0) 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=0.4.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | with open('requirements.txt', encoding='utf-8') as f: 5 | required = f.read().splitlines() 6 | 7 | with open('README.md', encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | setup( 11 | name='dropblock', 12 | version='0.3.0', 13 | packages=find_packages(), 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | install_requires=required, 17 | url='https://github.com/miguelvr/dropblock', 18 | license='MIT', 19 | author='Miguel Varela Ramos', 20 | author_email='miguelvramos92@gmail.com', 21 | description='Implementation of DropBlock: A regularization method for convolutional networks in PyTorch. ' 22 | ) 23 | -------------------------------------------------------------------------------- /tests/test_dropblock2d.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | import torch 5 | 6 | from dropblock import DropBlock2D 7 | 8 | 9 | # noinspection PyCallingNonCallable 10 | def test_block_mask_square_even(): 11 | db = DropBlock2D(block_size=2, drop_prob=0.1) 12 | mask = torch.tensor([[[1., 0., 0., 0., 0.], 13 | [0., 0., 0., 1., 0.], 14 | [0., 0., 0., 0., 0.], 15 | [0., 0., 0., 0., 0.], 16 | [0., 0., 0., 0., 0.]]]) 17 | 18 | expected = torch.tensor([[[0., 0., 1., 1., 1.], 19 | [0., 0., 1., 0., 0.], 20 | [1., 1., 1., 0., 0.], 21 | [1., 1., 1., 1., 1.], 22 | [1., 1., 1., 1., 1.]]]) 23 | 24 | block_mask = db._compute_block_mask(mask) 25 | assert torch.equal(block_mask, expected) 26 | 27 | 28 | # noinspection PyCallingNonCallable 29 | def test_block_mask_Hw_even(): 30 | db = DropBlock2D(block_size=2, drop_prob=0.1) 31 | mask = torch.tensor([[[1., 0., 0., 0.], 32 | [0., 0., 0., 1.], 33 | [0., 0., 0., 0.], 34 | [0., 0., 0., 0.], 35 | [0., 0., 0., 0.]]]) 36 | 37 | expected = torch.tensor([[[0., 0., 1., 1.], 38 | [0., 0., 1., 0.], 39 | [1., 1., 1., 0.], 40 | [1., 1., 1., 1.], 41 | [1., 1., 1., 1.]]]) 42 | 43 | block_mask = db._compute_block_mask(mask) 44 | assert torch.equal(block_mask, expected) 45 | 46 | 47 | # noinspection PyCallingNonCallable 48 | def test_block_mask_hW_even(): 49 | db = DropBlock2D(block_size=2, drop_prob=0.1) 50 | mask = torch.tensor([[[0., 0., 0., 1., 0.], 51 | [0., 0., 0., 0., 0.], 52 | [0., 0., 0., 0., 0.], 53 | [0., 0., 0., 0., 0.]]]) 54 | 55 | expected = torch.tensor([[[1., 1., 1., 0., 0.], 56 | [1., 1., 1., 0., 0.], 57 | [1., 1., 1., 1., 1.], 58 | [1., 1., 1., 1., 1.]]]) 59 | 60 | block_mask = db._compute_block_mask(mask) 61 | assert torch.equal(block_mask, expected) 62 | 63 | 64 | # noinspection PyCallingNonCallable 65 | def test_block_mask_square_odd(): 66 | db = DropBlock2D(block_size=3, drop_prob=0.1) 67 | mask = torch.tensor([[[1., 0., 0., 0., 0.], 68 | [0., 0., 0., 1., 0.], 69 | [0., 0., 0., 0., 0.], 70 | [0., 0., 0., 0., 0.], 71 | [0., 0., 0., 0., 0.]]]) 72 | 73 | expected = torch.tensor([[[0., 0., 0., 0., 0.], 74 | [0., 0., 0., 0., 0.], 75 | [1., 1., 0., 0., 0.], 76 | [1., 1., 1., 1., 1.], 77 | [1., 1., 1., 1., 1.]]]) 78 | 79 | block_mask = db._compute_block_mask(mask) 80 | assert torch.equal(block_mask, expected) 81 | 82 | 83 | # noinspection PyCallingNonCallable 84 | def test_block_mask_Hw_odd(): 85 | db = DropBlock2D(block_size=3, drop_prob=0.1) 86 | mask = torch.tensor([[[1., 0., 0., 0.], 87 | [0., 0., 0., 1.], 88 | [0., 0., 0., 0.], 89 | [0., 0., 0., 0.], 90 | [0., 0., 0., 0.]]]) 91 | 92 | expected = torch.tensor([[[0., 0., 0., 0.], 93 | [0., 0., 0., 0.], 94 | [1., 1., 0., 0.], 95 | [1., 1., 1., 1.], 96 | [1., 1., 1., 1.]]]) 97 | 98 | block_mask = db._compute_block_mask(mask) 99 | assert torch.equal(block_mask, expected) 100 | 101 | 102 | # noinspection PyCallingNonCallable 103 | def test_block_mask_hW_odd(): 104 | db = DropBlock2D(block_size=3, drop_prob=0.1) 105 | mask = torch.tensor([[[0., 0., 0., 1., 0.], 106 | [0., 0., 0., 0., 0.], 107 | [0., 0., 0., 0., 0.], 108 | [0., 0., 0., 0., 0.]]]) 109 | 110 | expected = torch.tensor([[[1., 1., 0., 0., 0.], 111 | [1., 1., 0., 0., 0.], 112 | [1., 1., 1., 1., 1.], 113 | [1., 1., 1., 1., 1.]]]) 114 | 115 | block_mask = db._compute_block_mask(mask) 116 | assert torch.equal(block_mask, expected) 117 | 118 | 119 | # noinspection PyCallingNonCallable 120 | def test_block_mask_overlap(): 121 | db = DropBlock2D(block_size=2, drop_prob=0.1) 122 | mask = torch.tensor([[[1., 0., 0., 0., 0.], 123 | [0., 1., 0., 0., 0.], 124 | [0., 0., 0., 0., 0.], 125 | [0., 0., 0., 0., 0.], 126 | [0., 0., 0., 0., 0.]]]) 127 | 128 | expected = torch.tensor([[[0., 0., 1., 1., 1.], 129 | [0., 0., 0., 1., 1.], 130 | [1., 0., 0., 1., 1.], 131 | [1., 1., 1., 1., 1.], 132 | [1., 1., 1., 1., 1.]]]) 133 | 134 | block_mask = db._compute_block_mask(mask) 135 | assert torch.equal(block_mask, expected) 136 | 137 | 138 | # noinspection PyCallingNonCallable 139 | def test_forward_pass(): 140 | db = DropBlock2D(block_size=3, drop_prob=0.1) 141 | block_mask = torch.tensor([[[0., 0., 0., 1., 1., 1., 1.], 142 | [0., 0., 0., 0., 0., 0., 1.], 143 | [0., 0., 0., 0., 0., 0., 1.], 144 | [1., 1., 1., 0., 0., 0., 1.], 145 | [1., 1., 1., 1., 1., 1., 1.], 146 | [1., 1., 1., 1., 1., 1., 1.], 147 | [1., 1., 1., 1., 1., 1., 1.]]]) 148 | 149 | db._compute_block_mask = mock.MagicMock(return_value=block_mask) 150 | 151 | x = torch.ones(10, 10, 7, 7) 152 | h = db(x) 153 | 154 | expected = block_mask * block_mask.numel() / block_mask.sum() 155 | expected = expected[:, None, :, :].expand_as(x) 156 | 157 | assert tuple(h.shape) == (10, 10, 7, 7) 158 | assert torch.equal(h, expected) 159 | 160 | 161 | def test_forward_pass2(): 162 | 163 | block_sizes = [2, 3, 4, 5, 6, 7, 8] 164 | heights = [5, 6, 8, 10, 11, 14, 15] 165 | widths = [5, 7, 8, 10, 15, 14, 15] 166 | 167 | for block_size, height, width in zip(block_sizes, heights, widths): 168 | dropout = DropBlock2D(0.1, block_size=block_size) 169 | input = torch.randn((5, 20, height, width)) 170 | output = dropout(input) 171 | assert tuple(input.shape) == tuple(output.shape) 172 | 173 | 174 | def test_large_block_size(): 175 | dropout = DropBlock2D(0.3, block_size=9) 176 | x = torch.rand(100, 10, 16, 16) 177 | output = dropout(x) 178 | 179 | assert tuple(x.shape) == tuple(output.shape) 180 | 181 | 182 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") 183 | def test_forward_pass_with_cuda(): 184 | dropout = DropBlock2D(0.3, block_size=5).to('cuda') 185 | x = torch.rand(100, 10, 16, 16).to('cuda') 186 | output = dropout(x) 187 | 188 | assert tuple(x.shape) == tuple(output.shape) 189 | -------------------------------------------------------------------------------- /tests/test_dropblock3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dropblock import DropBlock3D 3 | from unittest import mock 4 | import pytest 5 | 6 | 7 | # noinspection PyCallingNonCallable 8 | def test_block_mask_cube_even(): 9 | db = DropBlock3D(block_size=2, drop_prob=0.1) 10 | mask = torch.tensor([[[[0., 0., 0., 0., 0.], 11 | [0., 0., 0., 0., 0.], 12 | [0., 0., 0., 0., 0.], 13 | [0., 0., 0., 0., 0.], 14 | [0., 0., 0., 0., 0.]], 15 | [[1., 0., 0., 0., 0.], 16 | [0., 0., 0., 1., 0.], 17 | [0., 0., 0., 0., 0.], 18 | [0., 0., 0., 0., 0.], 19 | [0., 0., 0., 0., 0.]], 20 | [[0., 0., 0., 0., 0.], 21 | [0., 0., 0., 0., 0.], 22 | [0., 0., 0., 0., 0.], 23 | [0., 0., 0., 0., 0.], 24 | [0., 0., 0., 0., 0.]], 25 | [[0., 0., 0., 0., 0.], 26 | [0., 0., 0., 0., 0.], 27 | [0., 0., 0., 0., 0.], 28 | [0., 0., 0., 0., 0.], 29 | [0., 0., 0., 0., 0.]], 30 | [[0., 0., 0., 0., 0.], 31 | [0., 0., 0., 0., 0.], 32 | [0., 0., 0., 0., 0.], 33 | [0., 0., 0., 0., 0.], 34 | [0., 0., 0., 0., 0.]]]]) 35 | 36 | expected = torch.tensor([[[[1., 1., 1., 1., 1.], 37 | [1., 1., 1., 1., 1.], 38 | [1., 1., 1., 1., 1.], 39 | [1., 1., 1., 1., 1.], 40 | [1., 1., 1., 1., 1.]], 41 | [[0., 0., 1., 1., 1.], 42 | [0., 0., 1., 0., 0.], 43 | [1., 1., 1., 0., 0.], 44 | [1., 1., 1., 1., 1.], 45 | [1., 1., 1., 1., 1.]], 46 | [[0., 0., 1., 1., 1.], 47 | [0., 0., 1., 0., 0.], 48 | [1., 1., 1., 0., 0.], 49 | [1., 1., 1., 1., 1.], 50 | [1., 1., 1., 1., 1.]], 51 | [[1., 1., 1., 1., 1.], 52 | [1., 1., 1., 1., 1.], 53 | [1., 1., 1., 1., 1.], 54 | [1., 1., 1., 1., 1.], 55 | [1., 1., 1., 1., 1.]], 56 | [[1., 1., 1., 1., 1.], 57 | [1., 1., 1., 1., 1.], 58 | [1., 1., 1., 1., 1.], 59 | [1., 1., 1., 1., 1.], 60 | [1., 1., 1., 1., 1.]]]]) 61 | 62 | block_mask = db._compute_block_mask(mask) 63 | assert torch.equal(block_mask, expected) 64 | 65 | 66 | # noinspection PyCallingNonCallable 67 | def test_block_mask_cube_odd(): 68 | db = DropBlock3D(block_size=3, drop_prob=0.1) 69 | mask = torch.tensor([[[[0., 0., 0., 0., 0.], 70 | [0., 0., 0., 0., 0.], 71 | [0., 0., 0., 0., 0.], 72 | [0., 0., 0., 0., 0.], 73 | [0., 0., 0., 0., 0.]], 74 | [[1., 0., 0., 0., 0.], 75 | [0., 0., 0., 1., 0.], 76 | [0., 0., 0., 0., 0.], 77 | [0., 0., 0., 0., 0.], 78 | [0., 0., 0., 0., 0.]], 79 | [[0., 0., 0., 0., 0.], 80 | [0., 0., 0., 0., 0.], 81 | [0., 0., 0., 0., 0.], 82 | [0., 0., 0., 0., 0.], 83 | [0., 0., 0., 0., 0.]], 84 | [[0., 0., 0., 0., 0.], 85 | [0., 0., 0., 0., 0.], 86 | [0., 0., 0., 0., 0.], 87 | [0., 0., 0., 0., 0.], 88 | [0., 0., 0., 0., 0.]], 89 | [[0., 0., 0., 0., 0.], 90 | [0., 0., 0., 0., 0.], 91 | [0., 0., 0., 0., 0.], 92 | [0., 0., 0., 0., 0.], 93 | [0., 0., 0., 0., 0.]]]]) 94 | 95 | expected = torch.tensor([[[[0., 0., 0., 0., 0.], 96 | [0., 0., 0., 0., 0.], 97 | [1., 1., 0., 0., 0.], 98 | [1., 1., 1., 1., 1.], 99 | [1., 1., 1., 1., 1.]], 100 | [[0., 0., 0., 0., 0.], 101 | [0., 0., 0., 0., 0.], 102 | [1., 1., 0., 0., 0.], 103 | [1., 1., 1., 1., 1.], 104 | [1., 1., 1., 1., 1.]], 105 | [[0., 0., 0., 0., 0.], 106 | [0., 0., 0., 0., 0.], 107 | [1., 1., 0., 0., 0.], 108 | [1., 1., 1., 1., 1.], 109 | [1., 1., 1., 1., 1.]], 110 | [[1., 1., 1., 1., 1.], 111 | [1., 1., 1., 1., 1.], 112 | [1., 1., 1., 1., 1.], 113 | [1., 1., 1., 1., 1.], 114 | [1., 1., 1., 1., 1.]], 115 | [[1., 1., 1., 1., 1.], 116 | [1., 1., 1., 1., 1.], 117 | [1., 1., 1., 1., 1.], 118 | [1., 1., 1., 1., 1.], 119 | [1., 1., 1., 1., 1.]]]]) 120 | 121 | block_mask = db._compute_block_mask(mask) 122 | assert torch.equal(block_mask, expected) 123 | 124 | 125 | # noinspection PyCallingNonCallable 126 | def test_forward_pass(): 127 | db = DropBlock3D(block_size=3, drop_prob=0.1) 128 | block_mask = torch.tensor([[[[1., 1., 1., 1., 1., 1., 1.], 129 | [1., 1., 1., 1., 1., 1., 1.], 130 | [1., 1., 1., 1., 1., 1., 1.], 131 | [1., 1., 1., 1., 1., 1., 1.], 132 | [1., 1., 1., 1., 1., 1., 1.], 133 | [1., 1., 1., 1., 1., 1., 1.], 134 | [1., 1., 1., 1., 1., 1., 1.]], 135 | [[0., 0., 0., 1., 1., 1., 1.], 136 | [0., 0., 0., 0., 0., 0., 1.], 137 | [0., 0., 0., 0., 0., 0., 1.], 138 | [1., 1., 1., 0., 0., 0., 1.], 139 | [1., 1., 1., 1., 1., 1., 1.], 140 | [1., 1., 1., 1., 1., 1., 1.], 141 | [1., 1., 1., 1., 1., 1., 1.]], 142 | [[0., 0., 0., 1., 1., 1., 1.], 143 | [0., 0., 0., 0., 0., 0., 1.], 144 | [0., 0., 0., 0., 0., 0., 1.], 145 | [1., 1., 1., 0., 0., 0., 1.], 146 | [1., 1., 1., 1., 1., 1., 1.], 147 | [1., 1., 1., 1., 1., 1., 1.], 148 | [1., 1., 1., 1., 1., 1., 1.]], 149 | [[0., 0., 0., 1., 1., 1., 1.], 150 | [0., 0., 0., 0., 0., 0., 1.], 151 | [0., 0., 0., 0., 0., 0., 1.], 152 | [1., 1., 1., 0., 0., 0., 1.], 153 | [1., 1., 1., 1., 1., 1., 1.], 154 | [1., 1., 1., 1., 1., 1., 1.], 155 | [1., 1., 1., 1., 1., 1., 1.]], 156 | [[1., 1., 1., 1., 1., 1., 1.], 157 | [1., 1., 1., 1., 1., 1., 1.], 158 | [1., 1., 1., 1., 1., 1., 1.], 159 | [1., 1., 1., 1., 1., 1., 1.], 160 | [1., 1., 1., 1., 1., 1., 1.], 161 | [1., 1., 1., 1., 1., 1., 1.], 162 | [1., 1., 1., 1., 1., 1., 1.]], 163 | [[1., 1., 1., 1., 1., 1., 1.], 164 | [1., 1., 1., 1., 1., 1., 1.], 165 | [1., 1., 1., 1., 1., 1., 1.], 166 | [1., 1., 1., 1., 1., 1., 1.], 167 | [1., 1., 1., 1., 1., 1., 1.], 168 | [1., 1., 1., 1., 1., 1., 1.], 169 | [1., 1., 1., 1., 1., 1., 1.]], 170 | [[1., 1., 1., 1., 1., 1., 1.], 171 | [1., 1., 1., 1., 1., 1., 1.], 172 | [1., 1., 1., 1., 1., 1., 1.], 173 | [1., 1., 1., 1., 1., 1., 1.], 174 | [1., 1., 1., 1., 1., 1., 1.], 175 | [1., 1., 1., 1., 1., 1., 1.], 176 | [1., 1., 1., 1., 1., 1., 1.]]]]) 177 | 178 | db._compute_block_mask = mock.MagicMock(return_value=block_mask) 179 | 180 | x = torch.ones(10, 10, 7, 7, 7) 181 | h = db(x) 182 | 183 | expected = block_mask * block_mask.numel() / block_mask.sum() 184 | expected = expected[:, None, :, :, :].expand_as(x) 185 | 186 | assert tuple(h.shape) == (10, 10, 7, 7, 7) 187 | assert torch.equal(h, expected) 188 | 189 | 190 | def test_forward_pass2(): 191 | block_sizes = [2, 3, 4, 5, 6, 7, 8] 192 | depths = [5, 6, 8, 10, 11, 14, 15] 193 | heights = [5, 6, 8, 10, 11, 14, 15] 194 | widths = [5, 7, 8, 10, 15, 14, 15] 195 | 196 | for block_size, depth, height, width in zip(block_sizes, depths, heights, widths): 197 | dropout = DropBlock3D(0.2, block_size=block_size) 198 | input = torch.randn((5, 20, depth, height, width)) 199 | output = dropout(input) 200 | 201 | assert tuple(input.shape) == tuple(output.shape) 202 | 203 | 204 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") 205 | def test_forward_pass_with_cuda(): 206 | dropout = DropBlock3D(0.2, block_size=5).to('cuda') 207 | input = torch.randn((5, 20, 16, 16, 16)).to('cuda') 208 | output = dropout(input) 209 | 210 | assert tuple(input.shape) == tuple(output.shape) 211 | --------------------------------------------------------------------------------