├── .gitignore ├── LICENSE ├── README.md ├── cond_mobilenetv2.py ├── condconv.py └── fig └── condconv_layer.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Duo Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # condconv.pytorch 2 | 3 | PyTorch implementation of Conditional Convolution in [CondConv: Conditionally Parameterized Convolutions for Efficient Inference](https://arxiv.org/abs/1904.04971). 4 | 5 |

6 | 7 | * CondConv layer and CondConv-equipped MobileNetV2 is supported (pre-trained model release is not in the recent plan). 8 | * Dynamic batch inference is supported (implemented by group convolution). 9 | 10 | # Citation 11 | 12 | ```bibtex 13 | @incollection{NIPS2019_8412, 14 | title = {CondConv: Conditionally Parameterized Convolutions for Efficient Inference}, 15 | author = {Yang, Brandon and Bender, Gabriel and Le, Quoc V and Ngiam, Jiquan}, 16 | booktitle = {Advances in Neural Information Processing Systems 32}, 17 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 18 | pages = {1307--1318}, 19 | year = {2019}, 20 | publisher = {Curran Associates, Inc.}, 21 | url = {http://papers.nips.cc/paper/8412-condconv-conditionally-parameterized-convolutions-for-efficient-inference.pdf} 22 | } 23 | ``` 24 | 25 | Note that there exist similar works, such as [DY-CNN](https://openaccess.thecvf.com/content_CVPR_2020/html/Chen_Dynamic_Convolution_Attention_Over_Convolution_Kernels_CVPR_2020_paper.html) (accepted by CVPR'20) by Microsoft and [DyNet](https://openreview.net/forum?id=SyeZIkrKwS) (rejected by ICLR'20) by Huawei 26 | ```bibtex 27 | @InProceedings{Chen_2020_CVPR, 28 | author = {Chen, Yinpeng and Dai, Xiyang and Liu, Mengchen and Chen, Dongdong and Yuan, Lu and Liu, Zicheng}, 29 | title = {Dynamic Convolution: Attention Over Convolution Kernels}, 30 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 31 | month = {June}, 32 | year = {2020} 33 | } 34 | ``` 35 | ```bibtex 36 | @misc{ 37 | zhang2020dynet, 38 | title={DyNet: Dynamic Convolution for Accelerating Convolution Neural Networks}, 39 | author={Kane Zhang and Jian Zhang and Qiang Wang and Zhao Zhong}, 40 | year={2020}, 41 | url={https://openreview.net/forum?id=SyeZIkrKwS} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /cond_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import functools 5 | from condconv import CondConv2d, route_func 6 | 7 | __all__ = ['cond_mobilenetv2'] 8 | 9 | 10 | def _make_divisible(v, divisor, min_value=None): 11 | """ 12 | This function is taken from the original tf repo. 13 | It ensures that all layers have a channel number that is divisible by 8 14 | It can be seen here: 15 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 16 | :param v: 17 | :param divisor: 18 | :param min_value: 19 | :return: 20 | """ 21 | if min_value is None: 22 | min_value = divisor 23 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 24 | # Make sure that round down does not go down by more than 10%. 25 | if new_v < 0.9 * v: 26 | new_v += divisor 27 | return new_v 28 | 29 | 30 | def conv_3x3_bn(inp, oup, stride): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 33 | nn.BatchNorm2d(oup), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | def conv_1x1_bn(inp, oup): 39 | return nn.Sequential( 40 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 41 | nn.BatchNorm2d(oup), 42 | nn.ReLU6(inplace=True) 43 | ) 44 | 45 | 46 | class InvertedResidual(nn.Module): 47 | def __init__(self, inp, oup, stride, expand_ratio, num_experts=None): 48 | super(InvertedResidual, self).__init__() 49 | assert stride in [1, 2] 50 | 51 | hidden_dim = round(inp * expand_ratio) 52 | self.identity = stride == 1 and inp == oup 53 | self.expand_ratio = expand_ratio 54 | self.cond = num_experts is not None 55 | Conv2d = functools.partial(CondConv2d, num_experts=num_experts) if num_experts else nn.Conv2d 56 | 57 | if expand_ratio != 1: 58 | self.pw = Conv2d(inp, hidden_dim, 1, 1, 0, bias=False) 59 | self.bn_pw = nn.BatchNorm2d(hidden_dim) 60 | self.dw = Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False) 61 | self.bn_dw = nn.BatchNorm2d(hidden_dim) 62 | self.pw_linear = Conv2d(hidden_dim, oup, 1, 1, 0, bias=False) 63 | self.bn_pw_linear = nn.BatchNorm2d(oup) 64 | self.relu = nn.ReLU6(inplace=True) 65 | 66 | if num_experts: 67 | self.route = route_func(inp, num_experts) 68 | 69 | def forward(self, x): 70 | identity = x 71 | if self.cond: 72 | routing_weight = self.route(x) 73 | if self.expand_ratio != 1: 74 | x = self.relu(self.bn_pw(self.pw(x, routing_weight))) 75 | x = self.relu(self.bn_dw(self.dw(x, routing_weight))) 76 | x = self.bn_pw_linear(self.pw_linear(x, routing_weight)) 77 | else: 78 | if self.expand_ratio != 1: 79 | x = self.relu(self.bn_pw(self.pw(x))) 80 | x = self.relu(self.bn_dw(self.dw(x))) 81 | x = self.bn_pw_linear(self.pw_linear(x)) 82 | 83 | if self.identity: 84 | return x + identity 85 | else: 86 | return x 87 | 88 | 89 | class CondMobileNetV2(nn.Module): 90 | def __init__(self, num_classes=1000, width_mult=1., num_experts=8): 91 | super(CondMobileNetV2, self).__init__() 92 | # setting of inverted residual blocks 93 | self.cfgs = [ 94 | # t, c, n, s 95 | [1, 16, 1, 1], 96 | [6, 24, 2, 2], 97 | [6, 32, 3, 2], 98 | [6, 64, 4, 2], 99 | [6, 96, 3, 1], 100 | [6, 160, 3, 2], 101 | [6, 320, 1, 1], 102 | ] 103 | 104 | # building first layer 105 | input_channel = _make_divisible(32 * width_mult, 8) 106 | layers = [conv_3x3_bn(3, input_channel, 2)] 107 | # building inverted residual blocks 108 | block = InvertedResidual 109 | self.num_experts = None 110 | for j, (t, c, n, s) in enumerate(self.cfgs): 111 | output_channel = _make_divisible(c * width_mult, 8) 112 | for i in range(n): 113 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, self.num_experts)) 114 | input_channel = output_channel 115 | if j == 4 and i == 0: # CondConv layers in the final 6 inverted residual blocks 116 | self.num_experts = num_experts 117 | self.features = nn.Sequential(*layers) 118 | # building last several layers 119 | output_channel = _make_divisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280 120 | self.conv = conv_1x1_bn(input_channel, output_channel) 121 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 122 | self.classifier_route = route_func(output_channel, num_experts) 123 | self.classifier = CondConv2d(output_channel, num_classes, kernel_size=1, bias=False, num_experts=num_experts) 124 | 125 | self._initialize_weights() 126 | 127 | def forward(self, x): 128 | x = self.features(x) 129 | x = self.conv(x) 130 | x = self.avgpool(x) 131 | routing_weight = self.classifier_route(x) 132 | x = self.classifier(x, routing_weight) 133 | x = x.squeeze_() 134 | return x 135 | 136 | def _initialize_weights(self): 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | m.weight.data.normal_(0, math.sqrt(2. / n)) 141 | if m.bias is not None: 142 | m.bias.data.zero_() 143 | elif isinstance(m, nn.BatchNorm2d): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | elif isinstance(m, nn.Linear): 147 | m.weight.data.normal_(0, 0.01) 148 | m.bias.data.zero_() 149 | 150 | def cond_mobilenetv2(**kwargs): 151 | """ 152 | Constructs a CondConv-based MobileNet V2 model 153 | """ 154 | return CondMobileNetV2(**kwargs) 155 | 156 | -------------------------------------------------------------------------------- /condconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class route_func(nn.Module): 8 | r"""CondConv: Conditionally Parameterized Convolutions for Efficient Inference 9 | https://papers.nips.cc/paper/8412-condconv-conditionally-parameterized-convolutions-for-efficient-inference.pdf 10 | 11 | Args: 12 | c_in (int): Number of channels in the input image 13 | num_experts (int): Number of experts for mixture. Default: 1 14 | """ 15 | 16 | def __init__(self, c_in, num_experts): 17 | super(route_func, self).__init__() 18 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) 19 | self.fc = nn.Linear(c_in, num_experts) 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | def forward(self, x): 23 | x = self.avgpool(x) 24 | x = x.view(x.size(0), -1) 25 | x = self.fc(x) 26 | x = self.sigmoid(x) 27 | return x 28 | 29 | 30 | class CondConv2d(nn.Module): 31 | r"""CondConv: Conditionally Parameterized Convolutions for Efficient Inference 32 | https://papers.nips.cc/paper/8412-condconv-conditionally-parameterized-convolutions-for-efficient-inference.pdf 33 | 34 | Args: 35 | in_channels (int): Number of channels in the input image 36 | out_channels (int): Number of channels produced by the convolution 37 | kernel_size (int or tuple): Size of the convolving kernel 38 | stride (int or tuple, optional): Stride of the convolution. Default: 1 39 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 40 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 41 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 42 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 43 | num_experts (int): Number of experts for mixture. Default: 1 44 | 45 | """ 46 | 47 | def __init__(self, in_channels, out_channels, kernel_size, 48 | stride=1, padding=0, dilation=1, groups=1, bias=True, 49 | num_experts=1): 50 | super(CondConv2d, self).__init__() 51 | 52 | self.in_channels = in_channels 53 | self.out_channels = out_channels 54 | self.kernel_size = kernel_size 55 | self.stride = stride 56 | self.padding = padding 57 | self.dilation = dilation 58 | self.groups = groups 59 | self.num_experts = num_experts 60 | 61 | self.weight = nn.Parameter( 62 | torch.Tensor(num_experts, out_channels, in_channels // groups, kernel_size, kernel_size)) 63 | if bias: 64 | self.bias = nn.Parameter(torch.Tensor(num_experts, out_channels)) 65 | else: 66 | self.register_parameter('bias', None) 67 | 68 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 69 | if self.bias is not None: 70 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 71 | bound = 1 / math.sqrt(fan_in) 72 | nn.init.uniform_(self.bias, -bound, bound) 73 | 74 | def forward(self, x, routing_weight): 75 | b, c_in, h, w = x.size() 76 | k, c_out, c_in, kh, kw = self.weight.size() 77 | x = x.view(1, -1, h, w) 78 | weight = self.weight.view(k, -1) 79 | combined_weight = torch.mm(routing_weight, weight).view(-1, c_in, kh, kw) 80 | if self.bias is not None: 81 | combined_bias = torch.mm(routing_weight, self.bias).view(-1) 82 | output = F.conv2d( 83 | x, weight=combined_weight, bias=combined_bias, stride=self.stride, padding=self.padding, 84 | dilation=self.dilation, groups=self.groups * b) 85 | else: 86 | output = F.conv2d( 87 | x, weight=combined_weight, bias=None, stride=self.stride, padding=self.padding, 88 | dilation=self.dilation, groups=self.groups * b) 89 | 90 | output = output.view(b, c_out, output.size(-2), output.size(-1)) 91 | return output 92 | -------------------------------------------------------------------------------- /fig/condconv_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-li14/condconv.pytorch/5b93eb4796e38bd911fcd2278762aeaabdfec21b/fig/condconv_layer.png --------------------------------------------------------------------------------