├── LICENSE ├── README.md ├── cifar100_train.ipynb ├── efficientnet.py ├── efficientnet_v2.py └── imagenet_eval.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 abhuse 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientNetV2 EfficientNetV1 in Pytorch with pretrained weights 2 | 3 | A single-file implementation of EfficientNetV2 and EfficientNetV1 as introduced in: 4 | [\[Tan & Le 2021\]: EfficientNetV2: Smaller Models and Faster Training](https://arxiv.org/pdf/2104.00298.pdf) 5 | [\[Tan & Le 2019\]: EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) 6 | 7 | ## Pretrained Weights 8 | Original implementations of both [EfficientNetV2](https://github.com/google/automl/tree/master/efficientnetv2) and [EfficientNetV1](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) include pretrained weigths in Tensorflow format. 9 | These weigths were converted to Pytorch format and are provided in this repository. 10 | 11 | ## Accuracy 12 | 13 | ### EfficientNet V2 14 | | Model | ImageNet 1k Top-1 accuracy, % | 15 | | --- | --- | 16 | | EfficientNetV2-b0 | 77.590% | 17 | | EfficientNetV2-b1 | 78.872% | 18 | | EfficientNetV2-b2 | 79.388% | 19 | | EfficientNetV2-b3 | 82.260% | 20 | | EfficientNetV2-S | 84.282% | 21 | | EfficientNetV2-M | 85.596% | 22 | | EfficientNetV2-L | 86.298% | 23 | | EfficientNetV2-XL | 86.414% | 24 | 25 | ### EfficientNet V1 26 | | Model | ImageNet 1k Top-1 Accuracy, % | 27 | | --- | --- | 28 | | EfficientNet-B0 | 76.43% | 29 | | EfficientNet-B1 | 78.396% | 30 | | EfficientNet-B2 | 79.804% | 31 | | EfficientNet-B3 | 81.542% | 32 | | EfficientNet-B4 | 83.036% | 33 | | EfficientNet-B5 | 83.79% | 34 | | EfficientNet-B6 | 84.136% | 35 | | EfficientNet-B7 | 84.578% | 36 | 37 | ## Usage 38 | 39 | Check out [cifar100_train.ipynb](cifar100_train.ipynb) if you would like to experiment with models. 40 | To evaluate pretrained models against Imagenet validation set, run [imagenet_eval.ipynb](imagenet_eval.ipynb). 41 | 42 | ### EfficientNet V2 43 | The example below creates an EfficientNetV2-S model that takes 3-channel image of shape [224, 224] 44 | as input and outputs distribution over 50 classes, model weights are initialized with weights pretrained on ImageNet 45 | dataset: 46 | ```python 47 | import torch 48 | from efficientnet_v2 import EfficientNetV2 49 | 50 | model = EfficientNetV2('s', 51 | in_channels=3, 52 | n_classes=50, 53 | pretrained=True) 54 | 55 | # x - tensor of shape [batch_size, in_channels, image_height, image_width] 56 | x = torch.randn([10, 3, 224, 224]) 57 | 58 | # to get predictions: 59 | pred = model(x) 60 | print('out shape:', pred.shape) 61 | # >>> out shape: torch.Size([10, 50]) 62 | 63 | # to extract features: 64 | features = model.get_features(x) 65 | for i, feature in enumerate(features): 66 | print('feature %d shape:' % i, feature.shape) 67 | # >>> feature 0 shape: torch.Size([10, 48, 56, 56]) 68 | # >>> feature 1 shape: torch.Size([10, 64, 28, 28]) 69 | # >>> feature 2 shape: torch.Size([10, 160, 14, 14]) 70 | # >>> feature 3 shape: torch.Size([10, 256, 7, 7]) 71 | ``` 72 | 73 | ### EfficientNet (V1, original) 74 | 75 | The example below creates an EfficientNet-B0 model that takes 3-channel image of shape [224, 224] 76 | as input and outputs distribution over 50 classes, model weights are initialized with weights pretrained on ImageNet 77 | dataset: 78 | 79 | ```python 80 | import torch 81 | from efficientnet import EfficientNet 82 | 83 | model = EfficientNet(b=0, 84 | in_channels=3, 85 | n_classes=50, 86 | in_spatial_shape=(224,224), 87 | pretrained=True 88 | ) 89 | 90 | # x - tensor of shape [batch_size, in_channels, image_height, image_width] 91 | x = torch.randn([10, 3, 224, 224]) 92 | 93 | # to get predictions: 94 | pred = model(x) 95 | print('out shape:', pred.shape) 96 | # >>> out shape: torch.Size([10, 50]) 97 | 98 | # to extract features: 99 | features = model.get_features(x) 100 | for i, feature in enumerate(features): 101 | print('feature %d shape:' % i, feature.shape) 102 | # >>> feature 0 shape: torch.Size([10, 16, 112, 112]) 103 | # >>> feature 1 shape: torch.Size([10, 24, 56, 56]) 104 | # >>> feature 2 shape: torch.Size([10, 40, 28, 28]) 105 | # >>> feature 3 shape: torch.Size([10, 80, 14, 14]) 106 | # >>> feature 4 shape: torch.Size([10, 112, 14, 14]) 107 | # >>> feature 5 shape: torch.Size([10, 192, 7, 7]) 108 | # >>> feature 6 shape: torch.Size([10, 320, 7, 7]) 109 | ``` 110 | 111 | 112 | ## Parameters 113 | 114 | 115 | ### EfficientNet V2 116 | * ***model_name***, *(str)* - Model name, one of 'b0', 'b1', 'b2', 'b3', 's', 'm', 'l', 'xl' 117 | * ***in_channels***, *(int)*, *(Default=3)* - Number of channels in input image 118 | * ***n_classes***, *(int)*, *(Default=1000)* - Number of output classes 119 | * ***tf_style_conv***, *(bool)*, *(Default=False)* - Whether to simulate "SAME" padding of Tensorflow's convolution op. Set to *True* when evaluating pretrained models against Imagenet dataset 120 | * ***in_spatial_shape***, *(int or iterable of ints)*, 121 | *(Default=None)* - Spatial dimensionality of input image, tuple 122 | (height, width) or single integer *size* for shape (*size*, *size*). 123 | It is recommended to specify this parameter only when *tf_style_conv=True* 124 | * ***activation***, *(str)*, *(Default='silu')* - Activation function 125 | * ***activation_kwargs***, *(dict)*, *(Default=None)* - Keyword arguments to pass to activation function 126 | * ***bias***, *(bool)*, 127 | *(Default=False)* - Enable bias in convolution operations 128 | * ***drop_connect_rate***, *(float)*, 129 | *(Default=0.2)* - DropConnect rate, set to 0 to disable DropConnect 130 | * ***dropout_rate***, *(float or None)*, 131 | *(Default=None)* - Dropout rate, set to *None* to use default dropout rate for each model 132 | * ***bn_epsilon***, *(float)*, 133 | *(Default=0.001)* - Batch normalizaton epsilon 134 | * ***bn_momentum***, *(float)*, 135 | *(Default=0.01)* - Batch normalization momentum 136 | * ***pretrained***, *(bool)*, 137 | *(Default=False)* - Initialize model with weights pretrained on ImageNet dataset 138 | * ***progress***, *(bool)*, 139 | *(Default=False)* - Show progress bar when downloading pretrained weights 140 | 141 | The default parameter values are the ones that were used in 142 | [original implementation](https://github.com/google/automl/tree/master/efficientnetv2). 143 | 144 | 145 | ### EfficientNet V1 146 | * ***b***, *(int)* - Model index, e.g. 1 for EfficientNet-B1 147 | * ***in_channels***, *(int)*, *(Default=3)* - Number of channels in input image 148 | * ***n_classes***, *(int)*, *(Default=1000)* - Number of output classes 149 | * ***in_spatial_shape***, *(int or iterable of ints)*, 150 | *(Default=None)* - Spatial dimensionality of input image, tuple 151 | (height, width) or single integer *size* for shape (*size*, *size*). If None, default image shape will be used for 152 | each model index 153 | * ***activation***, *(callable)*, 154 | *(Default=Swish())* - Activation function 155 | * ***bias***, *(bool)*, 156 | *(Default=False)* - Enable bias in convolution operations 157 | * ***drop_connect_rate***, *(float)*, 158 | *(Default=0.2)* - DropConnect rate, set to 0 to disable DropConnect 159 | * ***dropout_rate***, *(float or None)*, 160 | *(Default=None)* - Dropout rate, set to *None* to use default dropout rate for each model 161 | * ***bn_epsilon***, *(float)*, 162 | *(Default=0.001)* - Batch normalizaton epsilon 163 | * ***bn_momentum***, *(float)*, 164 | *(Default=0.01)* - Batch normalization momentum 165 | * ***pretrained***, *(bool)*, 166 | *(Default=False)* - Initialize model with weights pretrained on ImageNet dataset 167 | * ***progress***, *(bool)*, 168 | *(Default=False)* - Show progress bar when downloading pretrained weights 169 | 170 | The default parameter values are the ones that were used in 171 | [original implementation](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). 172 | 173 | 174 | ## Requirements 175 | 176 | * Python v3.5+ 177 | * Pytorch v1.0+ -------------------------------------------------------------------------------- /cifar100_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## CIFAR-100 training script\n", 8 | "Note: This script merely outlines the procedures for training pytorch model. It does not replicate any study and does not attemp to achieve state-of-the-art in CIFAR-100 classification." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "from copy import deepcopy\n", 18 | "from datetime import datetime\n", 19 | "from os import makedirs\n", 20 | "from os.path import join, isfile, isdir\n", 21 | "\n", 22 | "import psutil\n", 23 | "import torch\n", 24 | "import torch.nn as nn\n", 25 | "import torch.optim as optim\n", 26 | "\n", 27 | "from efficientnet import EfficientNet\n", 28 | "from sklearn.metrics import accuracy_score\n", 29 | "from torch.optim import lr_scheduler\n", 30 | "from torch.utils.data import DataLoader\n", 31 | "from torchvision import datasets, transforms\n", 32 | "\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "%matplotlib inline" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "def train_model(model, dataloader, device, criterion, optimizer):\n", 44 | " model.train()\n", 45 | " for xb, yb in dataloader:\n", 46 | " xb, yb = xb.to(device), yb.to(device)\n", 47 | " optimizer.zero_grad()\n", 48 | " out = model(xb)\n", 49 | " if out.size(1) == 1:\n", 50 | " # regression, squeeze output of shape [N,1] to [N]\n", 51 | " out = torch.squeeze(out, 1)\n", 52 | " loss = criterion(out, yb)\n", 53 | " loss.backward()\n", 54 | " optimizer.step()\n", 55 | "\n", 56 | "\n", 57 | "def eval_model(model, dataloader, device, criterion=None):\n", 58 | " loss_value = []\n", 59 | " y_pred = []\n", 60 | " y_true = []\n", 61 | "\n", 62 | " model.eval()\n", 63 | " with torch.no_grad():\n", 64 | " for xb, yb in dataloader:\n", 65 | " xb, yb = xb.to(device), yb.to(device)\n", 66 | " out = model(xb)\n", 67 | " if out.size(1) == 1:\n", 68 | " # regression, squeeze output of shape [N,1] to [N]\n", 69 | " out = torch.squeeze(out, 1)\n", 70 | "\n", 71 | " if criterion is not None:\n", 72 | " loss = criterion(out, yb)\n", 73 | " loss_value.append(loss.item())\n", 74 | "\n", 75 | " y_pred.append(out.detach().cpu())\n", 76 | " y_true.append(yb.detach().cpu())\n", 77 | "\n", 78 | " if criterion is not None:\n", 79 | " loss_value = sum(loss_value) / len(loss_value)\n", 80 | " return torch.cat(y_pred), torch.cat(y_true), loss_value\n", 81 | " else:\n", 82 | " return torch.cat(y_pred), torch.cat(y_true)\n", 83 | "\n", 84 | "\n", 85 | "def run_experiment(dl_train,\n", 86 | " dl_train_val,\n", 87 | " dl_validation,\n", 88 | " model,\n", 89 | " optimizer,\n", 90 | " criterion,\n", 91 | " device,\n", 92 | " max_epoch,\n", 93 | " metric_fn,\n", 94 | " init_epoch=0,\n", 95 | " scheduler=None,\n", 96 | " load_path=None,\n", 97 | " save_path=None,\n", 98 | " early_stopping=None,\n", 99 | " ):\n", 100 | " results = {\n", 101 | " \"train_loss\": [],\n", 102 | " \"valid_loss\": [],\n", 103 | " \"train_met\": [],\n", 104 | " \"valid_met\": [],\n", 105 | " \"state_dict\": None,\n", 106 | " }\n", 107 | "\n", 108 | " best_validation_metric = .0\n", 109 | " model_best_state_dict = None\n", 110 | " no_score_improvement = 0\n", 111 | " experiment_start = datetime.now()\n", 112 | "\n", 113 | " if load_path is not None:\n", 114 | " # load full experiment state to continue experiment\n", 115 | " load_path = join(load_path, \"full_state.pth\")\n", 116 | " if not isfile(load_path):\n", 117 | " raise ValueError(\"Checkpoint file {} does not exist\".format(load_path))\n", 118 | "\n", 119 | " checkpoint = torch.load(load_path)\n", 120 | "\n", 121 | " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", 122 | "\n", 123 | " model_best_state_dict = checkpoint['model_best_state_dict']\n", 124 | " model.load_state_dict(checkpoint['model_curr_state_dict'])\n", 125 | "\n", 126 | " init_epoch = checkpoint['epoch']\n", 127 | " best_validation_metric = checkpoint['best_validation_metric']\n", 128 | "\n", 129 | " if scheduler is not None:\n", 130 | " scheduler.load_state_dict(checkpoint[\"scheduler_state_dict\"])\n", 131 | " print(\"Successfully loaded checkpoint.\")\n", 132 | "\n", 133 | " s = \"Epoch/Max | Loss: Train / Validation | Metric: Train / Validation | Epoch time\"\n", 134 | " print(s)\n", 135 | "\n", 136 | " if save_path is not None and not isdir(save_path):\n", 137 | " makedirs(save_path)\n", 138 | " for epoch in range(init_epoch, max_epoch):\n", 139 | " now = datetime.now()\n", 140 | " train_model(model=model,\n", 141 | " dataloader=dl_train,\n", 142 | " device=device,\n", 143 | " criterion=criterion,\n", 144 | " optimizer=optimizer, )\n", 145 | "\n", 146 | " # evaluate subset of train set (in eval mode)\n", 147 | " train_val_results = eval_model(model=model,\n", 148 | " dataloader=dl_train_val,\n", 149 | " device=device,\n", 150 | " criterion=criterion, )\n", 151 | " train_y_pred, train_y_true, train_loss = train_val_results\n", 152 | " train_metric = metric_fn(train_y_pred, train_y_true)\n", 153 | " results[\"train_loss\"].append(train_loss)\n", 154 | " results[\"train_met\"].append(train_metric)\n", 155 | "\n", 156 | " # evaluate validation subset\n", 157 | " valid_results = eval_model(model=model,\n", 158 | " dataloader=dl_validation,\n", 159 | " device=device,\n", 160 | " criterion=criterion, )\n", 161 | " valid_y_pred, valid_y_true, valid_loss = valid_results\n", 162 | " validation_metric = metric_fn(valid_y_pred, valid_y_true)\n", 163 | " results[\"valid_loss\"].append(valid_loss)\n", 164 | " results[\"valid_met\"].append(validation_metric)\n", 165 | "\n", 166 | " # check if validation score is improved\n", 167 | " if validation_metric > best_validation_metric:\n", 168 | " model_best_state_dict = deepcopy(model.state_dict())\n", 169 | " best_validation_metric = validation_metric\n", 170 | " # reset early stopping counter\n", 171 | " no_score_improvement = 0\n", 172 | " # save best model weights\n", 173 | " if save_path is not None:\n", 174 | " torch.save(model_best_state_dict, join(save_path, \"best_weights.pth\"))\n", 175 | " else:\n", 176 | " no_score_improvement += 1\n", 177 | " if early_stopping is not None and no_score_improvement >= early_stopping:\n", 178 | " print(\"Early stopping at epoch %d\" % epoch)\n", 179 | " break\n", 180 | "\n", 181 | " if scheduler is not None:\n", 182 | " scheduler.step(validation_metric)\n", 183 | "\n", 184 | " if save_path is not None:\n", 185 | " # (optional) save model state dict at end of each epoch\n", 186 | " # torch.save(model.state_dict(), join(save_path, \"model_state_{}.pth\".format(epoch)))\n", 187 | "\n", 188 | " # save full experiment state at the end of each epoch\n", 189 | " checkpoint = {\n", 190 | " 'epoch': epoch + 1,\n", 191 | " 'model_curr_state_dict': model.state_dict(),\n", 192 | " 'model_best_state_dict': model_best_state_dict,\n", 193 | " 'optimizer_state_dict': optimizer.state_dict(),\n", 194 | " 'scheduler_state_dict': None if scheduler is None else scheduler.state_dict(),\n", 195 | " 'no_score_improvement': no_score_improvement,\n", 196 | " 'best_validation_metric': best_validation_metric,\n", 197 | " }\n", 198 | " torch.save(checkpoint, join(save_path, \"full_state.pth\"))\n", 199 | "\n", 200 | " s = \"{:>5}/{} | Loss: {:.4f} / {:.4f}\".format(epoch, max_epoch, train_loss, valid_loss)\n", 201 | " s += \" | Metric: {:.4f} / {:.4f}\".format(train_metric, validation_metric)\n", 202 | " s += \" | +{}\".format(datetime.now() - now)\n", 203 | " print(s)\n", 204 | "\n", 205 | " print(\"Experiment time: {}\".format(datetime.now() - experiment_start))\n", 206 | " return results" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 3, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "model_index = 0 # i.e. EfficientNet-B{model_index}\n", 216 | "batch_size = 128\n", 217 | "max_epoch = 100\n", 218 | "n_classes = 100\n", 219 | "pretrained=False\n", 220 | "num_workers = psutil.cpu_count()\n", 221 | "img_size = 32\n", 222 | "\n", 223 | "def metric_fn(y_pred, y_true):\n", 224 | " _, y_pred = torch.max(y_pred, 1)\n", 225 | " return accuracy_score(y_pred, y_true)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 4, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Files already downloaded and verified\n", 238 | "Files already downloaded and verified\n", 239 | "Files already downloaded and verified\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "transform_train = transforms.Compose([\n", 245 | " transforms.RandomCrop(img_size, padding=4),\n", 246 | " transforms.RandomHorizontalFlip(),\n", 247 | " transforms.ToTensor(),\n", 248 | " transforms.Normalize(mean=[0.507075, 0.48655024, 0.44091907],\n", 249 | " std=[0.26733398, 0.25643876, 0.2761503]),\n", 250 | "])\n", 251 | "\n", 252 | "transform_validation = transforms.Compose([\n", 253 | " transforms.ToTensor(),\n", 254 | " transforms.Normalize(mean=[0.5070754, 0.48655024, 0.44091907],\n", 255 | " std=[0.26733398, 0.25643876, 0.2761503]),\n", 256 | "])\n", 257 | "\n", 258 | "dataset_train = datasets.CIFAR100(root='./data',\n", 259 | " train=True,\n", 260 | " download=True,\n", 261 | " transform=transform_train,\n", 262 | " )\n", 263 | "dataset_train_val = datasets.CIFAR100(root='./data',\n", 264 | " train=True,\n", 265 | " download=True,\n", 266 | " transform=transform_validation,\n", 267 | " )\n", 268 | "dataset_validation = datasets.CIFAR100(root='./data',\n", 269 | " train=False,\n", 270 | " download=True,\n", 271 | " transform=transform_validation,\n", 272 | " )\n", 273 | "\n", 274 | "dataloader_train = DataLoader(dataset_train,\n", 275 | " batch_size=batch_size,\n", 276 | " shuffle=True,\n", 277 | " num_workers=num_workers,\n", 278 | " )\n", 279 | "dataloader_train_val = DataLoader(dataset_train_val,\n", 280 | " batch_size=batch_size,\n", 281 | " shuffle=False,\n", 282 | " num_workers=num_workers,\n", 283 | " )\n", 284 | "dataloader_validation = DataLoader(dataset_validation,\n", 285 | " batch_size=batch_size,\n", 286 | " shuffle=False,\n", 287 | " num_workers=num_workers,\n", 288 | " )\n", 289 | "\n", 290 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 291 | "\n", 292 | "model = EfficientNet(b=model_index,\n", 293 | " in_spatial_shape=img_size,\n", 294 | " n_classes=n_classes,\n", 295 | " pretrained=pretrained,\n", 296 | " )\n", 297 | "model.to(device)\n", 298 | "\n", 299 | "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n", 300 | "criterion = nn.CrossEntropyLoss()\n", 301 | "scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90], gamma=0.1)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 5, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "Epoch/Max | Loss: Train / Validation | Metric: Train / Validation | Epoch time\n", 314 | " 0/100 | Loss: 4.6407 / 4.6411 | Metric: 0.0141 / 0.0145 | +0:00:25.398279\n", 315 | " 1/100 | Loss: 3.4596 / 3.4839 | Metric: 0.1654 / 0.1632 | +0:00:25.511929\n", 316 | " 2/100 | Loss: 3.2156 / 3.2652 | Metric: 0.2042 / 0.1926 | +0:00:25.510012\n", 317 | " 3/100 | Loss: 2.9651 / 3.0327 | Metric: 0.2499 / 0.2428 | +0:00:25.776397\n", 318 | " 4/100 | Loss: 2.8429 / 2.9313 | Metric: 0.2798 / 0.2601 | +0:00:25.733490\n", 319 | " 5/100 | Loss: 2.6884 / 2.7874 | Metric: 0.3061 / 0.2876 | +0:00:25.556902\n", 320 | " 6/100 | Loss: 2.6452 / 2.7454 | Metric: 0.3214 / 0.2998 | +0:00:25.513321\n", 321 | " 7/100 | Loss: 2.4652 / 2.6077 | Metric: 0.3521 / 0.3251 | +0:00:25.530684\n", 322 | " 8/100 | Loss: 2.3831 / 2.5385 | Metric: 0.3670 / 0.3369 | +0:00:25.537902\n", 323 | " 9/100 | Loss: 2.3387 / 2.5089 | Metric: 0.3823 / 0.3513 | +0:00:25.540766\n", 324 | " 10/100 | Loss: 2.2174 / 2.3946 | Metric: 0.4019 / 0.3645 | +0:00:25.510699\n", 325 | " 11/100 | Loss: 2.1654 / 2.3559 | Metric: 0.4215 / 0.3825 | +0:00:25.508304\n", 326 | " 12/100 | Loss: 2.0757 / 2.2894 | Metric: 0.4363 / 0.3941 | +0:00:25.623292\n", 327 | " 13/100 | Loss: 2.0408 / 2.2736 | Metric: 0.4533 / 0.4051 | +0:00:25.536239\n", 328 | " 14/100 | Loss: 1.9758 / 2.2310 | Metric: 0.4662 / 0.4121 | +0:00:25.539638\n", 329 | " 15/100 | Loss: 1.8560 / 2.1301 | Metric: 0.4842 / 0.4336 | +0:00:25.555950\n", 330 | " 16/100 | Loss: 1.8541 / 2.1502 | Metric: 0.4883 / 0.4267 | +0:00:25.516265\n", 331 | " 17/100 | Loss: 1.7353 / 2.0693 | Metric: 0.5102 / 0.4452 | +0:00:25.490337\n", 332 | " 18/100 | Loss: 1.7392 / 2.0899 | Metric: 0.5121 / 0.4434 | +0:00:25.516422\n", 333 | " 19/100 | Loss: 1.7272 / 2.1080 | Metric: 0.5171 / 0.4391 | +0:00:25.533143\n", 334 | " 20/100 | Loss: 1.7886 / 2.1671 | Metric: 0.5213 / 0.4404 | +0:00:25.480690\n", 335 | " 21/100 | Loss: 1.6183 / 2.0440 | Metric: 0.5479 / 0.4627 | +0:00:25.495024\n", 336 | " 22/100 | Loss: 1.6036 / 2.0365 | Metric: 0.5473 / 0.4624 | +0:00:25.475565\n", 337 | " 23/100 | Loss: 1.6070 / 2.0566 | Metric: 0.5441 / 0.4549 | +0:00:25.492129\n", 338 | " 24/100 | Loss: 1.8741 / 2.3108 | Metric: 0.5567 / 0.4611 | +0:00:25.520776\n", 339 | " 25/100 | Loss: 1.4591 / 1.9652 | Metric: 0.5847 / 0.4805 | +0:00:25.484019\n", 340 | " 26/100 | Loss: 1.4893 / 2.0240 | Metric: 0.5823 / 0.4774 | +0:00:25.546081\n", 341 | " 27/100 | Loss: 1.3791 / 1.9233 | Metric: 0.6107 / 0.4907 | +0:00:25.530611\n", 342 | " 28/100 | Loss: 1.3882 / 1.9783 | Metric: 0.5952 / 0.4772 | +0:00:25.555992\n", 343 | " 29/100 | Loss: 1.4314 / 2.0812 | Metric: 0.6025 / 0.4770 | +0:00:25.513177\n", 344 | " 30/100 | Loss: 1.4443 / 2.0862 | Metric: 0.6109 / 0.4828 | +0:00:25.481391\n", 345 | " 31/100 | Loss: 1.3089 / 1.9888 | Metric: 0.6425 / 0.4958 | +0:00:25.498638\n", 346 | " 32/100 | Loss: 1.6406 / 2.3167 | Metric: 0.6044 / 0.4676 | +0:00:25.481035\n", 347 | " 33/100 | Loss: 1.2996 / 2.0643 | Metric: 0.6382 / 0.4974 | +0:00:25.506845\n", 348 | " 34/100 | Loss: 1.1233 / 1.8935 | Metric: 0.6714 / 0.5125 | +0:00:25.511562\n", 349 | " 35/100 | Loss: 1.1297 / 1.9093 | Metric: 0.6669 / 0.5043 | +0:00:25.527171\n", 350 | " 36/100 | Loss: 1.0770 / 1.9064 | Metric: 0.6800 / 0.5069 | +0:00:25.521235\n", 351 | " 37/100 | Loss: 1.1357 / 1.9935 | Metric: 0.6757 / 0.4925 | +0:00:25.507412\n", 352 | " 38/100 | Loss: 1.0077 / 1.8913 | Metric: 0.6930 / 0.5111 | +0:00:25.524217\n", 353 | " 39/100 | Loss: 1.0088 / 1.9418 | Metric: 0.6951 / 0.5034 | +0:00:25.504317\n", 354 | " 40/100 | Loss: 0.9989 / 1.9648 | Metric: 0.6972 / 0.5018 | +0:00:25.521985\n", 355 | " 41/100 | Loss: 0.9202 / 1.8992 | Metric: 0.7220 / 0.5117 | +0:00:25.520785\n", 356 | " 42/100 | Loss: 0.9513 / 1.9598 | Metric: 0.7122 / 0.5062 | +0:00:25.476306\n", 357 | " 43/100 | Loss: 1.1995 / 2.2258 | Metric: 0.7097 / 0.4965 | +0:00:25.480849\n", 358 | " 44/100 | Loss: 0.8955 / 1.9741 | Metric: 0.7274 / 0.5104 | +0:00:25.500969\n", 359 | " 45/100 | Loss: 0.8821 / 1.9969 | Metric: 0.7390 / 0.5098 | +0:00:25.473217\n", 360 | " 46/100 | Loss: 0.8505 / 2.0071 | Metric: 0.7396 / 0.5051 | +0:00:25.514186\n", 361 | " 47/100 | Loss: 0.7809 / 1.9547 | Metric: 0.7605 / 0.5139 | +0:00:25.545920\n", 362 | " 48/100 | Loss: 0.7247 / 1.9603 | Metric: 0.7761 / 0.5194 | +0:00:25.488569\n", 363 | " 49/100 | Loss: 0.7574 / 2.0076 | Metric: 0.7688 / 0.5114 | +0:00:25.497039\n", 364 | " 50/100 | Loss: 1.0998 / 2.3453 | Metric: 0.7268 / 0.4854 | +0:00:26.040123\n", 365 | " 51/100 | Loss: 0.7521 / 2.0480 | Metric: 0.7694 / 0.5084 | +0:00:25.828677\n", 366 | " 52/100 | Loss: 1.5339 / 2.7158 | Metric: 0.6346 / 0.4254 | +0:00:25.815047\n", 367 | " 53/100 | Loss: 0.6663 / 2.0466 | Metric: 0.7918 / 0.5140 | +0:00:25.826483\n", 368 | " 54/100 | Loss: 0.6193 / 2.0372 | Metric: 0.8081 / 0.5146 | +0:00:25.676460\n", 369 | " 55/100 | Loss: 0.6105 / 2.0812 | Metric: 0.8096 / 0.5125 | +0:00:25.495489\n", 370 | " 56/100 | Loss: 0.6303 / 2.1622 | Metric: 0.8019 / 0.5090 | +0:00:25.558402\n", 371 | " 57/100 | Loss: 0.6239 / 2.1130 | Metric: 0.8040 / 0.5154 | +0:00:25.693560\n", 372 | " 58/100 | Loss: 0.6373 / 2.1264 | Metric: 0.8026 / 0.5037 | +0:00:25.591672\n", 373 | " 59/100 | Loss: 0.5696 / 2.1770 | Metric: 0.8213 / 0.5064 | +0:00:25.515435\n", 374 | " 60/100 | Loss: 0.5271 / 2.1327 | Metric: 0.8331 / 0.5101 | +0:00:25.548699\n", 375 | " 61/100 | Loss: 0.4884 / 2.1617 | Metric: 0.8464 / 0.5099 | +0:00:25.492324\n", 376 | " 62/100 | Loss: 0.5484 / 2.2190 | Metric: 0.8235 / 0.5099 | +0:00:25.512234\n", 377 | " 63/100 | Loss: 0.4893 / 2.1722 | Metric: 0.8443 / 0.5133 | +0:00:25.485472\n", 378 | " 64/100 | Loss: 0.4431 / 2.1613 | Metric: 0.8610 / 0.5143 | +0:00:25.490351\n", 379 | " 65/100 | Loss: 0.6444 / 2.4219 | Metric: 0.7960 / 0.4858 | +0:00:25.497480\n", 380 | " 66/100 | Loss: 0.3970 / 2.2098 | Metric: 0.8751 / 0.5135 | +0:00:25.486181\n", 381 | " 67/100 | Loss: 0.4105 / 2.2758 | Metric: 0.8712 / 0.5100 | +0:00:25.534624\n", 382 | " 68/100 | Loss: 0.3828 / 2.2263 | Metric: 0.8791 / 0.5215 | +0:00:25.548678\n", 383 | " 69/100 | Loss: 0.3694 / 2.2231 | Metric: 0.8835 / 0.5190 | +0:00:25.494058\n", 384 | " 70/100 | Loss: 0.6148 / 2.4313 | Metric: 0.8047 / 0.4892 | +0:00:25.559696\n", 385 | " 71/100 | Loss: 0.3615 / 2.2320 | Metric: 0.8884 / 0.5169 | +0:00:25.486475\n", 386 | " 72/100 | Loss: 0.3077 / 2.2765 | Metric: 0.9041 / 0.5216 | +0:00:25.551599\n", 387 | " 73/100 | Loss: 0.3263 / 2.2852 | Metric: 0.8973 / 0.5174 | +0:00:25.585349\n", 388 | " 74/100 | Loss: 0.3595 / 2.3874 | Metric: 0.8844 / 0.5094 | +0:00:25.507216\n", 389 | " 75/100 | Loss: 0.3023 / 2.3316 | Metric: 0.9051 / 0.5171 | +0:00:25.500517\n", 390 | " 76/100 | Loss: 0.4149 / 2.3949 | Metric: 0.8791 / 0.5073 | +0:00:25.490881\n", 391 | " 77/100 | Loss: 0.2814 / 2.3683 | Metric: 0.9099 / 0.5177 | +0:00:25.478628\n", 392 | " 78/100 | Loss: 0.2676 / 2.3558 | Metric: 0.9159 / 0.5198 | +0:00:25.454193\n", 393 | " 79/100 | Loss: 0.2627 / 2.3218 | Metric: 0.9180 / 0.5201 | +0:00:25.481055\n", 394 | " 80/100 | Loss: 0.2776 / 2.4139 | Metric: 0.9184 / 0.5189 | +0:00:25.497929\n", 395 | " 81/100 | Loss: 0.2523 / 2.3382 | Metric: 0.9194 / 0.5195 | +0:00:25.479852\n", 396 | " 82/100 | Loss: 0.3507 / 2.5150 | Metric: 0.8849 / 0.5029 | +0:00:25.473072\n", 397 | " 83/100 | Loss: 0.2899 / 2.4763 | Metric: 0.9064 / 0.5065 | +0:00:25.467117\n", 398 | " 84/100 | Loss: 0.5574 / 2.7482 | Metric: 0.8247 / 0.4775 | +0:00:25.505151\n", 399 | " 85/100 | Loss: 0.2233 / 2.4334 | Metric: 0.9305 / 0.5148 | +0:00:25.472094\n", 400 | " 86/100 | Loss: 0.7211 / 2.8693 | Metric: 0.8071 / 0.4584 | +0:00:25.507260\n", 401 | " 87/100 | Loss: 0.2264 / 2.4550 | Metric: 0.9299 / 0.5186 | +0:00:25.500985\n", 402 | " 88/100 | Loss: 0.1885 / 2.4443 | Metric: 0.9414 / 0.5235 | +0:00:25.486188\n", 403 | " 89/100 | Loss: 0.1809 / 2.4752 | Metric: 0.9438 / 0.5163 | +0:00:25.488561\n", 404 | " 90/100 | Loss: 0.1809 / 2.4432 | Metric: 0.9433 / 0.5203 | +0:00:25.483951\n", 405 | " 91/100 | Loss: 0.1726 / 2.4504 | Metric: 0.9472 / 0.5268 | +0:00:25.510927\n", 406 | " 92/100 | Loss: 0.2003 / 2.4957 | Metric: 0.9369 / 0.5235 | +0:00:25.476014\n", 407 | " 93/100 | Loss: 0.1826 / 2.4584 | Metric: 0.9435 / 0.5224 | +0:00:25.658069\n", 408 | " 94/100 | Loss: 0.2839 / 2.6165 | Metric: 0.9190 / 0.5104 | +0:00:25.498089\n", 409 | " 95/100 | Loss: 0.3082 / 2.6734 | Metric: 0.9001 / 0.4982 | +0:00:25.532489\n", 410 | " 96/100 | Loss: 0.2176 / 2.4900 | Metric: 0.9348 / 0.5201 | +0:00:25.512353\n", 411 | " 97/100 | Loss: 0.1912 / 2.4989 | Metric: 0.9411 / 0.5237 | +0:00:25.469654\n", 412 | " 98/100 | Loss: 0.1680 / 2.5038 | Metric: 0.9497 / 0.5241 | +0:00:25.611206\n", 413 | " 99/100 | Loss: 0.2187 / 2.5380 | Metric: 0.9370 / 0.5218 | +0:00:25.836507\n", 414 | "Experiment time: 0:42:33.825519\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "exp_results = run_experiment(dl_train=dataloader_train,\n", 420 | " dl_train_val=dataloader_train_val,\n", 421 | " dl_validation=dataloader_validation,\n", 422 | " model=model,\n", 423 | " optimizer=optimizer,\n", 424 | " criterion=criterion,\n", 425 | " device=device,\n", 426 | " max_epoch=max_epoch,\n", 427 | " metric_fn=metric_fn,\n", 428 | " scheduler=scheduler,\n", 429 | " load_path=None,\n", 430 | " save_path=None,\n", 431 | " )" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 6, 437 | "metadata": {}, 438 | "outputs": [ 439 | { 440 | "data": { 441 | "image/png": "\n", 442 | "text/plain": [ 443 | "
" 444 | ] 445 | }, 446 | "metadata": { 447 | "needs_background": "light" 448 | }, 449 | "output_type": "display_data" 450 | } 451 | ], 452 | "source": [ 453 | "epochs = list(range(max_epoch))\n", 454 | "train_loss = exp_results[\"train_loss\"]\n", 455 | "valid_loss = exp_results[\"valid_loss\"]\n", 456 | "lines = plt.plot(epochs, train_loss, epochs, valid_loss)\n", 457 | "\n", 458 | "plt.legend(('Train', 'Validation'), loc='upper right')\n", 459 | "plt.title('Loss chart')\n", 460 | "plt.xlabel('Epochs')\n", 461 | "plt.ylabel('Loss')\n", 462 | "plt.grid(True)\n", 463 | "plt.show()" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 7, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "data": { 473 | "image/png": "\n", 474 | "text/plain": [ 475 | "
" 476 | ] 477 | }, 478 | "metadata": { 479 | "needs_background": "light" 480 | }, 481 | "output_type": "display_data" 482 | } 483 | ], 484 | "source": [ 485 | "epochs = list(range(max_epoch))\n", 486 | "train_metric = exp_results[\"train_met\"]\n", 487 | "valid_metric = exp_results[\"valid_met\"]\n", 488 | "lines = plt.plot(epochs, train_metric, epochs, valid_metric)\n", 489 | "\n", 490 | "plt.legend(('Train', 'Validation'), loc='upper left')\n", 491 | "plt.title('Accuracy chart')\n", 492 | "plt.xlabel('Epochs')\n", 493 | "plt.ylabel('Accuracy')\n", 494 | "plt.grid(True)\n", 495 | "plt.show()" 496 | ] 497 | } 498 | ], 499 | "metadata": { 500 | "kernelspec": { 501 | "display_name": "Python 3", 502 | "language": "python", 503 | "name": "python3" 504 | }, 505 | "language_info": { 506 | "codemirror_mode": { 507 | "name": "ipython", 508 | "version": 3 509 | }, 510 | "file_extension": ".py", 511 | "mimetype": "text/x-python", 512 | "name": "python", 513 | "nbconvert_exporter": "python", 514 | "pygments_lexer": "ipython3", 515 | "version": "3.6.7" 516 | } 517 | }, 518 | "nbformat": 4, 519 | "nbformat_minor": 4 520 | } 521 | -------------------------------------------------------------------------------- /efficientnet.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import collections.abc as container_abcs 7 | from torch.utils import model_zoo 8 | 9 | 10 | def _pair(x): 11 | if isinstance(x, container_abcs.Iterable): 12 | return x 13 | return (x, x) 14 | 15 | 16 | class SamePaddingConv2d(nn.Module): 17 | def __init__(self, 18 | in_spatial_shape, 19 | in_channels, 20 | out_channels, 21 | kernel_size, 22 | stride, 23 | dilation=1, 24 | enforce_in_spatial_shape=False, 25 | **kwargs): 26 | super(SamePaddingConv2d, self).__init__() 27 | 28 | self._in_spatial_shape = _pair(in_spatial_shape) 29 | # e.g. throw exception if input spatial shape does not match in_spatial_shape 30 | # when calling self.forward() 31 | self.enforce_in_spatial_shape = enforce_in_spatial_shape 32 | kernel_size = _pair(kernel_size) 33 | stride = _pair(stride) 34 | dilation = _pair(dilation) 35 | 36 | in_height, in_width = self._in_spatial_shape 37 | filter_height, filter_width = kernel_size 38 | stride_heigth, stride_width = stride 39 | dilation_height, dilation_width = dilation 40 | 41 | out_height = int(ceil(float(in_height) / float(stride_heigth))) 42 | out_width = int(ceil(float(in_width) / float(stride_width))) 43 | 44 | pad_along_height = max((out_height - 1) * stride_heigth + 45 | filter_height + (filter_height - 1) * (dilation_height - 1) - in_height, 0) 46 | pad_along_width = max((out_width - 1) * stride_width + 47 | filter_width + (filter_width - 1) * (dilation_width - 1) - in_width, 0) 48 | 49 | pad_top = pad_along_height // 2 50 | pad_bottom = pad_along_height - pad_top 51 | pad_left = pad_along_width // 2 52 | pad_right = pad_along_width - pad_left 53 | 54 | paddings = (pad_left, pad_right, pad_top, pad_bottom) 55 | if any(p > 0 for p in paddings): 56 | self.zero_pad = nn.ZeroPad2d(paddings) 57 | else: 58 | self.zero_pad = None 59 | self.conv = nn.Conv2d(in_channels=in_channels, 60 | out_channels=out_channels, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | dilation=dilation, 64 | **kwargs) 65 | 66 | self._out_spatial_shape = (out_height, out_width) 67 | 68 | @property 69 | def in_spatial_shape(self): 70 | return self._in_spatial_shape 71 | 72 | @property 73 | def out_spatial_shape(self): 74 | return self._out_spatial_shape 75 | 76 | @property 77 | def in_channels(self): 78 | return self.conv.in_channels 79 | 80 | @property 81 | def out_channels(self): 82 | return self.conv.out_channels 83 | 84 | def check_spatial_shape(self, x): 85 | if x.size(2) != self.in_spatial_shape[0] or \ 86 | x.size(3) != self.in_spatial_shape[1]: 87 | raise ValueError( 88 | "Expected input spatial shape {}, got {} instead".format(self.in_spatial_shape, 89 | x.shape[2:])) 90 | 91 | def forward(self, x): 92 | if self.enforce_in_spatial_shape: 93 | self.check_spatial_shape(x) 94 | if self.zero_pad is not None: 95 | x = self.zero_pad(x) 96 | x = self.conv(x) 97 | return x 98 | 99 | 100 | class ConvBNAct(nn.Module): 101 | def __init__(self, 102 | out_channels, 103 | activation=None, 104 | bn_epsilon=None, 105 | bn_momentum=None, 106 | same_padding=False, 107 | **kwargs): 108 | super(ConvBNAct, self).__init__() 109 | 110 | _conv_cls = SamePaddingConv2d if same_padding else nn.Conv2d 111 | self.conv = _conv_cls(out_channels=out_channels, **kwargs) 112 | 113 | bn_kwargs = {} 114 | if bn_epsilon is not None: 115 | bn_kwargs["eps"] = bn_epsilon 116 | if bn_momentum is not None: 117 | bn_kwargs["momentum"] = bn_momentum 118 | 119 | self.bn = nn.BatchNorm2d(out_channels, **bn_kwargs) 120 | self.activation = activation 121 | 122 | @property 123 | def in_spatial_shape(self): 124 | if isinstance(self.conv, SamePaddingConv2d): 125 | return self.conv.in_spatial_shape 126 | else: 127 | return None 128 | 129 | @property 130 | def out_spatial_shape(self): 131 | if isinstance(self.conv, SamePaddingConv2d): 132 | return self.conv.out_spatial_shape 133 | else: 134 | return None 135 | 136 | @property 137 | def in_channels(self): 138 | return self.conv.in_channels 139 | 140 | @property 141 | def out_channels(self): 142 | return self.conv.out_channels 143 | 144 | def forward(self, x): 145 | x = self.conv(x) 146 | x = self.bn(x) 147 | if self.activation is not None: 148 | x = self.activation(x) 149 | return x 150 | 151 | 152 | class Swish(nn.Module): 153 | def __init__(self, 154 | beta=1.0, 155 | beta_learnable=False): 156 | super(Swish, self).__init__() 157 | 158 | if beta == 1.0 and not beta_learnable: 159 | self._op = self.simple_swish 160 | else: 161 | self.beta = nn.Parameter(torch.full([1], beta), 162 | requires_grad=beta_learnable) 163 | self._op = self.advanced_swish 164 | 165 | def simple_swish(self, x): 166 | return x * torch.sigmoid(x) 167 | 168 | def advanced_swish(self, x): 169 | return x * torch.sigmoid(self.beta * x) 170 | 171 | def forward(self, x): 172 | return self._op(x) 173 | 174 | 175 | class DropConnect(nn.Module): 176 | def __init__(self, rate=0.5): 177 | super(DropConnect, self).__init__() 178 | self.keep_prob = None 179 | self.set_rate(rate) 180 | 181 | def set_rate(self, rate): 182 | if not 0 <= rate < 1: 183 | raise ValueError("rate must be 0<=rate<1, got {} instead".format(rate)) 184 | self.keep_prob = 1 - rate 185 | 186 | def forward(self, x): 187 | if self.training: 188 | random_tensor = self.keep_prob + torch.rand([x.size(0), 1, 1, 1], 189 | dtype=x.dtype, 190 | device=x.device) 191 | binary_tensor = torch.floor(random_tensor) 192 | return torch.mul(torch.div(x, self.keep_prob), binary_tensor) 193 | else: 194 | return x 195 | 196 | 197 | class SqueezeExcitate(nn.Module): 198 | def __init__(self, 199 | in_channels, 200 | se_size, 201 | activation=None): 202 | super(SqueezeExcitate, self).__init__() 203 | self.dim_reduce = nn.Conv2d(in_channels=in_channels, 204 | out_channels=se_size, 205 | kernel_size=1) 206 | self.dim_restore = nn.Conv2d(in_channels=se_size, 207 | out_channels=in_channels, 208 | kernel_size=1) 209 | self.activation = F.relu if activation is None else activation 210 | 211 | def forward(self, x): 212 | inp = x 213 | x = F.adaptive_avg_pool2d(x, (1, 1)) 214 | x = self.dim_reduce(x) 215 | x = self.activation(x) 216 | x = self.dim_restore(x) 217 | x = torch.sigmoid(x) 218 | return torch.mul(inp, x) 219 | 220 | 221 | class MBConvBlock(nn.Module): 222 | def __init__(self, 223 | in_spatial_shape, 224 | in_channels, 225 | out_channels, 226 | kernel_size, 227 | stride, 228 | expansion_factor, 229 | activation, 230 | bn_epsilon=None, 231 | bn_momentum=None, 232 | se_size=None, 233 | drop_connect_rate=None, 234 | bias=False): 235 | """ 236 | Initialize new MBConv block 237 | :param in_spatial_shape: image shape, e.g. tuple [height, width] or int size for [size, size] 238 | :param in_channels: number of input channels 239 | :param out_channels: number of output channels 240 | :param kernel_size: kernel size for depth-wise convolution 241 | :param stride: stride for depth-wise convolution 242 | :param expansion_factor: expansion factor 243 | :param bn_epsilon: batch normalization epsilon 244 | :param bn_momentum: batch normalization momentum 245 | :param se_size: number of features in reduction layer of Squeeze-and-Excitate layer 246 | :param activation: activation function 247 | :param drop_connect_rate: DropConnect rate 248 | :param bias: enable bias in convolution operations 249 | """ 250 | super(MBConvBlock, self).__init__() 251 | 252 | if se_size is not None and se_size < 1: 253 | raise ValueError("se_size must be >=1, got {} instead".format(se_size)) 254 | 255 | if drop_connect_rate is not None and not 0 <= drop_connect_rate < 1: 256 | raise ValueError("drop_connect_rate must be in range [0,1), got {} instead".format(drop_connect_rate)) 257 | 258 | if not (isinstance(expansion_factor, int) and expansion_factor >= 1): 259 | raise ValueError("expansion factor must be int and >=1, got {} instead".format(expansion_factor)) 260 | 261 | exp_channels = in_channels * expansion_factor 262 | kernel_size = _pair(kernel_size) 263 | stride = _pair(stride) 264 | 265 | self.activation = activation 266 | 267 | # expansion convolution 268 | if expansion_factor != 1: 269 | self.expand_conv = ConvBNAct(in_channels=in_channels, 270 | out_channels=exp_channels, 271 | kernel_size=(1, 1), 272 | bias=bias, 273 | activation=self.activation, 274 | bn_epsilon=bn_epsilon, 275 | bn_momentum=bn_momentum) 276 | else: 277 | self.expand_conv = None 278 | 279 | # depth-wise convolution 280 | self.dp_conv = ConvBNAct(in_spatial_shape=in_spatial_shape, 281 | in_channels=exp_channels, 282 | out_channels=exp_channels, 283 | kernel_size=kernel_size, 284 | stride=stride, 285 | groups=exp_channels, 286 | bias=bias, 287 | activation=self.activation, 288 | same_padding=True, 289 | bn_epsilon=bn_epsilon, 290 | bn_momentum=bn_momentum) 291 | 292 | if se_size is not None: 293 | self.se = SqueezeExcitate(exp_channels, 294 | se_size, 295 | activation=self.activation) 296 | else: 297 | self.se = None 298 | 299 | if drop_connect_rate is not None: 300 | self.drop_connect = DropConnect(drop_connect_rate) 301 | else: 302 | self.drop_connect = None 303 | 304 | if in_channels == out_channels and all(s == 1 for s in stride): 305 | self.skip_enabled = True 306 | else: 307 | self.skip_enabled = False 308 | 309 | # projection convolution 310 | self.project_conv = ConvBNAct(in_channels=exp_channels, 311 | out_channels=out_channels, 312 | kernel_size=(1, 1), 313 | bias=bias, 314 | activation=None, 315 | bn_epsilon=bn_epsilon, 316 | bn_momentum=bn_momentum) 317 | 318 | @property 319 | def in_spatial_shape(self): 320 | return self.dp_conv.in_spatial_shape 321 | 322 | @property 323 | def out_spatial_shape(self): 324 | return self.dp_conv.out_spatial_shape 325 | 326 | @property 327 | def in_channels(self): 328 | if self.expand_conv is not None: 329 | return self.expand_conv.in_channels 330 | else: 331 | return self.dp_conv.in_channels 332 | 333 | @property 334 | def out_channels(self): 335 | return self.project_conv.out_channels 336 | 337 | def forward(self, x): 338 | inp = x 339 | 340 | if self.expand_conv is not None: 341 | # expansion convolution applied only if expansion ratio > 1 342 | x = self.expand_conv(x) 343 | 344 | # depth-wise convolution 345 | x = self.dp_conv(x) 346 | 347 | # squeeze-and-excitate 348 | if self.se is not None: 349 | x = self.se(x) 350 | 351 | # projection convolution 352 | x = self.project_conv(x) 353 | 354 | if self.skip_enabled: 355 | # drop-connect applied only if skip connection enabled 356 | if self.drop_connect is not None: 357 | x = self.drop_connect(x) 358 | x = x + inp 359 | return x 360 | 361 | 362 | class EnetStage(nn.Module): 363 | def __init__(self, 364 | num_layers, 365 | in_spatial_shape, 366 | in_channels, 367 | out_channels, 368 | stride, 369 | se_ratio, 370 | drop_connect_rates, 371 | **kwargs): 372 | super(EnetStage, self).__init__() 373 | 374 | if not (isinstance(num_layers, int) and num_layers >= 1): 375 | raise ValueError("num_layers must be int and >=1, got {} instead".format(num_layers)) 376 | 377 | if not (isinstance(drop_connect_rates, container_abcs.Iterable) and 378 | len(drop_connect_rates) == num_layers): 379 | raise ValueError("drop_connect_rates must be iterable of " 380 | "length num_layers ({}), got {} instead".format(num_layers, drop_connect_rates)) 381 | 382 | self.num_layers = num_layers 383 | self.layers = nn.ModuleList() 384 | spatial_shape = in_spatial_shape 385 | for i in range(self.num_layers): 386 | se_size = max(1, in_channels // se_ratio) 387 | layer = MBConvBlock(in_spatial_shape=spatial_shape, 388 | in_channels=in_channels, 389 | out_channels=out_channels, 390 | stride=stride, 391 | se_size=se_size, 392 | drop_connect_rate=drop_connect_rates[i], 393 | **kwargs) 394 | self.layers.append(layer) 395 | spatial_shape = layer.out_spatial_shape 396 | # remaining MBConv blocks have stride 1 and in_channels=out_channels 397 | stride = 1 398 | in_channels = out_channels 399 | 400 | @property 401 | def in_spatial_shape(self): 402 | return self.layers[0].in_spatial_shape 403 | 404 | @property 405 | def out_spatial_shape(self): 406 | return self.layers[-1].out_spatial_shape 407 | 408 | @property 409 | def in_channels(self): 410 | return self.layers[0].in_channels 411 | 412 | @property 413 | def out_channels(self): 414 | return self.layers[-1].out_channels 415 | 416 | def forward(self, x): 417 | for layer in self.layers: 418 | x = layer(x) 419 | return x 420 | 421 | 422 | def round_filters(filters, width_coefficient, depth_divisor=8): 423 | """Round number of filters based on depth multiplier.""" 424 | min_depth = depth_divisor 425 | 426 | filters *= width_coefficient 427 | new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor) 428 | # Make sure that round down does not go down by more than 10%. 429 | if new_filters < 0.9 * filters: 430 | new_filters += depth_divisor 431 | return int(new_filters) 432 | 433 | 434 | def round_repeats(repeats, depth_coefficient): 435 | """Round number of filters based on depth multiplier.""" 436 | return int(ceil(depth_coefficient * repeats)) 437 | 438 | 439 | class EfficientNet(nn.Module): 440 | # (width_coefficient, depth_coefficient, dropout_rate, in_spatial_shape) 441 | coefficients = [ 442 | (1.0, 1.0, 0.2, 224), 443 | (1.0, 1.1, 0.2, 240), 444 | (1.1, 1.2, 0.3, 260), 445 | (1.2, 1.4, 0.3, 300), 446 | (1.4, 1.8, 0.4, 380), 447 | (1.6, 2.2, 0.4, 456), 448 | (1.8, 2.6, 0.5, 528), 449 | (2.0, 3.1, 0.5, 600), 450 | ] 451 | 452 | # block_repeat, kernel_size, stride, expansion_factor, input_channels, output_channels, se_ratio 453 | stage_args = [ 454 | [1, 3, 1, 1, 32, 16, 4], 455 | [2, 3, 2, 6, 16, 24, 4], 456 | [2, 5, 2, 6, 24, 40, 4], 457 | [3, 3, 2, 6, 40, 80, 4], 458 | [3, 5, 1, 6, 80, 112, 4], 459 | [4, 5, 2, 6, 112, 192, 4], 460 | [1, 3, 1, 6, 192, 320, 4], 461 | ] 462 | 463 | state_dict_urls = [ 464 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmliYV9HaE5PWWVEbXVMd3c/root/content", 465 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlicV9HaE5PWWVEbXVMd3c/root/content", 466 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmliNl9HaE5PWWVEbXVMd3c/root/content", 467 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmljS19HaE5PWWVEbXVMd3c/root/content", 468 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmljYV9HaE5PWWVEbXVMd3c/root/content", 469 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmljcV9HaE5PWWVEbXVMd3c/root/content", 470 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmljNl9HaE5PWWVEbXVMd3c/root/content", 471 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlkS19HaE5PWWVEbXVMd3c/root/content", 472 | ] 473 | 474 | dict_names = [ 475 | 'efficientnet-b0-d86f8792.pth', 476 | 'efficientnet-b1-82896633.pth', 477 | 'efficientnet-b2-e4b93854.pth', 478 | 'efficientnet-b3-3b9ca610.pth', 479 | 'efficientnet-b4-24436ca5.pth', 480 | 'efficientnet-b5-d8e577e8.pth', 481 | 'efficientnet-b6-f20845c7.pth', 482 | 'efficientnet-b7-86e8e374.pth' 483 | ] 484 | 485 | def __init__(self, 486 | b, 487 | in_channels=3, 488 | n_classes=1000, 489 | in_spatial_shape=None, 490 | activation=Swish(), 491 | bias=False, 492 | drop_connect_rate=0.2, 493 | dropout_rate=None, 494 | bn_epsilon=1e-3, 495 | bn_momentum=0.01, 496 | pretrained=False, 497 | progress=False): 498 | """ 499 | Initialize new EfficientNet model 500 | :param b: model index, i.e. 0 for EfficientNet-B0 501 | :param in_channels: number of input channels 502 | :param n_classes: number of output classes 503 | :param in_spatial_shape: input image shape 504 | :param activation: activation function 505 | :param bias: enable bias in convolution operations 506 | :param drop_connect_rate: DropConnect rate 507 | :param dropout_rate: dropout rate, this will override default rate for each model 508 | :param bn_epsilon: batch normalization epsilon 509 | :param bn_momentum: batch normalization momentum 510 | :param pretrained: initialize model with weights pre-trained on ImageNet 511 | :param progress: show progress when downloading pre-trained weights 512 | """ 513 | 514 | super(EfficientNet, self).__init__() 515 | 516 | # verify all parameters 517 | EfficientNet.check_init_params(b, 518 | in_channels, 519 | n_classes, 520 | in_spatial_shape, 521 | activation, 522 | bias, 523 | drop_connect_rate, 524 | dropout_rate, 525 | bn_epsilon, 526 | bn_momentum, 527 | pretrained, 528 | progress) 529 | 530 | self.b = b 531 | self.in_channels = in_channels 532 | self.activation = activation 533 | self.drop_connect_rate = drop_connect_rate 534 | self._override_dropout_rate = dropout_rate 535 | 536 | width_coefficient, _, _, spatial_shape = EfficientNet.coefficients[self.b] 537 | 538 | if in_spatial_shape is not None: 539 | self.in_spatial_shape = _pair(in_spatial_shape) 540 | else: 541 | self.in_spatial_shape = _pair(spatial_shape) 542 | 543 | # initial convolution 544 | init_conv_out_channels = round_filters(32, width_coefficient) 545 | self.init_conv = ConvBNAct(in_spatial_shape=self.in_spatial_shape, 546 | in_channels=self.in_channels, 547 | out_channels=init_conv_out_channels, 548 | kernel_size=(3, 3), 549 | stride=(2, 2), 550 | bias=bias, 551 | activation=self.activation, 552 | same_padding=True, 553 | bn_epsilon=bn_epsilon, 554 | bn_momentum=bn_momentum) 555 | spatial_shape = self.init_conv.out_spatial_shape 556 | 557 | self.stages = nn.ModuleList() 558 | mbconv_idx = 0 559 | dc_rates = self.get_dc_rates() 560 | for stage_id in range(self.num_stages): 561 | kernel_size = self.get_stage_kernel_size(stage_id) 562 | stride = self.get_stage_stride(stage_id) 563 | expansion_factor = self.get_stage_expansion_factor(stage_id) 564 | stage_in_channels = self.get_stage_in_channels(stage_id) 565 | stage_out_channels = self.get_stage_out_channels(stage_id) 566 | stage_num_layers = self.get_stage_num_layers(stage_id) 567 | stage_dc_rates = dc_rates[mbconv_idx:mbconv_idx + stage_num_layers] 568 | stage_se_ratio = self.get_stage_se_ratio(stage_id) 569 | 570 | stage = EnetStage(num_layers=stage_num_layers, 571 | in_spatial_shape=spatial_shape, 572 | in_channels=stage_in_channels, 573 | out_channels=stage_out_channels, 574 | stride=stride, 575 | se_ratio=stage_se_ratio, 576 | drop_connect_rates=stage_dc_rates, 577 | kernel_size=kernel_size, 578 | expansion_factor=expansion_factor, 579 | activation=self.activation, 580 | bn_epsilon=bn_epsilon, 581 | bn_momentum=bn_momentum, 582 | bias=bias 583 | ) 584 | self.stages.append(stage) 585 | spatial_shape = stage.out_spatial_shape 586 | mbconv_idx += stage_num_layers 587 | 588 | head_conv_out_channels = round_filters(1280, width_coefficient) 589 | head_conv_in_channels = self.stages[-1].layers[-1].project_conv.out_channels 590 | self.head_conv = ConvBNAct(in_channels=head_conv_in_channels, 591 | out_channels=head_conv_out_channels, 592 | kernel_size=(1, 1), 593 | bias=bias, 594 | activation=self.activation, 595 | bn_epsilon=bn_epsilon, 596 | bn_momentum=bn_momentum) 597 | 598 | if self.dropout_rate > 0: 599 | self.dropout = nn.Dropout(p=self.dropout_rate) 600 | else: 601 | self.dropout = None 602 | 603 | self.avpool = nn.AdaptiveAvgPool2d((1, 1)) 604 | self.fc = nn.Linear(head_conv_out_channels, n_classes) 605 | 606 | if pretrained: 607 | self._load_state(self.b, in_channels, n_classes, progress) 608 | 609 | @property 610 | def num_stages(self): 611 | return len(EfficientNet.stage_args) 612 | 613 | @property 614 | def width_coefficient(self): 615 | return EfficientNet.coefficients[self.b][0] 616 | 617 | @property 618 | def depth_coefficient(self): 619 | return EfficientNet.coefficients[self.b][1] 620 | 621 | @property 622 | def dropout_rate(self): 623 | if self._override_dropout_rate is None: 624 | return EfficientNet.coefficients[self.b][2] 625 | else: 626 | return self._override_dropout_rate 627 | 628 | def get_stage_kernel_size(self, stage): 629 | return EfficientNet.stage_args[stage][1] 630 | 631 | def get_stage_stride(self, stage): 632 | return EfficientNet.stage_args[stage][2] 633 | 634 | def get_stage_expansion_factor(self, stage): 635 | return EfficientNet.stage_args[stage][3] 636 | 637 | def get_stage_in_channels(self, stage): 638 | width_coefficient = self.width_coefficient 639 | in_channels = EfficientNet.stage_args[stage][4] 640 | return round_filters(in_channels, width_coefficient) 641 | 642 | def get_stage_out_channels(self, stage): 643 | width_coefficient = self.width_coefficient 644 | out_channels = EfficientNet.stage_args[stage][5] 645 | return round_filters(out_channels, width_coefficient) 646 | 647 | def get_stage_se_ratio(self, stage): 648 | return EfficientNet.stage_args[stage][6] 649 | 650 | def get_stage_num_layers(self, stage): 651 | depth_coefficient = self.depth_coefficient 652 | num_layers = EfficientNet.stage_args[stage][0] 653 | return round_repeats(num_layers, depth_coefficient) 654 | 655 | def get_num_mbconv_layers(self): 656 | total = 0 657 | for i in range(self.num_stages): 658 | total += self.get_stage_num_layers(i) 659 | return total 660 | 661 | def get_dc_rates(self): 662 | total_mbconv_layers = self.get_num_mbconv_layers() 663 | return [self.drop_connect_rate * i / total_mbconv_layers 664 | for i in range(total_mbconv_layers)] 665 | 666 | def _load_state(self, b, in_channels, n_classes, progress): 667 | state_dict = model_zoo.load_url(EfficientNet.state_dict_urls[b], progress=progress, file_name=EfficientNet.dict_names[b]) 668 | strict = True 669 | if in_channels != 3: 670 | state_dict.pop('init_conv.conv.conv.weight') 671 | strict = False 672 | if n_classes != 1000: 673 | state_dict.pop('fc.weight') 674 | state_dict.pop('fc.bias') 675 | strict = False 676 | self.load_state_dict(state_dict, strict=strict) 677 | print("Model weights loaded successfully.") 678 | 679 | def check_input(self, x): 680 | if x.dim() != 4: 681 | raise ValueError("Input x must be 4 dimensional tensor, got {} instead".format(x.dim())) 682 | if x.size(1) != self.in_channels: 683 | raise ValueError("Input must have {} channels, got {} instead".format(self.in_channels, 684 | x.size(1))) 685 | 686 | @staticmethod 687 | def check_init_params(b, 688 | in_channels, 689 | n_classes, 690 | in_spatial_shape, 691 | activation, 692 | bias, 693 | drop_connect_rate, 694 | override_dropout_rate, 695 | bn_epsilon, 696 | bn_momentum, 697 | pretrained, 698 | progress): 699 | 700 | if not isinstance(b, int): 701 | raise ValueError("b must be int, got {} instead".format(type(b))) 702 | elif not 0 <= b < len(EfficientNet.coefficients): 703 | raise ValueError("b must be in range 0<=b<=7, got {} instead".format(b)) 704 | 705 | if not isinstance(in_channels, int): 706 | raise ValueError("in_channels must be int, got {} instead".format(type(in_channels))) 707 | elif not in_channels > 0: 708 | raise ValueError("in_channels must be > 0, got {} instead".format(in_channels)) 709 | 710 | if not isinstance(n_classes, int): 711 | raise ValueError("n_classes must be int, got {} instead".format(type(n_classes))) 712 | elif not n_classes > 0: 713 | raise ValueError("n_classes must be > 0, got {} instead".format(n_classes)) 714 | 715 | if not (in_spatial_shape is None or 716 | isinstance(in_spatial_shape, int) or 717 | (isinstance(in_spatial_shape, container_abcs.Iterable) and 718 | len(in_spatial_shape) == 2 and 719 | all(isinstance(s, int) for s in in_spatial_shape))): 720 | raise ValueError("in_spatial_shape must be either None, int or iterable of ints of length 2" 721 | ", got {} instead".format(in_spatial_shape)) 722 | 723 | if activation is not None and not callable(activation): 724 | raise ValueError("activation must be callable but is not") 725 | 726 | if not isinstance(bias, bool): 727 | raise ValueError("bias must be bool, got {} instead".format(type(bias))) 728 | 729 | if not isinstance(drop_connect_rate, float): 730 | raise ValueError("drop_connect_rate must be float, got {} instead".format(type(drop_connect_rate))) 731 | elif not 0 <= drop_connect_rate < 1.0: 732 | raise ValueError("drop_connect_rate must be within range 0 <= drop_connect_rate < 1.0, " 733 | "got {} instead".format(drop_connect_rate)) 734 | 735 | if override_dropout_rate is not None: 736 | if not isinstance(override_dropout_rate, float): 737 | raise ValueError("dropout_rate must be either None or float, " 738 | "got {} instead".format(type(override_dropout_rate))) 739 | elif not 0 <= override_dropout_rate < 1.0: 740 | raise ValueError("dropout_rate must be within range 0 <= dropout_rate < 1.0, " 741 | "got {} instead".format(override_dropout_rate)) 742 | 743 | if not isinstance(bn_epsilon, float): 744 | raise ValueError("bn_epsilon must be float, got {} instead".format(bn_epsilon)) 745 | 746 | if not isinstance(bn_momentum, float): 747 | raise ValueError("bn_momentum must be float, got {} instead".format(bn_momentum)) 748 | 749 | if not isinstance(pretrained, bool): 750 | raise ValueError("pretrained must be bool, got {} instead".format(type(pretrained))) 751 | 752 | if not isinstance(progress, bool): 753 | raise ValueError("progress must be bool, got {} instead".format(type(progress))) 754 | 755 | def get_features(self, x): 756 | 757 | self.check_input(x) 758 | 759 | x = self.init_conv(x) 760 | out = [] 761 | for stage in self.stages: 762 | x = stage(x) 763 | out.append(x) 764 | return out 765 | 766 | def forward(self, x): 767 | 768 | x = self.get_features(x)[-1] 769 | 770 | x = self.head_conv(x) 771 | 772 | x = self.avpool(x) 773 | x = torch.flatten(x, 1) 774 | 775 | if self.dropout is not None: 776 | x = self.dropout(x) 777 | x = self.fc(x) 778 | 779 | return x 780 | -------------------------------------------------------------------------------- /efficientnet_v2.py: -------------------------------------------------------------------------------- 1 | import collections.abc as container_abc 2 | from collections import OrderedDict 3 | from math import ceil, floor 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils import model_zoo 9 | 10 | 11 | def _pair(x): 12 | if isinstance(x, container_abc.Iterable): 13 | return x 14 | return (x, x) 15 | 16 | 17 | def torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride): 18 | if in_spatial_shape is None: 19 | return None 20 | # in_spatial_shape -> [H,W] 21 | hin, win = _pair(in_spatial_shape) 22 | kh, kw = _pair(kernel_size) 23 | sh, sw = _pair(stride) 24 | 25 | # dilation and padding are ignored since they are always fixed in efficientnetV2 26 | hout = int(floor((hin - kh - 1) / sh + 1)) 27 | wout = int(floor((win - kw - 1) / sw + 1)) 28 | return hout, wout 29 | 30 | 31 | def get_activation(act_fn: str, **kwargs): 32 | if act_fn in ('silu', 'swish'): 33 | return nn.SiLU(**kwargs) 34 | elif act_fn == 'relu': 35 | return nn.ReLU(**kwargs) 36 | elif act_fn == 'relu6': 37 | return nn.ReLU6(**kwargs) 38 | elif act_fn == 'elu': 39 | return nn.ELU(**kwargs) 40 | elif act_fn == 'leaky_relu': 41 | return nn.LeakyReLU(**kwargs) 42 | elif act_fn == 'selu': 43 | return nn.SELU(**kwargs) 44 | elif act_fn == 'mish': 45 | return nn.Mish(**kwargs) 46 | else: 47 | raise ValueError('Unsupported act_fn {}'.format(act_fn)) 48 | 49 | 50 | def round_filters(filters, width_coefficient, depth_divisor=8): 51 | """Round number of filters based on depth multiplier.""" 52 | min_depth = depth_divisor 53 | filters *= width_coefficient 54 | new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor) 55 | return int(new_filters) 56 | 57 | 58 | def round_repeats(repeats, depth_coefficient): 59 | """Round number of filters based on depth multiplier.""" 60 | return int(ceil(depth_coefficient * repeats)) 61 | 62 | 63 | class DropConnect(nn.Module): 64 | def __init__(self, rate=0.5): 65 | super(DropConnect, self).__init__() 66 | self.keep_prob = None 67 | self.set_rate(rate) 68 | 69 | def set_rate(self, rate): 70 | if not 0 <= rate < 1: 71 | raise ValueError("rate must be 0<=rate<1, got {} instead".format(rate)) 72 | self.keep_prob = 1 - rate 73 | 74 | def forward(self, x): 75 | if self.training: 76 | random_tensor = self.keep_prob + torch.rand([x.size(0), 1, 1, 1], 77 | dtype=x.dtype, 78 | device=x.device) 79 | binary_tensor = torch.floor(random_tensor) 80 | return torch.mul(torch.div(x, self.keep_prob), binary_tensor) 81 | else: 82 | return x 83 | 84 | 85 | class SamePaddingConv2d(nn.Module): 86 | def __init__(self, 87 | in_spatial_shape, 88 | in_channels, 89 | out_channels, 90 | kernel_size, 91 | stride, 92 | dilation=1, 93 | enforce_in_spatial_shape=False, 94 | **kwargs): 95 | super(SamePaddingConv2d, self).__init__() 96 | 97 | self._in_spatial_shape = _pair(in_spatial_shape) 98 | # e.g. throw exception if input spatial shape does not match in_spatial_shape 99 | # when calling self.forward() 100 | self.enforce_in_spatial_shape = enforce_in_spatial_shape 101 | kernel_size = _pair(kernel_size) 102 | stride = _pair(stride) 103 | dilation = _pair(dilation) 104 | 105 | in_height, in_width = self._in_spatial_shape 106 | filter_height, filter_width = kernel_size 107 | stride_heigth, stride_width = stride 108 | dilation_height, dilation_width = dilation 109 | 110 | out_height = int(ceil(float(in_height) / float(stride_heigth))) 111 | out_width = int(ceil(float(in_width) / float(stride_width))) 112 | 113 | pad_along_height = max((out_height - 1) * stride_heigth + 114 | filter_height + (filter_height - 1) * (dilation_height - 1) - in_height, 0) 115 | pad_along_width = max((out_width - 1) * stride_width + 116 | filter_width + (filter_width - 1) * (dilation_width - 1) - in_width, 0) 117 | 118 | pad_top = pad_along_height // 2 119 | pad_bottom = pad_along_height - pad_top 120 | pad_left = pad_along_width // 2 121 | pad_right = pad_along_width - pad_left 122 | 123 | paddings = (pad_left, pad_right, pad_top, pad_bottom) 124 | if any(p > 0 for p in paddings): 125 | self.zero_pad = nn.ZeroPad2d(paddings) 126 | else: 127 | self.zero_pad = None 128 | self.conv = nn.Conv2d(in_channels=in_channels, 129 | out_channels=out_channels, 130 | kernel_size=kernel_size, 131 | stride=stride, 132 | dilation=dilation, 133 | **kwargs) 134 | 135 | self._out_spatial_shape = (out_height, out_width) 136 | 137 | @property 138 | def out_spatial_shape(self): 139 | return self._out_spatial_shape 140 | 141 | def check_spatial_shape(self, x): 142 | if x.size(2) != self._in_spatial_shape[0] or \ 143 | x.size(3) != self._in_spatial_shape[1]: 144 | raise ValueError( 145 | "Expected input spatial shape {}, got {} instead".format(self._in_spatial_shape, x.shape[2:])) 146 | 147 | def forward(self, x): 148 | if self.enforce_in_spatial_shape: 149 | self.check_spatial_shape(x) 150 | if self.zero_pad is not None: 151 | x = self.zero_pad(x) 152 | x = self.conv(x) 153 | return x 154 | 155 | 156 | class SqueezeExcitate(nn.Module): 157 | def __init__(self, 158 | in_channels, 159 | se_size, 160 | activation=None): 161 | super(SqueezeExcitate, self).__init__() 162 | self.dim_reduce = nn.Conv2d(in_channels=in_channels, 163 | out_channels=se_size, 164 | kernel_size=1) 165 | self.dim_restore = nn.Conv2d(in_channels=se_size, 166 | out_channels=in_channels, 167 | kernel_size=1) 168 | self.activation = F.relu if activation is None else activation 169 | 170 | def forward(self, x): 171 | inp = x 172 | x = F.adaptive_avg_pool2d(x, (1, 1)) 173 | x = self.dim_reduce(x) 174 | x = self.activation(x) 175 | x = self.dim_restore(x) 176 | x = torch.sigmoid(x) 177 | return torch.mul(inp, x) 178 | 179 | 180 | class MBConvBlockV2(nn.Module): 181 | def __init__(self, 182 | in_channels, 183 | out_channels, 184 | kernel_size, 185 | stride, 186 | expansion_factor, 187 | act_fn, 188 | act_kwargs=None, 189 | bn_epsilon=None, 190 | bn_momentum=None, 191 | se_size=None, 192 | drop_connect_rate=None, 193 | bias=False, 194 | tf_style_conv=False, 195 | in_spatial_shape=None): 196 | 197 | super().__init__() 198 | 199 | if act_kwargs is None: 200 | act_kwargs = {} 201 | exp_channels = in_channels * expansion_factor 202 | 203 | self.ops_lst = [] 204 | 205 | # expansion convolution 206 | if expansion_factor != 1: 207 | self.expand_conv = nn.Conv2d(in_channels=in_channels, 208 | out_channels=exp_channels, 209 | kernel_size=1, 210 | bias=bias) 211 | 212 | self.expand_bn = nn.BatchNorm2d(num_features=exp_channels, 213 | eps=bn_epsilon, 214 | momentum=bn_momentum) 215 | 216 | self.expand_act = get_activation(act_fn, **act_kwargs) 217 | self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act]) 218 | 219 | # depth-wise convolution 220 | if tf_style_conv: 221 | self.dp_conv = SamePaddingConv2d(in_spatial_shape=in_spatial_shape, 222 | in_channels=exp_channels, 223 | out_channels=exp_channels, 224 | kernel_size=kernel_size, 225 | stride=stride, 226 | groups=exp_channels, 227 | bias=bias) 228 | self.out_spatial_shape = self.dp_conv.out_spatial_shape 229 | else: 230 | self.dp_conv = nn.Conv2d(in_channels=exp_channels, 231 | out_channels=exp_channels, 232 | kernel_size=kernel_size, 233 | stride=stride, 234 | padding=1, 235 | groups=exp_channels, 236 | bias=bias) 237 | self.out_spatial_shape = torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride) 238 | 239 | self.dp_bn = nn.BatchNorm2d(num_features=exp_channels, 240 | eps=bn_epsilon, 241 | momentum=bn_momentum) 242 | 243 | self.dp_act = get_activation(act_fn, **act_kwargs) 244 | self.ops_lst.extend([self.dp_conv, self.dp_bn, self.dp_act]) 245 | 246 | # Squeeze and Excitate 247 | if se_size is not None: 248 | self.se = SqueezeExcitate(exp_channels, 249 | se_size, 250 | activation=get_activation(act_fn, **act_kwargs)) 251 | self.ops_lst.append(self.se) 252 | 253 | # projection layer 254 | self.project_conv = nn.Conv2d(in_channels=exp_channels, 255 | out_channels=out_channels, 256 | kernel_size=1, 257 | bias=bias) 258 | 259 | self.project_bn = nn.BatchNorm2d(num_features=out_channels, 260 | eps=bn_epsilon, 261 | momentum=bn_momentum) 262 | 263 | # no activation function in projection layer 264 | 265 | self.ops_lst.extend([self.project_conv, self.project_bn]) 266 | 267 | self.skip_enabled = in_channels == out_channels and stride == 1 268 | 269 | if self.skip_enabled and drop_connect_rate is not None: 270 | self.drop_connect = DropConnect(drop_connect_rate) 271 | self.ops_lst.append(self.drop_connect) 272 | 273 | def forward(self, x): 274 | inp = x 275 | for op in self.ops_lst: 276 | x = op(x) 277 | if self.skip_enabled: 278 | return x + inp 279 | else: 280 | return x 281 | 282 | 283 | class FusedMBConvBlockV2(nn.Module): 284 | def __init__(self, 285 | in_channels, 286 | out_channels, 287 | kernel_size, 288 | stride, 289 | expansion_factor, 290 | act_fn, 291 | act_kwargs=None, 292 | bn_epsilon=None, 293 | bn_momentum=None, 294 | se_size=None, 295 | drop_connect_rate=None, 296 | bias=False, 297 | tf_style_conv=False, 298 | in_spatial_shape=None): 299 | 300 | super().__init__() 301 | 302 | if act_kwargs is None: 303 | act_kwargs = {} 304 | exp_channels = in_channels * expansion_factor 305 | 306 | self.ops_lst = [] 307 | 308 | # expansion convolution 309 | expansion_out_shape = in_spatial_shape 310 | if expansion_factor != 1: 311 | if tf_style_conv: 312 | self.expand_conv = SamePaddingConv2d(in_spatial_shape=in_spatial_shape, 313 | in_channels=in_channels, 314 | out_channels=exp_channels, 315 | kernel_size=kernel_size, 316 | stride=stride, 317 | bias=bias) 318 | expansion_out_shape = self.expand_conv.out_spatial_shape 319 | else: 320 | self.expand_conv = nn.Conv2d(in_channels=in_channels, 321 | out_channels=exp_channels, 322 | kernel_size=kernel_size, 323 | padding=1, 324 | stride=stride, 325 | bias=bias) 326 | expansion_out_shape = torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride) 327 | 328 | self.expand_bn = nn.BatchNorm2d(num_features=exp_channels, 329 | eps=bn_epsilon, 330 | momentum=bn_momentum) 331 | 332 | self.expand_act = get_activation(act_fn, **act_kwargs) 333 | self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act]) 334 | 335 | # Squeeze and Excitate 336 | if se_size is not None: 337 | self.se = SqueezeExcitate(exp_channels, 338 | se_size, 339 | activation=get_activation(act_fn, **act_kwargs)) 340 | self.ops_lst.append(self.se) 341 | 342 | # projection layer 343 | kernel_size = 1 if expansion_factor != 1 else kernel_size 344 | stride = 1 if expansion_factor != 1 else stride 345 | if tf_style_conv: 346 | self.project_conv = SamePaddingConv2d(in_spatial_shape=expansion_out_shape, 347 | in_channels=exp_channels, 348 | out_channels=out_channels, 349 | kernel_size=kernel_size, 350 | stride=stride, 351 | bias=bias) 352 | self.out_spatial_shape = self.project_conv.out_spatial_shape 353 | else: 354 | self.project_conv = nn.Conv2d(in_channels=exp_channels, 355 | out_channels=out_channels, 356 | kernel_size=kernel_size, 357 | stride=stride, 358 | padding=1 if kernel_size > 1 else 0, 359 | bias=bias) 360 | self.out_spatial_shape = torch_conv_out_spatial_shape(expansion_out_shape, kernel_size, stride) 361 | 362 | self.project_bn = nn.BatchNorm2d(num_features=out_channels, 363 | eps=bn_epsilon, 364 | momentum=bn_momentum) 365 | 366 | self.ops_lst.extend( 367 | [self.project_conv, self.project_bn]) 368 | 369 | if expansion_factor == 1: 370 | self.project_act = get_activation(act_fn, **act_kwargs) 371 | self.ops_lst.append(self.project_act) 372 | 373 | self.skip_enabled = in_channels == out_channels and stride == 1 374 | 375 | if self.skip_enabled and drop_connect_rate is not None: 376 | self.drop_connect = DropConnect(drop_connect_rate) 377 | self.ops_lst.append(self.drop_connect) 378 | 379 | def forward(self, x): 380 | inp = x 381 | for op in self.ops_lst: 382 | x = op(x) 383 | if self.skip_enabled: 384 | return x + inp 385 | else: 386 | return x 387 | 388 | 389 | class EfficientNetV2(nn.Module): 390 | _models = {'b0': {'num_repeat': [1, 2, 2, 3, 5, 8], 391 | 'kernel_size': [3, 3, 3, 3, 3, 3], 392 | 'stride': [1, 2, 2, 2, 1, 2], 393 | 'expand_ratio': [1, 4, 4, 4, 6, 6], 394 | 'in_channel': [32, 16, 32, 48, 96, 112], 395 | 'out_channel': [16, 32, 48, 96, 112, 192], 396 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25], 397 | 'conv_type': [1, 1, 1, 0, 0, 0], 398 | 'is_feature_stage': [False, True, True, False, True, True], 399 | 'width_coefficient': 1.0, 400 | 'depth_coefficient': 1.0, 401 | 'train_size': 192, 402 | 'eval_size': 224, 403 | 'dropout': 0.2, 404 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVBhWkZRcWNXR3dINmRLP2U9UUI5ZndH/root/content', 405 | 'model_name': 'efficientnet_v2_b0_21k_ft1k-a91e14c5.pth'}, 406 | 'b1': {'num_repeat': [1, 2, 2, 3, 5, 8], 407 | 'kernel_size': [3, 3, 3, 3, 3, 3], 408 | 'stride': [1, 2, 2, 2, 1, 2], 409 | 'expand_ratio': [1, 4, 4, 4, 6, 6], 410 | 'in_channel': [32, 16, 32, 48, 96, 112], 411 | 'out_channel': [16, 32, 48, 96, 112, 192], 412 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25], 413 | 'conv_type': [1, 1, 1, 0, 0, 0], 414 | 'is_feature_stage': [False, True, True, False, True, True], 415 | 'width_coefficient': 1.0, 416 | 'depth_coefficient': 1.1, 417 | 'train_size': 192, 418 | 'eval_size': 240, 419 | 'dropout': 0.2, 420 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVJnVGV5UndSY2J2amwtP2U9dTBiV1lO/root/content', 421 | 'model_name': 'efficientnet_v2_b1_21k_ft1k-58f4fb47.pth'}, 422 | 'b2': {'num_repeat': [1, 2, 2, 3, 5, 8], 423 | 'kernel_size': [3, 3, 3, 3, 3, 3], 424 | 'stride': [1, 2, 2, 2, 1, 2], 425 | 'expand_ratio': [1, 4, 4, 4, 6, 6], 426 | 'in_channel': [32, 16, 32, 48, 96, 112], 427 | 'out_channel': [16, 32, 48, 96, 112, 192], 428 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25], 429 | 'conv_type': [1, 1, 1, 0, 0, 0], 430 | 'is_feature_stage': [False, True, True, False, True, True], 431 | 'width_coefficient': 1.1, 432 | 'depth_coefficient': 1.2, 433 | 'train_size': 208, 434 | 'eval_size': 260, 435 | 'dropout': 0.3, 436 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVY4M2NySVFZbU41X0tGP2U9ZERZVmxK/root/content', 437 | 'model_name': 'efficientnet_v2_b2_21k_ft1k-db4ac0ee.pth'}, 438 | 'b3': {'num_repeat': [1, 2, 2, 3, 5, 8], 439 | 'kernel_size': [3, 3, 3, 3, 3, 3], 440 | 'stride': [1, 2, 2, 2, 1, 2], 441 | 'expand_ratio': [1, 4, 4, 4, 6, 6], 442 | 'in_channel': [32, 16, 32, 48, 96, 112], 443 | 'out_channel': [16, 32, 48, 96, 112, 192], 444 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25], 445 | 'conv_type': [1, 1, 1, 0, 0, 0], 446 | 'is_feature_stage': [False, True, True, False, True, True], 447 | 'width_coefficient': 1.2, 448 | 'depth_coefficient': 1.4, 449 | 'train_size': 240, 450 | 'eval_size': 300, 451 | 'dropout': 0.3, 452 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVpkamdZUzhhaDdtTTZLP2U9anA4VWN2/root/content', 453 | 'model_name': 'efficientnet_v2_b3_21k_ft1k-3da5874c.pth'}, 454 | 's': {'num_repeat': [2, 4, 4, 6, 9, 15], 455 | 'kernel_size': [3, 3, 3, 3, 3, 3], 456 | 'stride': [1, 2, 2, 2, 1, 2], 457 | 'expand_ratio': [1, 4, 4, 4, 6, 6], 458 | 'in_channel': [24, 24, 48, 64, 128, 160], 459 | 'out_channel': [24, 48, 64, 128, 160, 256], 460 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25], 461 | 'conv_type': [1, 1, 1, 0, 0, 0], 462 | 'is_feature_stage': [False, True, True, False, True, True], 463 | 'width_coefficient': 1.0, 464 | 'depth_coefficient': 1.0, 465 | 'train_size': 300, 466 | 'eval_size': 384, 467 | 'dropout': 0.2, 468 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllbFF5VWJOZzd0cmhBbm8/root/content', 469 | 'model_name': 'efficientnet_v2_s_21k_ft1k-dbb43f38.pth'}, 470 | 'm': {'num_repeat': [3, 5, 5, 7, 14, 18, 5], 471 | 'kernel_size': [3, 3, 3, 3, 3, 3, 3], 472 | 'stride': [1, 2, 2, 2, 1, 2, 1], 473 | 'expand_ratio': [1, 4, 4, 4, 6, 6, 6], 474 | 'in_channel': [24, 24, 48, 80, 160, 176, 304], 475 | 'out_channel': [24, 48, 80, 160, 176, 304, 512], 476 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25], 477 | 'conv_type': [1, 1, 1, 0, 0, 0, 0], 478 | 'is_feature_stage': [False, True, True, False, True, False, True], 479 | 'width_coefficient': 1.0, 480 | 'depth_coefficient': 1.0, 481 | 'train_size': 384, 482 | 'eval_size': 480, 483 | 'dropout': 0.3, 484 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllN1ZDazRFb0o1bnlyNUE/root/content', 485 | 'model_name': 'efficientnet_v2_m_21k_ft1k-da8e56c0.pth'}, 486 | 'l': {'num_repeat': [4, 7, 7, 10, 19, 25, 7], 487 | 'kernel_size': [3, 3, 3, 3, 3, 3, 3], 488 | 'stride': [1, 2, 2, 2, 1, 2, 1], 489 | 'expand_ratio': [1, 4, 4, 4, 6, 6, 6], 490 | 'in_channel': [32, 32, 64, 96, 192, 224, 384], 491 | 'out_channel': [32, 64, 96, 192, 224, 384, 640], 492 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25], 493 | 'conv_type': [1, 1, 1, 0, 0, 0, 0], 494 | 'is_feature_stage': [False, True, True, False, True, False, True], 495 | 'feature_stages': [1, 2, 4, 6], 496 | 'width_coefficient': 1.0, 497 | 'depth_coefficient': 1.0, 498 | 'train_size': 384, 499 | 'eval_size': 480, 500 | 'dropout': 0.4, 501 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmcmIyRHEtQTBhUTBhWVE/root/content', 502 | 'model_name': 'efficientnet_v2_l_21k_ft1k-08121eee.pth'}, 503 | 'xl': {'num_repeat': [4, 8, 8, 16, 24, 32, 8], 504 | 'kernel_size': [3, 3, 3, 3, 3, 3, 3], 505 | 'stride': [1, 2, 2, 2, 1, 2, 1], 506 | 'expand_ratio': [1, 4, 4, 4, 6, 6, 6], 507 | 'in_channel': [32, 32, 64, 96, 192, 256, 512], 508 | 'out_channel': [32, 64, 96, 192, 256, 512, 640], 509 | 'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25], 510 | 'conv_type': [1, 1, 1, 0, 0, 0, 0], 511 | 'is_feature_stage': [False, True, True, False, True, False, True], 512 | 'feature_stages': [1, 2, 4, 6], 513 | 'width_coefficient': 1.0, 514 | 'depth_coefficient': 1.0, 515 | 'train_size': 384, 516 | 'eval_size': 512, 517 | 'dropout': 0.4, 518 | 'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmVXQtRHJLa21taUkxWkE/root/content', 519 | 'model_name': 'efficientnet_v2_xl_21k_ft1k-1fcc9744.pth'}} 520 | 521 | def __init__(self, 522 | model_name, 523 | in_channels=3, 524 | n_classes=1000, 525 | tf_style_conv=False, 526 | in_spatial_shape=None, 527 | activation='silu', 528 | activation_kwargs=None, 529 | bias=False, 530 | drop_connect_rate=0.2, 531 | dropout_rate=None, 532 | bn_epsilon=1e-3, 533 | bn_momentum=0.01, 534 | pretrained=False, 535 | progress=False, 536 | ): 537 | super().__init__() 538 | 539 | self.blocks = nn.ModuleList() 540 | self.model_name = model_name 541 | self.cfg = self._models[model_name] 542 | 543 | if tf_style_conv and in_spatial_shape is None: 544 | in_spatial_shape = self.cfg['eval_size'] 545 | 546 | activation_kwargs = {} if activation_kwargs is None else activation_kwargs 547 | dropout_rate = self.cfg['dropout'] if dropout_rate is None else dropout_rate 548 | _input_ch = in_channels 549 | 550 | self.feature_block_ids = [] 551 | 552 | # stem 553 | if tf_style_conv: 554 | self.stem_conv = SamePaddingConv2d( 555 | in_spatial_shape=in_spatial_shape, 556 | in_channels=in_channels, 557 | out_channels=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']), 558 | kernel_size=3, 559 | stride=2, 560 | bias=bias 561 | ) 562 | in_spatial_shape = self.stem_conv.out_spatial_shape 563 | else: 564 | self.stem_conv = nn.Conv2d( 565 | in_channels=in_channels, 566 | out_channels=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']), 567 | kernel_size=3, 568 | stride=2, 569 | padding=1, 570 | bias=bias 571 | ) 572 | 573 | self.stem_bn = nn.BatchNorm2d( 574 | num_features=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']), 575 | eps=bn_epsilon, 576 | momentum=bn_momentum) 577 | 578 | self.stem_act = get_activation(activation, **activation_kwargs) 579 | 580 | drop_connect_rates = self.get_dropconnect_rates(drop_connect_rate) 581 | 582 | stages = zip(*[self.cfg[x] for x in 583 | ['num_repeat', 'kernel_size', 'stride', 'expand_ratio', 'in_channel', 'out_channel', 'se_ratio', 584 | 'conv_type', 'is_feature_stage']]) 585 | 586 | idx = 0 587 | 588 | for stage_args in stages: 589 | (num_repeat, kernel_size, stride, expand_ratio, 590 | in_channels, out_channels, se_ratio, conv_type, is_feature_stage) = stage_args 591 | 592 | in_channels = round_filters( 593 | in_channels, self.cfg['width_coefficient']) 594 | out_channels = round_filters( 595 | out_channels, self.cfg['width_coefficient']) 596 | num_repeat = round_repeats( 597 | num_repeat, self.cfg['depth_coefficient']) 598 | 599 | conv_block = MBConvBlockV2 if conv_type == 0 else FusedMBConvBlockV2 600 | 601 | for _ in range(num_repeat): 602 | se_size = None if se_ratio is None else max(1, int(in_channels * se_ratio)) 603 | _b = conv_block(in_channels=in_channels, 604 | out_channels=out_channels, 605 | kernel_size=kernel_size, 606 | stride=stride, 607 | expansion_factor=expand_ratio, 608 | act_fn=activation, 609 | act_kwargs=activation_kwargs, 610 | bn_epsilon=bn_epsilon, 611 | bn_momentum=bn_momentum, 612 | se_size=se_size, 613 | drop_connect_rate=drop_connect_rates[idx], 614 | bias=bias, 615 | tf_style_conv=tf_style_conv, 616 | in_spatial_shape=in_spatial_shape 617 | ) 618 | self.blocks.append(_b) 619 | idx += 1 620 | if tf_style_conv: 621 | in_spatial_shape = _b.out_spatial_shape 622 | in_channels = out_channels 623 | stride = 1 624 | 625 | if is_feature_stage: 626 | self.feature_block_ids.append(idx - 1) 627 | 628 | head_conv_out_channels = round_filters(1280, self.cfg['width_coefficient']) 629 | 630 | self.head_conv = nn.Conv2d(in_channels=in_channels, 631 | out_channels=head_conv_out_channels, 632 | kernel_size=1, 633 | bias=bias) 634 | self.head_bn = nn.BatchNorm2d(num_features=head_conv_out_channels, 635 | eps=bn_epsilon, 636 | momentum=bn_momentum) 637 | self.head_act = get_activation(activation, **activation_kwargs) 638 | 639 | self.dropout = nn.Dropout(p=dropout_rate) 640 | 641 | self.avpool = nn.AdaptiveAvgPool2d((1, 1)) 642 | self.fc = nn.Linear(head_conv_out_channels, n_classes) 643 | 644 | if pretrained: 645 | self._load_state(_input_ch, n_classes, progress, tf_style_conv) 646 | 647 | return 648 | 649 | def _load_state(self, in_channels, n_classes, progress, tf_style_conv): 650 | state_dict = model_zoo.load_url(self.cfg['weight_url'], 651 | progress=progress, 652 | file_name=self.cfg['model_name']) 653 | 654 | strict = True 655 | 656 | if not tf_style_conv: 657 | state_dict = OrderedDict( 658 | [(k.replace('.conv.', '.'), v) if '.conv.' in k else (k, v) for k, v in state_dict.items()]) 659 | 660 | if in_channels != 3: 661 | if tf_style_conv: 662 | state_dict.pop('stem_conv.conv.weight') 663 | else: 664 | state_dict.pop('stem_conv.weight') 665 | strict = False 666 | 667 | if n_classes != 1000: 668 | state_dict.pop('fc.weight') 669 | state_dict.pop('fc.bias') 670 | strict = False 671 | 672 | self.load_state_dict(state_dict, strict=strict) 673 | print("Model weights loaded successfully.") 674 | 675 | def get_dropconnect_rates(self, drop_connect_rate): 676 | nr = self.cfg['num_repeat'] 677 | dc = self.cfg['depth_coefficient'] 678 | total = sum(round_repeats(nr[i], dc) for i in range(len(nr))) 679 | return [drop_connect_rate * i / total for i in range(total)] 680 | 681 | def get_features(self, x): 682 | x = self.stem_act(self.stem_bn(self.stem_conv(x))) 683 | 684 | features = [] 685 | feat_idx = 0 686 | for block_idx, block in enumerate(self.blocks): 687 | x = block(x) 688 | if block_idx == self.feature_block_ids[feat_idx]: 689 | features.append(x) 690 | feat_idx += 1 691 | 692 | return features 693 | 694 | def forward(self, x): 695 | x = self.stem_act(self.stem_bn(self.stem_conv(x))) 696 | for block in self.blocks: 697 | x = block(x) 698 | x = self.head_act(self.head_bn(self.head_conv(x))) 699 | x = self.dropout(torch.flatten(self.avpool(x), 1)) 700 | x = self.fc(x) 701 | 702 | return x 703 | -------------------------------------------------------------------------------- /imagenet_eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchvision.transforms as transforms\n", 11 | "import torchvision.datasets as datasets\n", 12 | "from sklearn.metrics import accuracy_score\n", 13 | "from PIL import Image\n", 14 | "\n", 15 | "from efficientnet import EfficientNet\n", 16 | "from efficientnet_v2 import EfficientNetV2\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "def eval_model(model, dataloader, device, criterion=None):\n", 26 | " loss_value = []\n", 27 | " y_pred = []\n", 28 | " y_true = []\n", 29 | "\n", 30 | " model.eval()\n", 31 | " with torch.no_grad():\n", 32 | " for xb, yb in dataloader:\n", 33 | " xb, yb = xb.to(device), yb.to(device)\n", 34 | " out = model(xb)\n", 35 | " if out.size(1) == 1:\n", 36 | " # regression\n", 37 | " out = torch.squeeze(out, 1)\n", 38 | "\n", 39 | " if criterion is not None:\n", 40 | " loss = criterion(out, yb)\n", 41 | " loss_value.append(loss.item())\n", 42 | "\n", 43 | " y_pred.append(out.detach().cpu())\n", 44 | " y_true.append(yb.detach().cpu())\n", 45 | "\n", 46 | " if criterion is not None:\n", 47 | " loss_value = sum(loss_value) / len(loss_value)\n", 48 | " return torch.cat(y_pred), torch.cat(y_true), loss_value\n", 49 | " else:\n", 50 | " return torch.cat(y_pred), torch.cat(y_true)\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## EfficientNetV2" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 67 | "\n", 68 | "modelname = 's'\n", 69 | "in_spatial_shape = EfficientNetV2._models[modelname]['eval_size']\n", 70 | "\n", 71 | "# Setting tf_style_conv=True and in_spatial_shape only necessary when evaluating against Imagenet dataset\n", 72 | "# Model names: 'b0, 'b1', 'b2', 'b3', 's', 'm', 'l', 'xl'\n", 73 | "model = EfficientNetV2(modelname,\n", 74 | " tf_style_conv=True,\n", 75 | " in_spatial_shape=in_spatial_shape,\n", 76 | " pretrained=True,\n", 77 | " progress=True)\n", 78 | "model.to(device)\n", 79 | "\n", 80 | "val_trainsforms = transforms.Compose([\n", 81 | " transforms.Resize(in_spatial_shape,\n", 82 | " interpolation=transforms.InterpolationMode.BICUBIC),\n", 83 | " transforms.CenterCrop(in_spatial_shape),\n", 84 | " transforms.ToTensor(),\n", 85 | " transforms.Normalize(mean=0.5,\n", 86 | " std=0.5),\n", 87 | "])\n", 88 | "\n", 89 | "val_dataset = datasets.ImageNet(root=\"/path/to/imagenet/val/subset\", split=\"val\",\n", 90 | " transform=val_trainsforms)\n", 91 | "\n", 92 | "val_loader = torch.utils.data.DataLoader(\n", 93 | " val_dataset,\n", 94 | " batch_size=32, shuffle=False,\n", 95 | " num_workers=2, pin_memory=True)\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "y_pred, y_true = eval_model(model, val_loader, device)\n", 105 | "_, y_pred = torch.max(y_pred, 1)\n", 106 | "\n", 107 | "score = accuracy_score(y_pred, y_true)\n", 108 | "print(\"Accuracy: {:.3%}\".format(score))\n" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "Expected evaluation metric values on ImageNet validation set \n", 116 | "\n", 117 | "EfficientNetV2-b0 - 77.590%
\n", 118 | "EfficientNetV2-b1 - 78.872%
\n", 119 | "EfficientNetV2-b2 - 79.388%
\n", 120 | "EfficientNetV2-b3 - 82.260%
\n", 121 | "EfficientNetV2-S - 84.282%
\n", 122 | "EfficientNetV2-M - 85.596%
\n", 123 | "EfficientNetV2-L - 86.298%
\n", 124 | "EfficientNetV2-XL - 86.414%
" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## EfficientNetV1" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 141 | "\n", 142 | "# EfficientNet model index, i.e. 0 for for EfficientNet-B0\n", 143 | "idx = 0\n", 144 | "model = EfficientNet(idx, pretrained=True, progress=True)\n", 145 | "model.to(device)\n", 146 | "\n", 147 | "val_trainsforms = transforms.Compose([\n", 148 | " transforms.Resize(model.in_spatial_shape[0], interpolation=Image.BICUBIC),\n", 149 | " transforms.CenterCrop(model.in_spatial_shape),\n", 150 | " transforms.ToTensor(),\n", 151 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 152 | " std=[0.229, 0.224, 0.225]),\n", 153 | "])\n", 154 | "\n", 155 | "\n", 156 | "val_dataset = datasets.ImageNet(root=\"path/to/imagenet/dataset\", split=\"val\",\n", 157 | " transform=val_trainsforms)\n", 158 | "\n", 159 | "val_loader = torch.utils.data.DataLoader(\n", 160 | " val_dataset,\n", 161 | " batch_size=32, shuffle=False,\n", 162 | " num_workers=1, pin_memory=True)\n" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "y_pred, y_true = eval_model(model, val_loader, device)\n", 172 | "_, y_pred = torch.max(y_pred, 1)\n", 173 | "\n", 174 | "score = accuracy_score(y_pred, y_true)\n", 175 | "print(\"Accuracy: {:.3%}\".format(score))\n" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "Expected evaluation metric values on ImageNet validation set \n", 183 | "\n", 184 | "EfficientNet-B0 - 76.43%
\n", 185 | "EfficientNet-B1 - 78.396%
\n", 186 | "EfficientNet-B2 - 79.804%
\n", 187 | "EfficientNet-B3 - 81.542%
\n", 188 | "EfficientNet-B4 - 83.036%
\n", 189 | "EfficientNet-B5 - 83.79%
\n", 190 | "EfficientNet-B6 - 84.136%
\n", 191 | "EfficientNet-B7 - 84.578%
" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "Python 3", 198 | "language": "python", 199 | "name": "python3" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.8.10" 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 4 216 | } 217 | --------------------------------------------------------------------------------