├── lambda_layer.py
├── README.md
└── lambda_resnet.py
/lambda_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LambdaLayer(nn.Module):
6 | def __init__(self, d, dk=16, du=1, Nh=4, m=None, r=23, stride=1):
7 | super(LambdaLayer, self).__init__()
8 | self.d = d
9 | self.dk = dk
10 | self.du = du
11 | self.Nh = Nh
12 | assert d % Nh == 0, 'd should be divided by Nh'
13 | dv = d // Nh
14 | self.dv = dv
15 | assert stride in [1, 2]
16 | self.stride = stride
17 |
18 | self.conv_qkv = nn.Conv2d(d, Nh * dk + dk * du + dv * du, 1, bias=False)
19 | self.norm_q = nn.BatchNorm2d(Nh * dk)
20 | self.norm_v = nn.BatchNorm2d(dv * du)
21 | self.softmax = nn.Softmax(dim=-1)
22 | self.lambda_conv = nn.Conv3d(du, dk, (1, r, r), padding = (0, (r - 1) // 2, (r - 1) // 2))
23 |
24 | if self.stride > 1:
25 | self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
26 |
27 | def forward(self, x):
28 | N, C, H, W = x.shape
29 |
30 | qkv = self.conv_qkv(x)
31 | q, k, v = torch.split(qkv, [self.Nh * self.dk, self.dk * self.du, self.dv * self.du], dim=1)
32 | q = self.norm_q(q).view(N, self.Nh, self.dk, H*W)
33 | v = self.norm_v(v).view(N, self.du, self.dv, H*W)
34 | k = self.softmax(k.view(N, self.du, self.dk, H*W))
35 |
36 | lambda_c = torch.einsum('bukm,buvm->bkv', k, v)
37 | yc = torch.einsum('bhkm,bkv->bhvm', q, lambda_c)
38 | lambda_p = self.lambda_conv(v.view(N, self.du, self.dv, H, W)).view(N, self.dk, self.dv, H*W)
39 | yp = torch.einsum('bhkm,bkvm->bhvm', q, lambda_p)
40 | out = (yc + yp).reshape(N, C, H, W)
41 |
42 | if self.stride > 1:
43 | out = self.avgpool(out)
44 |
45 | return out
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # lambda.pytorch
2 |
3 | **[NEW!]** Check out our latest work [involution](https://github.com/d-li14/involution) in CVPR'21 that bridges convolution and self-attention operators.
4 |
5 | ---
6 |
7 | PyTorch implementation of [LambdaNetworks: Modeling long-range Interactions without Attention](https://openreview.net/forum?id=xTJEN-ggl1b).
8 |
9 | Lambda Networks apply associative law of matrix multiplication to reverse the computing order of self-attention, achieving the linear computation complexity regarding content interactions.
10 |
11 | Similar techniques have been used previously in [A2-Net](https://arxiv.org/abs/1810.11579) and [CGNL](https://arxiv.org/abs/1810.13125). Check out a collection of self-attention modules in another repository [dot-product-attention](https://github.com/d-li14/dot-product-attention).
12 |
13 | ## Training Configuration
14 | ✓ SGD optimizer, initial learning rate 0.1, momentum 0.9, weight decay 0.0001
15 |
16 | ✓ epoch 130, batch size 256, 8x Tesla V100 GPUs, LR decay strategy cosine
17 |
18 | ✓ label smoothing 0.1
19 |
20 | ## Pre-trained checkpoints
21 | | Architecture | Parameters | FLOPs | Top-1 / Top-5 Acc. (%) | Download |
22 | | :----------------------: | :--------: | :---: | :------------------------: | :------: |
23 | | Lambda-ResNet-50 | 14.995M | 6.576G | 78.208 / 93.820 | [model](https://hkustconnect-my.sharepoint.com/:u:/g/personal/dlibh_connect_ust_hk/EUZkICtpXitIq6PGa6h6m_YBnFXCiCYTSuqoIUqiR33C5A?e=mhgEbC) | [log](https://hkustconnect-my.sharepoint.com/:t:/g/personal/dlibh_connect_ust_hk/EQuZ1itCS2dFpN2MBVepL5YBQe9N-ZUv6y4vNdO5uiVFig?e=dX7Id1) |
24 |
25 | ## Citation
26 | If you find this repository useful in your research, please cite
27 | ```bibtex
28 | @InProceedings{Li_2021_CVPR,
29 | author = {Li, Duo and Hu, Jie and Wang, Changhu and Li, Xiangtai and She, Qi and Zhu, Lei and Zhang, Tong and Chen, Qifeng},
30 | title = {Involution: Inverting the Inherence of Convolution for Visual Recognition},
31 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
32 | month = {June},
33 | year = {2021}
34 | }
35 | ```
36 | ```bibtex
37 | @inproceedings{
38 | bello2021lambdanetworks,
39 | title={LambdaNetworks: Modeling long-range Interactions without Attention},
40 | author={Irwan Bello},
41 | booktitle={International Conference on Learning Representations},
42 | year={2021},
43 | }
44 | ```
45 |
--------------------------------------------------------------------------------
/lambda_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .lambda_layer import LambdaLayer
4 |
5 |
6 | __all__ = ['lambda_resnet26', 'lambda_resnet38', 'lambda_resnet50', 'lambda_resnet101', 'lambda_resnet152']
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=dilation, groups=groups, bias=False, dilation=dilation)
13 |
14 |
15 | def conv1x1(in_planes, out_planes, stride=1):
16 | """1x1 convolution"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
18 |
19 |
20 | class Bottleneck(nn.Module):
21 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
22 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
23 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
24 | # This variant is also known as ResNet V1.5 and improves accuracy according to
25 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
26 |
27 | expansion = 4
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
30 | base_width=64, dilation=1, norm_layer=None, size=None):
31 | super(Bottleneck, self).__init__()
32 | if norm_layer is None:
33 | norm_layer = nn.BatchNorm2d
34 | width = int(planes * (base_width / 64.)) * groups
35 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
36 | self.conv1 = conv1x1(inplanes, width)
37 | self.bn1 = norm_layer(width)
38 | #self.conv2 = conv3x3(width, width, stride, groups, dilation)
39 | self.conv2 = LambdaLayer(width, m=size, stride=stride)
40 | self.bn2 = norm_layer(width)
41 | self.conv3 = conv1x1(width, planes * self.expansion)
42 | self.bn3 = norm_layer(planes * self.expansion)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.downsample = downsample
45 | self.stride = stride
46 |
47 | def forward(self, x):
48 | identity = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | identity = self.downsample(x)
63 |
64 | out += identity
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class LambdaResNet(nn.Module):
71 |
72 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=True,
73 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
74 | norm_layer=None):
75 | super(LambdaResNet, self).__init__()
76 | if norm_layer is None:
77 | norm_layer = nn.BatchNorm2d
78 | self._norm_layer = norm_layer
79 |
80 | self.inplanes = 64
81 | self.dilation = 1
82 | if replace_stride_with_dilation is None:
83 | # each element in the tuple indicates if we should replace
84 | # the 2x2 stride with a dilated convolution instead
85 | replace_stride_with_dilation = [False, False, False]
86 | if len(replace_stride_with_dilation) != 3:
87 | raise ValueError("replace_stride_with_dilation should be None "
88 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
89 | self.groups = groups
90 | self.base_width = width_per_group
91 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
92 | bias=False)
93 | self.bn1 = norm_layer(self.inplanes)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(block, 64, layers[0], size=56)
97 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
98 | dilate=replace_stride_with_dilation[0], size=28)
99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
100 | dilate=replace_stride_with_dilation[1], size=14)
101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
102 | dilate=replace_stride_with_dilation[2], size=7)
103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
104 | self.fc = nn.Linear(512 * block.expansion, num_classes)
105 |
106 | for m in self.modules():
107 | if isinstance(m, nn.Conv2d):
108 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
109 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
110 | nn.init.constant_(m.weight, 1)
111 | nn.init.constant_(m.bias, 0)
112 |
113 | # Zero-initialize the last BN in each residual branch,
114 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
115 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
116 | if zero_init_residual:
117 | for m in self.modules():
118 | if isinstance(m, Bottleneck):
119 | nn.init.constant_(m.bn3.weight, 0)
120 | #elif isinstance(m, BasicBlock):
121 | # nn.init.constant_(m.bn2.weight, 0)
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, size=None):
124 | norm_layer = self._norm_layer
125 | downsample = None
126 | previous_dilation = self.dilation
127 | if dilate:
128 | self.dilation *= stride
129 | stride = 1
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | conv1x1(self.inplanes, planes * block.expansion, stride),
133 | norm_layer(planes * block.expansion),
134 | )
135 |
136 | layers = []
137 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
138 | self.base_width, previous_dilation, norm_layer, size))
139 | self.inplanes = planes * block.expansion
140 | for _ in range(1, blocks):
141 | layers.append(block(self.inplanes, planes, groups=self.groups,
142 | base_width=self.base_width, dilation=self.dilation,
143 | norm_layer=norm_layer, size=size))
144 |
145 | return nn.Sequential(*layers)
146 |
147 | def _forward_impl(self, x):
148 | # See note [TorchScript super()]
149 | x = self.conv1(x)
150 | x = self.bn1(x)
151 | x = self.relu(x)
152 | x = self.maxpool(x)
153 |
154 | x = self.layer1(x)
155 | x = self.layer2(x)
156 | x = self.layer3(x)
157 | x = self.layer4(x)
158 |
159 | x = self.avgpool(x)
160 | x = torch.flatten(x, 1)
161 | x = self.fc(x)
162 |
163 | return x
164 |
165 | def forward(self, x):
166 | return self._forward_impl(x)
167 |
168 |
169 | def lambda_resnet26(**kwargs):
170 | r"""ResNet-26 model from
171 | `"Deep Residual Learning for Image Recognition" `_
172 |
173 | Args:
174 | pretrained (bool): If True, returns a model pre-trained on ImageNet
175 | progress (bool): If True, displays a progress bar of the download to stderr
176 | """
177 | return LambdaResNet(Bottleneck, [2, 2, 2, 2], **kwargs)
178 |
179 | def lambda_resnet38(**kwargs):
180 | r"""ResNet-38 model from
181 | `"Deep Residual Learning for Image Recognition" `_
182 |
183 | Args:
184 | pretrained (bool): If True, returns a model pre-trained on ImageNet
185 | progress (bool): If True, displays a progress bar of the download to stderr
186 | """
187 | return LambdaResNet(Bottleneck, [2, 3, 5, 2], **kwargs)
188 |
189 | def lambda_resnet50(**kwargs):
190 | r"""ResNet-50 model from
191 | `"Deep Residual Learning for Image Recognition" `_
192 |
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | progress (bool): If True, displays a progress bar of the download to stderr
196 | """
197 | return LambdaResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
198 |
199 | def lambda_resnet101(**kwargs):
200 | r"""ResNet-101 model from
201 | `"Deep Residual Learning for Image Recognition" `_
202 |
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on ImageNet
205 | progress (bool): If True, displays a progress bar of the download to stderr
206 | """
207 | return LambdaResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
208 |
209 | def lambda_resnet152(**kwargs):
210 | r"""ResNet-152 model from
211 | `"Deep Residual Learning for Image Recognition" `_
212 |
213 | Args:
214 | pretrained (bool): If True, returns a model pre-trained on ImageNet
215 | progress (bool): If True, displays a progress bar of the download to stderr
216 | """
217 | return LambdaResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
218 |
219 |
--------------------------------------------------------------------------------