├── LICENSE
├── README.md
├── ResNet50.ipynb
├── frelu.py
├── frelu_resnet50.pth
├── main.py
└── resnet_frelu.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 nekitmm
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FunnelAct Pytorch
2 | Pytorch implementation of Funnel Activation (FReLU): https://arxiv.org/pdf/2007.11824.pdf
3 |
4 | Validation results are listed below:
5 |
6 | | Model | Activation | Err@1 | Err@5 |
7 | | :---------------------- | :--------: | :------: | :------: |
8 | | ResNet50 | FReLU | **22.40** | **6.164** |
9 |
10 | Note that from the file resnet_frelu.py you can call ResNet18, ResNet34, ResNet50, ResNet101 and ResNet152
11 | but the weights in this repo only available for ResNet50 and I never tried to train other models,
12 | so no guaranties there!
13 |
14 | The code in this repo is based on pytorch imagenet example:
15 |
16 | https://github.com/pytorch/examples/tree/master/imagenet
17 |
18 | and original implementation of Funnel Activation in Megengine:
19 |
20 | https://github.com/megvii-model/FunnelAct
21 |
22 | Enjoy!
--------------------------------------------------------------------------------
/ResNet50.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torchvision.transforms as transforms\n",
11 | "import torchvision.datasets as datasets\n",
12 | "import numpy as np\n",
13 | "from matplotlib import pyplot as plt\n",
14 | "import resnet_frelu as resnet\n",
15 | "import os\n",
16 | "from main import AverageMeter, ProgressMeter, accuracy, train, validate\n",
17 | "import time"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "
Validate current set of weights
"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "model = torch.load('frelu_resnet50.pth')"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "data = 'C://ImageNet/'\n",
43 | "traindir = os.path.join(data, 'train')\n",
44 | "valdir = os.path.join(data, 'val')\n",
45 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
46 | "\n",
47 | "total_steps = 300000\n",
48 | "learning_rate = 0.001\n",
49 | "\n",
50 | "criterion = torch.nn.CrossEntropyLoss().cuda()\n",
51 | "\n",
52 | "train_dataset = datasets.ImageFolder(traindir,\n",
53 | " transforms.Compose([\n",
54 | " transforms.RandomResizedCrop(224),\n",
55 | " transforms.RandomHorizontalFlip(),\n",
56 | " transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n",
57 | " transforms.ToTensor(),\n",
58 | " normalize\n",
59 | " ]))"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 4,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "val_loader = torch.utils.data.DataLoader(\n",
69 | " datasets.ImageFolder(valdir, transforms.Compose([\n",
70 | " transforms.Resize(256),\n",
71 | " transforms.CenterCrop(224),\n",
72 | " transforms.ToTensor(),\n",
73 | " normalize\n",
74 | " ])),\n",
75 | " batch_size=100, shuffle=True,\n",
76 | " num_workers=4, pin_memory=True)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 5,
82 | "metadata": {},
83 | "outputs": [
84 | {
85 | "name": "stdout",
86 | "output_type": "stream",
87 | "text": [
88 | "Test: [ 0/500]\tTime 5.682 ( 5.682)\tLoss 7.2411e-01 (7.2411e-01)\tAcc@1 85.00 ( 85.00)\tAcc@5 97.00 ( 97.00)\n",
89 | "Test: [100/500]\tTime 0.553 ( 0.602)\tLoss 8.2409e-01 (9.4904e-01)\tAcc@1 78.00 ( 77.50)\tAcc@5 95.00 ( 93.64)\n",
90 | "Test: [200/500]\tTime 0.559 ( 0.579)\tLoss 7.7373e-01 (9.3647e-01)\tAcc@1 84.00 ( 77.70)\tAcc@5 94.00 ( 93.80)\n",
91 | "Test: [300/500]\tTime 0.555 ( 0.573)\tLoss 1.3577e+00 (9.4489e-01)\tAcc@1 74.00 ( 77.42)\tAcc@5 90.00 ( 93.77)\n",
92 | "Test: [400/500]\tTime 0.559 ( 0.571)\tLoss 9.4712e-01 (9.3480e-01)\tAcc@1 81.00 ( 77.57)\tAcc@5 95.00 ( 93.87)\n",
93 | " * Acc@1 77.606 Acc@5 93.836\n"
94 | ]
95 | },
96 | {
97 | "data": {
98 | "text/plain": [
99 | "tensor(77.6060, device='cuda:0')"
100 | ]
101 | },
102 | "execution_count": 5,
103 | "metadata": {},
104 | "output_type": "execute_result"
105 | }
106 | ],
107 | "source": [
108 | "model.cuda()\n",
109 | "validate(val_loader, model, criterion, {})"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "Create and train new ResNet with FReLU activations (primitive example)
"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 3,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "model = resnet.resnet101()"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "model"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 5,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "optimizer = torch.optim.SGD(\n",
144 | " model.parameters(),\n",
145 | " lr=learning_rate / 10,\n",
146 | " momentum=0.9,\n",
147 | " weight_decay=1e-4,\n",
148 | " )"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 6,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "model.train()\n",
158 | "train_sampler = None\n",
159 | "train_loader = torch.utils.data.DataLoader(\n",
160 | " train_dataset, batch_size=4, shuffle=(train_sampler is None),\n",
161 | " num_workers=1, pin_memory=True, sampler=train_sampler)"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {
168 | "scrolled": true
169 | },
170 | "outputs": [],
171 | "source": [
172 | "for e in range(10):\n",
173 | " train(train_loader, model, criterion, optimizer, e)\n",
174 | " torch.save(model, 'FResNet50_' + str(e) + '.pth')"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "metadata": {},
181 | "outputs": [],
182 | "source": []
183 | }
184 | ],
185 | "metadata": {
186 | "kernelspec": {
187 | "display_name": "Python 3",
188 | "language": "python",
189 | "name": "python3"
190 | },
191 | "language_info": {
192 | "codemirror_mode": {
193 | "name": "ipython",
194 | "version": 3
195 | },
196 | "file_extension": ".py",
197 | "mimetype": "text/x-python",
198 | "name": "python",
199 | "nbconvert_exporter": "python",
200 | "pygments_lexer": "ipython3",
201 | "version": "3.8.3"
202 | }
203 | },
204 | "nbformat": 4,
205 | "nbformat_minor": 4
206 | }
207 |
--------------------------------------------------------------------------------
/frelu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class FReLU(nn.Module):
5 | r""" FReLU formulation. The funnel condition has a window size of kxk. (k=3 by default)
6 | """
7 | def __init__(self, in_channels):
8 | super().__init__()
9 | self.conv_frelu = nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1, groups = in_channels)
10 | self.bn_frelu = nn.BatchNorm2d(in_channels)
11 |
12 | def forward(self, x):
13 | y = self.conv_frelu(x)
14 | y = self.bn_frelu(y)
15 | x = torch.max(x, y)
16 | return x
--------------------------------------------------------------------------------
/frelu_resnet50.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nekitmm/FunnelAct_Pytorch/791b22ca72bccf7781bd7ee0b5ad7d2690951d41/frelu_resnet50.pth
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.distributed as dist
13 | import torch.optim
14 | import torch.multiprocessing as mp
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | import torchvision.transforms as transforms
18 | import torchvision.datasets as datasets
19 | import torchvision.models as models
20 |
21 |
22 | def validate(val_loader, model, criterion, args):
23 | batch_time = AverageMeter('Time', ':6.3f')
24 | losses = AverageMeter('Loss', ':.4e')
25 | top1 = AverageMeter('Acc@1', ':6.2f')
26 | top5 = AverageMeter('Acc@5', ':6.2f')
27 | progress = ProgressMeter(
28 | len(val_loader),
29 | [batch_time, losses, top1, top5],
30 | prefix='Test: ')
31 |
32 | # switch to evaluate mode
33 | model.eval()
34 |
35 | with torch.no_grad():
36 | end = time.time()
37 | for i, (images, target) in enumerate(val_loader):
38 |
39 | images = images.cuda(non_blocking=True)
40 |
41 | if torch.cuda.is_available():
42 | target = target.cuda(non_blocking=True)
43 |
44 | # compute output
45 | output = model(images)
46 |
47 | #plt.hist(output.cpu().detach().numpy().ravel(), bins = 50)
48 | #plt.show()
49 |
50 | loss = criterion(output, target)
51 |
52 | # measure accuracy and record loss
53 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
54 | losses.update(loss.item(), images.size(0))
55 | top1.update(acc1[0], images.size(0))
56 | top5.update(acc5[0], images.size(0))
57 |
58 | # measure elapsed time
59 | batch_time.update(time.time() - end)
60 | end = time.time()
61 |
62 | if i % 100 == 0:
63 | progress.display(i)
64 |
65 | # TODO: this should also be done with the ProgressMeter
66 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
67 | .format(top1=top1, top5=top5))
68 |
69 | return top1.avg
70 |
71 | def train(train_loader, model, criterion, optimizer, epoch, step = 250000):
72 | batch_time = AverageMeter('Time', ':6.3f')
73 | data_time = AverageMeter('Data', ':6.3f')
74 | losses = AverageMeter('Loss', ':.4e')
75 | top1 = AverageMeter('Acc@1', ':6.2f')
76 | top5 = AverageMeter('Acc@5', ':6.2f')
77 | progress = ProgressMeter(
78 | len(train_loader),
79 | [batch_time, data_time, losses, top1, top5],
80 | prefix="Epoch: [{}]".format(epoch))
81 |
82 | # switch to train mode
83 | model.train()
84 | model.cuda()
85 |
86 | end = time.time()
87 | for i, (images, target) in enumerate(train_loader):
88 | # measure data loading time
89 | data_time.update(time.time() - end)
90 |
91 | images = images.cuda(non_blocking=True)
92 |
93 | if torch.cuda.is_available():
94 | target = target.cuda(non_blocking=True)
95 |
96 | # compute output
97 | output = model(images)
98 | loss = criterion(output, target)
99 |
100 | # measure accuracy and record loss
101 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
102 | losses.update(loss.item(), images.size(0))
103 | top1.update(acc1[0], images.size(0))
104 | top5.update(acc5[0], images.size(0))
105 |
106 | # compute gradient and do SGD step
107 | optimizer.zero_grad()
108 | loss.backward()
109 | optimizer.step()
110 |
111 | # measure elapsed time
112 | batch_time.update(time.time() - end)
113 | end = time.time()
114 |
115 | if i % 100 == 0:
116 | progress.display(i)
117 | for param_group in optimizer.param_groups:
118 | print("Current lr:", param_group["lr"])
119 | step += 1
120 |
121 |
122 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
123 | torch.save(state, filename)
124 | if is_best:
125 | shutil.copyfile(filename, 'model_best.pth.tar')
126 |
127 |
128 | class AverageMeter(object):
129 | """Computes and stores the average and current value"""
130 | def __init__(self, name, fmt=':f'):
131 | self.name = name
132 | self.fmt = fmt
133 | self.reset()
134 |
135 | def reset(self):
136 | self.val = 0
137 | self.avg = 0
138 | self.sum = 0
139 | self.count = 0
140 |
141 | def update(self, val, n=1):
142 | self.val = val
143 | self.sum += val * n
144 | self.count += n
145 | self.avg = self.sum / self.count
146 |
147 | def __str__(self):
148 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
149 | return fmtstr.format(**self.__dict__)
150 |
151 |
152 | class ProgressMeter(object):
153 | def __init__(self, num_batches, meters, prefix=""):
154 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
155 | self.meters = meters
156 | self.prefix = prefix
157 |
158 | def display(self, batch):
159 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
160 | entries += [str(meter) for meter in self.meters]
161 | print('\t'.join(entries))
162 |
163 | def _get_batch_fmtstr(self, num_batches):
164 | num_digits = len(str(num_batches // 1))
165 | fmt = '{:' + str(num_digits) + 'd}'
166 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
167 |
168 | def accuracy(output, target, topk=(1,)):
169 | """Computes the accuracy over the k top predictions for the specified values of k"""
170 | with torch.no_grad():
171 | maxk = max(topk)
172 | batch_size = target.size(0)
173 |
174 | _, pred = output.topk(maxk, 1, True, True)
175 | pred = pred.t()
176 | correct = pred.eq(target.view(1, -1).expand_as(pred))
177 |
178 | res = []
179 | for k in topk:
180 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
181 | res.append(correct_k.mul_(100.0 / batch_size))
182 | return res
--------------------------------------------------------------------------------
/resnet_frelu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from frelu import FReLU
4 |
5 | try:
6 | from torch.hub import load_state_dict_from_url
7 | except ImportError:
8 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
9 |
10 |
11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
13 | 'wide_resnet50_2', 'wide_resnet101_2']
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
25 | }
26 |
27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, bias = False):
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=dilation, groups=groups, bias=bias, dilation=dilation)
31 |
32 |
33 | def conv1x1(in_planes, out_planes, stride=1, bias=False):
34 | """1x1 convolution"""
35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias = bias)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 |
41 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
42 | base_width=64, dilation=1, norm_layer=None):
43 | super(BasicBlock, self).__init__()
44 | if norm_layer is None:
45 | norm_layer = nn.BatchNorm2d
46 | if groups != 1 or base_width != 64:
47 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
48 | if dilation > 1:
49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51 | self.conv1 = conv3x3(inplanes, planes, stride)
52 | self.bn1 = norm_layer(planes)
53 | self.relu = nn.ReLU(inplace=True)
54 | self.conv2 = conv3x3(planes, planes)
55 | self.bn2 = norm_layer(planes)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | identity = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | identity = self.downsample(x)
71 |
72 | out += identity
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class Bottleneck(nn.Module):
79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
82 | # This variant is also known as ResNet V1.5 and improves accuracy according to
83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
84 | expansion = 4
85 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
86 | base_width=64, dilation=1, norm_layer=None):
87 | super(Bottleneck, self).__init__()
88 | if norm_layer is None:
89 | norm_layer = nn.BatchNorm2d
90 | width = int(planes * (base_width / 64.)) * groups
91 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92 | self.conv1 = conv1x1(inplanes, width, bias = True)
93 | self.bn1 = norm_layer(width)
94 | self.conv2 = conv3x3(width, width, stride, groups, dilation, bias = True)
95 | self.bn2 = norm_layer(width)
96 | self.conv3 = conv1x1(width, planes * self.expansion, bias = True)
97 | self.bn3 = norm_layer(planes * self.expansion)
98 | self.relu = nn.ReLU(inplace=True)
99 | self.downsample = downsample
100 | self.stride = stride
101 |
102 | def forward(self, x):
103 | identity = x
104 |
105 | out = self.conv1(x)
106 | out = self.bn1(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv2(out)
110 | out = self.bn2(out)
111 | out = self.relu(out)
112 |
113 | out = self.conv3(out)
114 | out = self.bn3(out)
115 |
116 | if self.downsample is not None:
117 | identity = self.downsample(x)
118 |
119 | out += identity
120 | out = self.relu(out)
121 |
122 | return out
123 |
124 | class Bottleneck_FReLU(nn.Module):
125 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
126 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
127 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
128 | # This variant is also known as ResNet V1.5 and improves accuracy according to
129 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
130 | expansion = 4
131 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
132 | base_width=64, dilation=1, norm_layer=None):
133 | super(Bottleneck_FReLU, self).__init__()
134 | if norm_layer is None:
135 | norm_layer = nn.BatchNorm2d
136 | width = int(planes * (base_width / 64.)) * groups
137 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
138 | # bias should be enabled to be in agreement with original implementation
139 | self.conv1 = conv1x1(inplanes, width, bias = True)
140 | self.bn1 = norm_layer(width)
141 | self.frelu1 = FReLU(width)
142 | self.conv2 = conv3x3(width, width, stride, groups, dilation, bias = True)
143 | self.bn2 = norm_layer(width)
144 | self.frelu2 = FReLU(width)
145 | self.conv3 = conv1x1(width, planes * self.expansion, bias = True)
146 | self.bn3 = norm_layer(planes * self.expansion)
147 | self.frelu3 = FReLU(planes * self.expansion)
148 | self.downsample = downsample
149 | print(self.downsample)
150 | self.stride = stride
151 |
152 | def forward(self, x):
153 | identity = x
154 |
155 | out = self.conv1(x)
156 | out = self.bn1(out)
157 | out = self.frelu1(out)
158 |
159 | out = self.conv2(out)
160 | out = self.bn2(out)
161 | out = self.frelu2(out)
162 |
163 | out = self.conv3(out)
164 | out = self.bn3(out)
165 |
166 | if self.downsample is not None:
167 | identity = self.downsample(x)
168 |
169 | out += identity
170 | out = self.frelu3(out)
171 |
172 | return out
173 |
174 |
175 | class ResNet(nn.Module):
176 |
177 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
178 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
179 | norm_layer=None):
180 | super(ResNet, self).__init__()
181 | if norm_layer is None:
182 | norm_layer = nn.BatchNorm2d
183 | self._norm_layer = norm_layer
184 |
185 | self.inplanes = 64
186 | self.dilation = 1
187 | if replace_stride_with_dilation is None:
188 | # each element in the tuple indicates if we should replace
189 | # the 2x2 stride with a dilated convolution instead
190 | replace_stride_with_dilation = [False, False, False]
191 | if len(replace_stride_with_dilation) != 3:
192 | raise ValueError("replace_stride_with_dilation should be None "
193 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
194 | self.groups = groups
195 | self.base_width = width_per_group
196 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
197 | bias=True)
198 | self.bn1 = norm_layer(self.inplanes)
199 | self.relu = nn.ReLU(inplace=True)
200 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
201 | self.layer1 = self._make_layer(block, 64, layers[0])
202 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
203 | dilate=replace_stride_with_dilation[0])
204 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
205 | dilate=replace_stride_with_dilation[1])
206 | self.layer4 = self._make_layer(Bottleneck, 512, layers[3], stride=2,
207 | dilate=replace_stride_with_dilation[2])
208 | self.avgpool = nn.AvgPool2d((7, 7))
209 | self.fc = nn.Linear(512 * block.expansion, num_classes)
210 | self.dropout = nn.Dropout(0.2)
211 |
212 | for m in self.modules():
213 | if isinstance(m, nn.Conv2d):
214 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
215 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
216 | nn.init.constant_(m.weight, 1)
217 | nn.init.constant_(m.bias, 0)
218 |
219 | # Zero-initialize the last BN in each residual branch,
220 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
221 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
222 | if zero_init_residual:
223 | for m in self.modules():
224 | if isinstance(m, Bottleneck):
225 | nn.init.constant_(m.bn3.weight, 0)
226 | elif isinstance(m, BasicBlock):
227 | nn.init.constant_(m.bn2.weight, 0)
228 |
229 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
230 | norm_layer = self._norm_layer
231 | downsample = None
232 | previous_dilation = self.dilation
233 | if dilate:
234 | self.dilation *= stride
235 | stride = 1
236 | if stride != 1 or self.inplanes != planes * block.expansion:
237 | downsample = nn.Sequential(
238 | conv1x1(self.inplanes, planes * block.expansion, stride, bias = True),
239 | norm_layer(planes * block.expansion),
240 | )
241 |
242 | layers = []
243 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
244 | self.base_width, previous_dilation, norm_layer))
245 | self.inplanes = planes * block.expansion
246 | for _ in range(1, blocks):
247 | layers.append(block(self.inplanes, planes, groups=self.groups,
248 | base_width=self.base_width, dilation=self.dilation,
249 | norm_layer=norm_layer))
250 |
251 | return nn.Sequential(*layers)
252 |
253 | def _forward_impl(self, x):
254 | # See note [TorchScript super()]
255 | x = self.conv1(x)
256 | x = self.bn1(x)
257 | x = self.relu(x)
258 | x = self.maxpool(x)
259 |
260 | x = self.layer1(x)
261 | x = self.layer2(x)
262 | x = self.layer3(x)
263 | x = self.layer4(x)
264 |
265 | x = self.avgpool(x)
266 | x = self.dropout(x)
267 | x = torch.flatten(x, 1)
268 | x = self.fc(x)
269 |
270 | return x
271 |
272 | def forward(self, x):
273 | return self._forward_impl(x)
274 |
275 |
276 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
277 | model = ResNet(block, layers, **kwargs)
278 | if pretrained:
279 | state_dict = load_state_dict_from_url(model_urls[arch],
280 | progress=progress)
281 | model.load_state_dict(state_dict)
282 | return model
283 |
284 |
285 | def resnet18(pretrained=False, progress=True, **kwargs):
286 | r"""ResNet-18 model from
287 | `"Deep Residual Learning for Image Recognition" `_
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
293 | **kwargs)
294 |
295 |
296 | def resnet34(pretrained=False, progress=True, **kwargs):
297 | r"""ResNet-34 model from
298 | `"Deep Residual Learning for Image Recognition" `_
299 | Args:
300 | pretrained (bool): If True, returns a model pre-trained on ImageNet
301 | progress (bool): If True, displays a progress bar of the download to stderr
302 | """
303 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
304 | **kwargs)
305 |
306 |
307 | def resnet50(pretrained=False, progress=True, **kwargs):
308 | r"""ResNet-50 model from
309 | `"Deep Residual Learning for Image Recognition" `_
310 | Args:
311 | pretrained (bool): If True, returns a model pre-trained on ImageNet
312 | progress (bool): If True, displays a progress bar of the download to stderr
313 | """
314 | return _resnet('resnet50', Bottleneck_FReLU, [3, 4, 6, 3], pretrained, progress,
315 | **kwargs)
316 |
317 |
318 | def resnet101(pretrained=False, progress=True, **kwargs):
319 | r"""ResNet-101 model from
320 | `"Deep Residual Learning for Image Recognition" `_
321 | Args:
322 | pretrained (bool): If True, returns a model pre-trained on ImageNet
323 | progress (bool): If True, displays a progress bar of the download to stderr
324 | """
325 | return _resnet('resnet101', Bottleneck_FReLU, [3, 4, 23, 3], pretrained, progress,
326 | **kwargs)
327 |
328 |
329 | def resnet152(pretrained=False, progress=True, **kwargs):
330 | r"""ResNet-152 model from
331 | `"Deep Residual Learning for Image Recognition" `_
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | progress (bool): If True, displays a progress bar of the download to stderr
335 | """
336 | return _resnet('resnet152', Bottleneck_FReLU, [3, 8, 36, 3], pretrained, progress,
337 | **kwargs)
338 |
339 |
340 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
341 | r"""ResNeXt-50 32x4d model from
342 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
343 | Args:
344 | pretrained (bool): If True, returns a model pre-trained on ImageNet
345 | progress (bool): If True, displays a progress bar of the download to stderr
346 | """
347 | kwargs['groups'] = 32
348 | kwargs['width_per_group'] = 4
349 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
350 | pretrained, progress, **kwargs)
351 |
352 |
353 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
354 | r"""ResNeXt-101 32x8d model from
355 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
356 | Args:
357 | pretrained (bool): If True, returns a model pre-trained on ImageNet
358 | progress (bool): If True, displays a progress bar of the download to stderr
359 | """
360 | kwargs['groups'] = 32
361 | kwargs['width_per_group'] = 8
362 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
363 | pretrained, progress, **kwargs)
364 |
365 |
366 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
367 | r"""Wide ResNet-50-2 model from
368 | `"Wide Residual Networks" `_
369 | The model is the same as ResNet except for the bottleneck number of channels
370 | which is twice larger in every block. The number of channels in outer 1x1
371 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
372 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
373 | Args:
374 | pretrained (bool): If True, returns a model pre-trained on ImageNet
375 | progress (bool): If True, displays a progress bar of the download to stderr
376 | """
377 | kwargs['width_per_group'] = 64 * 2
378 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
379 | pretrained, progress, **kwargs)
380 |
381 |
382 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
383 | r"""Wide ResNet-101-2 model from
384 | `"Wide Residual Networks" `_
385 | The model is the same as ResNet except for the bottleneck number of channels
386 | which is twice larger in every block. The number of channels in outer 1x1
387 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
388 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
389 | Args:
390 | pretrained (bool): If True, returns a model pre-trained on ImageNet
391 | progress (bool): If True, displays a progress bar of the download to stderr
392 | """
393 | kwargs['width_per_group'] = 64 * 2
394 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
395 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------