├── efficientnet_pytorch
├── __init__.py
├── model.py
└── utils.py
├── partyloss
├── focal_loss.py
└── metrics.py
├── GetImages.py
├── datasets
└── dataset.py
└── README.md
/efficientnet_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.4"
2 | from .model import EfficientNet
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
10 |
11 |
--------------------------------------------------------------------------------
/partyloss/focal_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on 18-6-7 上午10:11
4 |
5 | @author: ronghuaiyang
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class FocalLoss(nn.Module):
13 |
14 | def __init__(self, gamma=0, eps=1e-7):
15 | super(FocalLoss, self).__init__()
16 | self.gamma = gamma
17 | self.eps = eps
18 | self.ce = torch.nn.CrossEntropyLoss()
19 |
20 | def forward(self, input, target):
21 | logp = self.ce(input, target)
22 | p = torch.exp(-logp)
23 | loss = (1 - p) ** self.gamma * logp
24 | return loss.mean()
--------------------------------------------------------------------------------
/GetImages.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from torch.utils.data import Dataset, ConcatDataset, DataLoader
3 | from torchvision import transforms as trans
4 | from torchvision.datasets import ImageFolder
5 | from PIL import Image, ImageFile
6 | ImageFile.LOAD_TRUNCATED_IMAGES = True
7 | import numpy as np
8 | import cv2
9 | import bcolz
10 | import pickle
11 | import torch
12 | import mxnet as mx
13 | from tqdm import tqdm
14 | import os
15 |
16 |
17 |
18 | def load_mx_rec(rec_path,savefold):
19 | save_path = savefold+'/imgs'
20 | if not os.path.exists(save_path):
21 | os.makedirs(save_path)
22 | imgrec = mx.recordio.MXIndexedRecordIO(str(rec_path+'/'+'train.idx'), str(rec_path+'/'+'train.rec'), 'r')
23 | img_info = imgrec.read_idx(0)
24 | header,_ = mx.recordio.unpack(img_info)
25 | max_idx = int(header.label[0])
26 | for idx in tqdm(range(1,max_idx)):
27 | img_info = imgrec.read_idx(idx)
28 | header, img = mx.recordio.unpack_img(img_info)
29 | label = int(header.label)
30 | img = Image.fromarray(img)
31 | label_path = save_path+'/'+str(label)
32 | if not os.path.exists(label_path):
33 | os.makedirs(label_path)
34 | img.save(label_path+'/'+'{}.jpg'.format(idx), quality=100)
35 |
36 | def load_bin(path, savepath, image_size=[112,112]):
37 | if not os.path.exists(savepath):
38 | os.makedirs(savepath)
39 | bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes')
40 | for i in range(len(bins)):
41 | _bin = bins[i]
42 | img = mx.image.imdecode(_bin).asnumpy()
43 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
44 | img = Image.fromarray(img.astype(np.uint8))
45 | img.save(savepath+'/'+'{}.jpg'.format(i), quality=100)
46 | np.save(savepath.split('/')[-1]+'_list', np.array(issame_list))
47 |
48 | if __name__ == '__main__':
49 | mainfold = './faces_emore'
50 | bin_files = ['agedb_30', 'cfp_fp', 'lfw', 'calfw', 'cfp_ff', 'cplfw', 'vgg2_fp']
51 | savefold = './faces_emore_imgs'
52 | load_mx_rec(mainfold,savefold)
53 | for i in range(len(bin_files)):
54 | load_bin(mainfold+'/'+bin_files[i]+'.bin', savepath = savefold+'/'+bin_files[i], image_size=[112,112])
55 |
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 |
9 |
10 | class SampleProperty(object):
11 | def __init__(self, row):
12 | self._sample = row.strip().split(' ')
13 | @property
14 | def path(self):
15 | return self._sample[0]
16 |
17 | @property
18 | def label(self):
19 | return int(self._sample[1])
20 |
21 | class FaceImageDataset(Dataset):
22 | def __init__(self, data_root, list_file_root, modality, transform=None):
23 | self.data_root = data_root
24 | self.list_file_root = list_file_root
25 | self.modality = modality
26 | self.transform = transform
27 | self.Sample_List = [SampleProperty(x) for x in open(self.list_file_root)]
28 |
29 | def __getitem__(self, idx):
30 | img_path = self.Sample_List[idx].path
31 | label = self.Sample_List[idx].label
32 |
33 | if self.modality == 'RGB':
34 | image = Image.open(os.path.join(self.data_root, img_path)).convert('RGB')
35 | if self.modality == 'Gray':
36 | image = Image.open(os.path.join(self.data_root, img_path)).convert('L')
37 | if self.transform is not None:
38 | image = self.transform(image)###C H W
39 | label=torch.tensor(label)
40 |
41 | return image,label
42 |
43 | def __len__(self):
44 | return len(self.Sample_List)
45 |
46 |
47 |
48 |
49 | class FaceImagePiarDataset(Dataset):
50 | def __init__(self, data_root, list_file_root, modality, transform=None):
51 | self.data_root = data_root
52 | self.list_file_root = list_file_root
53 | self.modality = modality
54 | self.transform = transform
55 | self.Sample_List = [line.strip().split(' ') for line in open(self.list_file_root)]
56 |
57 | def __getitem__(self, idx):
58 | img_path1 = self.Sample_List[idx][0]
59 | img_path2 = self.Sample_List[idx][1]
60 | label = int(self.Sample_List[idx][2])
61 |
62 | if self.modality == 'RGB':
63 | image1 = Image.open(os.path.join(self.data_root, img_path1)).convert('RGB')
64 | image2 = Image.open(os.path.join(self.data_root, img_path2)).convert('RGB')
65 | if self.modality == 'Gray':
66 | image1 = Image.open(os.path.join(self.data_root, img_path1)).convert('L')
67 | image2 = Image.open(os.path.join(self.data_root, img_path2)).convert('L')
68 | if self.transform is not None:
69 | image1 = self.transform(image1)###C H W
70 | image2 = self.transform(image2)###C H W
71 | label=torch.tensor(label)
72 |
73 | return image1,image2,label
74 |
75 | def __len__(self):
76 | return len(self.Sample_List)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Insightface_EfficientNet
2 | Pytorch implements the Deep Face Recognition part of Insightface([github](https://github.com/deepinsight/insightface)) with a backbone of EfficientNet([github](https://github.com/lukemelas/EfficientNet-PyTorch)).
3 | # About EfficientNet
4 | Official explanation: EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models. We develop EfficientNets based on AutoML and Compound Scaling. In particular, we first use [AutoML Mobile framework](https://ai.googleblog.com/2018/08/mnasnet-towards-automating-design-of.html) to develop a mobile-size baseline network, named as EfficientNet-B0; Then, we use the compound scaling method to scale up this baseline to obtain EfficientNet-B1 to B7.
5 |
6 |
7 |
8 |
9 |
10 | |
11 |
12 |
13 | |
14 |
15 |
16 |
17 | Details about the EfficientNet models are below:
18 | | *Name* |*# Params*|*Top-1 Acc.*|
19 | |:-----------------:|:--------:|:----------:|
20 | | `efficientnet-b0` | 5.3M | 76.3 |
21 | | `efficientnet-b1` | 7.8M | 78.8 |
22 | | `efficientnet-b2` | 9.2M | 79.8 |
23 | | `efficientnet-b3` | 12M | 81.1 |
24 | | `efficientnet-b4` | 19M | 82.6 |
25 | | `efficientnet-b5` | 30M | 83.3 |
26 | | `efficientnet-b6` | 43M | 84.0 |
27 | | `efficientnet-b7` | 66M | 84.4 |
28 |
29 | # Data Preparation for face recognition
30 | downloading the Training data [MS1M](https://github.com/deepinsight/insightface/wiki/Dataset-Zoo), face is detected by MTCNN and resized to 112x112. If you need to tansfer the `.bin` or `.rec` files into images(.jpg),please run the script `python GetImages.py` under your data fold, note that maxnet should be install.
31 | # Training strategies and results
32 | a. EfficientNet(b0,Params is 5.3M) with batchsize 80 + Argface(m=64,s=0.5) + focalloss(gam=2)
33 | | LFW(%) | CFP-FF(%) | CFP-FP(%) | AgeDB-30(%) | calfw(%) | cplfw(%) | vgg2_fp(%) |
34 | | ------ | --------- | --------- | ----------- | -------- | -------- | ---------- |
35 | | 0.9955 | 0.9940 | 0.9347 | 0.9545 | 0.9532 | 0.8973 | 0.9320 |
36 |
37 | b. EfficientNet(b7,Params is 66M) with batchsize 80 + Argface(m=64,s=0.5) + focalloss(gam=2)
38 | The results is only trained 20 epoch, pretrained model can be download in [here](https://pan.baidu.com/s/1nhrVz33Bc09E0UNhhMzb1Q)(baidu drive, code:wkd2) or [here](https://drive.google.com/file/d/1CiveiSBjmKc5__uBrBpJ2orYkg8ZG2CZ/view?usp=sharing)(google drive).
39 | | LFW(%) | CFP-FF(%) | CFP-FP(%) | AgeDB-30(%) | calfw(%) | cplfw(%) | vgg2_fp(%) |
40 | | ------ | --------- | --------- | ----------- | -------- | -------- | ---------- |
41 | | 0.9973 | 0.9967 | 0.9620 | 0.9705 | 0.9553 | 0.9105 | 0.9428 |
42 |
43 | c.other pretrained model b1, b2, ..., b6 and results is updating...
44 | # PS
45 | If you have questions, post them as GitHub issues.
46 |
--------------------------------------------------------------------------------
/partyloss/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import Parameter
7 | import math
8 |
9 |
10 | class ArcMarginProduct(nn.Module):
11 | r"""Implement of large margin arc distance: :
12 | Args:
13 | in_features: size of each input sample
14 | out_features: size of each output sample
15 | s: norm of input feature
16 | m: margin
17 |
18 | cos(theta + m)
19 | """
20 | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
21 | super(ArcMarginProduct, self).__init__()
22 | self.in_features = in_features
23 | self.out_features = out_features
24 | self.s = s
25 | self.m = m
26 | self.weight = Parameter(torch.FloatTensor(out_features, in_features))
27 | nn.init.xavier_uniform_(self.weight)
28 |
29 | self.easy_margin = easy_margin
30 | self.cos_m = math.cos(m)
31 | self.sin_m = math.sin(m)
32 | self.th = math.cos(math.pi - m)
33 | self.mm = math.sin(math.pi - m) * m
34 |
35 | def forward(self, input, label):
36 | # --------------------------- cos(theta) & phi(theta) ---------------------------
37 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
38 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
39 | phi = cosine * self.cos_m - sine * self.sin_m
40 | if self.easy_margin:
41 | phi = torch.where(cosine > 0, phi, cosine)
42 | else:
43 | phi = torch.where(cosine > self.th, phi, cosine - self.mm)
44 | # --------------------------- convert label to one-hot ---------------------------
45 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
46 | one_hot = torch.zeros(cosine.size(), device='cuda')
47 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
48 | # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
49 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
50 | output *= self.s
51 | # print(output)
52 |
53 | return output
54 |
55 |
56 | class AddMarginProduct(nn.Module):
57 | r"""Implement of large margin cosine distance: :
58 | Args:
59 | in_features: size of each input sample
60 | out_features: size of each output sample
61 | s: norm of input feature
62 | m: margin
63 | cos(theta) - m
64 | """
65 |
66 | def __init__(self, in_features, out_features, s=30.0, m=0.40):
67 | super(AddMarginProduct, self).__init__()
68 | self.in_features = in_features
69 | self.out_features = out_features
70 | self.s = s
71 | self.m = m
72 | self.weight = Parameter(torch.FloatTensor(out_features, in_features))
73 | nn.init.xavier_uniform_(self.weight)
74 |
75 | def forward(self, input, label):
76 | # --------------------------- cos(theta) & phi(theta) ---------------------------
77 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
78 | phi = cosine - self.m
79 | # --------------------------- convert label to one-hot ---------------------------
80 | one_hot = torch.zeros(cosine.size(), device='cuda')
81 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
82 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
83 | # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
84 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
85 | output *= self.s
86 | # print(output)
87 |
88 | return output
89 |
90 | def __repr__(self):
91 | return self.__class__.__name__ + '(' \
92 | + 'in_features=' + str(self.in_features) \
93 | + ', out_features=' + str(self.out_features) \
94 | + ', s=' + str(self.s) \
95 | + ', m=' + str(self.m) + ')'
96 |
97 |
98 | class SphereProduct(nn.Module):
99 | r"""Implement of large margin cosine distance: :
100 | Args:
101 | in_features: size of each input sample
102 | out_features: size of each output sample
103 | m: margin
104 | cos(m*theta)
105 | """
106 | def __init__(self, in_features, out_features, m=4):
107 | super(SphereProduct, self).__init__()
108 | self.in_features = in_features
109 | self.out_features = out_features
110 | self.m = m
111 | self.base = 1000.0
112 | self.gamma = 0.12
113 | self.power = 1
114 | self.LambdaMin = 5.0
115 | self.iter = 0
116 | self.weight = Parameter(torch.FloatTensor(out_features, in_features))
117 | nn.init.xavier_uniform(self.weight)
118 |
119 | # duplication formula
120 | self.mlambda = [
121 | lambda x: x ** 0,
122 | lambda x: x ** 1,
123 | lambda x: 2 * x ** 2 - 1,
124 | lambda x: 4 * x ** 3 - 3 * x,
125 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
126 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
127 | ]
128 |
129 | def forward(self, input, label):
130 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
131 | self.iter += 1
132 | self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))
133 |
134 | # --------------------------- cos(theta) & phi(theta) ---------------------------
135 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
136 | cos_theta = cos_theta.clamp(-1, 1)
137 | cos_m_theta = self.mlambda[self.m](cos_theta)
138 | theta = cos_theta.data.acos()
139 | k = (self.m * theta / 3.14159265).floor()
140 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
141 | NormOfFeature = torch.norm(input, 2, 1)
142 |
143 | # --------------------------- convert label to one-hot ---------------------------
144 | one_hot = torch.zeros(cos_theta.size())
145 | one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot
146 | one_hot.scatter_(1, label.view(-1, 1), 1)
147 |
148 | # --------------------------- Calculate output ---------------------------
149 | output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta
150 | output *= NormOfFeature.view(-1, 1)
151 |
152 | return output
153 |
154 | def __repr__(self):
155 | return self.__class__.__name__ + '(' \
156 | + 'in_features=' + str(self.in_features) \
157 | + ', out_features=' + str(self.out_features) \
158 | + ', m=' + str(self.m) + ')'
159 |
--------------------------------------------------------------------------------
/efficientnet_pytorch/model.py:
--------------------------------------------------------------------------------
1 | """model.py - Model and module class for EfficientNet.
2 | They are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from .utils import (
13 | round_filters,
14 | round_repeats,
15 | drop_connect,
16 | get_same_padding_conv2d,
17 | get_model_params,
18 | efficientnet_params,
19 | load_pretrained_weights,
20 | Swish,
21 | MemoryEfficientSwish,
22 | calculate_output_image_size
23 | )
24 |
25 | class MBConvBlock(nn.Module):
26 | """Mobile Inverted Residual Bottleneck Block.
27 |
28 | Args:
29 | block_args (namedtuple): BlockArgs, defined in utils.py.
30 | global_params (namedtuple): GlobalParam, defined in utils.py.
31 | image_size (tuple or list): [image_height, image_width].
32 |
33 | References:
34 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
35 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
36 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
37 | """
38 |
39 | def __init__(self, block_args, global_params, image_size=None):
40 | super().__init__()
41 | self._block_args = block_args
42 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
43 | self._bn_eps = global_params.batch_norm_epsilon
44 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
45 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
46 |
47 | # Expansion phase (Inverted Bottleneck)
48 | inp = self._block_args.input_filters # number of input channels
49 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
50 | if self._block_args.expand_ratio != 1:
51 | Conv2d = get_same_padding_conv2d(image_size=image_size)
52 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
53 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
54 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
55 |
56 | # Depthwise convolution phase
57 | k = self._block_args.kernel_size
58 | s = self._block_args.stride
59 | Conv2d = get_same_padding_conv2d(image_size=image_size)
60 | self._depthwise_conv = Conv2d(
61 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
62 | kernel_size=k, stride=s, bias=False)
63 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
64 | image_size = calculate_output_image_size(image_size, s)
65 |
66 | # Squeeze and Excitation layer, if desired
67 | if self.has_se:
68 | Conv2d = get_same_padding_conv2d(image_size=(1,1))
69 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
70 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
71 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
72 |
73 | # Pointwise convolution phase
74 | final_oup = self._block_args.output_filters
75 | Conv2d = get_same_padding_conv2d(image_size=image_size)
76 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
77 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
78 | self._swish = MemoryEfficientSwish()
79 |
80 | def forward(self, inputs, drop_connect_rate=None):
81 | """MBConvBlock's forward function.
82 |
83 | Args:
84 | inputs (tensor): Input tensor.
85 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
86 |
87 | Returns:
88 | Output of this block after processing.
89 | """
90 |
91 | # Expansion and Depthwise Convolution
92 | x = inputs
93 | if self._block_args.expand_ratio != 1:
94 | x = self._expand_conv(inputs)
95 | x = self._bn0(x)
96 | x = self._swish(x)
97 |
98 | x = self._depthwise_conv(x)
99 | x = self._bn1(x)
100 | x = self._swish(x)
101 |
102 | # Squeeze and Excitation
103 | if self.has_se:
104 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
105 | x_squeezed = self._se_reduce(x_squeezed)
106 | x_squeezed = self._swish(x_squeezed)
107 | x_squeezed = self._se_expand(x_squeezed)
108 | x = torch.sigmoid(x_squeezed) * x
109 |
110 | # Pointwise Convolution
111 | x = self._project_conv(x)
112 | x = self._bn2(x)
113 |
114 | # Skip connection and drop connect
115 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
116 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
117 | # The combination of skip connection and drop connect brings about stochastic depth.
118 | if drop_connect_rate:
119 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
120 | x = x + inputs # skip connection
121 | return x
122 |
123 | def set_swish(self, memory_efficient=True):
124 | """Sets swish function as memory efficient (for training) or standard (for export).
125 |
126 | Args:
127 | memory_efficient (bool): Whether to use memory-efficient version of swish.
128 | """
129 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
130 |
131 |
132 | class EfficientNet(nn.Module):
133 | """EfficientNet model.
134 | Most easily loaded with the .from_name or .from_pretrained methods.
135 |
136 | Args:
137 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
138 | global_params (namedtuple): A set of GlobalParams shared between blocks.
139 |
140 | References:
141 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
142 |
143 | Example:
144 | >>> import torch
145 | >>> from efficientnet.model import EfficientNet
146 | >>> inputs = torch.rand(1, 3, 224, 224)
147 | >>> model = EfficientNet.from_pretrained('efficientnet-b0')
148 | >>> model.eval()
149 | >>> outputs = model(inputs)
150 | """
151 |
152 | def __init__(self, blocks_args=None, global_params=None):
153 | super().__init__()
154 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
155 | assert len(blocks_args) > 0, 'block args must be greater than 0'
156 | self._global_params = global_params
157 | self._blocks_args = blocks_args
158 |
159 | # Batch norm parameters
160 | bn_mom = 1 - self._global_params.batch_norm_momentum
161 | bn_eps = self._global_params.batch_norm_epsilon
162 |
163 | # Get stem static or dynamic convolution depending on image size
164 | image_size = global_params.image_size
165 | Conv2d = get_same_padding_conv2d(image_size=image_size)
166 |
167 | # Stem
168 | in_channels = 3 # rgb
169 | out_channels = round_filters(32, self._global_params) # number of output channels
170 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
171 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
172 | image_size = calculate_output_image_size(image_size, 2)
173 |
174 | # Build blocks
175 | self._blocks = nn.ModuleList([])
176 | for block_args in self._blocks_args:
177 |
178 | # Update block input and output filters based on depth multiplier.
179 | block_args = block_args._replace(
180 | input_filters=round_filters(block_args.input_filters, self._global_params),
181 | output_filters=round_filters(block_args.output_filters, self._global_params),
182 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
183 | )
184 |
185 | # The first block needs to take care of stride and filter size increase.
186 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
187 | image_size = calculate_output_image_size(image_size, block_args.stride)
188 | if block_args.num_repeat > 1: # modify block_args to keep same output size
189 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
190 | for _ in range(block_args.num_repeat - 1):
191 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
192 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
193 |
194 | # Head
195 | in_channels = block_args.output_filters # output of final block
196 | out_channels = round_filters(1280, self._global_params)
197 | Conv2d = get_same_padding_conv2d(image_size=image_size)
198 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
199 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
200 |
201 | # Final linear layer
202 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
203 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
204 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
205 | self._swish = MemoryEfficientSwish()
206 |
207 | def set_swish(self, memory_efficient=True):
208 | """Sets swish function as memory efficient (for training) or standard (for export).
209 |
210 | Args:
211 | memory_efficient (bool): Whether to use memory-efficient version of swish.
212 |
213 | """
214 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
215 | for block in self._blocks:
216 | block.set_swish(memory_efficient)
217 |
218 |
219 | def extract_features(self, inputs):
220 | """use convolution layer to extract feature .
221 |
222 | Args:
223 | inputs (tensor): Input tensor.
224 |
225 | Returns:
226 | Output of the final convolution
227 | layer in the efficientnet model.
228 | """
229 | # Stem
230 | x = self._swish(self._bn0(self._conv_stem(inputs)))
231 |
232 | # Blocks
233 | for idx, block in enumerate(self._blocks):
234 | drop_connect_rate = self._global_params.drop_connect_rate
235 | if drop_connect_rate:
236 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
237 | x = block(x, drop_connect_rate=drop_connect_rate)
238 |
239 | # Head
240 | x = self._swish(self._bn1(self._conv_head(x)))
241 |
242 | return x
243 |
244 | def forward(self, inputs):
245 | """EfficientNet's forward function.
246 | Calls extract_features to extract features, applies final linear layer, and returns logits.
247 |
248 | Args:
249 | inputs (tensor): Input tensor.
250 |
251 | Returns:
252 | Output of this model after processing.
253 | """
254 | bs = inputs.size(0)
255 |
256 | # Convolution layers
257 | x = self.extract_features(inputs)
258 |
259 | # Pooling and final linear layer
260 | x = self._avg_pooling(x)
261 | x = x.view(bs, -1)
262 | x = self._dropout(x)
263 | x = self._fc(x)
264 |
265 |
266 | return x
267 |
268 | @classmethod
269 | def from_name(cls, model_name, in_channels=3, **override_params):
270 | """create an efficientnet model according to name.
271 |
272 | Args:
273 | model_name (str): Name for efficientnet.
274 | in_channels (int): Input data's channel number.
275 | override_params (other key word params):
276 | Params to override model's global_params.
277 | Optional key:
278 | 'width_coefficient', 'depth_coefficient',
279 | 'image_size', 'dropout_rate',
280 | 'num_classes', 'batch_norm_momentum',
281 | 'batch_norm_epsilon', 'drop_connect_rate',
282 | 'depth_divisor', 'min_depth'
283 |
284 | Returns:
285 | An efficientnet model.
286 | """
287 | cls._check_model_name_is_valid(model_name)
288 | blocks_args, global_params = get_model_params(model_name, override_params)
289 | model = cls(blocks_args, global_params)
290 | model._change_in_channels(in_channels)
291 | return model
292 |
293 | @classmethod
294 | def from_pretrained(cls, model_name, weights_path=None, advprop=False,
295 | in_channels=3, num_classes=1000, **override_params):
296 | """create an efficientnet model according to name.
297 |
298 | Args:
299 | model_name (str): Name for efficientnet.
300 | weights_path (None or str):
301 | str: path to pretrained weights file on the local disk.
302 | None: use pretrained weights downloaded from the Internet.
303 | advprop (bool):
304 | Whether to load pretrained weights
305 | trained with advprop (valid when weights_path is None).
306 | in_channels (int): Input data's channel number.
307 | num_classes (int):
308 | Number of categories for classification.
309 | It controls the output size for final linear layer.
310 | override_params (other key word params):
311 | Params to override model's global_params.
312 | Optional key:
313 | 'width_coefficient', 'depth_coefficient',
314 | 'image_size', 'dropout_rate',
315 | 'num_classes', 'batch_norm_momentum',
316 | 'batch_norm_epsilon', 'drop_connect_rate',
317 | 'depth_divisor', 'min_depth'
318 |
319 | Returns:
320 | A pretrained efficientnet model.
321 | """
322 | model = cls.from_name(model_name, num_classes = num_classes, **override_params)
323 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
324 | model._change_in_channels(in_channels)
325 | return model
326 |
327 | @classmethod
328 | def get_image_size(cls, model_name):
329 | """Get the input image size for a given efficientnet model.
330 |
331 | Args:
332 | model_name (str): Name for efficientnet.
333 |
334 | Returns:
335 | Input image size (resolution).
336 | """
337 | cls._check_model_name_is_valid(model_name)
338 | _, _, res, _ = efficientnet_params(model_name)
339 | return res
340 |
341 | @classmethod
342 | def _check_model_name_is_valid(cls, model_name):
343 | """Validates model name.
344 |
345 | Args:
346 | model_name (str): Name for efficientnet.
347 |
348 | Returns:
349 | bool: Is a valid name or not.
350 | """
351 | valid_models = ['efficientnet-b'+str(i) for i in range(9)]
352 |
353 | # Support the construction of 'efficientnet-l2' without pretrained weights
354 | valid_models += ['efficientnet-l2']
355 |
356 | if model_name not in valid_models:
357 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
358 |
359 | def _change_in_channels(self, in_channels):
360 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
361 |
362 | Args:
363 | in_channels (int): Input data's channel number.
364 | """
365 | if in_channels != 3:
366 | Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size)
367 | out_channels = round_filters(32, self._global_params)
368 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
369 |
--------------------------------------------------------------------------------
/efficientnet_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | """utils.py - Helper functions for building the model and for loading model parameters.
2 | These helper functions are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import re
10 | import math
11 | import collections
12 | from functools import partial
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.utils import model_zoo
17 |
18 |
19 | ################################################################################
20 | ### Help functions for model architecture
21 | ################################################################################
22 |
23 | # GlobalParams and BlockArgs: Two namedtuples
24 | # Swish and MemoryEfficientSwish: Two implementations of the method
25 | # round_filters and round_repeats:
26 | # Functions to calculate params for scaling model width and depth ! ! !
27 | # get_width_and_height_from_size and calculate_output_image_size
28 | # drop_connect: A structural design
29 | # get_same_padding_conv2d:
30 | # Conv2dDynamicSamePadding
31 | # Conv2dStaticSamePadding
32 | # get_same_padding_maxPool2d:
33 | # MaxPool2dDynamicSamePadding
34 | # MaxPool2dStaticSamePadding
35 | # It's an additional function, not used in EfficientNet,
36 | # but can be used in other model (such as EfficientDet).
37 | # Identity: An implementation of identical mapping
38 |
39 | # Parameters for the entire model (stem, all blocks, and head)
40 | GlobalParams = collections.namedtuple('GlobalParams', [
41 | 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
42 | 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
43 | 'drop_connect_rate', 'depth_divisor', 'min_depth'])
44 |
45 | # Parameters for an individual model block
46 | BlockArgs = collections.namedtuple('BlockArgs', [
47 | 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
48 | 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
49 |
50 | # Set GlobalParams and BlockArgs's defaults
51 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
52 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
53 |
54 |
55 | # An ordinary implementation of Swish function
56 | class Swish(nn.Module):
57 | def forward(self, x):
58 | return x * torch.sigmoid(x)
59 |
60 |
61 | # A memory-efficient implementation of Swish function
62 | class SwishImplementation(torch.autograd.Function):
63 | @staticmethod
64 | def forward(ctx, i):
65 | result = i * torch.sigmoid(i)
66 | ctx.save_for_backward(i)
67 | return result
68 |
69 | @staticmethod
70 | def backward(ctx, grad_output):
71 | i = ctx.saved_variables[0]
72 | sigmoid_i = torch.sigmoid(i)
73 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
74 |
75 | class MemoryEfficientSwish(nn.Module):
76 | def forward(self, x):
77 | return SwishImplementation.apply(x)
78 |
79 |
80 | def round_filters(filters, global_params):
81 | """Calculate and round number of filters based on width multiplier.
82 | Use width_coefficient, depth_divisor and min_depth of global_params.
83 |
84 | Args:
85 | filters (int): Filters number to be calculated.
86 | global_params (namedtuple): Global params of the model.
87 |
88 | Returns:
89 | new_filters: New filters number after calculating.
90 | """
91 | multiplier = global_params.width_coefficient
92 | if not multiplier:
93 | return filters
94 | # TODO: modify the params names.
95 | # maybe the names (width_divisor,min_width)
96 | # are more suitable than (depth_divisor,min_depth).
97 | divisor = global_params.depth_divisor
98 | min_depth = global_params.min_depth
99 | filters *= multiplier
100 | min_depth = min_depth or divisor # pay attention to this line when using min_depth
101 | # follow the formula transferred from official TensorFlow implementation
102 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
103 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
104 | new_filters += divisor
105 | return int(new_filters)
106 |
107 |
108 | def round_repeats(repeats, global_params):
109 | """Calculate module's repeat number of a block based on depth multiplier.
110 | Use depth_coefficient of global_params.
111 |
112 | Args:
113 | repeats (int): num_repeat to be calculated.
114 | global_params (namedtuple): Global params of the model.
115 |
116 | Returns:
117 | new repeat: New repeat number after calculating.
118 | """
119 | multiplier = global_params.depth_coefficient
120 | if not multiplier:
121 | return repeats
122 | # follow the formula transferred from official TensorFlow implementation
123 | return int(math.ceil(multiplier * repeats))
124 |
125 |
126 | def drop_connect(inputs, p, training):
127 | """Drop connect.
128 |
129 | Args:
130 | input (tensor: BCWH): Input of this structure.
131 | p (float: 0.0~1.0): Probability of drop connection.
132 | training (bool): The running mode.
133 |
134 | Returns:
135 | output: Output after drop connection.
136 | """
137 | assert p >= 0 and p <= 1, 'p must be in range of [0,1]'
138 |
139 | if not training:
140 | return inputs
141 |
142 | batch_size = inputs.shape[0]
143 | keep_prob = 1 - p
144 |
145 | # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
146 | random_tensor = keep_prob
147 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
148 | binary_tensor = torch.floor(random_tensor)
149 |
150 | output = inputs / keep_prob * binary_tensor
151 | return output
152 |
153 |
154 | def get_width_and_height_from_size(x):
155 | """Obtain height and width from x.
156 |
157 | Args:
158 | x (int, tuple or list): Data size.
159 |
160 | Returns:
161 | size: A tuple or list (H,W).
162 | """
163 | if isinstance(x, int):
164 | return x, x
165 | if isinstance(x, list) or isinstance(x, tuple):
166 | return x
167 | else:
168 | raise TypeError()
169 |
170 |
171 | def calculate_output_image_size(input_image_size, stride):
172 | """Calculates the output image size when using Conv2dSamePadding with a stride.
173 | Necessary for static padding. Thanks to mannatsingh for pointing this out.
174 |
175 | Args:
176 | input_image_size (int, tuple or list): Size of input image.
177 | stride (int, tuple or list): Conv2d operation's stride.
178 |
179 | Returns:
180 | output_image_size: A list [H,W].
181 | """
182 | if input_image_size is None:
183 | return None
184 | image_height, image_width = get_width_and_height_from_size(input_image_size)
185 | stride = stride if isinstance(stride, int) else stride[0]
186 | image_height = int(math.ceil(image_height / stride))
187 | image_width = int(math.ceil(image_width / stride))
188 | return [image_height, image_width]
189 |
190 |
191 | # Note:
192 | # The following 'SamePadding' functions make output size equal ceil(input size/stride).
193 | # Only when stride equals 1, can the output size be the same as input size.
194 | # Don't be confused by their function names ! ! !
195 |
196 | def get_same_padding_conv2d(image_size=None):
197 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
198 | Static padding is necessary for ONNX exporting of models.
199 |
200 | Args:
201 | image_size (int or tuple): Size of the image.
202 |
203 | Returns:
204 | Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
205 | """
206 | if image_size is None:
207 | return Conv2dDynamicSamePadding
208 | else:
209 | return partial(Conv2dStaticSamePadding, image_size=image_size)
210 |
211 |
212 | class Conv2dDynamicSamePadding(nn.Conv2d):
213 | """2D Convolutions like TensorFlow, for a dynamic image size.
214 | The padding is operated in forward function by calculating dynamically.
215 | """
216 |
217 | # Tips for 'SAME' mode padding.
218 | # Given the following:
219 | # i: width or height
220 | # s: stride
221 | # k: kernel size
222 | # d: dilation
223 | # p: padding
224 | # Output after Conv2d:
225 | # o = floor((i+p-((k-1)*d+1))/s+1)
226 | # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
227 | # => p = (i-1)*s+((k-1)*d+1)-i
228 |
229 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
230 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
231 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
232 |
233 | def forward(self, x):
234 | ih, iw = x.size()[-2:]
235 | kh, kw = self.weight.size()[-2:]
236 | sh, sw = self.stride
237 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
238 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
239 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
240 | if pad_h > 0 or pad_w > 0:
241 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
242 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
243 |
244 |
245 | class Conv2dStaticSamePadding(nn.Conv2d):
246 | """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
247 | The padding mudule is calculated in construction function, then used in forward.
248 | """
249 |
250 | # With the same calculation as Conv2dDynamicSamePadding
251 |
252 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
253 | super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
254 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
255 |
256 | # Calculate padding based on image size and save it
257 | assert image_size is not None
258 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
259 | kh, kw = self.weight.size()[-2:]
260 | sh, sw = self.stride
261 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
262 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
263 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
264 | if pad_h > 0 or pad_w > 0:
265 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
266 | else:
267 | self.static_padding = Identity()
268 |
269 | def forward(self, x):
270 | x = self.static_padding(x)
271 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
272 | return x
273 |
274 |
275 | def get_same_padding_maxPool2d(image_size=None):
276 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
277 | Static padding is necessary for ONNX exporting of models.
278 |
279 | Args:
280 | image_size (int or tuple): Size of the image.
281 |
282 | Returns:
283 | MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
284 | """
285 | if image_size is None:
286 | return MaxPool2dDynamicSamePadding
287 | else:
288 | return partial(MaxPool2dStaticSamePadding, image_size=image_size)
289 |
290 |
291 | class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
292 | """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
293 | The padding is operated in forward function by calculating dynamically.
294 | """
295 |
296 | def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
297 | super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
298 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
299 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
300 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
301 |
302 | def forward(self, x):
303 | ih, iw = x.size()[-2:]
304 | kh, kw = self.kernel_size
305 | sh, sw = self.stride
306 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
307 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
308 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
309 | if pad_h > 0 or pad_w > 0:
310 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
311 | return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
312 | self.dilation, self.ceil_mode, self.return_indices)
313 |
314 | class MaxPool2dStaticSamePadding(nn.MaxPool2d):
315 | """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
316 | The padding mudule is calculated in construction function, then used in forward.
317 | """
318 |
319 | def __init__(self, kernel_size, stride, image_size=None, **kwargs):
320 | super().__init__(kernel_size, stride, **kwargs)
321 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
322 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
323 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
324 |
325 | # Calculate padding based on image size and save it
326 | assert image_size is not None
327 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
328 | kh, kw = self.kernel_size
329 | sh, sw = self.stride
330 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
331 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
332 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
333 | if pad_h > 0 or pad_w > 0:
334 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
335 | else:
336 | self.static_padding = Identity()
337 |
338 | def forward(self, x):
339 | x = self.static_padding(x)
340 | x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
341 | self.dilation, self.ceil_mode, self.return_indices)
342 | return x
343 |
344 | class Identity(nn.Module):
345 | """Identity mapping.
346 | Send input to output directly.
347 | """
348 |
349 | def __init__(self):
350 | super(Identity, self).__init__()
351 |
352 | def forward(self, input):
353 | return input
354 |
355 |
356 | ################################################################################
357 | ### Helper functions for loading model params
358 | ################################################################################
359 |
360 | # BlockDecoder: A Class for encoding and decoding BlockArgs
361 | # efficientnet_params: A function to query compound coefficient
362 | # get_model_params and efficientnet:
363 | # Functions to get BlockArgs and GlobalParams for efficientnet
364 | # url_map and url_map_advprop: Dicts of url_map for pretrained weights
365 | # load_pretrained_weights: A function to load pretrained weights
366 |
367 | class BlockDecoder(object):
368 | """Block Decoder for readability,
369 | straight from the official TensorFlow repository.
370 | """
371 |
372 | @staticmethod
373 | def _decode_block_string(block_string):
374 | """Get a block through a string notation of arguments.
375 |
376 | Args:
377 | block_string (str): A string notation of arguments.
378 | Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
379 |
380 | Returns:
381 | BlockArgs: The namedtuple defined at the top of this file.
382 | """
383 | assert isinstance(block_string, str)
384 |
385 | ops = block_string.split('_')
386 | options = {}
387 | for op in ops:
388 | splits = re.split(r'(\d.*)', op)
389 | if len(splits) >= 2:
390 | key, value = splits[:2]
391 | options[key] = value
392 |
393 | # Check stride
394 | assert (('s' in options and len(options['s']) == 1) or
395 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
396 |
397 | return BlockArgs(
398 | num_repeat=int(options['r']),
399 | kernel_size=int(options['k']),
400 | stride=[int(options['s'][0])],
401 | expand_ratio=int(options['e']),
402 | input_filters=int(options['i']),
403 | output_filters=int(options['o']),
404 | se_ratio=float(options['se']) if 'se' in options else None,
405 | id_skip=('noskip' not in block_string))
406 |
407 | @staticmethod
408 | def _encode_block_string(block):
409 | """Encode a block to a string.
410 |
411 | Args:
412 | block (namedtuple): A BlockArgs type argument.
413 |
414 | Returns:
415 | block_string: A String form of BlockArgs.
416 | """
417 | args = [
418 | 'r%d' % block.num_repeat,
419 | 'k%d' % block.kernel_size,
420 | 's%d%d' % (block.strides[0], block.strides[1]),
421 | 'e%s' % block.expand_ratio,
422 | 'i%d' % block.input_filters,
423 | 'o%d' % block.output_filters
424 | ]
425 | if 0 < block.se_ratio <= 1:
426 | args.append('se%s' % block.se_ratio)
427 | if block.id_skip is False:
428 | args.append('noskip')
429 | return '_'.join(args)
430 |
431 | @staticmethod
432 | def decode(string_list):
433 | """Decode a list of string notations to specify blocks inside the network.
434 |
435 | Args:
436 | string_list (list[str]): A list of strings, each string is a notation of block.
437 |
438 | Returns:
439 | blocks_args: A list of BlockArgs namedtuples of block args.
440 | """
441 | assert isinstance(string_list, list)
442 | blocks_args = []
443 | for block_string in string_list:
444 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
445 | return blocks_args
446 |
447 | @staticmethod
448 | def encode(blocks_args):
449 | """Encode a list of BlockArgs to a list of strings.
450 |
451 | Args:
452 | blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
453 |
454 | Returns:
455 | block_strings: A list of strings, each string is a notation of block.
456 | """
457 | block_strings = []
458 | for block in blocks_args:
459 | block_strings.append(BlockDecoder._encode_block_string(block))
460 | return block_strings
461 |
462 |
463 | def efficientnet_params(model_name):
464 | """Map EfficientNet model name to parameter coefficients.
465 |
466 | Args:
467 | model_name (str): Model name to be queried.
468 |
469 | Returns:
470 | params_dict[model_name]: A (width,depth,res,dropout) tuple.
471 | """
472 | params_dict = {
473 | # Coefficients: width,depth,res,dropout
474 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
475 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
476 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
477 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
478 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
479 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
480 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
481 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
482 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
483 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
484 | }
485 | return params_dict[model_name]
486 |
487 |
488 | def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
489 | dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000):
490 | """Create BlockArgs and GlobalParams for efficientnet model.
491 |
492 | Args:
493 | width_coefficient (float)
494 | depth_coefficient (float)
495 | image_size (int)
496 | dropout_rate (float)
497 | drop_connect_rate (float)
498 | num_classes (int)
499 |
500 | Meaning as the name suggests.
501 |
502 | Returns:
503 | blocks_args, global_params.
504 | """
505 |
506 | # Blocks args for the whole model(efficientnet-b0 by default)
507 | # It will be modified in the construction of EfficientNet Class according to model
508 | blocks_args = [
509 | 'r1_k3_s11_e1_i32_o16_se0.25',
510 | 'r2_k3_s22_e6_i16_o24_se0.25',
511 | 'r2_k5_s22_e6_i24_o40_se0.25',
512 | 'r3_k3_s22_e6_i40_o80_se0.25',
513 | 'r3_k5_s11_e6_i80_o112_se0.25',
514 | 'r4_k5_s22_e6_i112_o192_se0.25',
515 | 'r1_k3_s11_e6_i192_o320_se0.25',
516 | ]
517 | blocks_args = BlockDecoder.decode(blocks_args)
518 |
519 | global_params = GlobalParams(
520 | width_coefficient=width_coefficient,
521 | depth_coefficient=depth_coefficient,
522 | image_size=image_size,
523 | dropout_rate=dropout_rate,
524 |
525 | num_classes=num_classes,
526 | batch_norm_momentum=0.99,
527 | batch_norm_epsilon=1e-3,
528 | drop_connect_rate=drop_connect_rate,
529 | depth_divisor=8,
530 | min_depth=None,
531 | )
532 |
533 | return blocks_args, global_params
534 |
535 |
536 | def get_model_params(model_name, override_params):
537 | """Get the block args and global params for a given model name.
538 |
539 | Args:
540 | model_name (str): Model's name.
541 | override_params (dict): A dict to modify global_params.
542 |
543 | Returns:
544 | blocks_args, global_params
545 | """
546 | if model_name.startswith('efficientnet'):
547 | w, d, s, p = efficientnet_params(model_name)
548 | # note: all models have drop connect rate = 0.2
549 | blocks_args, global_params = efficientnet(
550 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
551 | else:
552 | raise NotImplementedError('model name is not pre-defined: %s' % model_name)
553 | if override_params:
554 | # ValueError will be raised here if override_params has fields not included in global_params.
555 | global_params = global_params._replace(**override_params)
556 | return blocks_args, global_params
557 |
558 |
559 | # train with Standard methods
560 | # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
561 | url_map = {
562 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
563 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
564 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
565 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
566 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
567 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
568 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
569 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
570 | }
571 |
572 | # train with Adversarial Examples(AdvProp)
573 | # check more details in paper(Adversarial Examples Improve Image Recognition)
574 | url_map_advprop = {
575 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
576 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
577 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
578 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
579 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
580 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
581 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
582 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
583 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
584 | }
585 |
586 | # TODO: add the petrained weights url map of 'efficientnet-l2'
587 |
588 |
589 | def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False):
590 | """Loads pretrained weights from weights path or download using url.
591 |
592 | Args:
593 | model (Module): The whole model of efficientnet.
594 | model_name (str): Model name of efficientnet.
595 | weights_path (None or str):
596 | str: path to pretrained weights file on the local disk.
597 | None: use pretrained weights downloaded from the Internet.
598 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
599 | advprop (bool): Whether to load pretrained weights
600 | trained with advprop (valid when weights_path is None).
601 | """
602 | if isinstance(weights_path,str):
603 | state_dict = torch.load(weights_path)
604 | else:
605 | # AutoAugment or Advprop (different preprocessing)
606 | url_map_ = url_map_advprop if advprop else url_map
607 | state_dict = model_zoo.load_url(url_map_[model_name])
608 |
609 | if load_fc:
610 | ret = model.load_state_dict(state_dict, strict=False)
611 | # assert not ret.missing_keys, f'Missing keys when loading pretrained weights: {ret.missing_keys}'
612 | else:
613 | state_dict.pop('_fc.weight')
614 | state_dict.pop('_fc.bias')
615 | ret = model.load_state_dict(state_dict, strict=False)
616 | # assert set(ret.missing_keys) == set(
617 | # ['_fc.weight', '_fc.bias']), f'Missing keys when loading pretrained weights: {ret.missing_keys}'
618 | # assert not ret.unexpected_keys, f'Missing keys when loading pretrained weights: {ret.unexpected_keys}'
619 |
620 | print('Loaded pretrained weights for {}'.format(model_name))
621 |
--------------------------------------------------------------------------------