├── .gitignore ├── README.md └── image-classification └── models ├── __init__.py ├── imagenet ├── __init__.py └── resnet.py └── layers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # tmp dirs and files 2 | checkpoint 3 | checkpoints 4 | data 5 | cifar-debug.py 6 | test.eps 7 | dev 8 | monitor.py 9 | exp 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | env/ 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *,cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # IPython Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Context-Gated Convolution [[arXiv]](https://arxiv.org/abs/1910.05577) 2 | ## To appear in ECCV 2020 3 | ## A sample implementation of CGC+ ResNet 50 is now provided! 4 | 5 | ## Denpendencies 6 | 7 | The code is built with following libraries: 8 | 9 | - [PyTorch](https://pytorch.org/) 1.0 or higher 10 | 11 | 12 | ## The training and testing scripts are coming soon! 13 | 14 | ``` 15 | @misc{lin2019contextgated, 16 | title={Context-Gated Convolution}, 17 | author={Xudong Lin and Lin Ma and Wei Liu and Shih-Fu Chang}, 18 | year={2019}, 19 | eprint={1910.05577}, 20 | archivePrefix={arXiv}, 21 | primaryClass={cs.CV} 22 | } 23 | ``` 24 | 25 | -------------------------------------------------------------------------------- /image-classification/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XudongLinthu/context-gated-convolution/920557efd10f29172275d3687040124440043fa0/image-classification/models/__init__.py -------------------------------------------------------------------------------- /image-classification/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .resnet import * -------------------------------------------------------------------------------- /image-classification/models/imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | # Code for "Context-Gated Convolution" 2 | # ECCV 2020 3 | # Xudong Lin*, Lin Ma, Wei Liu, Shih-Fu Chang 4 | # {xudong.lin, shih.fu.chang}@columbia.edu, forest.linma@gmail.com, wl2223@columbia.edu 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .. import layers as L 10 | 11 | 12 | __all__ = ['cgc_resnet50'] 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 context gated convolution with padding""" 17 | return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = conv1x1(inplanes, planes) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = conv3x3(planes, planes, stride) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = conv1x1(planes, planes * self.expansion) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 101 | super(ResNet, self).__init__() 102 | self.inplanes = 64 103 | self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 104 | bias=False) 105 | self.bn1 = nn.BatchNorm2d(64) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 108 | self.layer1 = self._make_layer(Bottleneck, 64, layers[0]) 109 | self.layer2 = self._make_layer(Bottleneck, 128, layers[1], stride=2) 110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 112 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 113 | self.fc = nn.Linear(512 * block.expansion, num_classes) 114 | 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 118 | elif isinstance(m, nn.BatchNorm2d): 119 | nn.init.constant_(m.weight, 1) 120 | nn.init.constant_(m.bias, 0) 121 | 122 | # Zero-initialize the last BN in each residual branch, 123 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 124 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 125 | if zero_init_residual: 126 | for m in self.modules(): 127 | if isinstance(m, Bottleneck): 128 | nn.init.constant_(m.bn3.weight, 0) 129 | elif isinstance(m, BasicBlock): 130 | nn.init.constant_(m.bn2.weight, 0) 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | conv1x1(self.inplanes, planes * block.expansion, stride), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for _ in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | 149 | def forward(self, x, my=False, ft=False): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | 167 | 168 | 169 | 170 | def cgc_resnet50(pretrained=False, **kwargs): 171 | """Constructs a CGC-ResNet-50 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 176 | return model 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /image-classification/models/layers.py: -------------------------------------------------------------------------------- 1 | # Code for "Context-Gated Convolution" 2 | # ECCV 2020 3 | # Xudong Lin*, Lin Ma, Wei Liu, Shih-Fu Chang 4 | # {xudong.lin, shih.fu.chang}@columbia.edu, forest.linma@gmail.com, wl2223@columbia.edu 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | class Conv2d(nn.Conv2d): 13 | 14 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 15 | padding=0, dilation=1, groups=1, bias=True): 16 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 17 | padding, dilation, groups, bias) 18 | # for convolutional layers with a kernel size of 1, just use traditional convolution 19 | if kernel_size == 1: 20 | self.ind = True 21 | else: 22 | self.ind = False 23 | self.oc = out_channels 24 | self.ks = kernel_size 25 | 26 | # the target spatial size of the pooling layer 27 | ws = kernel_size 28 | self.avg_pool = nn.AdaptiveAvgPool2d((ws,ws)) 29 | 30 | # the dimension of the latent repsentation 31 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 32 | 33 | # the context encoding module 34 | self.ce = nn.Linear(ws*ws, self.num_lat, False) 35 | self.ce_bn = nn.BatchNorm1d(in_channels) 36 | self.ci_bn2 = nn.BatchNorm1d(in_channels) 37 | 38 | # activation function is relu 39 | self.act = nn.ReLU(inplace=True) 40 | 41 | 42 | # the number of groups in the channel interacting module 43 | if in_channels // 16: 44 | self.g = 16 45 | else: 46 | self.g = in_channels 47 | # the channel interacting module 48 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 49 | self.ci_bn = nn.BatchNorm1d(out_channels) 50 | 51 | # the gate decoding module 52 | self.gd = nn.Linear(num_lat, kernel_size * kernel_size, False) 53 | self.gd2 = nn.Linear(num_lat, kernel_size * kernel_size, False) 54 | 55 | # used to prrepare the input feature map to patches 56 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 57 | 58 | # sigmoid function 59 | self.sig = nn.Sigmoid() 60 | def forward(self, x): 61 | # for convolutional layers with a kernel size of 1, just use traditional convolution 62 | if self.ind: 63 | return F.conv2d(x, self.weight, self.bias, self.stride, 64 | self.padding, self.dilation, self.groups) 65 | else: 66 | b, c, h, w = x.size() 67 | weight = self.weight 68 | # allocate glbal information 69 | gl = self.avg_pool(x).view(b,c,-1) 70 | # context-encoding module 71 | out = self.ce(gl) 72 | # use different bn for the following two branches 73 | ce2 = out 74 | out = self.ce_bn(out) 75 | out = self.act(out) 76 | # gate decoding branch 1 77 | out = self.gd(out) 78 | # channel interacting module 79 | if self.g >3: 80 | # grouped linear 81 | oc = self.ci(self.act(self.ci_bn2(ce2).\ 82 | view(b, c//self.g, self.g, -1).transpose(2,3))).transpose(2,3).contiguous() 83 | else: 84 | # linear layer for resnet.conv1 85 | oc = self.ci(self.act(self.ci_bn2(ce2).transpose(2,1))).transpose(2,1).contiguous() 86 | oc = oc.view(b,self.oc,-1) 87 | oc = self.ci_bn(oc) 88 | oc = self.act(oc) 89 | # gate decoding branch 2 90 | oc = self.gd2(oc) 91 | # produce gate 92 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 93 | # unfolding input feature map to patches 94 | x_un = self.unfold(x) 95 | b, _, l = x_un.size() 96 | # gating 97 | out = (out * weight.unsqueeze(0)).view(b, self.oc, -1) 98 | # currently only handle square input and output 99 | return torch.matmul(out,x_un).view(b, self.oc, int(np.sqrt(l)), int(np.sqrt(l))) 100 | 101 | --------------------------------------------------------------------------------