├── LICENSE ├── README.md ├── ResNet50.ipynb ├── frelu.py ├── frelu_resnet50.pth ├── main.py └── resnet_frelu.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 nekitmm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FunnelAct Pytorch 2 | Pytorch implementation of Funnel Activation (FReLU): https://arxiv.org/pdf/2007.11824.pdf 3 | 4 | Validation results are listed below: 5 | 6 | | Model | Activation | Err@1 | Err@5 | 7 | | :---------------------- | :--------: | :------: | :------: | 8 | | ResNet50 | FReLU | **22.40** | **6.164** | 9 | 10 | Note that from the file resnet_frelu.py you can call ResNet18, ResNet34, ResNet50, ResNet101 and ResNet152 11 | but the weights in this repo only available for ResNet50 and I never tried to train other models, 12 | so no guaranties there! 13 | 14 | The code in this repo is based on pytorch imagenet example: 15 | 16 | https://github.com/pytorch/examples/tree/master/imagenet 17 | 18 | and original implementation of Funnel Activation in Megengine: 19 | 20 | https://github.com/megvii-model/FunnelAct 21 | 22 | Enjoy! -------------------------------------------------------------------------------- /ResNet50.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchvision.transforms as transforms\n", 11 | "import torchvision.datasets as datasets\n", 12 | "import numpy as np\n", 13 | "from matplotlib import pyplot as plt\n", 14 | "import resnet_frelu as resnet\n", 15 | "import os\n", 16 | "from main import AverageMeter, ProgressMeter, accuracy, train, validate\n", 17 | "import time" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "

Validate current set of weights

" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "model = torch.load('frelu_resnet50.pth')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "data = 'C://ImageNet/'\n", 43 | "traindir = os.path.join(data, 'train')\n", 44 | "valdir = os.path.join(data, 'val')\n", 45 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 46 | "\n", 47 | "total_steps = 300000\n", 48 | "learning_rate = 0.001\n", 49 | "\n", 50 | "criterion = torch.nn.CrossEntropyLoss().cuda()\n", 51 | "\n", 52 | "train_dataset = datasets.ImageFolder(traindir,\n", 53 | " transforms.Compose([\n", 54 | " transforms.RandomResizedCrop(224),\n", 55 | " transforms.RandomHorizontalFlip(),\n", 56 | " transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n", 57 | " transforms.ToTensor(),\n", 58 | " normalize\n", 59 | " ]))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "val_loader = torch.utils.data.DataLoader(\n", 69 | " datasets.ImageFolder(valdir, transforms.Compose([\n", 70 | " transforms.Resize(256),\n", 71 | " transforms.CenterCrop(224),\n", 72 | " transforms.ToTensor(),\n", 73 | " normalize\n", 74 | " ])),\n", 75 | " batch_size=100, shuffle=True,\n", 76 | " num_workers=4, pin_memory=True)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 5, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "Test: [ 0/500]\tTime 5.682 ( 5.682)\tLoss 7.2411e-01 (7.2411e-01)\tAcc@1 85.00 ( 85.00)\tAcc@5 97.00 ( 97.00)\n", 89 | "Test: [100/500]\tTime 0.553 ( 0.602)\tLoss 8.2409e-01 (9.4904e-01)\tAcc@1 78.00 ( 77.50)\tAcc@5 95.00 ( 93.64)\n", 90 | "Test: [200/500]\tTime 0.559 ( 0.579)\tLoss 7.7373e-01 (9.3647e-01)\tAcc@1 84.00 ( 77.70)\tAcc@5 94.00 ( 93.80)\n", 91 | "Test: [300/500]\tTime 0.555 ( 0.573)\tLoss 1.3577e+00 (9.4489e-01)\tAcc@1 74.00 ( 77.42)\tAcc@5 90.00 ( 93.77)\n", 92 | "Test: [400/500]\tTime 0.559 ( 0.571)\tLoss 9.4712e-01 (9.3480e-01)\tAcc@1 81.00 ( 77.57)\tAcc@5 95.00 ( 93.87)\n", 93 | " * Acc@1 77.606 Acc@5 93.836\n" 94 | ] 95 | }, 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "tensor(77.6060, device='cuda:0')" 100 | ] 101 | }, 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "model.cuda()\n", 109 | "validate(val_loader, model, criterion, {})" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "

