├── .gitignore ├── LICENSE ├── README.md ├── efficientnet.py └── eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Huijun Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientNet.PyTorch 2 | This repository contains a concise, modular, human-friendly **PyTorch** implementation of **[EfficientNet](https://arxiv.org/abs/1905.11946)** with **[Pre-trained Weights](https://drive.google.com/open?id=1C5IhQd8UfvVY32GYhyQAjvY92IzPi7fY)**. 3 | 4 | 5 | ## Dependencies 6 | 7 | - [PyTorch(1.4.1+)](http://pytorch.org) 8 | - [torchstat](https://github.com/Swall0w/torchstat) 9 | - [pytorch_memlab](https://github.com/Stonesjtu/pytorch_memlab) 10 | 11 | 12 | ## Result Details(Val.) 13 | 14 | | *Name* |*# Params*| *# FLOPS* |*Top-1 Acc.*| *Pretrained* | 15 | |:-----------------:|:--------:|:----------:|:----------:|:------------:| 16 | | `efficientnet-b0` | 5.3M | 0.39B | 74.2 | [GoogleDrive](https://drive.google.com/open?id=1GAB04ft47OhmG_AbrQCcYiezJ8o00veX) | 17 | | `efficientnet-b1` | 7.8M | 0.70B | 78.0 | [GoogleDrive](https://drive.google.com/open?id=1h_JT21EcPEmy7eNgnbI4-ORwRVJNxGsu) | 18 | | `efficientnet-b2` | 9.2M | 1.0B | 81.8 | [GoogleDrive](https://drive.google.com/open?id=1CapQmg4Yvrdzzi3XaJkjWVFHOPT74Zat) | 19 | | `efficientnet-b3` | 12M | 1.8B | 82.7 | [GoogleDrive](https://drive.google.com/open?id=1pJwZcIDBg236uWkYcqjT-WMqaFUWTI_U) | 20 | | `efficientnet-b4` | 19M | 4.2B | 84.6 | [GoogleDrive](https://drive.google.com/open?id=1uHUfuxwz99t3YhSGLzr44Sty6yUbukQW) | 21 | | `efficientnet-b5` | 30M | 9.9B | 86.1 | [GoogleDrive](https://drive.google.com/open?id=1G6B1rYedovUyG9tNwqS2BBeQE4v5S_A_) | 22 | | `efficientnet-b6` | 43M | 19B | 86.0 | [GoogleDrive](https://drive.google.com/open?id=1py6oQlFvwh7wRf6fqIj835jdhXDXOF2Y) | 23 | | `efficientnet-b7` | 66M | 37B | 86.7 | [GoogleDrive](https://drive.google.com/open?id=1duvROO9nVSnO9FC6u0EXQqcNpiG-WKbB) | 24 | -------------------------------------------------------------------------------- /efficientnet.py: -------------------------------------------------------------------------------- 1 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 2 | # LightNet++: Boosted Light-weighted Networks for Real-time Semantic Segmentation 3 | # ---------------------------------------------------------------------------------------------------------------- # 4 | # PyTorch implementation for EfficientNet 5 | # class: 6 | # > Swish 7 | # > SEBlock 8 | # > MBConvBlock 9 | # > EfficientNet 10 | # ---------------------------------------------------------------------------------------------------------------- # 11 | # Author: Huijun Liu M.Sc. 12 | # Date: 08.02.2020 13 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 14 | from torch.nn import functional as F 15 | from collections import OrderedDict 16 | 17 | from torch import nn 18 | import torch 19 | import math 20 | 21 | 22 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 23 | # Swish: Swish Activation Function 24 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 25 | class Swish(nn.Module): 26 | def __init__(self, inplace=True): 27 | super(Swish, self).__init__() 28 | self.inplace = inplace 29 | 30 | def forward(self, x): 31 | return x.mul_(x.sigmoid()) if self.inplace else x.mul(x.sigmoid()) 32 | 33 | 34 | class ConvBlock(nn.Module): 35 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, 36 | groups=1, dilate=1): 37 | 38 | super(ConvBlock, self).__init__() 39 | dilate = 1 if stride > 1 else dilate 40 | padding = ((kernel_size - 1) // 2) * dilate 41 | 42 | self.conv_block = nn.Sequential(OrderedDict([ 43 | ("conv", nn.Conv2d(in_channels=in_planes, out_channels=out_planes, 44 | kernel_size=kernel_size, stride=stride, padding=padding, 45 | dilation=dilate, groups=groups, bias=False)), 46 | ("norm", nn.BatchNorm2d(num_features=out_planes, 47 | eps=1e-3, momentum=0.01)), 48 | ("act", Swish(inplace=True)) 49 | ])) 50 | 51 | def forward(self, x): 52 | return self.conv_block(x) 53 | 54 | 55 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 56 | # SEBlock: Squeeze & Excitation (SCSE) 57 | # namely, Channel-wise Attention 58 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 59 | class SEBlock(nn.Module): 60 | def __init__(self, in_planes, reduced_dim): 61 | super(SEBlock, self).__init__() 62 | self.channel_se = nn.Sequential(OrderedDict([ 63 | ("linear1", nn.Conv2d(in_planes, reduced_dim, kernel_size=1, stride=1, padding=0, bias=True)), 64 | ("act", Swish(inplace=True)), 65 | ("linear2", nn.Conv2d(reduced_dim, in_planes, kernel_size=1, stride=1, padding=0, bias=True)) 66 | ])) 67 | 68 | def forward(self, x): 69 | x_se = torch.sigmoid(self.channel_se(F.adaptive_avg_pool2d(x, output_size=(1, 1)))) 70 | return torch.mul(x, x_se) 71 | 72 | 73 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 74 | # MBConvBlock: MBConvBlock for EfficientNet 75 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 76 | class MBConvBlock(nn.Module): 77 | def __init__(self, in_planes, out_planes, 78 | expand_ratio, kernel_size, stride, dilate, 79 | reduction_ratio=4, dropout_rate=0.2): 80 | super(MBConvBlock, self).__init__() 81 | self.dropout_rate = dropout_rate 82 | self.expand_ratio = expand_ratio 83 | self.use_se = (reduction_ratio is not None) and (reduction_ratio > 1) 84 | self.use_residual = in_planes == out_planes and stride == 1 85 | 86 | assert stride in [1, 2] 87 | assert kernel_size in [3, 5] 88 | dilate = 1 if stride > 1 else dilate 89 | hidden_dim = in_planes * expand_ratio 90 | reduced_dim = max(1, int(in_planes / reduction_ratio)) 91 | 92 | # step 1. Expansion phase/Point-wise convolution 93 | if expand_ratio != 1: 94 | self.expansion = ConvBlock(in_planes, hidden_dim, 1) 95 | 96 | # step 2. Depth-wise convolution phase 97 | self.depth_wise = ConvBlock(hidden_dim, hidden_dim, kernel_size, 98 | stride=stride, groups=hidden_dim, dilate=dilate) 99 | # step 3. Squeeze and Excitation 100 | if self.use_se: 101 | self.se_block = SEBlock(hidden_dim, reduced_dim) 102 | 103 | # step 4. Point-wise convolution phase 104 | self.point_wise = nn.Sequential(OrderedDict([ 105 | ("conv", nn.Conv2d(in_channels=hidden_dim, 106 | out_channels=out_planes, kernel_size=1, 107 | stride=1, padding=0, dilation=1, groups=1, bias=False)), 108 | ("norm", nn.BatchNorm2d(out_planes, eps=1e-3, momentum=0.01)) 109 | ])) 110 | 111 | def forward(self, x): 112 | res = x 113 | 114 | # step 1. Expansion phase/Point-wise convolution 115 | if self.expand_ratio != 1: 116 | x = self.expansion(x) 117 | 118 | # step 2. Depth-wise convolution phase 119 | x = self.depth_wise(x) 120 | 121 | # step 3. Squeeze and Excitation 122 | if self.use_se: 123 | x = self.se_block(x) 124 | 125 | # step 4. Point-wise convolution phase 126 | x = self.point_wise(x) 127 | 128 | # step 5. Skip connection and drop connect 129 | if self.use_residual: 130 | if self.training and (self.dropout_rate is not None): 131 | x = F.dropout2d(input=x, p=self.dropout_rate, 132 | training=self.training, inplace=True) 133 | x = x + res 134 | 135 | return x 136 | 137 | 138 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 139 | # EfficientNet: EfficientNet Implementation 140 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # 141 | class EfficientNet(nn.Module): 142 | def __init__(self, arch="bo", num_classes=1000): 143 | super(EfficientNet, self).__init__() 144 | 145 | arch_params = { 146 | # arch width_multi depth_multi input_h dropout_rate 147 | 'b0': (1.0, 1.0, 224, 0.2), 148 | 'b1': (1.0, 1.1, 240, 0.2), 149 | 'b2': (1.1, 1.2, 260, 0.3), 150 | 'b3': (1.2, 1.4, 300, 0.3), 151 | 'b4': (1.4, 1.8, 380, 0.4), 152 | 'b5': (1.6, 2.2, 456, 0.4), 153 | 'b6': (1.8, 2.6, 528, 0.5), 154 | 'b7': (2.0, 3.1, 600, 0.5), 155 | } 156 | width_multi, depth_multi, net_h, dropout_rate = arch_params[arch] 157 | 158 | settings = [ 159 | # t, c, n, k, s, d 160 | [1, 16, 1, 3, 1, 1], # 3x3, 112 -> 112 161 | [6, 24, 2, 3, 2, 1], # 3x3, 112 -> 56 162 | [6, 40, 2, 5, 2, 1], # 5x5, 56 -> 28 163 | [6, 80, 3, 3, 2, 1], # 3x3, 28 -> 14 164 | [6, 112, 3, 5, 1, 1], # 5x5, 14 -> 14 165 | [6, 192, 4, 5, 2, 1], # 5x5, 14 -> 7 166 | [6, 320, 1, 3, 1, 1], # 3x3, 7 -> 7 167 | ] 168 | self.dropout_rate = dropout_rate 169 | out_channels = self._round_filters(32, width_multi) 170 | self.mod1 = ConvBlock(3, out_channels, kernel_size=3, stride=2, groups=1, dilate=1) 171 | 172 | in_channels = out_channels 173 | drop_rate = self.dropout_rate 174 | mod_id = 0 175 | for t, c, n, k, s, d in settings: 176 | out_channels = self._round_filters(c, width_multi) 177 | repeats = self._round_repeats(n, depth_multi) 178 | 179 | if self.dropout_rate: 180 | drop_rate = self.dropout_rate * float(mod_id+1) / len(settings) 181 | 182 | # Create blocks for module 183 | blocks = [] 184 | for block_id in range(repeats): 185 | stride = s if block_id == 0 else 1 186 | dilate = d if stride == 1 else 1 187 | 188 | blocks.append(("block%d" % (block_id + 1), MBConvBlock(in_channels, out_channels, 189 | expand_ratio=t, kernel_size=k, 190 | stride=stride, dilate=dilate, 191 | dropout_rate=drop_rate))) 192 | 193 | in_channels = out_channels 194 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 195 | mod_id += 1 196 | 197 | self.last_channels = self._round_filters(1280, width_multi) 198 | self.last_feat = ConvBlock(in_channels, self.last_channels, 1) 199 | 200 | self.classifier = nn.Linear(self.last_channels, num_classes) 201 | 202 | self._initialize_weights() 203 | 204 | def _initialize_weights(self): 205 | # weight initialization 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 209 | if m.bias is not None: 210 | nn.init.zeros_(m.bias) 211 | elif isinstance(m, nn.BatchNorm2d): 212 | nn.init.ones_(m.weight) 213 | nn.init.zeros_(m.bias) 214 | elif isinstance(m, nn.Linear): 215 | fan_out = m.weight.size(0) 216 | init_range = 1.0 / math.sqrt(fan_out) 217 | nn.init.uniform_(m.weight, -init_range, init_range) 218 | if m.bias is not None: 219 | nn.init.zeros_(m.bias) 220 | 221 | @staticmethod 222 | def _make_divisible(value, divisor=8): 223 | new_value = max(divisor, int(value + divisor / 2) // divisor * divisor) 224 | if new_value < 0.9 * value: 225 | new_value += divisor 226 | return new_value 227 | 228 | def _round_filters(self, filters, width_multi): 229 | if width_multi == 1.0: 230 | return filters 231 | return int(self._make_divisible(filters * width_multi)) 232 | 233 | @staticmethod 234 | def _round_repeats(repeats, depth_multi): 235 | if depth_multi == 1.0: 236 | return repeats 237 | return int(math.ceil(depth_multi * repeats)) 238 | 239 | def forward(self, x): 240 | x = self.mod2(self.mod1(x)) # (N, 16, H/2, W/2) 241 | x = self.mod3(x) # (N, 24, H/4, W/4) 242 | x = self.mod4(x) # (N, 32, H/8, W/8) 243 | x = self.mod6(self.mod5(x)) # (N, 96, H/16, W/16) 244 | x = self.mod8(self.mod7(x)) # (N, 320, H/32, W/32) 245 | x = self.last_feat(x) 246 | 247 | x = F.adaptive_avg_pool2d(x, (1, 1)).view(-1, self.last_channels) 248 | if self.training and (self.dropout_rate is not None): 249 | x = F.dropout(input=x, p=self.dropout_rate, 250 | training=self.training, inplace=True) 251 | x = self.classifier(x) 252 | return x 253 | 254 | 255 | if __name__ == "__main__": 256 | import os 257 | import time 258 | from torchstat import stat 259 | from pytorch_memlab import MemReporter 260 | 261 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 262 | 263 | arch = "b6" 264 | img_preparam = {"b0": (224, 0.875), 265 | "b1": (240, 0.882), 266 | "b2": (260, 0.890), 267 | "b3": (300, 0.904), 268 | "b4": (380, 0.922), 269 | "b5": (456, 0.934), 270 | "b6": (528, 0.942), 271 | "b7": (600, 0.949)} 272 | net_h = img_preparam[arch][0] 273 | model = EfficientNet(arch=arch, num_classes=1000) 274 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, 275 | momentum=0.90, weight_decay=1.0e-4, nesterov=True) 276 | 277 | # stat(model, (3, net_h, net_h)) 278 | 279 | model = model.cuda().train() 280 | loss_func = nn.CrossEntropyLoss().cuda() 281 | dummy_in = torch.randn(2, 3, net_h, net_h).cuda().requires_grad_() 282 | dummy_target = torch.ones(2).cuda().long().cuda() 283 | reporter = MemReporter(model) 284 | 285 | optimizer.zero_grad() 286 | dummy_out = model(dummy_in) 287 | loss = loss_func(dummy_out, dummy_target) 288 | print('========================================== before backward ===========================================') 289 | reporter.report() 290 | 291 | loss.backward() 292 | optimizer.step() 293 | print('========================================== after backward =============================================') 294 | reporter.report() 295 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import torch.utils.data 5 | import torchvision.datasets as datasets 6 | import torchvision.transforms as transforms 7 | 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | 15 | def __init__(self): 16 | self.val = 0.0 17 | self.avg = 0.0 18 | self.sum = 0.0 19 | self.count = 0.0 20 | 21 | def reset(self): 22 | self.val = 0.0 23 | self.avg = 0.0 24 | self.sum = 0.0 25 | self.count = 0.0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | 34 | def accuracy(output, target, topk=(1,)): 35 | """Computes the precision@k for the specified values of k""" 36 | maxk = max(topk) 37 | batch_size = target.size(0) 38 | 39 | _, pred = output.topk(maxk, 1, True, True) 40 | pred = pred.t() 41 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 42 | 43 | res = [] 44 | for k in topk: 45 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 46 | res.append(correct_k.mul_(100.0 / batch_size)) 47 | return res 48 | 49 | 50 | if __name__ == "__main__": 51 | from efficientnet import EfficientNet 52 | data_root = "/home/liuhuijun/Datasets/ImageNet" 53 | val_dir = os.path.join(data_root, 'val') 54 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 55 | 56 | arch = "b7" 57 | img_preparam = {"b0": (224, 0.875), 58 | "b1": (240, 0.882), 59 | "b2": (260, 0.890), 60 | "b3": (300, 0.904), 61 | "b4": (380, 0.922), 62 | "b5": (456, 0.934), 63 | "b6": (528, 0.942), 64 | "b7": (600, 0.949)} 65 | valid_dataset = datasets.ImageFolder(val_dir, transforms.Compose([transforms.Resize(int(img_preparam[arch][0] / img_preparam[arch][1]), Image.BICUBIC), 66 | transforms.CenterCrop(img_preparam[arch][0]), 67 | transforms.ToTensor(), 68 | normalize])) 69 | 70 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, shuffle=False, 71 | num_workers=16, pin_memory=False) 72 | num_batches = int(math.ceil(len(valid_loader.dataset) / float(valid_loader.batch_size))) 73 | 74 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 75 | model = EfficientNet(arch=arch, num_classes=1000).cuda() 76 | used_gpus = [idx for idx in range(torch.cuda.device_count())] 77 | model = torch.nn.DataParallel(model, device_ids=used_gpus).cuda() 78 | 79 | checkpoint = torch.load("/home/liuhuijun/TrainLog/release/imagenet/efficientnet_{}_top1v_86.7.pkl".format(arch)) 80 | pre_weight = checkpoint['model_state'] 81 | model_dict = model.state_dict() 82 | pretrained_dict = {"module." + k: v for k, v in pre_weight.items() if "module." + k in model_dict} 83 | model_dict.update(pretrained_dict) 84 | model.load_state_dict(model_dict) 85 | 86 | model.eval() 87 | 88 | top1 = AverageMeter() 89 | top5 = AverageMeter() 90 | with torch.no_grad(): 91 | pbar = tqdm(np.arange(num_batches)) 92 | for i_val, (images, labels) in enumerate(valid_loader): 93 | 94 | images = images.cuda() 95 | labels = torch.squeeze(labels.cuda()) 96 | 97 | net_out = model(images) 98 | 99 | prec1, prec5 = accuracy(net_out, labels, topk=(1, 5)) 100 | top1.update(prec1.item(), images.size(0)) 101 | top5.update(prec5.item(), images.size(0)) 102 | 103 | pbar.update(1) 104 | pbar.set_description("> Eval") 105 | pbar.set_postfix(Top1=top1.avg, Top5=top5.avg) 106 | pbar.set_postfix(Top1=top1.avg, Top5=top5.avg) 107 | pbar.update(1) 108 | pbar.close() 109 | --------------------------------------------------------------------------------