├── README.md ├── efficientnet ├── _init_.py ├── model.py └── utils.py └── efficientnet_sample.py /README.md: -------------------------------------------------------------------------------- 1 | # EfficientNet-Pytorch 2 | A demo for train your own dataset on EfficientNet 3 | Thanks for the >[A PyTorch implementation of EfficientNet](https://github.com/lukemelas/EfficientNet-PyTorch), I just simply demonstrate how to train your own dataset based on the EfficientNet-Pytorch. 4 | ## Step 1:Prepare your own classification dataset 5 | --- 6 | Then the data directory should looks like: 7 | ``` 8 | -dataset\ 9 | -model\ 10 | -train\ 11 | -1\ 12 | -2\ 13 | ... 14 | -test\ 15 | -1\ 16 | -2\ 17 | ... 18 | ``` 19 | 20 | ## Step 2: train and test 21 | 22 | ```python efficientnet_sample.py``` 23 | 24 | ```--data-dir``` : (str) Path of ```/dataset``` folder. Default: ```None``` 25 | 26 | ```--num-epochs``` : (int) Number of epochs for training. Default: ```40``` 27 | 28 | ```--batch-size``` : (int) Batch size. Default: ```4``` 29 | 30 | ```--img-size``` : (int) Selected size for image to be resized. Default: ```[1024,1024]``` 31 | 32 | ```--class-num``` : (int) Number of classes in dataset. Default: ```3``` 33 | 34 | ```--weights-loc``` : (str) Path of weights to be loaded. If None, pretrained weights will automatically be downloaded & loaded. Default: ```None``` Example: ```"...//weights.pth//"``` 35 | 36 | ```--lr``` : (float) Learning rate. Default: ```0.01``` 37 | 38 | ```--net-name``` : (str) States which efficientnet model will be used. Used for downloading pretrained weights as well. 39 | 40 | ```--resume-epoch``` : (int) Defines starting epoch. Default: ```0``` 41 | 42 | ```--momentum``` : (float) Sets momentum. Default: ```0.9``` 43 | 44 | Example usage: ```python ".\efficientnet_sample.py" --data-dir "D:\\ml_data\\dataset" --num-epochs 80 --batch-size 4 --img-size 896 --class-num 3 --weights-loc "D:\\ML\\efficientnet-b3-birads.pth" --lr 0.01 --net-name "efficientnet-b3" --resume-epoch 40``` 45 | 46 | 47 | The pre-trained model is available on >[release](https://github.com/lukemelas/EfficientNet-PyTorch/releases). 48 | You can download them under the folder ```eff_weights```. 49 | 50 | (3)You can get the final results and the best model on ```dataset/model/```. 51 | -------------------------------------------------------------------------------- /efficientnet/_init_.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.3" 2 | from .model import EfficientNet 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | -------------------------------------------------------------------------------- /efficientnet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import ( 6 | round_filters, 7 | round_repeats, 8 | drop_connect, 9 | get_same_padding_conv2d, 10 | get_model_params, 11 | efficientnet_params, 12 | load_pretrained_weights, 13 | Swish, 14 | MemoryEfficientSwish, 15 | ) 16 | 17 | 18 | class MBConvBlock(nn.Module): 19 | """ 20 | Mobile Inverted Residual Bottleneck Block 21 | Args: 22 | block_args (namedtuple): BlockArgs, see above 23 | global_params (namedtuple): GlobalParam, see above 24 | Attributes: 25 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 26 | """ 27 | 28 | def __init__(self, block_args, global_params): 29 | super().__init__() 30 | self._block_args = block_args 31 | self._bn_mom = 1 - global_params.batch_norm_momentum 32 | self._bn_eps = global_params.batch_norm_epsilon 33 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 34 | self.id_skip = block_args.id_skip # skip connection and drop connect 35 | 36 | # Get static or dynamic convolution depending on image size 37 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 38 | 39 | # Expansion phase 40 | inp = self._block_args.input_filters # number of input channels 41 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 42 | if self._block_args.expand_ratio != 1: 43 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 44 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 45 | 46 | # Depthwise convolution phase 47 | k = self._block_args.kernel_size 48 | s = self._block_args.stride 49 | self._depthwise_conv = Conv2d( 50 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 51 | kernel_size=k, stride=s, bias=False) 52 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 53 | 54 | # Squeeze and Excitation layer, if desired 55 | if self.has_se: 56 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 57 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 58 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 59 | 60 | # Output phase 61 | final_oup = self._block_args.output_filters 62 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 63 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 64 | self._swish = MemoryEfficientSwish() 65 | 66 | def forward(self, inputs, drop_connect_rate=None): 67 | """ 68 | :param inputs: input tensor 69 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 70 | :return: output of block 71 | """ 72 | 73 | # Expansion and Depthwise Convolution 74 | x = inputs 75 | if self._block_args.expand_ratio != 1: 76 | x = self._swish(self._bn0(self._expand_conv(inputs))) 77 | x = self._swish(self._bn1(self._depthwise_conv(x))) 78 | 79 | # Squeeze and Excitation 80 | if self.has_se: 81 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 82 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 83 | x = torch.sigmoid(x_squeezed) * x 84 | 85 | x = self._bn2(self._project_conv(x)) 86 | 87 | # Skip connection and drop connect 88 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 89 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 90 | if drop_connect_rate: 91 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 92 | x = x + inputs # skip connection 93 | return x 94 | 95 | def set_swish(self, memory_efficient=True): 96 | """Sets swish function as memory efficient (for training) or standard (for export)""" 97 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 98 | 99 | 100 | class EfficientNet(nn.Module): 101 | """ 102 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 103 | Args: 104 | blocks_args (list): A list of BlockArgs to construct blocks 105 | global_params (namedtuple): A set of GlobalParams shared between blocks 106 | Example: 107 | model = EfficientNet.from_pretrained('efficientnet-b0') 108 | """ 109 | 110 | def __init__(self, blocks_args=None, global_params=None): 111 | super().__init__() 112 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 113 | assert len(blocks_args) > 0, 'block args must be greater than 0' 114 | self._global_params = global_params 115 | self._blocks_args = blocks_args 116 | 117 | # Get static or dynamic convolution depending on image size 118 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 119 | 120 | # Batch norm parameters 121 | bn_mom = 1 - self._global_params.batch_norm_momentum 122 | bn_eps = self._global_params.batch_norm_epsilon 123 | 124 | # Stem 125 | in_channels = 3 # rgb 126 | out_channels = round_filters(32, self._global_params) # number of output channels 127 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 128 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 129 | 130 | # Build blocks 131 | self._blocks = nn.ModuleList([]) 132 | for block_args in self._blocks_args: 133 | 134 | # Update block input and output filters based on depth multiplier. 135 | block_args = block_args._replace( 136 | input_filters=round_filters(block_args.input_filters, self._global_params), 137 | output_filters=round_filters(block_args.output_filters, self._global_params), 138 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 139 | ) 140 | 141 | # The first block needs to take care of stride and filter size increase. 142 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 143 | if block_args.num_repeat > 1: 144 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 145 | for _ in range(block_args.num_repeat - 1): 146 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 147 | 148 | # Head 149 | in_channels = block_args.output_filters # output of final block 150 | out_channels = round_filters(1280, self._global_params) 151 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 152 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 153 | 154 | # Final linear layer 155 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 156 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 157 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 158 | self._swish = MemoryEfficientSwish() 159 | 160 | def set_swish(self, memory_efficient=True): 161 | """Sets swish function as memory efficient (for training) or standard (for export)""" 162 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 163 | for block in self._blocks: 164 | block.set_swish(memory_efficient) 165 | 166 | def extract_features(self, inputs): 167 | """ Returns output of the final convolution layer """ 168 | 169 | # Stem 170 | x = self._swish(self._bn0(self._conv_stem(inputs))) 171 | 172 | # Blocks 173 | for idx, block in enumerate(self._blocks): 174 | drop_connect_rate = self._global_params.drop_connect_rate 175 | if drop_connect_rate: 176 | drop_connect_rate *= float(idx) / len(self._blocks) 177 | x = block(x, drop_connect_rate=drop_connect_rate) 178 | 179 | # Head 180 | x = self._swish(self._bn1(self._conv_head(x))) 181 | 182 | return x 183 | 184 | def forward(self, inputs): 185 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ 186 | bs = inputs.size(0) 187 | # Convolution layers 188 | x = self.extract_features(inputs) 189 | 190 | # Pooling and final linear layer 191 | x = self._avg_pooling(x) 192 | x = x.view(bs, -1) 193 | x = self._dropout(x) 194 | x = self._fc(x) 195 | return x 196 | 197 | @classmethod 198 | def from_name(cls, model_name, override_params=None): 199 | cls._check_model_name_is_valid(model_name) 200 | blocks_args, global_params = get_model_params(model_name, override_params) 201 | return cls(blocks_args, global_params) 202 | 203 | @classmethod 204 | def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3): 205 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 206 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) 207 | if in_channels != 3: 208 | Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size) 209 | out_channels = round_filters(32, model._global_params) 210 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 211 | return model 212 | 213 | @classmethod 214 | def get_image_size(cls, model_name): 215 | cls._check_model_name_is_valid(model_name) 216 | _, _, res, _ = efficientnet_params(model_name) 217 | return res 218 | 219 | @classmethod 220 | def _check_model_name_is_valid(cls, model_name): 221 | """ Validates model name. """ 222 | valid_models = ['efficientnet-b' + str(i) for i in range(9)] 223 | if model_name not in valid_models: 224 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) -------------------------------------------------------------------------------- /efficientnet/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains helper functions for building the model and for loading model parameters. 3 | These helper functions are built to mirror those in the official TensorFlow implementation. 4 | """ 5 | 6 | import re 7 | import math 8 | import collections 9 | from functools import partial 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torch.utils import model_zoo 14 | 15 | ######################################################################## 16 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### 17 | ######################################################################## 18 | 19 | 20 | # Parameters for the entire model (stem, all blocks, and head) 21 | GlobalParams = collections.namedtuple('GlobalParams', [ 22 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 23 | 'num_classes', 'width_coefficient', 'depth_coefficient', 24 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) 25 | 26 | # Parameters for an individual model block 27 | BlockArgs = collections.namedtuple('BlockArgs', [ 28 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 29 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) 30 | 31 | # Change namedtuple defaults 32 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 33 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 34 | 35 | 36 | class SwishImplementation(torch.autograd.Function): 37 | @staticmethod 38 | def forward(ctx, i): 39 | result = i * torch.sigmoid(i) 40 | ctx.save_for_backward(i) 41 | return result 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | i = ctx.saved_variables[0] 46 | sigmoid_i = torch.sigmoid(i) 47 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 48 | 49 | 50 | class MemoryEfficientSwish(nn.Module): 51 | def forward(self, x): 52 | return SwishImplementation.apply(x) 53 | 54 | class Swish(nn.Module): 55 | def forward(self, x): 56 | return x * torch.sigmoid(x) 57 | 58 | 59 | def round_filters(filters, global_params): 60 | """ Calculate and round number of filters based on depth multiplier. """ 61 | multiplier = global_params.width_coefficient 62 | if not multiplier: 63 | return filters 64 | divisor = global_params.depth_divisor 65 | min_depth = global_params.min_depth 66 | filters *= multiplier 67 | min_depth = min_depth or divisor 68 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 69 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 70 | new_filters += divisor 71 | return int(new_filters) 72 | 73 | 74 | def round_repeats(repeats, global_params): 75 | """ Round number of filters based on depth multiplier. """ 76 | multiplier = global_params.depth_coefficient 77 | if not multiplier: 78 | return repeats 79 | return int(math.ceil(multiplier * repeats)) 80 | 81 | 82 | def drop_connect(inputs, p, training): 83 | """ Drop connect. """ 84 | if not training: return inputs 85 | batch_size = inputs.shape[0] 86 | keep_prob = 1 - p 87 | random_tensor = keep_prob 88 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 89 | binary_tensor = torch.floor(random_tensor) 90 | output = inputs / keep_prob * binary_tensor 91 | return output 92 | 93 | 94 | def get_same_padding_conv2d(image_size=None): 95 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. 96 | Static padding is necessary for ONNX exporting of models. """ 97 | if image_size is None: 98 | return Conv2dDynamicSamePadding 99 | else: 100 | return partial(Conv2dStaticSamePadding, image_size=image_size) 101 | 102 | 103 | class Conv2dDynamicSamePadding(nn.Conv2d): 104 | """ 2D Convolutions like TensorFlow, for a dynamic image size """ 105 | 106 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 107 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 108 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 109 | 110 | def forward(self, x): 111 | ih, iw = x.size()[-2:] 112 | kh, kw = self.weight.size()[-2:] 113 | sh, sw = self.stride 114 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 115 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 116 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 117 | if pad_h > 0 or pad_w > 0: 118 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 119 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 120 | 121 | 122 | class Conv2dStaticSamePadding(nn.Conv2d): 123 | """ 2D Convolutions like TensorFlow, for a fixed image size""" 124 | 125 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): 126 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 127 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 128 | 129 | # Calculate padding based on image size and save it 130 | assert image_size is not None 131 | ih, iw = image_size if type(image_size) == list else [image_size, image_size] 132 | kh, kw = self.weight.size()[-2:] 133 | sh, sw = self.stride 134 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 135 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 136 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 137 | if pad_h > 0 or pad_w > 0: 138 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 139 | else: 140 | self.static_padding = Identity() 141 | 142 | def forward(self, x): 143 | x = self.static_padding(x) 144 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 145 | return x 146 | 147 | 148 | class Identity(nn.Module): 149 | def __init__(self, ): 150 | super(Identity, self).__init__() 151 | 152 | def forward(self, input): 153 | return input 154 | 155 | 156 | ######################################################################## 157 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## 158 | ######################################################################## 159 | 160 | 161 | def efficientnet_params(model_name): 162 | """ Map EfficientNet model name to parameter coefficients. """ 163 | params_dict = { 164 | # Coefficients: width,depth,res,dropout 165 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 166 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 167 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 168 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 169 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 170 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 171 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 172 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 173 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 174 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 175 | } 176 | return params_dict[model_name] 177 | 178 | 179 | class BlockDecoder(object): 180 | """ Block Decoder for readability, straight from the official TensorFlow repository """ 181 | 182 | @staticmethod 183 | def _decode_block_string(block_string): 184 | """ Gets a block through a string notation of arguments. """ 185 | assert isinstance(block_string, str) 186 | 187 | ops = block_string.split('_') 188 | options = {} 189 | for op in ops: 190 | splits = re.split(r'(\d.*)', op) 191 | if len(splits) >= 2: 192 | key, value = splits[:2] 193 | options[key] = value 194 | 195 | # Check stride 196 | assert (('s' in options and len(options['s']) == 1) or 197 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 198 | 199 | return BlockArgs( 200 | kernel_size=int(options['k']), 201 | num_repeat=int(options['r']), 202 | input_filters=int(options['i']), 203 | output_filters=int(options['o']), 204 | expand_ratio=int(options['e']), 205 | id_skip=('noskip' not in block_string), 206 | se_ratio=float(options['se']) if 'se' in options else None, 207 | stride=[int(options['s'][0])]) 208 | 209 | @staticmethod 210 | def _encode_block_string(block): 211 | """Encodes a block to a string.""" 212 | args = [ 213 | 'r%d' % block.num_repeat, 214 | 'k%d' % block.kernel_size, 215 | 's%d%d' % (block.strides[0], block.strides[1]), 216 | 'e%s' % block.expand_ratio, 217 | 'i%d' % block.input_filters, 218 | 'o%d' % block.output_filters 219 | ] 220 | if 0 < block.se_ratio <= 1: 221 | args.append('se%s' % block.se_ratio) 222 | if block.id_skip is False: 223 | args.append('noskip') 224 | return '_'.join(args) 225 | 226 | @staticmethod 227 | def decode(string_list): 228 | """ 229 | Decodes a list of string notations to specify blocks inside the network. 230 | :param string_list: a list of strings, each string is a notation of block 231 | :return: a list of BlockArgs namedtuples of block args 232 | """ 233 | assert isinstance(string_list, list) 234 | blocks_args = [] 235 | for block_string in string_list: 236 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 237 | return blocks_args 238 | 239 | @staticmethod 240 | def encode(blocks_args): 241 | """ 242 | Encodes a list of BlockArgs to a list of strings. 243 | :param blocks_args: a list of BlockArgs namedtuples of block args 244 | :return: a list of strings, each string is a notation of block 245 | """ 246 | block_strings = [] 247 | for block in blocks_args: 248 | block_strings.append(BlockDecoder._encode_block_string(block)) 249 | return block_strings 250 | 251 | 252 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, 253 | drop_connect_rate=0.2, image_size=None, num_classes=1000): 254 | """ Creates a efficientnet model. """ 255 | 256 | blocks_args = [ 257 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 258 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 259 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 260 | 'r1_k3_s11_e6_i192_o320_se0.25', 261 | ] 262 | blocks_args = BlockDecoder.decode(blocks_args) 263 | 264 | global_params = GlobalParams( 265 | batch_norm_momentum=0.99, 266 | batch_norm_epsilon=1e-3, 267 | dropout_rate=dropout_rate, 268 | drop_connect_rate=drop_connect_rate, 269 | # data_format='channels_last', # removed, this is always true in PyTorch 270 | num_classes=num_classes, 271 | width_coefficient=width_coefficient, 272 | depth_coefficient=depth_coefficient, 273 | depth_divisor=8, 274 | min_depth=None, 275 | image_size=image_size, 276 | ) 277 | 278 | return blocks_args, global_params 279 | 280 | 281 | def get_model_params(model_name, override_params): 282 | """ Get the block args and global params for a given model """ 283 | if model_name.startswith('efficientnet'): 284 | w, d, s, p = efficientnet_params(model_name) 285 | # note: all models have drop connect rate = 0.2 286 | blocks_args, global_params = efficientnet( 287 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 288 | else: 289 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 290 | if override_params: 291 | # ValueError will be raised here if override_params has fields not included in global_params. 292 | global_params = global_params._replace(**override_params) 293 | return blocks_args, global_params 294 | 295 | 296 | url_map = { 297 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', 298 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', 299 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', 300 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', 301 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', 302 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', 303 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', 304 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', 305 | } 306 | 307 | 308 | url_map_advprop = { 309 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', 310 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', 311 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', 312 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', 313 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', 314 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', 315 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', 316 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', 317 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', 318 | } 319 | 320 | 321 | def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): 322 | """ Loads pretrained weights, and downloads if loading for the first time. """ 323 | # AutoAugment or Advprop (different preprocessing) 324 | url_map_ = url_map_advprop if advprop else url_map 325 | state_dict = model_zoo.load_url(url_map_[model_name]) 326 | if load_fc: 327 | model.load_state_dict(state_dict) 328 | else: 329 | state_dict.pop('_fc.weight') 330 | state_dict.pop('_fc.bias') 331 | res = model.load_state_dict(state_dict, strict=False) 332 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' 333 | print('Loaded pretrained weights for {}'.format(model_name)) -------------------------------------------------------------------------------- /efficientnet_sample.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from torchvision import datasets, models, transforms 8 | import time 9 | import os 10 | from efficientnet.model import EfficientNet 11 | 12 | import argparse 13 | 14 | # some parameters 15 | use_gpu = torch.cuda.is_available() 16 | print(use_gpu) 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | data_dir = '' 21 | num_epochs = 40 22 | batch_size = 2 23 | input_size = 4 24 | class_num = 3 25 | weights_loc = "" 26 | lr = 0.01 27 | net_name = 'efficientnet-b3' 28 | epoch_to_resume_from = 0 29 | momentum = 0.9 30 | 31 | 32 | def loaddata(data_dir, batch_size, set_name, shuffle): 33 | data_transforms = { 34 | 'train': transforms.Compose([ 35 | transforms.Resize(input_size), 36 | transforms.CenterCrop(input_size), 37 | transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 41 | ]), 42 | 'test': transforms.Compose([ 43 | transforms.Resize(input_size), 44 | transforms.CenterCrop(input_size), 45 | transforms.ToTensor(), 46 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 47 | ]), 48 | } 49 | 50 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [set_name]} 51 | # num_workers=0 if CPU else =1 52 | dataset_loaders = {x: torch.utils.data.DataLoader(image_datasets[x], 53 | batch_size=batch_size, 54 | shuffle=shuffle, num_workers=1) for x in [set_name]} 55 | data_set_sizes = len(image_datasets[set_name]) 56 | return dataset_loaders, data_set_sizes 57 | 58 | 59 | def train_model(model_ft, criterion, optimizer, lr_scheduler, num_epochs=50): 60 | 61 | train_loss = [] 62 | since = time.time() 63 | best_model_wts = model_ft.state_dict() 64 | best_acc = 0.0 65 | model_ft.train(True) 66 | 67 | for epoch in range(epoch_to_resume_from, num_epochs): 68 | 69 | dset_loaders, dset_sizes = loaddata(data_dir=data_dir, batch_size=batch_size, set_name='train', shuffle=True) 70 | print('Data Size', dset_sizes) 71 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 72 | print('-' * 10) 73 | optimizer = lr_scheduler(optimizer, epoch) 74 | 75 | running_loss = 0.0 76 | running_corrects = 0 77 | count = 0 78 | 79 | for data in dset_loaders['train']: 80 | inputs, labels = data 81 | labels = torch.squeeze(labels.type(torch.LongTensor)) 82 | if use_gpu: 83 | inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) 84 | else: 85 | inputs, labels = Variable(inputs), Variable(labels) 86 | 87 | outputs = model_ft(inputs) 88 | 89 | if count % 500 == 0: 90 | print(outputs) 91 | print(labels) 92 | 93 | loss = criterion(outputs, labels) 94 | _, preds = torch.max(outputs.data, 1) 95 | 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | count += 1 101 | if count % 30 == 0 or outputs.size()[0] < batch_size: 102 | print('Epoch:{}: loss:{:.3f}'.format(epoch, loss.item())) 103 | train_loss.append(loss.item()) 104 | 105 | running_loss += loss.item() * inputs.size(0) 106 | running_corrects += torch.sum(preds == labels.data) 107 | 108 | epoch_loss = running_loss / dset_sizes 109 | epoch_acc = running_corrects.double() / dset_sizes 110 | 111 | print('Loss: {:.4f} Acc: {:.4f}'.format( 112 | epoch_loss, epoch_acc)) 113 | 114 | if epoch_acc > best_acc: 115 | best_acc = epoch_acc 116 | best_model_wts = model_ft.state_dict() 117 | if epoch_acc > 0.999: 118 | break 119 | 120 | # save best model 121 | save_dir = data_dir + '/model' 122 | model_ft.load_state_dict(best_model_wts) 123 | model_out_path = save_dir + "/" + net_name + '.pth' 124 | torch.save(model_ft, model_out_path) 125 | 126 | time_elapsed = time.time() - since 127 | print('Training complete in {:.0f}m {:.0f}s'.format( 128 | time_elapsed // 60, time_elapsed % 60)) 129 | 130 | return train_loss, best_model_wts 131 | 132 | 133 | def test_model(model, criterion): 134 | model.eval() 135 | running_loss = 0.0 136 | running_corrects = 0 137 | cont = 0 138 | outPre = [] 139 | outLabel = [] 140 | dset_loaders, dset_sizes = loaddata(data_dir=data_dir, batch_size=batch_size, set_name='test', shuffle=False) 141 | for data in dset_loaders['test']: 142 | inputs, labels = data 143 | labels = torch.squeeze(labels.type(torch.LongTensor)) 144 | inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) 145 | outputs = model(inputs) 146 | _, preds = torch.max(outputs.data, 1) 147 | loss = criterion(outputs, labels) 148 | if cont == 0: 149 | outPre = outputs.data.cpu() 150 | outLabel = labels.data.cpu() 151 | else: 152 | outPre = torch.cat((outPre, outputs.data.cpu()), 0) 153 | outLabel = torch.cat((outLabel, labels.data.cpu()), 0) 154 | running_loss += loss.item() * inputs.size(0) 155 | running_corrects += torch.sum(preds == labels.data) 156 | cont += 1 157 | print('Loss: {:.4f} Acc: {:.4f}'.format(running_loss / dset_sizes, 158 | running_corrects.double() / dset_sizes)) 159 | 160 | 161 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.01, lr_decay_epoch=10): 162 | """Decay learning rate by a f# model_out_path ="./model/W_epoch_{}.pth".format(epoch) 163 | # torch.save(model_W, model_out_path) actor of 0.1 every lr_decay_epoch epochs.""" 164 | lr = init_lr * (0.8**(epoch // lr_decay_epoch)) 165 | print('LR is set to {}'.format(lr)) 166 | for param_group in optimizer.param_groups: 167 | param_group['lr'] = lr 168 | 169 | return optimizer 170 | 171 | 172 | def run(): 173 | # train 174 | pth_map = { 175 | 'efficientnet-b0': 'efficientnet-b0-355c32eb.pth', 176 | 'efficientnet-b1': 'efficientnet-b1-f1951068.pth', 177 | 'efficientnet-b2': 'efficientnet-b2-8bb594d6.pth', 178 | 'efficientnet-b3': 'efficientnet-b3-5fb5a3c3.pth', 179 | 'efficientnet-b4': 'efficientnet-b4-6ed6700e.pth', 180 | 'efficientnet-b5': 'efficientnet-b5-b6417697.pth', 181 | 'efficientnet-b6': 'efficientnet-b6-c76e70fd.pth', 182 | 'efficientnet-b7': 'efficientnet-b7-dcc49843.pth', 183 | } 184 | 185 | 186 | 187 | if weights_loc != None: 188 | model_ft = torch.load(weights_loc) 189 | else: 190 | model_ft = EfficientNet.from_pretrained(net_name) 191 | 192 | 193 | # Modify the fully connected layer 194 | num_ftrs = model_ft._fc.in_features 195 | model_ft._fc = nn.Linear(num_ftrs, class_num) 196 | 197 | criterion = nn.CrossEntropyLoss() 198 | 199 | if use_gpu: 200 | model_ft = model_ft.cuda() 201 | criterion = criterion.cuda() 202 | 203 | optimizer = optim.SGD((model_ft.parameters()), lr=lr, 204 | momentum=momentum, weight_decay=0.0004) 205 | 206 | train_loss, best_model_wts = train_model(model_ft, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs) 207 | 208 | # test 209 | print('-' * 10) 210 | print('Test Accuracy:') 211 | 212 | model_ft.load_state_dict(best_model_wts) 213 | 214 | criterion = nn.CrossEntropyLoss().cuda() 215 | 216 | test_model(model_ft, criterion) 217 | 218 | 219 | 220 | if __name__ == '__main__': 221 | 222 | parser = argparse.ArgumentParser() 223 | 224 | parser.add_argument('--data-dir', type=str, default=None, help='path of /dataset/') 225 | parser.add_argument('--num-epochs', type=int, default=40) 226 | parser.add_argument('--batch-size', type=int, default=4, help='total batch size for all GPUs') 227 | parser.add_argument('--img-size', type=int, default=[1024, 1024], help='img sizes') 228 | parser.add_argument('--class-num', type=int, default=3, help='class num') 229 | 230 | parser.add_argument('--weights-loc', type=str, default= None, help='path of weights (if going to be loaded)') 231 | 232 | parser.add_argument("--lr", type=float, default= 0.01, help="learning rate") 233 | parser.add_argument("--net-name", type=str, default="efficientnet-b3", help="efficientnet type") 234 | 235 | parser.add_argument('--resume-epoch', type=int, default=0, help='what epoch to start from') 236 | 237 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 238 | 239 | 240 | 241 | opt = parser.parse_args() 242 | 243 | data_dir = opt.data_dir 244 | num_epochs = opt.num_epochs 245 | batch_size = opt.batch_size 246 | input_size = opt.img_size 247 | class_num = opt.class_num 248 | 249 | weights_loc = opt.weights_loc 250 | 251 | lr = opt.lr 252 | net_name = opt.net_name 253 | 254 | epoch_to_resume_from = opt.resume_epoch 255 | 256 | momentum = opt.momentum 257 | 258 | print("data dir: ", data_dir, ", num epochs: ", num_epochs, ", batch size: ",batch_size, 259 | ", img size: ", input_size, ", num of classes:", class_num, ".pth weights file location:", weights_loc, 260 | ", learning rate:", lr, ", net name:", net_name, "epoch to resume from: ", epoch_to_resume_from, "momen") 261 | 262 | run() 263 | --------------------------------------------------------------------------------