Create and train new ResNet with FReLU activations (primitive example)

" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 3, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "model = resnet.resnet101()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "model" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 5, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "optimizer = torch.optim.SGD(\n", 144 | " model.parameters(),\n", 145 | " lr=learning_rate / 10,\n", 146 | " momentum=0.9,\n", 147 | " weight_decay=1e-4,\n", 148 | " )" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "model.train()\n", 158 | "train_sampler = None\n", 159 | "train_loader = torch.utils.data.DataLoader(\n", 160 | " train_dataset, batch_size=4, shuffle=(train_sampler is None),\n", 161 | " num_workers=1, pin_memory=True, sampler=train_sampler)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "scrolled": true 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "for e in range(10):\n", 173 | " train(train_loader, model, criterion, optimizer, e)\n", 174 | " torch.save(model, 'FResNet50_' + str(e) + '.pth')" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.8.3" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 4 206 | } 207 | -------------------------------------------------------------------------------- /frelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FReLU(nn.Module): 5 | r""" FReLU formulation. The funnel condition has a window size of kxk. (k=3 by default) 6 | """ 7 | def __init__(self, in_channels): 8 | super().__init__() 9 | self.conv_frelu = nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1, groups = in_channels) 10 | self.bn_frelu = nn.BatchNorm2d(in_channels) 11 | 12 | def forward(self, x): 13 | y = self.conv_frelu(x) 14 | y = self.bn_frelu(y) 15 | x = torch.max(x, y) 16 | return x -------------------------------------------------------------------------------- /frelu_resnet50.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nekitmm/FunnelAct_Pytorch/791b22ca72bccf7781bd7ee0b5ad7d2690951d41/frelu_resnet50.pth -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | 21 | 22 | def validate(val_loader, model, criterion, args): 23 | batch_time = AverageMeter('Time', ':6.3f') 24 | losses = AverageMeter('Loss', ':.4e') 25 | top1 = AverageMeter('Acc@1', ':6.2f') 26 | top5 = AverageMeter('Acc@5', ':6.2f') 27 | progress = ProgressMeter( 28 | len(val_loader), 29 | [batch_time, losses, top1, top5], 30 | prefix='Test: ') 31 | 32 | # switch to evaluate mode 33 | model.eval() 34 | 35 | with torch.no_grad(): 36 | end = time.time() 37 | for i, (images, target) in enumerate(val_loader): 38 | 39 | images = images.cuda(non_blocking=True) 40 | 41 | if torch.cuda.is_available(): 42 | target = target.cuda(non_blocking=True) 43 | 44 | # compute output 45 | output = model(images) 46 | 47 | #plt.hist(output.cpu().detach().numpy().ravel(), bins = 50) 48 | #plt.show() 49 | 50 | loss = criterion(output, target) 51 | 52 | # measure accuracy and record loss 53 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 54 | losses.update(loss.item(), images.size(0)) 55 | top1.update(acc1[0], images.size(0)) 56 | top5.update(acc5[0], images.size(0)) 57 | 58 | # measure elapsed time 59 | batch_time.update(time.time() - end) 60 | end = time.time() 61 | 62 | if i % 100 == 0: 63 | progress.display(i) 64 | 65 | # TODO: this should also be done with the ProgressMeter 66 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 67 | .format(top1=top1, top5=top5)) 68 | 69 | return top1.avg 70 | 71 | def train(train_loader, model, criterion, optimizer, epoch, step = 250000): 72 | batch_time = AverageMeter('Time', ':6.3f') 73 | data_time = AverageMeter('Data', ':6.3f') 74 | losses = AverageMeter('Loss', ':.4e') 75 | top1 = AverageMeter('Acc@1', ':6.2f') 76 | top5 = AverageMeter('Acc@5', ':6.2f') 77 | progress = ProgressMeter( 78 | len(train_loader), 79 | [batch_time, data_time, losses, top1, top5], 80 | prefix="Epoch: [{}]".format(epoch)) 81 | 82 | # switch to train mode 83 | model.train() 84 | model.cuda() 85 | 86 | end = time.time() 87 | for i, (images, target) in enumerate(train_loader): 88 | # measure data loading time 89 | data_time.update(time.time() - end) 90 | 91 | images = images.cuda(non_blocking=True) 92 | 93 | if torch.cuda.is_available(): 94 | target = target.cuda(non_blocking=True) 95 | 96 | # compute output 97 | output = model(images) 98 | loss = criterion(output, target) 99 | 100 | # measure accuracy and record loss 101 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 102 | losses.update(loss.item(), images.size(0)) 103 | top1.update(acc1[0], images.size(0)) 104 | top5.update(acc5[0], images.size(0)) 105 | 106 | # compute gradient and do SGD step 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | 111 | # measure elapsed time 112 | batch_time.update(time.time() - end) 113 | end = time.time() 114 | 115 | if i % 100 == 0: 116 | progress.display(i) 117 | for param_group in optimizer.param_groups: 118 | print("Current lr:", param_group["lr"]) 119 | step += 1 120 | 121 | 122 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 123 | torch.save(state, filename) 124 | if is_best: 125 | shutil.copyfile(filename, 'model_best.pth.tar') 126 | 127 | 128 | class AverageMeter(object): 129 | """Computes and stores the average and current value""" 130 | def __init__(self, name, fmt=':f'): 131 | self.name = name 132 | self.fmt = fmt 133 | self.reset() 134 | 135 | def reset(self): 136 | self.val = 0 137 | self.avg = 0 138 | self.sum = 0 139 | self.count = 0 140 | 141 | def update(self, val, n=1): 142 | self.val = val 143 | self.sum += val * n 144 | self.count += n 145 | self.avg = self.sum / self.count 146 | 147 | def __str__(self): 148 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 149 | return fmtstr.format(**self.__dict__) 150 | 151 | 152 | class ProgressMeter(object): 153 | def __init__(self, num_batches, meters, prefix=""): 154 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 155 | self.meters = meters 156 | self.prefix = prefix 157 | 158 | def display(self, batch): 159 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 160 | entries += [str(meter) for meter in self.meters] 161 | print('\t'.join(entries)) 162 | 163 | def _get_batch_fmtstr(self, num_batches): 164 | num_digits = len(str(num_batches // 1)) 165 | fmt = '{:' + str(num_digits) + 'd}' 166 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 167 | 168 | def accuracy(output, target, topk=(1,)): 169 | """Computes the accuracy over the k top predictions for the specified values of k""" 170 | with torch.no_grad(): 171 | maxk = max(topk) 172 | batch_size = target.size(0) 173 | 174 | _, pred = output.topk(maxk, 1, True, True) 175 | pred = pred.t() 176 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 177 | 178 | res = [] 179 | for k in topk: 180 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 181 | res.append(correct_k.mul_(100.0 / batch_size)) 182 | return res -------------------------------------------------------------------------------- /resnet_frelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from frelu import FReLU 4 | 5 | try: 6 | from torch.hub import load_state_dict_from_url 7 | except ImportError: 8 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, bias = False): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=bias, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1, bias=False): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias = bias) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 42 | base_width=64, dilation=1, norm_layer=None): 43 | super(BasicBlock, self).__init__() 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if groups != 1 or base_width != 64: 47 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 48 | if dilation > 1: 49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 82 | # This variant is also known as ResNet V1.5 and improves accuracy according to 83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 84 | expansion = 4 85 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 86 | base_width=64, dilation=1, norm_layer=None): 87 | super(Bottleneck, self).__init__() 88 | if norm_layer is None: 89 | norm_layer = nn.BatchNorm2d 90 | width = int(planes * (base_width / 64.)) * groups 91 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 92 | self.conv1 = conv1x1(inplanes, width, bias = True) 93 | self.bn1 = norm_layer(width) 94 | self.conv2 = conv3x3(width, width, stride, groups, dilation, bias = True) 95 | self.bn2 = norm_layer(width) 96 | self.conv3 = conv1x1(width, planes * self.expansion, bias = True) 97 | self.bn3 = norm_layer(planes * self.expansion) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | identity = x 104 | 105 | out = self.conv1(x) 106 | out = self.bn1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv3(out) 114 | out = self.bn3(out) 115 | 116 | if self.downsample is not None: 117 | identity = self.downsample(x) 118 | 119 | out += identity 120 | out = self.relu(out) 121 | 122 | return out 123 | 124 | class Bottleneck_FReLU(nn.Module): 125 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 126 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 127 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 128 | # This variant is also known as ResNet V1.5 and improves accuracy according to 129 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 130 | expansion = 4 131 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 132 | base_width=64, dilation=1, norm_layer=None): 133 | super(Bottleneck_FReLU, self).__init__() 134 | if norm_layer is None: 135 | norm_layer = nn.BatchNorm2d 136 | width = int(planes * (base_width / 64.)) * groups 137 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 138 | # bias should be enabled to be in agreement with original implementation 139 | self.conv1 = conv1x1(inplanes, width, bias = True) 140 | self.bn1 = norm_layer(width) 141 | self.frelu1 = FReLU(width) 142 | self.conv2 = conv3x3(width, width, stride, groups, dilation, bias = True) 143 | self.bn2 = norm_layer(width) 144 | self.frelu2 = FReLU(width) 145 | self.conv3 = conv1x1(width, planes * self.expansion, bias = True) 146 | self.bn3 = norm_layer(planes * self.expansion) 147 | self.frelu3 = FReLU(planes * self.expansion) 148 | self.downsample = downsample 149 | print(self.downsample) 150 | self.stride = stride 151 | 152 | def forward(self, x): 153 | identity = x 154 | 155 | out = self.conv1(x) 156 | out = self.bn1(out) 157 | out = self.frelu1(out) 158 | 159 | out = self.conv2(out) 160 | out = self.bn2(out) 161 | out = self.frelu2(out) 162 | 163 | out = self.conv3(out) 164 | out = self.bn3(out) 165 | 166 | if self.downsample is not None: 167 | identity = self.downsample(x) 168 | 169 | out += identity 170 | out = self.frelu3(out) 171 | 172 | return out 173 | 174 | 175 | class ResNet(nn.Module): 176 | 177 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 178 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 179 | norm_layer=None): 180 | super(ResNet, self).__init__() 181 | if norm_layer is None: 182 | norm_layer = nn.BatchNorm2d 183 | self._norm_layer = norm_layer 184 | 185 | self.inplanes = 64 186 | self.dilation = 1 187 | if replace_stride_with_dilation is None: 188 | # each element in the tuple indicates if we should replace 189 | # the 2x2 stride with a dilated convolution instead 190 | replace_stride_with_dilation = [False, False, False] 191 | if len(replace_stride_with_dilation) != 3: 192 | raise ValueError("replace_stride_with_dilation should be None " 193 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 194 | self.groups = groups 195 | self.base_width = width_per_group 196 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 197 | bias=True) 198 | self.bn1 = norm_layer(self.inplanes) 199 | self.relu = nn.ReLU(inplace=True) 200 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 201 | self.layer1 = self._make_layer(block, 64, layers[0]) 202 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 203 | dilate=replace_stride_with_dilation[0]) 204 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 205 | dilate=replace_stride_with_dilation[1]) 206 | self.layer4 = self._make_layer(Bottleneck, 512, layers[3], stride=2, 207 | dilate=replace_stride_with_dilation[2]) 208 | self.avgpool = nn.AvgPool2d((7, 7)) 209 | self.fc = nn.Linear(512 * block.expansion, num_classes) 210 | self.dropout = nn.Dropout(0.2) 211 | 212 | for m in self.modules(): 213 | if isinstance(m, nn.Conv2d): 214 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 215 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 216 | nn.init.constant_(m.weight, 1) 217 | nn.init.constant_(m.bias, 0) 218 | 219 | # Zero-initialize the last BN in each residual branch, 220 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 221 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 222 | if zero_init_residual: 223 | for m in self.modules(): 224 | if isinstance(m, Bottleneck): 225 | nn.init.constant_(m.bn3.weight, 0) 226 | elif isinstance(m, BasicBlock): 227 | nn.init.constant_(m.bn2.weight, 0) 228 | 229 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 230 | norm_layer = self._norm_layer 231 | downsample = None 232 | previous_dilation = self.dilation 233 | if dilate: 234 | self.dilation *= stride 235 | stride = 1 236 | if stride != 1 or self.inplanes != planes * block.expansion: 237 | downsample = nn.Sequential( 238 | conv1x1(self.inplanes, planes * block.expansion, stride, bias = True), 239 | norm_layer(planes * block.expansion), 240 | ) 241 | 242 | layers = [] 243 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 244 | self.base_width, previous_dilation, norm_layer)) 245 | self.inplanes = planes * block.expansion 246 | for _ in range(1, blocks): 247 | layers.append(block(self.inplanes, planes, groups=self.groups, 248 | base_width=self.base_width, dilation=self.dilation, 249 | norm_layer=norm_layer)) 250 | 251 | return nn.Sequential(*layers) 252 | 253 | def _forward_impl(self, x): 254 | # See note [TorchScript super()] 255 | x = self.conv1(x) 256 | x = self.bn1(x) 257 | x = self.relu(x) 258 | x = self.maxpool(x) 259 | 260 | x = self.layer1(x) 261 | x = self.layer2(x) 262 | x = self.layer3(x) 263 | x = self.layer4(x) 264 | 265 | x = self.avgpool(x) 266 | x = self.dropout(x) 267 | x = torch.flatten(x, 1) 268 | x = self.fc(x) 269 | 270 | return x 271 | 272 | def forward(self, x): 273 | return self._forward_impl(x) 274 | 275 | 276 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 277 | model = ResNet(block, layers, **kwargs) 278 | if pretrained: 279 | state_dict = load_state_dict_from_url(model_urls[arch], 280 | progress=progress) 281 | model.load_state_dict(state_dict) 282 | return model 283 | 284 | 285 | def resnet18(pretrained=False, progress=True, **kwargs): 286 | r"""ResNet-18 model from 287 | `"Deep Residual Learning for Image Recognition" `_ 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def resnet34(pretrained=False, progress=True, **kwargs): 297 | r"""ResNet-34 model from 298 | `"Deep Residual Learning for Image Recognition" `_ 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 304 | **kwargs) 305 | 306 | 307 | def resnet50(pretrained=False, progress=True, **kwargs): 308 | r"""ResNet-50 model from 309 | `"Deep Residual Learning for Image Recognition" `_ 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet50', Bottleneck_FReLU, [3, 4, 6, 3], pretrained, progress, 315 | **kwargs) 316 | 317 | 318 | def resnet101(pretrained=False, progress=True, **kwargs): 319 | r"""ResNet-101 model from 320 | `"Deep Residual Learning for Image Recognition" `_ 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet101', Bottleneck_FReLU, [3, 4, 23, 3], pretrained, progress, 326 | **kwargs) 327 | 328 | 329 | def resnet152(pretrained=False, progress=True, **kwargs): 330 | r"""ResNet-152 model from 331 | `"Deep Residual Learning for Image Recognition" `_ 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | return _resnet('resnet152', Bottleneck_FReLU, [3, 8, 36, 3], pretrained, progress, 337 | **kwargs) 338 | 339 | 340 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 341 | r"""ResNeXt-50 32x4d model from 342 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | kwargs['groups'] = 32 348 | kwargs['width_per_group'] = 4 349 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 350 | pretrained, progress, **kwargs) 351 | 352 | 353 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 354 | r"""ResNeXt-101 32x8d model from 355 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 356 | Args: 357 | pretrained (bool): If True, returns a model pre-trained on ImageNet 358 | progress (bool): If True, displays a progress bar of the download to stderr 359 | """ 360 | kwargs['groups'] = 32 361 | kwargs['width_per_group'] = 8 362 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 363 | pretrained, progress, **kwargs) 364 | 365 | 366 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 367 | r"""Wide ResNet-50-2 model from 368 | `"Wide Residual Networks" `_ 369 | The model is the same as ResNet except for the bottleneck number of channels 370 | which is twice larger in every block. The number of channels in outer 1x1 371 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 372 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 373 | Args: 374 | pretrained (bool): If True, returns a model pre-trained on ImageNet 375 | progress (bool): If True, displays a progress bar of the download to stderr 376 | """ 377 | kwargs['width_per_group'] = 64 * 2 378 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 379 | pretrained, progress, **kwargs) 380 | 381 | 382 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 383 | r"""Wide ResNet-101-2 model from 384 | `"Wide Residual Networks" `_ 385 | The model is the same as ResNet except for the bottleneck number of channels 386 | which is twice larger in every block. The number of channels in outer 1x1 387 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 388 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 389 | Args: 390 | pretrained (bool): If True, returns a model pre-trained on ImageNet 391 | progress (bool): If True, displays a progress bar of the download to stderr 392 | """ 393 | kwargs['width_per_group'] = 64 * 2 394 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 395 | pretrained, progress, **kwargs) --------------------------------------------------------------------------------