114 | {%- if docsearch or hasdoc('search') %}
115 |
116 |
130 |
159 |
160 |
161 | {%- include "searchbox.html" %}
162 |
163 |
164 |
184 |
185 | {%- endif %}
186 |
187 | {%- block extra_header_link_icons %}
188 |
220 | {%- endblock extra_header_link_icons %}
221 |
222 | {%- endblock header_right %}
--------------------------------------------------------------------------------
/linear_operator_learning/nn/modules/resnet.py:
--------------------------------------------------------------------------------
1 | """Resnet Module."""
2 |
3 | from typing import Any, Callable, List, Optional, Type, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 |
9 | __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
10 |
11 |
12 | def conv3x3(
13 | in_planes: int,
14 | out_planes: int,
15 | stride: int = 1,
16 | groups: int = 1,
17 | dilation: int = 1,
18 | padding_mode: str = "zeros",
19 | ) -> nn.Conv2d:
20 | """3x3 convolution with padding."""
21 | return nn.Conv2d(
22 | in_planes,
23 | out_planes,
24 | kernel_size=3,
25 | stride=stride,
26 | padding=dilation,
27 | groups=groups,
28 | bias=False,
29 | dilation=dilation,
30 | padding_mode=padding_mode,
31 | )
32 |
33 |
34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
35 | """1x1 convolution."""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion: int = 1
41 |
42 | def __init__(
43 | self,
44 | inplanes: int,
45 | planes: int,
46 | stride: int = 1,
47 | downsample: Optional[nn.Module] = None,
48 | groups: int = 1,
49 | base_width: int = 64,
50 | dilation: int = 1,
51 | padding_mode: str = "zeros",
52 | norm_layer: Optional[Callable[..., nn.Module]] = None,
53 | ) -> None:
54 | super().__init__()
55 | if norm_layer is None:
56 | norm_layer = nn.BatchNorm2d
57 | if groups != 1 or base_width != 64:
58 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
59 | if dilation > 1:
60 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
61 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
62 | self.conv1 = conv3x3(inplanes, planes, stride, padding_mode=padding_mode)
63 | self.bn1 = norm_layer(planes)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.conv2 = conv3x3(planes, planes, padding_mode=padding_mode)
66 | self.bn2 = norm_layer(planes)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x: Tensor) -> Tensor:
71 | identity = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 |
80 | if self.downsample is not None:
81 | identity = self.downsample(x)
82 |
83 | out += identity
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class Bottleneck(nn.Module):
90 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
91 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
92 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
93 | # This variant is also known as ResNet V1.5 and improves accuracy according to
94 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
95 |
96 | expansion: int = 4
97 |
98 | def __init__(
99 | self,
100 | inplanes: int,
101 | planes: int,
102 | stride: int = 1,
103 | downsample: Optional[nn.Module] = None,
104 | groups: int = 1,
105 | base_width: int = 64,
106 | dilation: int = 1,
107 | padding_mode: str = "zeros",
108 | norm_layer: Optional[Callable[..., nn.Module]] = None,
109 | ) -> None:
110 | super().__init__()
111 | if norm_layer is None:
112 | norm_layer = nn.BatchNorm2d
113 | width = int(planes * (base_width / 64.0)) * groups
114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
115 | self.conv1 = conv1x1(inplanes, width)
116 | self.bn1 = norm_layer(width)
117 | self.conv2 = conv3x3(width, width, stride, groups, dilation, padding_mode=padding_mode)
118 | self.bn2 = norm_layer(width)
119 | self.conv3 = conv1x1(width, planes * self.expansion)
120 | self.bn3 = norm_layer(planes * self.expansion)
121 | self.relu = nn.ReLU(inplace=True)
122 | self.downsample = downsample
123 | self.stride = stride
124 |
125 | def forward(self, x: Tensor) -> Tensor:
126 | identity = x
127 |
128 | out = self.conv1(x)
129 | out = self.bn1(out)
130 | out = self.relu(out)
131 |
132 | out = self.conv2(out)
133 | out = self.bn2(out)
134 | out = self.relu(out)
135 |
136 | out = self.conv3(out)
137 | out = self.bn3(out)
138 |
139 | if self.downsample is not None:
140 | identity = self.downsample(x)
141 |
142 | out += identity
143 | out = self.relu(out)
144 |
145 | return out
146 |
147 |
148 | class ResNet(nn.Module):
149 | """ResNet model from :footcite:t:`he2016deep`.
150 |
151 | Args:
152 | block (Type[Union[BasicBlock, Bottleneck]]): Block type.
153 | layers (List[int]): Number of layers.
154 | channels_in (int): Number of input channels.
155 | num_features (int): Number of features.
156 | zero_init_residual (bool): Zero initialization of residual.
157 | groups (int): Number of groups.
158 | width_per_group (int): Width per group.
159 | replace_stride_with_dilation (Optional[List[bool]]): Replace stride with dilation.
160 | padding_mode (str): Padding mode for the convolutional layers.
161 | norm_layer (Optional[Callable[..., nn.Module]]): Normalization layer.
162 | """
163 |
164 | def __init__(
165 | self,
166 | block: Type[Union[BasicBlock, Bottleneck]],
167 | layers: List[int],
168 | channels_in: int = 3,
169 | num_features: int = 1024,
170 | zero_init_residual: bool = False,
171 | groups: int = 1,
172 | width_per_group: int = 64,
173 | replace_stride_with_dilation: Optional[List[bool]] = None,
174 | padding_mode: str = "zeros",
175 | norm_layer: Optional[Callable[..., nn.Module]] = None,
176 | ) -> None:
177 | super().__init__()
178 | if norm_layer is None:
179 | norm_layer = nn.BatchNorm2d
180 | self._norm_layer = norm_layer
181 |
182 | self.inplanes = 64
183 | self.dilation = 1
184 | if replace_stride_with_dilation is None:
185 | # each element in the tuple indicates if we should replace
186 | # the 2x2 stride with a dilated convolution instead
187 | replace_stride_with_dilation = [False, False, False]
188 | if len(replace_stride_with_dilation) != 3:
189 | raise ValueError(
190 | "replace_stride_with_dilation should be None "
191 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
192 | )
193 | self.groups = groups
194 | self.base_width = width_per_group
195 | self.conv1 = nn.Conv2d(
196 | channels_in,
197 | self.inplanes,
198 | kernel_size=7,
199 | stride=2,
200 | padding=3,
201 | bias=False,
202 | padding_mode=padding_mode,
203 | )
204 | self.bn1 = norm_layer(self.inplanes)
205 | self.relu = nn.ReLU(inplace=True)
206 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
207 | self.layer1 = self._make_layer(block, 64, layers[0], padding_mode=padding_mode)
208 | self.layer2 = self._make_layer(
209 | block,
210 | 128,
211 | layers[1],
212 | stride=2,
213 | dilate=replace_stride_with_dilation[0],
214 | padding_mode=padding_mode,
215 | )
216 | self.layer3 = self._make_layer(
217 | block,
218 | 256,
219 | layers[2],
220 | stride=2,
221 | dilate=replace_stride_with_dilation[1],
222 | padding_mode=padding_mode,
223 | )
224 | self.layer4 = self._make_layer(
225 | block,
226 | 512,
227 | layers[3],
228 | stride=2,
229 | dilate=replace_stride_with_dilation[2],
230 | padding_mode=padding_mode,
231 | )
232 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
233 | self.fc = nn.Linear(512 * block.expansion, num_features, bias=False)
234 |
235 | for m in self.modules():
236 | if isinstance(m, nn.Conv2d):
237 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
238 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
239 | nn.init.constant_(m.weight, 1)
240 | nn.init.constant_(m.bias, 0)
241 |
242 | # Zero-initialize the last BN in each residual branch,
243 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
244 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
245 | if zero_init_residual:
246 | for m in self.modules():
247 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
248 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
249 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
250 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
251 |
252 | def _make_layer(
253 | self,
254 | block: Type[Union[BasicBlock, Bottleneck]],
255 | planes: int,
256 | blocks: int,
257 | stride: int = 1,
258 | dilate: bool = False,
259 | padding_mode: str = "zeros",
260 | ) -> nn.Sequential:
261 | norm_layer = self._norm_layer
262 | downsample = None
263 | previous_dilation = self.dilation
264 | if dilate:
265 | self.dilation *= stride
266 | stride = 1
267 | if stride != 1 or self.inplanes != planes * block.expansion:
268 | downsample = nn.Sequential(
269 | conv1x1(self.inplanes, planes * block.expansion, stride),
270 | norm_layer(planes * block.expansion),
271 | )
272 |
273 | layers = []
274 | layers.append(
275 | block(
276 | self.inplanes,
277 | planes,
278 | stride,
279 | downsample,
280 | self.groups,
281 | self.base_width,
282 | previous_dilation,
283 | padding_mode,
284 | norm_layer,
285 | )
286 | )
287 | self.inplanes = planes * block.expansion
288 | for _ in range(1, blocks):
289 | layers.append(
290 | block(
291 | self.inplanes,
292 | planes,
293 | groups=self.groups,
294 | base_width=self.base_width,
295 | dilation=self.dilation,
296 | norm_layer=norm_layer,
297 | padding_mode=padding_mode,
298 | )
299 | )
300 |
301 | return nn.Sequential(*layers)
302 |
303 | def _forward_impl(self, x: Tensor) -> Tensor:
304 | # See note [TorchScript super()]
305 | x = self.conv1(x)
306 | x = self.bn1(x)
307 | x = self.relu(x)
308 | x = self.maxpool(x)
309 |
310 | x = self.layer1(x)
311 | x = self.layer2(x)
312 | x = self.layer3(x)
313 | x = self.layer4(x)
314 |
315 | x = self.avgpool(x)
316 | x = torch.flatten(x, 1)
317 | x = self.fc(x)
318 | return x
319 |
320 | def forward(self, x: Tensor) -> Tensor:
321 | """Forward pass of the ResNet model."""
322 | return self._forward_impl(x)
323 |
324 |
325 | def _resnet(
326 | block: Type[Union[BasicBlock, Bottleneck]],
327 | layers: List[int],
328 | **kwargs: Any,
329 | ) -> ResNet:
330 | model = ResNet(block, layers, **kwargs)
331 |
332 | return model
333 |
334 |
335 | def resnet18(*args, **kwargs: Any) -> ResNet:
336 | """ResNet-18 from `Deep Residual Learning for Image Recognition