├── efficientnet_pytorch ├── __init__.py ├── model.py └── utils.py ├── partyloss ├── focal_loss.py └── metrics.py ├── GetImages.py ├── datasets └── dataset.py └── README.md /efficientnet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.4" 2 | from .model import EfficientNet 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /partyloss/focal_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 18-6-7 上午10:11 4 | 5 | @author: ronghuaiyang 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class FocalLoss(nn.Module): 13 | 14 | def __init__(self, gamma=0, eps=1e-7): 15 | super(FocalLoss, self).__init__() 16 | self.gamma = gamma 17 | self.eps = eps 18 | self.ce = torch.nn.CrossEntropyLoss() 19 | 20 | def forward(self, input, target): 21 | logp = self.ce(input, target) 22 | p = torch.exp(-logp) 23 | loss = (1 - p) ** self.gamma * logp 24 | return loss.mean() -------------------------------------------------------------------------------- /GetImages.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from torch.utils.data import Dataset, ConcatDataset, DataLoader 3 | from torchvision import transforms as trans 4 | from torchvision.datasets import ImageFolder 5 | from PIL import Image, ImageFile 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | import numpy as np 8 | import cv2 9 | import bcolz 10 | import pickle 11 | import torch 12 | import mxnet as mx 13 | from tqdm import tqdm 14 | import os 15 | 16 | 17 | 18 | def load_mx_rec(rec_path,savefold): 19 | save_path = savefold+'/imgs' 20 | if not os.path.exists(save_path): 21 | os.makedirs(save_path) 22 | imgrec = mx.recordio.MXIndexedRecordIO(str(rec_path+'/'+'train.idx'), str(rec_path+'/'+'train.rec'), 'r') 23 | img_info = imgrec.read_idx(0) 24 | header,_ = mx.recordio.unpack(img_info) 25 | max_idx = int(header.label[0]) 26 | for idx in tqdm(range(1,max_idx)): 27 | img_info = imgrec.read_idx(idx) 28 | header, img = mx.recordio.unpack_img(img_info) 29 | label = int(header.label) 30 | img = Image.fromarray(img) 31 | label_path = save_path+'/'+str(label) 32 | if not os.path.exists(label_path): 33 | os.makedirs(label_path) 34 | img.save(label_path+'/'+'{}.jpg'.format(idx), quality=100) 35 | 36 | def load_bin(path, savepath, image_size=[112,112]): 37 | if not os.path.exists(savepath): 38 | os.makedirs(savepath) 39 | bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes') 40 | for i in range(len(bins)): 41 | _bin = bins[i] 42 | img = mx.image.imdecode(_bin).asnumpy() 43 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 44 | img = Image.fromarray(img.astype(np.uint8)) 45 | img.save(savepath+'/'+'{}.jpg'.format(i), quality=100) 46 | np.save(savepath.split('/')[-1]+'_list', np.array(issame_list)) 47 | 48 | if __name__ == '__main__': 49 | mainfold = './faces_emore' 50 | bin_files = ['agedb_30', 'cfp_fp', 'lfw', 'calfw', 'cfp_ff', 'cplfw', 'vgg2_fp'] 51 | savefold = './faces_emore_imgs' 52 | load_mx_rec(mainfold,savefold) 53 | for i in range(len(bin_files)): 54 | load_bin(mainfold+'/'+bin_files[i]+'.bin', savepath = savefold+'/'+bin_files[i], image_size=[112,112]) 55 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | 10 | class SampleProperty(object): 11 | def __init__(self, row): 12 | self._sample = row.strip().split(' ') 13 | @property 14 | def path(self): 15 | return self._sample[0] 16 | 17 | @property 18 | def label(self): 19 | return int(self._sample[1]) 20 | 21 | class FaceImageDataset(Dataset): 22 | def __init__(self, data_root, list_file_root, modality, transform=None): 23 | self.data_root = data_root 24 | self.list_file_root = list_file_root 25 | self.modality = modality 26 | self.transform = transform 27 | self.Sample_List = [SampleProperty(x) for x in open(self.list_file_root)] 28 | 29 | def __getitem__(self, idx): 30 | img_path = self.Sample_List[idx].path 31 | label = self.Sample_List[idx].label 32 | 33 | if self.modality == 'RGB': 34 | image = Image.open(os.path.join(self.data_root, img_path)).convert('RGB') 35 | if self.modality == 'Gray': 36 | image = Image.open(os.path.join(self.data_root, img_path)).convert('L') 37 | if self.transform is not None: 38 | image = self.transform(image)###C H W 39 | label=torch.tensor(label) 40 | 41 | return image,label 42 | 43 | def __len__(self): 44 | return len(self.Sample_List) 45 | 46 | 47 | 48 | 49 | class FaceImagePiarDataset(Dataset): 50 | def __init__(self, data_root, list_file_root, modality, transform=None): 51 | self.data_root = data_root 52 | self.list_file_root = list_file_root 53 | self.modality = modality 54 | self.transform = transform 55 | self.Sample_List = [line.strip().split(' ') for line in open(self.list_file_root)] 56 | 57 | def __getitem__(self, idx): 58 | img_path1 = self.Sample_List[idx][0] 59 | img_path2 = self.Sample_List[idx][1] 60 | label = int(self.Sample_List[idx][2]) 61 | 62 | if self.modality == 'RGB': 63 | image1 = Image.open(os.path.join(self.data_root, img_path1)).convert('RGB') 64 | image2 = Image.open(os.path.join(self.data_root, img_path2)).convert('RGB') 65 | if self.modality == 'Gray': 66 | image1 = Image.open(os.path.join(self.data_root, img_path1)).convert('L') 67 | image2 = Image.open(os.path.join(self.data_root, img_path2)).convert('L') 68 | if self.transform is not None: 69 | image1 = self.transform(image1)###C H W 70 | image2 = self.transform(image2)###C H W 71 | label=torch.tensor(label) 72 | 73 | return image1,image2,label 74 | 75 | def __len__(self): 76 | return len(self.Sample_List) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Insightface_EfficientNet 2 | Pytorch implements the Deep Face Recognition part of Insightface([github](https://github.com/deepinsight/insightface)) with a backbone of EfficientNet([github](https://github.com/lukemelas/EfficientNet-PyTorch)). 3 | # About EfficientNet 4 | Official explanation: EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models. We develop EfficientNets based on AutoML and Compound Scaling. In particular, we first use [AutoML Mobile framework](https://ai.googleblog.com/2018/08/mnasnet-towards-automating-design-of.html) to develop a mobile-size baseline network, named as EfficientNet-B0; Then, we use the compound scaling method to scale up this baseline to obtain EfficientNet-B1 to B7. 5 | 6 | 7 | 8 | 11 | 14 | 15 |
9 | 10 | 12 | 13 |
16 | 17 | Details about the EfficientNet models are below: 18 | | *Name* |*# Params*|*Top-1 Acc.*| 19 | |:-----------------:|:--------:|:----------:| 20 | | `efficientnet-b0` | 5.3M | 76.3 | 21 | | `efficientnet-b1` | 7.8M | 78.8 | 22 | | `efficientnet-b2` | 9.2M | 79.8 | 23 | | `efficientnet-b3` | 12M | 81.1 | 24 | | `efficientnet-b4` | 19M | 82.6 | 25 | | `efficientnet-b5` | 30M | 83.3 | 26 | | `efficientnet-b6` | 43M | 84.0 | 27 | | `efficientnet-b7` | 66M | 84.4 | 28 | 29 | # Data Preparation for face recognition 30 | downloading the Training data [MS1M](https://github.com/deepinsight/insightface/wiki/Dataset-Zoo), face is detected by MTCNN and resized to 112x112. If you need to tansfer the `.bin` or `.rec` files into images(.jpg),please run the script `python GetImages.py` under your data fold, note that maxnet should be install. 31 | # Training strategies and results 32 | a. EfficientNet(b0,Params is 5.3M) with batchsize 80 + Argface(m=64,s=0.5) + focalloss(gam=2) 33 | | LFW(%) | CFP-FF(%) | CFP-FP(%) | AgeDB-30(%) | calfw(%) | cplfw(%) | vgg2_fp(%) | 34 | | ------ | --------- | --------- | ----------- | -------- | -------- | ---------- | 35 | | 0.9955 | 0.9940 | 0.9347 | 0.9545 | 0.9532 | 0.8973 | 0.9320 | 36 | 37 | b. EfficientNet(b7,Params is 66M) with batchsize 80 + Argface(m=64,s=0.5) + focalloss(gam=2) 38 | The results is only trained 20 epoch, pretrained model can be download in [here](https://pan.baidu.com/s/1nhrVz33Bc09E0UNhhMzb1Q)(baidu drive, code:wkd2) or [here](https://drive.google.com/file/d/1CiveiSBjmKc5__uBrBpJ2orYkg8ZG2CZ/view?usp=sharing)(google drive). 39 | | LFW(%) | CFP-FF(%) | CFP-FP(%) | AgeDB-30(%) | calfw(%) | cplfw(%) | vgg2_fp(%) | 40 | | ------ | --------- | --------- | ----------- | -------- | -------- | ---------- | 41 | | 0.9973 | 0.9967 | 0.9620 | 0.9705 | 0.9553 | 0.9105 | 0.9428 | 42 | 43 | c.other pretrained model b1, b2, ..., b6 and results is updating... 44 | # PS 45 | If you have questions, post them as GitHub issues. 46 | -------------------------------------------------------------------------------- /partyloss/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import Parameter 7 | import math 8 | 9 | 10 | class ArcMarginProduct(nn.Module): 11 | r"""Implement of large margin arc distance: : 12 | Args: 13 | in_features: size of each input sample 14 | out_features: size of each output sample 15 | s: norm of input feature 16 | m: margin 17 | 18 | cos(theta + m) 19 | """ 20 | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False): 21 | super(ArcMarginProduct, self).__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.s = s 25 | self.m = m 26 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 27 | nn.init.xavier_uniform_(self.weight) 28 | 29 | self.easy_margin = easy_margin 30 | self.cos_m = math.cos(m) 31 | self.sin_m = math.sin(m) 32 | self.th = math.cos(math.pi - m) 33 | self.mm = math.sin(math.pi - m) * m 34 | 35 | def forward(self, input, label): 36 | # --------------------------- cos(theta) & phi(theta) --------------------------- 37 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 38 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 39 | phi = cosine * self.cos_m - sine * self.sin_m 40 | if self.easy_margin: 41 | phi = torch.where(cosine > 0, phi, cosine) 42 | else: 43 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 44 | # --------------------------- convert label to one-hot --------------------------- 45 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 46 | one_hot = torch.zeros(cosine.size(), device='cuda') 47 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 48 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 49 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 50 | output *= self.s 51 | # print(output) 52 | 53 | return output 54 | 55 | 56 | class AddMarginProduct(nn.Module): 57 | r"""Implement of large margin cosine distance: : 58 | Args: 59 | in_features: size of each input sample 60 | out_features: size of each output sample 61 | s: norm of input feature 62 | m: margin 63 | cos(theta) - m 64 | """ 65 | 66 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 67 | super(AddMarginProduct, self).__init__() 68 | self.in_features = in_features 69 | self.out_features = out_features 70 | self.s = s 71 | self.m = m 72 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 73 | nn.init.xavier_uniform_(self.weight) 74 | 75 | def forward(self, input, label): 76 | # --------------------------- cos(theta) & phi(theta) --------------------------- 77 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 78 | phi = cosine - self.m 79 | # --------------------------- convert label to one-hot --------------------------- 80 | one_hot = torch.zeros(cosine.size(), device='cuda') 81 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 82 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 83 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 84 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 85 | output *= self.s 86 | # print(output) 87 | 88 | return output 89 | 90 | def __repr__(self): 91 | return self.__class__.__name__ + '(' \ 92 | + 'in_features=' + str(self.in_features) \ 93 | + ', out_features=' + str(self.out_features) \ 94 | + ', s=' + str(self.s) \ 95 | + ', m=' + str(self.m) + ')' 96 | 97 | 98 | class SphereProduct(nn.Module): 99 | r"""Implement of large margin cosine distance: : 100 | Args: 101 | in_features: size of each input sample 102 | out_features: size of each output sample 103 | m: margin 104 | cos(m*theta) 105 | """ 106 | def __init__(self, in_features, out_features, m=4): 107 | super(SphereProduct, self).__init__() 108 | self.in_features = in_features 109 | self.out_features = out_features 110 | self.m = m 111 | self.base = 1000.0 112 | self.gamma = 0.12 113 | self.power = 1 114 | self.LambdaMin = 5.0 115 | self.iter = 0 116 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 117 | nn.init.xavier_uniform(self.weight) 118 | 119 | # duplication formula 120 | self.mlambda = [ 121 | lambda x: x ** 0, 122 | lambda x: x ** 1, 123 | lambda x: 2 * x ** 2 - 1, 124 | lambda x: 4 * x ** 3 - 3 * x, 125 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 126 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 127 | ] 128 | 129 | def forward(self, input, label): 130 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power)) 131 | self.iter += 1 132 | self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) 133 | 134 | # --------------------------- cos(theta) & phi(theta) --------------------------- 135 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) 136 | cos_theta = cos_theta.clamp(-1, 1) 137 | cos_m_theta = self.mlambda[self.m](cos_theta) 138 | theta = cos_theta.data.acos() 139 | k = (self.m * theta / 3.14159265).floor() 140 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k 141 | NormOfFeature = torch.norm(input, 2, 1) 142 | 143 | # --------------------------- convert label to one-hot --------------------------- 144 | one_hot = torch.zeros(cos_theta.size()) 145 | one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot 146 | one_hot.scatter_(1, label.view(-1, 1), 1) 147 | 148 | # --------------------------- Calculate output --------------------------- 149 | output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta 150 | output *= NormOfFeature.view(-1, 1) 151 | 152 | return output 153 | 154 | def __repr__(self): 155 | return self.__class__.__name__ + '(' \ 156 | + 'in_features=' + str(self.in_features) \ 157 | + ', out_features=' + str(self.out_features) \ 158 | + ', m=' + str(self.m) + ')' 159 | -------------------------------------------------------------------------------- /efficientnet_pytorch/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for EfficientNet. 2 | They are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from .utils import ( 13 | round_filters, 14 | round_repeats, 15 | drop_connect, 16 | get_same_padding_conv2d, 17 | get_model_params, 18 | efficientnet_params, 19 | load_pretrained_weights, 20 | Swish, 21 | MemoryEfficientSwish, 22 | calculate_output_image_size 23 | ) 24 | 25 | class MBConvBlock(nn.Module): 26 | """Mobile Inverted Residual Bottleneck Block. 27 | 28 | Args: 29 | block_args (namedtuple): BlockArgs, defined in utils.py. 30 | global_params (namedtuple): GlobalParam, defined in utils.py. 31 | image_size (tuple or list): [image_height, image_width]. 32 | 33 | References: 34 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 35 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 36 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 37 | """ 38 | 39 | def __init__(self, block_args, global_params, image_size=None): 40 | super().__init__() 41 | self._block_args = block_args 42 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow 43 | self._bn_eps = global_params.batch_norm_epsilon 44 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 45 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect 46 | 47 | # Expansion phase (Inverted Bottleneck) 48 | inp = self._block_args.input_filters # number of input channels 49 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 50 | if self._block_args.expand_ratio != 1: 51 | Conv2d = get_same_padding_conv2d(image_size=image_size) 52 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 53 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 54 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 55 | 56 | # Depthwise convolution phase 57 | k = self._block_args.kernel_size 58 | s = self._block_args.stride 59 | Conv2d = get_same_padding_conv2d(image_size=image_size) 60 | self._depthwise_conv = Conv2d( 61 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 62 | kernel_size=k, stride=s, bias=False) 63 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 64 | image_size = calculate_output_image_size(image_size, s) 65 | 66 | # Squeeze and Excitation layer, if desired 67 | if self.has_se: 68 | Conv2d = get_same_padding_conv2d(image_size=(1,1)) 69 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 70 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 71 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 72 | 73 | # Pointwise convolution phase 74 | final_oup = self._block_args.output_filters 75 | Conv2d = get_same_padding_conv2d(image_size=image_size) 76 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 77 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 78 | self._swish = MemoryEfficientSwish() 79 | 80 | def forward(self, inputs, drop_connect_rate=None): 81 | """MBConvBlock's forward function. 82 | 83 | Args: 84 | inputs (tensor): Input tensor. 85 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 86 | 87 | Returns: 88 | Output of this block after processing. 89 | """ 90 | 91 | # Expansion and Depthwise Convolution 92 | x = inputs 93 | if self._block_args.expand_ratio != 1: 94 | x = self._expand_conv(inputs) 95 | x = self._bn0(x) 96 | x = self._swish(x) 97 | 98 | x = self._depthwise_conv(x) 99 | x = self._bn1(x) 100 | x = self._swish(x) 101 | 102 | # Squeeze and Excitation 103 | if self.has_se: 104 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 105 | x_squeezed = self._se_reduce(x_squeezed) 106 | x_squeezed = self._swish(x_squeezed) 107 | x_squeezed = self._se_expand(x_squeezed) 108 | x = torch.sigmoid(x_squeezed) * x 109 | 110 | # Pointwise Convolution 111 | x = self._project_conv(x) 112 | x = self._bn2(x) 113 | 114 | # Skip connection and drop connect 115 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 116 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 117 | # The combination of skip connection and drop connect brings about stochastic depth. 118 | if drop_connect_rate: 119 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 120 | x = x + inputs # skip connection 121 | return x 122 | 123 | def set_swish(self, memory_efficient=True): 124 | """Sets swish function as memory efficient (for training) or standard (for export). 125 | 126 | Args: 127 | memory_efficient (bool): Whether to use memory-efficient version of swish. 128 | """ 129 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 130 | 131 | 132 | class EfficientNet(nn.Module): 133 | """EfficientNet model. 134 | Most easily loaded with the .from_name or .from_pretrained methods. 135 | 136 | Args: 137 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 138 | global_params (namedtuple): A set of GlobalParams shared between blocks. 139 | 140 | References: 141 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 142 | 143 | Example: 144 | >>> import torch 145 | >>> from efficientnet.model import EfficientNet 146 | >>> inputs = torch.rand(1, 3, 224, 224) 147 | >>> model = EfficientNet.from_pretrained('efficientnet-b0') 148 | >>> model.eval() 149 | >>> outputs = model(inputs) 150 | """ 151 | 152 | def __init__(self, blocks_args=None, global_params=None): 153 | super().__init__() 154 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 155 | assert len(blocks_args) > 0, 'block args must be greater than 0' 156 | self._global_params = global_params 157 | self._blocks_args = blocks_args 158 | 159 | # Batch norm parameters 160 | bn_mom = 1 - self._global_params.batch_norm_momentum 161 | bn_eps = self._global_params.batch_norm_epsilon 162 | 163 | # Get stem static or dynamic convolution depending on image size 164 | image_size = global_params.image_size 165 | Conv2d = get_same_padding_conv2d(image_size=image_size) 166 | 167 | # Stem 168 | in_channels = 3 # rgb 169 | out_channels = round_filters(32, self._global_params) # number of output channels 170 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 171 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 172 | image_size = calculate_output_image_size(image_size, 2) 173 | 174 | # Build blocks 175 | self._blocks = nn.ModuleList([]) 176 | for block_args in self._blocks_args: 177 | 178 | # Update block input and output filters based on depth multiplier. 179 | block_args = block_args._replace( 180 | input_filters=round_filters(block_args.input_filters, self._global_params), 181 | output_filters=round_filters(block_args.output_filters, self._global_params), 182 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 183 | ) 184 | 185 | # The first block needs to take care of stride and filter size increase. 186 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 187 | image_size = calculate_output_image_size(image_size, block_args.stride) 188 | if block_args.num_repeat > 1: # modify block_args to keep same output size 189 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 190 | for _ in range(block_args.num_repeat - 1): 191 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 192 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 193 | 194 | # Head 195 | in_channels = block_args.output_filters # output of final block 196 | out_channels = round_filters(1280, self._global_params) 197 | Conv2d = get_same_padding_conv2d(image_size=image_size) 198 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 199 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 200 | 201 | # Final linear layer 202 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 203 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 204 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 205 | self._swish = MemoryEfficientSwish() 206 | 207 | def set_swish(self, memory_efficient=True): 208 | """Sets swish function as memory efficient (for training) or standard (for export). 209 | 210 | Args: 211 | memory_efficient (bool): Whether to use memory-efficient version of swish. 212 | 213 | """ 214 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 215 | for block in self._blocks: 216 | block.set_swish(memory_efficient) 217 | 218 | 219 | def extract_features(self, inputs): 220 | """use convolution layer to extract feature . 221 | 222 | Args: 223 | inputs (tensor): Input tensor. 224 | 225 | Returns: 226 | Output of the final convolution 227 | layer in the efficientnet model. 228 | """ 229 | # Stem 230 | x = self._swish(self._bn0(self._conv_stem(inputs))) 231 | 232 | # Blocks 233 | for idx, block in enumerate(self._blocks): 234 | drop_connect_rate = self._global_params.drop_connect_rate 235 | if drop_connect_rate: 236 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 237 | x = block(x, drop_connect_rate=drop_connect_rate) 238 | 239 | # Head 240 | x = self._swish(self._bn1(self._conv_head(x))) 241 | 242 | return x 243 | 244 | def forward(self, inputs): 245 | """EfficientNet's forward function. 246 | Calls extract_features to extract features, applies final linear layer, and returns logits. 247 | 248 | Args: 249 | inputs (tensor): Input tensor. 250 | 251 | Returns: 252 | Output of this model after processing. 253 | """ 254 | bs = inputs.size(0) 255 | 256 | # Convolution layers 257 | x = self.extract_features(inputs) 258 | 259 | # Pooling and final linear layer 260 | x = self._avg_pooling(x) 261 | x = x.view(bs, -1) 262 | x = self._dropout(x) 263 | x = self._fc(x) 264 | 265 | 266 | return x 267 | 268 | @classmethod 269 | def from_name(cls, model_name, in_channels=3, **override_params): 270 | """create an efficientnet model according to name. 271 | 272 | Args: 273 | model_name (str): Name for efficientnet. 274 | in_channels (int): Input data's channel number. 275 | override_params (other key word params): 276 | Params to override model's global_params. 277 | Optional key: 278 | 'width_coefficient', 'depth_coefficient', 279 | 'image_size', 'dropout_rate', 280 | 'num_classes', 'batch_norm_momentum', 281 | 'batch_norm_epsilon', 'drop_connect_rate', 282 | 'depth_divisor', 'min_depth' 283 | 284 | Returns: 285 | An efficientnet model. 286 | """ 287 | cls._check_model_name_is_valid(model_name) 288 | blocks_args, global_params = get_model_params(model_name, override_params) 289 | model = cls(blocks_args, global_params) 290 | model._change_in_channels(in_channels) 291 | return model 292 | 293 | @classmethod 294 | def from_pretrained(cls, model_name, weights_path=None, advprop=False, 295 | in_channels=3, num_classes=1000, **override_params): 296 | """create an efficientnet model according to name. 297 | 298 | Args: 299 | model_name (str): Name for efficientnet. 300 | weights_path (None or str): 301 | str: path to pretrained weights file on the local disk. 302 | None: use pretrained weights downloaded from the Internet. 303 | advprop (bool): 304 | Whether to load pretrained weights 305 | trained with advprop (valid when weights_path is None). 306 | in_channels (int): Input data's channel number. 307 | num_classes (int): 308 | Number of categories for classification. 309 | It controls the output size for final linear layer. 310 | override_params (other key word params): 311 | Params to override model's global_params. 312 | Optional key: 313 | 'width_coefficient', 'depth_coefficient', 314 | 'image_size', 'dropout_rate', 315 | 'num_classes', 'batch_norm_momentum', 316 | 'batch_norm_epsilon', 'drop_connect_rate', 317 | 'depth_divisor', 'min_depth' 318 | 319 | Returns: 320 | A pretrained efficientnet model. 321 | """ 322 | model = cls.from_name(model_name, num_classes = num_classes, **override_params) 323 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop) 324 | model._change_in_channels(in_channels) 325 | return model 326 | 327 | @classmethod 328 | def get_image_size(cls, model_name): 329 | """Get the input image size for a given efficientnet model. 330 | 331 | Args: 332 | model_name (str): Name for efficientnet. 333 | 334 | Returns: 335 | Input image size (resolution). 336 | """ 337 | cls._check_model_name_is_valid(model_name) 338 | _, _, res, _ = efficientnet_params(model_name) 339 | return res 340 | 341 | @classmethod 342 | def _check_model_name_is_valid(cls, model_name): 343 | """Validates model name. 344 | 345 | Args: 346 | model_name (str): Name for efficientnet. 347 | 348 | Returns: 349 | bool: Is a valid name or not. 350 | """ 351 | valid_models = ['efficientnet-b'+str(i) for i in range(9)] 352 | 353 | # Support the construction of 'efficientnet-l2' without pretrained weights 354 | valid_models += ['efficientnet-l2'] 355 | 356 | if model_name not in valid_models: 357 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 358 | 359 | def _change_in_channels(self, in_channels): 360 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 361 | 362 | Args: 363 | in_channels (int): Input data's channel number. 364 | """ 365 | if in_channels != 3: 366 | Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size) 367 | out_channels = round_filters(32, self._global_params) 368 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 369 | -------------------------------------------------------------------------------- /efficientnet_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions for building the model and for loading model parameters. 2 | These helper functions are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import re 10 | import math 11 | import collections 12 | from functools import partial 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.utils import model_zoo 17 | 18 | 19 | ################################################################################ 20 | ### Help functions for model architecture 21 | ################################################################################ 22 | 23 | # GlobalParams and BlockArgs: Two namedtuples 24 | # Swish and MemoryEfficientSwish: Two implementations of the method 25 | # round_filters and round_repeats: 26 | # Functions to calculate params for scaling model width and depth ! ! ! 27 | # get_width_and_height_from_size and calculate_output_image_size 28 | # drop_connect: A structural design 29 | # get_same_padding_conv2d: 30 | # Conv2dDynamicSamePadding 31 | # Conv2dStaticSamePadding 32 | # get_same_padding_maxPool2d: 33 | # MaxPool2dDynamicSamePadding 34 | # MaxPool2dStaticSamePadding 35 | # It's an additional function, not used in EfficientNet, 36 | # but can be used in other model (such as EfficientDet). 37 | # Identity: An implementation of identical mapping 38 | 39 | # Parameters for the entire model (stem, all blocks, and head) 40 | GlobalParams = collections.namedtuple('GlobalParams', [ 41 | 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', 42 | 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', 43 | 'drop_connect_rate', 'depth_divisor', 'min_depth']) 44 | 45 | # Parameters for an individual model block 46 | BlockArgs = collections.namedtuple('BlockArgs', [ 47 | 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 48 | 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) 49 | 50 | # Set GlobalParams and BlockArgs's defaults 51 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 52 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 53 | 54 | 55 | # An ordinary implementation of Swish function 56 | class Swish(nn.Module): 57 | def forward(self, x): 58 | return x * torch.sigmoid(x) 59 | 60 | 61 | # A memory-efficient implementation of Swish function 62 | class SwishImplementation(torch.autograd.Function): 63 | @staticmethod 64 | def forward(ctx, i): 65 | result = i * torch.sigmoid(i) 66 | ctx.save_for_backward(i) 67 | return result 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | i = ctx.saved_variables[0] 72 | sigmoid_i = torch.sigmoid(i) 73 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 74 | 75 | class MemoryEfficientSwish(nn.Module): 76 | def forward(self, x): 77 | return SwishImplementation.apply(x) 78 | 79 | 80 | def round_filters(filters, global_params): 81 | """Calculate and round number of filters based on width multiplier. 82 | Use width_coefficient, depth_divisor and min_depth of global_params. 83 | 84 | Args: 85 | filters (int): Filters number to be calculated. 86 | global_params (namedtuple): Global params of the model. 87 | 88 | Returns: 89 | new_filters: New filters number after calculating. 90 | """ 91 | multiplier = global_params.width_coefficient 92 | if not multiplier: 93 | return filters 94 | # TODO: modify the params names. 95 | # maybe the names (width_divisor,min_width) 96 | # are more suitable than (depth_divisor,min_depth). 97 | divisor = global_params.depth_divisor 98 | min_depth = global_params.min_depth 99 | filters *= multiplier 100 | min_depth = min_depth or divisor # pay attention to this line when using min_depth 101 | # follow the formula transferred from official TensorFlow implementation 102 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 103 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 104 | new_filters += divisor 105 | return int(new_filters) 106 | 107 | 108 | def round_repeats(repeats, global_params): 109 | """Calculate module's repeat number of a block based on depth multiplier. 110 | Use depth_coefficient of global_params. 111 | 112 | Args: 113 | repeats (int): num_repeat to be calculated. 114 | global_params (namedtuple): Global params of the model. 115 | 116 | Returns: 117 | new repeat: New repeat number after calculating. 118 | """ 119 | multiplier = global_params.depth_coefficient 120 | if not multiplier: 121 | return repeats 122 | # follow the formula transferred from official TensorFlow implementation 123 | return int(math.ceil(multiplier * repeats)) 124 | 125 | 126 | def drop_connect(inputs, p, training): 127 | """Drop connect. 128 | 129 | Args: 130 | input (tensor: BCWH): Input of this structure. 131 | p (float: 0.0~1.0): Probability of drop connection. 132 | training (bool): The running mode. 133 | 134 | Returns: 135 | output: Output after drop connection. 136 | """ 137 | assert p >= 0 and p <= 1, 'p must be in range of [0,1]' 138 | 139 | if not training: 140 | return inputs 141 | 142 | batch_size = inputs.shape[0] 143 | keep_prob = 1 - p 144 | 145 | # generate binary_tensor mask according to probability (p for 0, 1-p for 1) 146 | random_tensor = keep_prob 147 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 148 | binary_tensor = torch.floor(random_tensor) 149 | 150 | output = inputs / keep_prob * binary_tensor 151 | return output 152 | 153 | 154 | def get_width_and_height_from_size(x): 155 | """Obtain height and width from x. 156 | 157 | Args: 158 | x (int, tuple or list): Data size. 159 | 160 | Returns: 161 | size: A tuple or list (H,W). 162 | """ 163 | if isinstance(x, int): 164 | return x, x 165 | if isinstance(x, list) or isinstance(x, tuple): 166 | return x 167 | else: 168 | raise TypeError() 169 | 170 | 171 | def calculate_output_image_size(input_image_size, stride): 172 | """Calculates the output image size when using Conv2dSamePadding with a stride. 173 | Necessary for static padding. Thanks to mannatsingh for pointing this out. 174 | 175 | Args: 176 | input_image_size (int, tuple or list): Size of input image. 177 | stride (int, tuple or list): Conv2d operation's stride. 178 | 179 | Returns: 180 | output_image_size: A list [H,W]. 181 | """ 182 | if input_image_size is None: 183 | return None 184 | image_height, image_width = get_width_and_height_from_size(input_image_size) 185 | stride = stride if isinstance(stride, int) else stride[0] 186 | image_height = int(math.ceil(image_height / stride)) 187 | image_width = int(math.ceil(image_width / stride)) 188 | return [image_height, image_width] 189 | 190 | 191 | # Note: 192 | # The following 'SamePadding' functions make output size equal ceil(input size/stride). 193 | # Only when stride equals 1, can the output size be the same as input size. 194 | # Don't be confused by their function names ! ! ! 195 | 196 | def get_same_padding_conv2d(image_size=None): 197 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 198 | Static padding is necessary for ONNX exporting of models. 199 | 200 | Args: 201 | image_size (int or tuple): Size of the image. 202 | 203 | Returns: 204 | Conv2dDynamicSamePadding or Conv2dStaticSamePadding. 205 | """ 206 | if image_size is None: 207 | return Conv2dDynamicSamePadding 208 | else: 209 | return partial(Conv2dStaticSamePadding, image_size=image_size) 210 | 211 | 212 | class Conv2dDynamicSamePadding(nn.Conv2d): 213 | """2D Convolutions like TensorFlow, for a dynamic image size. 214 | The padding is operated in forward function by calculating dynamically. 215 | """ 216 | 217 | # Tips for 'SAME' mode padding. 218 | # Given the following: 219 | # i: width or height 220 | # s: stride 221 | # k: kernel size 222 | # d: dilation 223 | # p: padding 224 | # Output after Conv2d: 225 | # o = floor((i+p-((k-1)*d+1))/s+1) 226 | # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), 227 | # => p = (i-1)*s+((k-1)*d+1)-i 228 | 229 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 230 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 231 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 232 | 233 | def forward(self, x): 234 | ih, iw = x.size()[-2:] 235 | kh, kw = self.weight.size()[-2:] 236 | sh, sw = self.stride 237 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! 238 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 239 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 240 | if pad_h > 0 or pad_w > 0: 241 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 242 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 243 | 244 | 245 | class Conv2dStaticSamePadding(nn.Conv2d): 246 | """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. 247 | The padding mudule is calculated in construction function, then used in forward. 248 | """ 249 | 250 | # With the same calculation as Conv2dDynamicSamePadding 251 | 252 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): 253 | super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) 254 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 255 | 256 | # Calculate padding based on image size and save it 257 | assert image_size is not None 258 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 259 | kh, kw = self.weight.size()[-2:] 260 | sh, sw = self.stride 261 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 262 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 263 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 264 | if pad_h > 0 or pad_w > 0: 265 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 266 | else: 267 | self.static_padding = Identity() 268 | 269 | def forward(self, x): 270 | x = self.static_padding(x) 271 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 272 | return x 273 | 274 | 275 | def get_same_padding_maxPool2d(image_size=None): 276 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 277 | Static padding is necessary for ONNX exporting of models. 278 | 279 | Args: 280 | image_size (int or tuple): Size of the image. 281 | 282 | Returns: 283 | MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. 284 | """ 285 | if image_size is None: 286 | return MaxPool2dDynamicSamePadding 287 | else: 288 | return partial(MaxPool2dStaticSamePadding, image_size=image_size) 289 | 290 | 291 | class MaxPool2dDynamicSamePadding(nn.MaxPool2d): 292 | """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. 293 | The padding is operated in forward function by calculating dynamically. 294 | """ 295 | 296 | def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): 297 | super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) 298 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride 299 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size 300 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation 301 | 302 | def forward(self, x): 303 | ih, iw = x.size()[-2:] 304 | kh, kw = self.kernel_size 305 | sh, sw = self.stride 306 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 307 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 308 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 309 | if pad_h > 0 or pad_w > 0: 310 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 311 | return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, 312 | self.dilation, self.ceil_mode, self.return_indices) 313 | 314 | class MaxPool2dStaticSamePadding(nn.MaxPool2d): 315 | """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. 316 | The padding mudule is calculated in construction function, then used in forward. 317 | """ 318 | 319 | def __init__(self, kernel_size, stride, image_size=None, **kwargs): 320 | super().__init__(kernel_size, stride, **kwargs) 321 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride 322 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size 323 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation 324 | 325 | # Calculate padding based on image size and save it 326 | assert image_size is not None 327 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 328 | kh, kw = self.kernel_size 329 | sh, sw = self.stride 330 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 331 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 332 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 333 | if pad_h > 0 or pad_w > 0: 334 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 335 | else: 336 | self.static_padding = Identity() 337 | 338 | def forward(self, x): 339 | x = self.static_padding(x) 340 | x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, 341 | self.dilation, self.ceil_mode, self.return_indices) 342 | return x 343 | 344 | class Identity(nn.Module): 345 | """Identity mapping. 346 | Send input to output directly. 347 | """ 348 | 349 | def __init__(self): 350 | super(Identity, self).__init__() 351 | 352 | def forward(self, input): 353 | return input 354 | 355 | 356 | ################################################################################ 357 | ### Helper functions for loading model params 358 | ################################################################################ 359 | 360 | # BlockDecoder: A Class for encoding and decoding BlockArgs 361 | # efficientnet_params: A function to query compound coefficient 362 | # get_model_params and efficientnet: 363 | # Functions to get BlockArgs and GlobalParams for efficientnet 364 | # url_map and url_map_advprop: Dicts of url_map for pretrained weights 365 | # load_pretrained_weights: A function to load pretrained weights 366 | 367 | class BlockDecoder(object): 368 | """Block Decoder for readability, 369 | straight from the official TensorFlow repository. 370 | """ 371 | 372 | @staticmethod 373 | def _decode_block_string(block_string): 374 | """Get a block through a string notation of arguments. 375 | 376 | Args: 377 | block_string (str): A string notation of arguments. 378 | Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. 379 | 380 | Returns: 381 | BlockArgs: The namedtuple defined at the top of this file. 382 | """ 383 | assert isinstance(block_string, str) 384 | 385 | ops = block_string.split('_') 386 | options = {} 387 | for op in ops: 388 | splits = re.split(r'(\d.*)', op) 389 | if len(splits) >= 2: 390 | key, value = splits[:2] 391 | options[key] = value 392 | 393 | # Check stride 394 | assert (('s' in options and len(options['s']) == 1) or 395 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 396 | 397 | return BlockArgs( 398 | num_repeat=int(options['r']), 399 | kernel_size=int(options['k']), 400 | stride=[int(options['s'][0])], 401 | expand_ratio=int(options['e']), 402 | input_filters=int(options['i']), 403 | output_filters=int(options['o']), 404 | se_ratio=float(options['se']) if 'se' in options else None, 405 | id_skip=('noskip' not in block_string)) 406 | 407 | @staticmethod 408 | def _encode_block_string(block): 409 | """Encode a block to a string. 410 | 411 | Args: 412 | block (namedtuple): A BlockArgs type argument. 413 | 414 | Returns: 415 | block_string: A String form of BlockArgs. 416 | """ 417 | args = [ 418 | 'r%d' % block.num_repeat, 419 | 'k%d' % block.kernel_size, 420 | 's%d%d' % (block.strides[0], block.strides[1]), 421 | 'e%s' % block.expand_ratio, 422 | 'i%d' % block.input_filters, 423 | 'o%d' % block.output_filters 424 | ] 425 | if 0 < block.se_ratio <= 1: 426 | args.append('se%s' % block.se_ratio) 427 | if block.id_skip is False: 428 | args.append('noskip') 429 | return '_'.join(args) 430 | 431 | @staticmethod 432 | def decode(string_list): 433 | """Decode a list of string notations to specify blocks inside the network. 434 | 435 | Args: 436 | string_list (list[str]): A list of strings, each string is a notation of block. 437 | 438 | Returns: 439 | blocks_args: A list of BlockArgs namedtuples of block args. 440 | """ 441 | assert isinstance(string_list, list) 442 | blocks_args = [] 443 | for block_string in string_list: 444 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 445 | return blocks_args 446 | 447 | @staticmethod 448 | def encode(blocks_args): 449 | """Encode a list of BlockArgs to a list of strings. 450 | 451 | Args: 452 | blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. 453 | 454 | Returns: 455 | block_strings: A list of strings, each string is a notation of block. 456 | """ 457 | block_strings = [] 458 | for block in blocks_args: 459 | block_strings.append(BlockDecoder._encode_block_string(block)) 460 | return block_strings 461 | 462 | 463 | def efficientnet_params(model_name): 464 | """Map EfficientNet model name to parameter coefficients. 465 | 466 | Args: 467 | model_name (str): Model name to be queried. 468 | 469 | Returns: 470 | params_dict[model_name]: A (width,depth,res,dropout) tuple. 471 | """ 472 | params_dict = { 473 | # Coefficients: width,depth,res,dropout 474 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 475 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 476 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 477 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 478 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 479 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 480 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 481 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 482 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 483 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 484 | } 485 | return params_dict[model_name] 486 | 487 | 488 | def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, 489 | dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000): 490 | """Create BlockArgs and GlobalParams for efficientnet model. 491 | 492 | Args: 493 | width_coefficient (float) 494 | depth_coefficient (float) 495 | image_size (int) 496 | dropout_rate (float) 497 | drop_connect_rate (float) 498 | num_classes (int) 499 | 500 | Meaning as the name suggests. 501 | 502 | Returns: 503 | blocks_args, global_params. 504 | """ 505 | 506 | # Blocks args for the whole model(efficientnet-b0 by default) 507 | # It will be modified in the construction of EfficientNet Class according to model 508 | blocks_args = [ 509 | 'r1_k3_s11_e1_i32_o16_se0.25', 510 | 'r2_k3_s22_e6_i16_o24_se0.25', 511 | 'r2_k5_s22_e6_i24_o40_se0.25', 512 | 'r3_k3_s22_e6_i40_o80_se0.25', 513 | 'r3_k5_s11_e6_i80_o112_se0.25', 514 | 'r4_k5_s22_e6_i112_o192_se0.25', 515 | 'r1_k3_s11_e6_i192_o320_se0.25', 516 | ] 517 | blocks_args = BlockDecoder.decode(blocks_args) 518 | 519 | global_params = GlobalParams( 520 | width_coefficient=width_coefficient, 521 | depth_coefficient=depth_coefficient, 522 | image_size=image_size, 523 | dropout_rate=dropout_rate, 524 | 525 | num_classes=num_classes, 526 | batch_norm_momentum=0.99, 527 | batch_norm_epsilon=1e-3, 528 | drop_connect_rate=drop_connect_rate, 529 | depth_divisor=8, 530 | min_depth=None, 531 | ) 532 | 533 | return blocks_args, global_params 534 | 535 | 536 | def get_model_params(model_name, override_params): 537 | """Get the block args and global params for a given model name. 538 | 539 | Args: 540 | model_name (str): Model's name. 541 | override_params (dict): A dict to modify global_params. 542 | 543 | Returns: 544 | blocks_args, global_params 545 | """ 546 | if model_name.startswith('efficientnet'): 547 | w, d, s, p = efficientnet_params(model_name) 548 | # note: all models have drop connect rate = 0.2 549 | blocks_args, global_params = efficientnet( 550 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 551 | else: 552 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 553 | if override_params: 554 | # ValueError will be raised here if override_params has fields not included in global_params. 555 | global_params = global_params._replace(**override_params) 556 | return blocks_args, global_params 557 | 558 | 559 | # train with Standard methods 560 | # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) 561 | url_map = { 562 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', 563 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', 564 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', 565 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', 566 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', 567 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', 568 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', 569 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', 570 | } 571 | 572 | # train with Adversarial Examples(AdvProp) 573 | # check more details in paper(Adversarial Examples Improve Image Recognition) 574 | url_map_advprop = { 575 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', 576 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', 577 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', 578 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', 579 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', 580 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', 581 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', 582 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', 583 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', 584 | } 585 | 586 | # TODO: add the petrained weights url map of 'efficientnet-l2' 587 | 588 | 589 | def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False): 590 | """Loads pretrained weights from weights path or download using url. 591 | 592 | Args: 593 | model (Module): The whole model of efficientnet. 594 | model_name (str): Model name of efficientnet. 595 | weights_path (None or str): 596 | str: path to pretrained weights file on the local disk. 597 | None: use pretrained weights downloaded from the Internet. 598 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 599 | advprop (bool): Whether to load pretrained weights 600 | trained with advprop (valid when weights_path is None). 601 | """ 602 | if isinstance(weights_path,str): 603 | state_dict = torch.load(weights_path) 604 | else: 605 | # AutoAugment or Advprop (different preprocessing) 606 | url_map_ = url_map_advprop if advprop else url_map 607 | state_dict = model_zoo.load_url(url_map_[model_name]) 608 | 609 | if load_fc: 610 | ret = model.load_state_dict(state_dict, strict=False) 611 | # assert not ret.missing_keys, f'Missing keys when loading pretrained weights: {ret.missing_keys}' 612 | else: 613 | state_dict.pop('_fc.weight') 614 | state_dict.pop('_fc.bias') 615 | ret = model.load_state_dict(state_dict, strict=False) 616 | # assert set(ret.missing_keys) == set( 617 | # ['_fc.weight', '_fc.bias']), f'Missing keys when loading pretrained weights: {ret.missing_keys}' 618 | # assert not ret.unexpected_keys, f'Missing keys when loading pretrained weights: {ret.unexpected_keys}' 619 | 620 | print('Loaded pretrained weights for {}'.format(model_name)) 621 | --------------------------------------------------------------------------------