├── lambda_layer.py ├── README.md └── lambda_resnet.py /lambda_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LambdaLayer(nn.Module): 6 | def __init__(self, d, dk=16, du=1, Nh=4, m=None, r=23, stride=1): 7 | super(LambdaLayer, self).__init__() 8 | self.d = d 9 | self.dk = dk 10 | self.du = du 11 | self.Nh = Nh 12 | assert d % Nh == 0, 'd should be divided by Nh' 13 | dv = d // Nh 14 | self.dv = dv 15 | assert stride in [1, 2] 16 | self.stride = stride 17 | 18 | self.conv_qkv = nn.Conv2d(d, Nh * dk + dk * du + dv * du, 1, bias=False) 19 | self.norm_q = nn.BatchNorm2d(Nh * dk) 20 | self.norm_v = nn.BatchNorm2d(dv * du) 21 | self.softmax = nn.Softmax(dim=-1) 22 | self.lambda_conv = nn.Conv3d(du, dk, (1, r, r), padding = (0, (r - 1) // 2, (r - 1) // 2)) 23 | 24 | if self.stride > 1: 25 | self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 26 | 27 | def forward(self, x): 28 | N, C, H, W = x.shape 29 | 30 | qkv = self.conv_qkv(x) 31 | q, k, v = torch.split(qkv, [self.Nh * self.dk, self.dk * self.du, self.dv * self.du], dim=1) 32 | q = self.norm_q(q).view(N, self.Nh, self.dk, H*W) 33 | v = self.norm_v(v).view(N, self.du, self.dv, H*W) 34 | k = self.softmax(k.view(N, self.du, self.dk, H*W)) 35 | 36 | lambda_c = torch.einsum('bukm,buvm->bkv', k, v) 37 | yc = torch.einsum('bhkm,bkv->bhvm', q, lambda_c) 38 | lambda_p = self.lambda_conv(v.view(N, self.du, self.dv, H, W)).view(N, self.dk, self.dv, H*W) 39 | yp = torch.einsum('bhkm,bkvm->bhvm', q, lambda_p) 40 | out = (yc + yp).reshape(N, C, H, W) 41 | 42 | if self.stride > 1: 43 | out = self.avgpool(out) 44 | 45 | return out 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lambda.pytorch 2 | 3 | **[NEW!]** Check out our latest work [involution](https://github.com/d-li14/involution) in CVPR'21 that bridges convolution and self-attention operators. 4 | 5 | --- 6 | 7 | PyTorch implementation of [LambdaNetworks: Modeling long-range Interactions without Attention](https://openreview.net/forum?id=xTJEN-ggl1b). 8 | 9 | Lambda Networks apply associative law of matrix multiplication to reverse the computing order of self-attention, achieving the linear computation complexity regarding content interactions. 10 | 11 | Similar techniques have been used previously in [A2-Net](https://arxiv.org/abs/1810.11579) and [CGNL](https://arxiv.org/abs/1810.13125). Check out a collection of self-attention modules in another repository [dot-product-attention](https://github.com/d-li14/dot-product-attention). 12 | 13 | ## Training Configuration 14 | ✓ SGD optimizer, initial learning rate 0.1, momentum 0.9, weight decay 0.0001 15 | 16 | ✓ epoch 130, batch size 256, 8x Tesla V100 GPUs, LR decay strategy cosine 17 | 18 | ✓ label smoothing 0.1 19 | 20 | ## Pre-trained checkpoints 21 | | Architecture | Parameters | FLOPs | Top-1 / Top-5 Acc. (%) | Download | 22 | | :----------------------: | :--------: | :---: | :------------------------: | :------: | 23 | | Lambda-ResNet-50 | 14.995M | 6.576G | 78.208 / 93.820 | [model](https://hkustconnect-my.sharepoint.com/:u:/g/personal/dlibh_connect_ust_hk/EUZkICtpXitIq6PGa6h6m_YBnFXCiCYTSuqoIUqiR33C5A?e=mhgEbC) | [log](https://hkustconnect-my.sharepoint.com/:t:/g/personal/dlibh_connect_ust_hk/EQuZ1itCS2dFpN2MBVepL5YBQe9N-ZUv6y4vNdO5uiVFig?e=dX7Id1) | 24 | 25 | ## Citation 26 | If you find this repository useful in your research, please cite 27 | ```bibtex 28 | @InProceedings{Li_2021_CVPR, 29 | author = {Li, Duo and Hu, Jie and Wang, Changhu and Li, Xiangtai and She, Qi and Zhu, Lei and Zhang, Tong and Chen, Qifeng}, 30 | title = {Involution: Inverting the Inherence of Convolution for Visual Recognition}, 31 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 32 | month = {June}, 33 | year = {2021} 34 | } 35 | ``` 36 | ```bibtex 37 | @inproceedings{ 38 | bello2021lambdanetworks, 39 | title={LambdaNetworks: Modeling long-range Interactions without Attention}, 40 | author={Irwan Bello}, 41 | booktitle={International Conference on Learning Representations}, 42 | year={2021}, 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /lambda_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .lambda_layer import LambdaLayer 4 | 5 | 6 | __all__ = ['lambda_resnet26', 'lambda_resnet38', 'lambda_resnet50', 'lambda_resnet101', 'lambda_resnet152'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=dilation, groups=groups, bias=False, dilation=dilation) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | 20 | class Bottleneck(nn.Module): 21 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 22 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 23 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 24 | # This variant is also known as ResNet V1.5 and improves accuracy according to 25 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 26 | 27 | expansion = 4 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 30 | base_width=64, dilation=1, norm_layer=None, size=None): 31 | super(Bottleneck, self).__init__() 32 | if norm_layer is None: 33 | norm_layer = nn.BatchNorm2d 34 | width = int(planes * (base_width / 64.)) * groups 35 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 36 | self.conv1 = conv1x1(inplanes, width) 37 | self.bn1 = norm_layer(width) 38 | #self.conv2 = conv3x3(width, width, stride, groups, dilation) 39 | self.conv2 = LambdaLayer(width, m=size, stride=stride) 40 | self.bn2 = norm_layer(width) 41 | self.conv3 = conv1x1(width, planes * self.expansion) 42 | self.bn3 = norm_layer(planes * self.expansion) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out += identity 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class LambdaResNet(nn.Module): 71 | 72 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=True, 73 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 74 | norm_layer=None): 75 | super(LambdaResNet, self).__init__() 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | self._norm_layer = norm_layer 79 | 80 | self.inplanes = 64 81 | self.dilation = 1 82 | if replace_stride_with_dilation is None: 83 | # each element in the tuple indicates if we should replace 84 | # the 2x2 stride with a dilated convolution instead 85 | replace_stride_with_dilation = [False, False, False] 86 | if len(replace_stride_with_dilation) != 3: 87 | raise ValueError("replace_stride_with_dilation should be None " 88 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 89 | self.groups = groups 90 | self.base_width = width_per_group 91 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 92 | bias=False) 93 | self.bn1 = norm_layer(self.inplanes) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0], size=56) 97 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 98 | dilate=replace_stride_with_dilation[0], size=28) 99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 100 | dilate=replace_stride_with_dilation[1], size=14) 101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 102 | dilate=replace_stride_with_dilation[2], size=7) 103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 104 | self.fc = nn.Linear(512 * block.expansion, num_classes) 105 | 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 109 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 110 | nn.init.constant_(m.weight, 1) 111 | nn.init.constant_(m.bias, 0) 112 | 113 | # Zero-initialize the last BN in each residual branch, 114 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 115 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 116 | if zero_init_residual: 117 | for m in self.modules(): 118 | if isinstance(m, Bottleneck): 119 | nn.init.constant_(m.bn3.weight, 0) 120 | #elif isinstance(m, BasicBlock): 121 | # nn.init.constant_(m.bn2.weight, 0) 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, size=None): 124 | norm_layer = self._norm_layer 125 | downsample = None 126 | previous_dilation = self.dilation 127 | if dilate: 128 | self.dilation *= stride 129 | stride = 1 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | conv1x1(self.inplanes, planes * block.expansion, stride), 133 | norm_layer(planes * block.expansion), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 138 | self.base_width, previous_dilation, norm_layer, size)) 139 | self.inplanes = planes * block.expansion 140 | for _ in range(1, blocks): 141 | layers.append(block(self.inplanes, planes, groups=self.groups, 142 | base_width=self.base_width, dilation=self.dilation, 143 | norm_layer=norm_layer, size=size)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def _forward_impl(self, x): 148 | # See note [TorchScript super()] 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) 152 | x = self.maxpool(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | 159 | x = self.avgpool(x) 160 | x = torch.flatten(x, 1) 161 | x = self.fc(x) 162 | 163 | return x 164 | 165 | def forward(self, x): 166 | return self._forward_impl(x) 167 | 168 | 169 | def lambda_resnet26(**kwargs): 170 | r"""ResNet-26 model from 171 | `"Deep Residual Learning for Image Recognition" `_ 172 | 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | progress (bool): If True, displays a progress bar of the download to stderr 176 | """ 177 | return LambdaResNet(Bottleneck, [2, 2, 2, 2], **kwargs) 178 | 179 | def lambda_resnet38(**kwargs): 180 | r"""ResNet-38 model from 181 | `"Deep Residual Learning for Image Recognition" `_ 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | progress (bool): If True, displays a progress bar of the download to stderr 186 | """ 187 | return LambdaResNet(Bottleneck, [2, 3, 5, 2], **kwargs) 188 | 189 | def lambda_resnet50(**kwargs): 190 | r"""ResNet-50 model from 191 | `"Deep Residual Learning for Image Recognition" `_ 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | progress (bool): If True, displays a progress bar of the download to stderr 196 | """ 197 | return LambdaResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 198 | 199 | def lambda_resnet101(**kwargs): 200 | r"""ResNet-101 model from 201 | `"Deep Residual Learning for Image Recognition" `_ 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | progress (bool): If True, displays a progress bar of the download to stderr 206 | """ 207 | return LambdaResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 208 | 209 | def lambda_resnet152(**kwargs): 210 | r"""ResNet-152 model from 211 | `"Deep Residual Learning for Image Recognition" `_ 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | progress (bool): If True, displays a progress bar of the download to stderr 216 | """ 217 | return LambdaResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 218 | 219 | --------------------------------------------------------------------------------