├── test ├── __init__.py ├── test_oct_resnet.py ├── test_mobilenet.py └── oct_test.py ├── module ├── __init__.py ├── layers.py ├── dropblock.py └── octconv.py ├── .gitignore ├── models ├── __init__.py ├── base.py ├── resnest.py ├── proxyless_nas.py ├── lamdba_net.py ├── efficientnet.py ├── fairnet.py ├── regnet.py ├── oct_resnet.py ├── evo_norm.py ├── mobilenet.py ├── oct_resnet_re.py ├── ghostnet.py └── resnet.py ├── LICENSE ├── scripts ├── generate_LMDB_dataset.py ├── test_script.py ├── utils.py ├── train_script.py ├── train_sample.py └── distribute_train_script.py └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- 1 | from .octconv import * 2 | from .layers import * 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | test/ 2 | .idea/ 3 | __pycache__/ 4 | models/__pycache__/ 5 | module/__pycache__/ 6 | params/ 7 | test.py 8 | *.log 9 | *.lock 10 | 11 | -------------------------------------------------------------------------------- /test/test_oct_resnet.py: -------------------------------------------------------------------------------- 1 | from models.oct_resnet import * 2 | from torchtoolbox.tools.summary import summary 3 | import torch 4 | 5 | model = oct_resnet50v2(0) 6 | dt = torch.randn(1, 3, 224, 224) 7 | 8 | print(model) 9 | # out = model(dt) 10 | # summary(model, dt) 11 | # print(out.size()) 12 | -------------------------------------------------------------------------------- /test/test_mobilenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from models import MobileNetV3_Large, MobileNetV3_Small, MobileNetV1, MobileNetV2 3 | from torchtoolbox.tools import summary 4 | import torch 5 | 6 | model = MobileNetV3_Large() 7 | # model = MobileNetV1() 8 | # model = MobileNetV2() 9 | # model = MobileNetV3_Small() 10 | summary(model, torch.rand(1, 3, 224, 224)) 11 | -------------------------------------------------------------------------------- /test/oct_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from module.octconv import * 3 | 4 | # test first layer 5 | data1 = torch.randn(1, 64, 64, 64) 6 | fo = OctaveConv(64, 64, 0, 0.125, 3, 1, 2) 7 | out = fo(data1) 8 | print(out[0].size(), out[1].size()) 9 | 10 | # test oct layer 11 | oo = OctaveConv(64, 128, 0.125, 0.125, 3, 1, groups=128) 12 | out = oo(out[0], out[1]) 13 | print(out[0].size(), out[1].size()) 14 | 15 | # test last layer 16 | ol = OctaveConv(128, 128, 0.125, 0, 3, 1) 17 | out = ol(out[0], out[1]) 18 | print(out[0].size()) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | from .oct_resnet_re import * 5 | from .resnet import * 6 | from .fairnet import * 7 | from .mobilenet import * 8 | from .proxyless_nas import * 9 | from .efficientnet import * 10 | from .evo_norm import * 11 | from .resnest import * 12 | from .regnet import * 13 | from .ghostnet import * 14 | from .lamdba_net import * 15 | from torchvision.models.alexnet import * 16 | from torchvision.models.densenet import * 17 | from torchvision.models.googlenet import * 18 | from torchvision.models.inception import * 19 | from torchvision.models.shufflenetv2 import * 20 | from torchvision.models.vgg import * 21 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['SELayer'] 4 | 5 | from torch import nn 6 | from torchtoolbox.nn import Activation 7 | 8 | 9 | class SELayer(nn.Module): 10 | def __init__(self, in_c, reducation_c, act='relu'): 11 | super(SELayer, self).__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.lin1 = nn.Linear(in_c, reducation_c) 14 | self.act = Activation(act, auto_optimize=True) 15 | self.lin2 = nn.Linear(reducation_c, in_c) 16 | 17 | def forward(self, x): 18 | out = self.avg_pool(x) 19 | out = out.view(out.size(0), -1) 20 | out = self.lin1(out) 21 | out = self.act(out) 22 | out = self.lin2(out) 23 | out = x * out 24 | return out 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 X.Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/generate_LMDB_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | import argparse 4 | from torchvision.datasets import ImageNet 5 | from torchtoolbox.tools import check_dir 6 | from torchtoolbox.tools.convert_lmdb import generate_lmdb_dataset, raw_reader 7 | 8 | parser = argparse.ArgumentParser(description='Convert a ImageFolder dataset to LMDB format.') 9 | parser.add_argument('--data-dir', type=str, required=True, 10 | help='ImageFolder path, this param will give to ImageFolder Dataset.') 11 | parser.add_argument('--save-dir', type=str, required=True, 12 | help='Save dir.') 13 | parser.add_argument('--download', action='store_true', help='download dataset.') 14 | parser.add_argument('-j', dest='num_workers', type=int, default=0) 15 | parser.add_argument('--write-frequency', type=int, default=5000) 16 | parser.add_argument('--max-size', type=float, default=1.0, 17 | help='Maximum size database, this is rate, default is 1T, final setting would be ' 18 | '1T * `this param`') 19 | 20 | args = parser.parse_args() 21 | check_dir(args.save_dir) 22 | train_data_set = ImageNet(args.data_dir, 'train', args.download, loader=raw_reader) 23 | val_data_set = ImageNet(args.data_dir, 'val', args.download, loader=raw_reader) 24 | 25 | if __name__ == '__main__': 26 | generate_lmdb_dataset(train_data_set, args.save_dir, 'train', args.num_workers, 27 | args.max_size, args.write_frequency) 28 | generate_lmdb_dataset(val_data_set, args.save_dir, 'val', args.num_workers, 29 | args.max_size, args.write_frequency) 30 | -------------------------------------------------------------------------------- /module/layers.py: -------------------------------------------------------------------------------- 1 | __all__ = ['fs_bn', 'fs_relu'] 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class fs_bn(nn.Module): 7 | def __init__(self, channels, alpha): 8 | super().__init__() 9 | h_out = int((1 - alpha) * channels) 10 | l_out = int(alpha * channels) 11 | 12 | self.h_bn = nn.BatchNorm2d(h_out) 13 | self.l_bn = nn.BatchNorm2d(l_out) if alpha != 0 else None 14 | 15 | def forward(self, x_h, x_l=None): 16 | y_h = self.h_bn(x_h) 17 | y_l = self.l_bn(x_l) if x_l is not None else None 18 | return y_h, y_l 19 | 20 | 21 | class fs_relu(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x_h, x_l=None): 27 | y_h = self.relu(x_h) 28 | y_l = self.relu(x_l) if x_l is not None else None 29 | return y_h, y_l 30 | 31 | 32 | class se(nn.Module): 33 | def __init__(self, channels, reduction=4): 34 | super().__init__() 35 | self.se = nn.Sequential( 36 | nn.AdaptiveAvgPool2d(1), 37 | nn.Conv2d(channels, channels // reduction, 1), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(channels // reduction, channels, 1), 40 | nn.Sigmoid(), 41 | ) 42 | 43 | def forward(self, input): 44 | y = self.se(input) 45 | return input * y 46 | 47 | 48 | class cbam(nn.Module): 49 | def __init__(self, channel, reduction=4, k=3): 50 | super().__init__() 51 | self.avgpool = nn.AdaptiveAvgPool2d(1) 52 | self.maxpool = nn.AdaptiveMaxPool2d(1) 53 | self.cat = nn.Sequential( 54 | nn.Conv2d(channel, channel // reduction, 1), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(channel // reduction, channel, 1), 57 | ) 58 | assert k in (3, 7) 59 | padding = 3 if k == 7 else 1 60 | self.sat = nn.Sequential( 61 | nn.Conv2d(2, 1, k, 1, padding, bias=False), 62 | nn.BatchNorm2d(1), 63 | ) 64 | self.sigmoid = nn.Sigmoid() 65 | 66 | def forward(self, x): 67 | # ChannelAttention 68 | avg_out = self.cat(self.avgpool(x)) 69 | max_out = self.cat(self.maxpool(x)) 70 | out = self.sigmoid(avg_out + max_out) 71 | x = x * out 72 | # SpatialAttention 73 | avg_out = torch.mean(x, dim=1, keepdim=True) 74 | max_out = torch.max(x, dim=1, keepdim=True) 75 | out = torch.cat([avg_out, max_out], dim=1) 76 | out = self.sat(out) 77 | out = self.sigmoid(out) 78 | return x * out 79 | -------------------------------------------------------------------------------- /module/dropblock.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | def drop_block(x, mask): 6 | return x * mask * mask.numel() / mask.sum() 7 | 8 | 9 | class DropBlock2d(nn.Module): 10 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 11 | As described in the paper 12 | `DropBlock: A regularization method for convolutional networks`_ , 13 | dropping whole blocks of feature map allows to remove semantic 14 | information as compared to regular dropout. 15 | Args: 16 | p (float): probability of an element to be dropped. 17 | block_size (int): size of the block to drop 18 | Shape: 19 | - Input: `(N, C, H, W)` 20 | - Output: `(N, C, H, W)` 21 | .. _DropBlock: A regularization method for convolutional networks: 22 | https://arxiv.org/abs/1810.12890 23 | """ 24 | 25 | def __init__(self, p=0.1, block_size=7): 26 | super(DropBlock2d, self).__init__() 27 | assert 0 <= p <= 1 28 | self.p = p 29 | self.block_size = block_size 30 | 31 | def forward(self, x): 32 | if not self.training or self.p == 0: 33 | return x 34 | _, _, h, w = x.size() 35 | gamma = self.get_gamma(h, w) 36 | mask = self.get_mask(x, gamma) 37 | y = drop_block(x, mask) 38 | return y 39 | 40 | @torch.no_grad() 41 | def get_mask(self, x, gamma): 42 | mask = torch.bernoulli(torch.ones_like(x.sum(dim=0, keepdim=True)) * gamma) 43 | mask = 1 - torch.max_pool2d(mask, kernel_size=self.block_size, stride=1, padding=self.block_size // 2) 44 | return mask 45 | 46 | def get_gamma(self, h, w): 47 | return self.p * (h * w) / (self.block_size ** 2) / \ 48 | ((w - self.block_size + 1) * (h * self.block_size + 1)) 49 | 50 | 51 | class DropBlockScheduler(object): 52 | def __init__(self, model, batches: int, num_epochs: int, start_value=0.1, stop_value=1.): 53 | self.model = model 54 | self.iter = 0 55 | self.start_value = start_value 56 | self.num_iter = batches * num_epochs 57 | self.st_line = (stop_value - start_value) / self.num_iter 58 | self.groups = [] 59 | self.value = start_value 60 | 61 | def coll_dbs(md): 62 | if hasattr(md, 'block_size'): 63 | self.groups.append(md) 64 | 65 | model.apply(coll_dbs) 66 | 67 | def update_values(self, value): 68 | for db in self.groups: 69 | db.p = value 70 | 71 | def load_state_dict(self, state_dict): 72 | """Loads the schedulers state. 73 | 74 | Arguments: 75 | state_dict (dict): scheduler state. Should be an object returned 76 | from a call to :meth:`state_dict`. 77 | """ 78 | self.__dict__.update(state_dict) 79 | 80 | def get_value(self): 81 | self.value = self.st_line * self.iter + self.start_value 82 | 83 | def state_dict(self): 84 | return { 85 | key: value for key, 86 | value in self.__dict__.items() if (key != 'model' and key != 'groups')} 87 | 88 | def step(self): 89 | self.get_value() 90 | self.update_values(self.value) 91 | self.iter += 1 92 | -------------------------------------------------------------------------------- /scripts/test_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import argparse 5 | import os 6 | import models 7 | import torch 8 | 9 | from torchtoolbox import metric 10 | from torchtoolbox.nn import SwitchNorm2d, Swish 11 | from torchtoolbox.data import ImageLMDB 12 | from torchtoolbox.tools import summary 13 | 14 | from torchvision import transforms 15 | from torchvision.datasets import ImageNet 16 | from torch.utils.data import DataLoader 17 | from torch import nn 18 | from tqdm import tqdm 19 | 20 | parser = argparse.ArgumentParser(description='Train a model on ImageNet.') 21 | parser.add_argument('--data-path', type=str, required=True, 22 | help='training and validation dataset.') 23 | parser.add_argument('--use-lmdb', action='store_true', 24 | help='use LMDB dataset/format') 25 | parser.add_argument('--batch-size', type=int, default=32, 26 | help='training batch size per device (CPU/GPU).') 27 | parser.add_argument('--dtype', type=str, default='float32', 28 | help='data type for training. default is float32') 29 | parser.add_argument('--devices', type=str, default='0', 30 | help='gpus to use.') 31 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 32 | help='number of preprocessing workers') 33 | parser.add_argument('--model', type=str, required=True, 34 | help='type of model to use. see vision_model for options.') 35 | parser.add_argument('--alpha', type=float, default=0, 36 | help='model param.') 37 | parser.add_argument('--input-size', type=int, default=224, 38 | help='size of the input image size. default is 224') 39 | parser.add_argument('--norm-layer', type=str, default='', 40 | help='Norm layer to use.') 41 | parser.add_argument('--activation', type=str, default='', 42 | help='activation to use.') 43 | parser.add_argument('--param-path', type=str, default='', 44 | help='param used to test.') 45 | 46 | args = parser.parse_args() 47 | 48 | 49 | def get_model(name, **kwargs): 50 | return models.__dict__[name](**kwargs) 51 | 52 | 53 | def set_model(drop_out, norm_layer, act): 54 | setting = {} 55 | if drop_out != 0: 56 | setting['dropout_rate'] = args.dropout 57 | if norm_layer != '': 58 | if args.norm_layer == 'switch': 59 | setting['norm_layer'] = SwitchNorm2d 60 | else: 61 | raise NotImplementedError 62 | if act != '': 63 | if args.activation == 'swish': 64 | setting['activation'] = Swish() 65 | elif args.activation == 'relu6': 66 | setting['activation'] = nn.ReLU6(inplace=True) 67 | else: 68 | raise NotImplementedError 69 | return setting 70 | 71 | 72 | classes = 1000 73 | num_training_samples = 1281167 74 | 75 | assert torch.cuda.is_available() 76 | device = torch.device("cuda:0") 77 | device_ids = args.devices.strip().split(',') 78 | device_ids = [int(device) for device in device_ids] 79 | 80 | dtype = args.dtype 81 | num_workers = args.num_workers 82 | batch_size = args.batch_size * len(device_ids) 83 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 84 | std=[0.229, 0.224, 0.225]) 85 | 86 | val_transform = transforms.Compose([ 87 | transforms.Resize(256), 88 | transforms.CenterCrop(224), 89 | transforms.ToTensor(), 90 | normalize, 91 | ]) 92 | 93 | if not args.use_lmdb: 94 | val_set = ImageNet(args.data_path, split='val', transform=val_transform) 95 | else: 96 | val_set = ImageLMDB(os.path.join(args.data_path, 'val.lmdb'), transform=val_transform) 97 | 98 | val_data = DataLoader(val_set, batch_size, False, pin_memory=True, num_workers=num_workers, drop_last=False) 99 | 100 | model_setting = set_model(0, args.norm_layer, args.activation) 101 | 102 | try: 103 | model = get_model(args.model, alpha=args.alpha, **model_setting) 104 | except TypeError: 105 | model = get_model(args.model, **model_setting) 106 | 107 | summary(model, torch.rand((1, 3, 224, 224))) 108 | 109 | model.to(device) 110 | model = nn.DataParallel(model) 111 | 112 | checkpoint = torch.load(args.param_path, map_location=device) 113 | model.load_state_dict(checkpoint['model']) 114 | print("Finish loading resume param.") 115 | 116 | top1_acc = metric.Accuracy(name='Top1 Accuracy') 117 | top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy') 118 | loss_record = metric.NumericalCost(name='Loss') 119 | 120 | Loss = nn.CrossEntropyLoss() 121 | 122 | 123 | @torch.no_grad() 124 | def test(): 125 | top1_acc.reset() 126 | top5_acc.reset() 127 | loss_record.reset() 128 | model.eval() 129 | for data, labels in tqdm(val_data): 130 | data = data.to(device, non_blocking=True) 131 | labels = labels.to(device, non_blocking=True) 132 | 133 | outputs = model(data) 134 | losses = Loss(outputs, labels) 135 | 136 | top1_acc.update(outputs, labels) 137 | top5_acc.update(outputs, labels) 138 | loss_record.update(losses) 139 | 140 | test_msg = 'Test: {}:{:.5}, {}:{:.5}, {}:{:.5}\n'.format( 141 | top1_acc.name, top1_acc.get(), top5_acc.name, top5_acc.get(), 142 | loss_record.name, loss_record.get()) 143 | print(test_msg) 144 | 145 | 146 | if __name__ == '__main__': 147 | test() 148 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | from cloghandler import ConcurrentRotatingFileHandler 5 | from torchtoolbox.nn import SwitchNorm2d 6 | from torch import nn 7 | import logging 8 | 9 | 10 | # init DALI 11 | # try: 12 | # import nvidia.dali as dali 13 | # from nvidia.dali.plugin.pytorch import DALIClassificationIterator 14 | # except ImportError: 15 | # print('DALI is not available') 16 | 17 | 18 | def get_model(models, name, **kwargs) -> nn.Module: 19 | return models.__dict__[name](**kwargs) 20 | 21 | 22 | def set_model(drop_out, norm_layer, act): 23 | setting = {} 24 | if drop_out != 0: 25 | setting['dropout_rate'] = drop_out 26 | if norm_layer != '': 27 | if norm_layer == 'switch': 28 | setting['norm_layer'] = SwitchNorm2d 29 | else: 30 | raise NotImplementedError 31 | if act != '': 32 | setting['activation'] = act 33 | 34 | return setting 35 | 36 | 37 | def get_logger(file_path): 38 | filehandler = ConcurrentRotatingFileHandler(file_path) 39 | streamhandler = logging.StreamHandler() 40 | 41 | logger = logging.getLogger('Distribute training logs.') 42 | logger.setLevel(logging.INFO) 43 | logger.addHandler(filehandler) 44 | logger.addHandler(streamhandler) 45 | return logger 46 | 47 | # class TrainPipe(dali.pipeline.Pipeline): 48 | # def __init__(self, data_dir, batch_size, num_threads, device_id, crop, color_jit=0.4, use_cpu=False): 49 | # super(TrainPipe, self).__init__(batch_size, num_threads, device_id) 50 | # dali_device = 'cpu' if use_cpu else 'gpu' 51 | # decoder_device = 'cpu' if use_cpu else 'mixed' 52 | # 53 | # device_memory_padding = 211025920 if decoder_device == 'mixed' else 0 54 | # host_memory_padding = 140544512 if decoder_device == 'mixed' else 0 55 | # 56 | # self.input = dali.ops.FileReader(file_root=data_dir, shard_id=device_id, num_shards=1, 57 | # shuffle_after_epoch=True) 58 | # 59 | # self.decode = dali.ops.ImageDecoderRandomCrop(device=decoder_device, output_type=dali.types.RGB, 60 | # device_memory_padding=device_memory_padding, 61 | # host_memory_padding=host_memory_padding, 62 | # num_attempts=100) 63 | # 64 | # self.res = dali.ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, 65 | # interp_type=dali.types.INTERP_TRIANGULAR) 66 | # 67 | # self.bri = dali.ops.Brightness(device=dali_device) 68 | # self.con = dali.ops.Contrast(device=dali_device) 69 | # self.sat = dali.ops.Saturation(device=dali_device) 70 | # 71 | # self.cmnp = dali.ops.CropMirrorNormalize(device=dali_device, output_dtype=dali.types.FLOAT, 72 | # output_layout=dali.types.NCHW, 73 | # crop=(crop, crop), image_type=dali.types.RGB, 74 | # mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 75 | # std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 76 | # 77 | # self.coin = dali.ops.CoinFlip(probability=0.5) 78 | # self.uniform = dali.ops.Uniform(range=(max(0., 1 - color_jit), 1 + color_jit)) 79 | # 80 | # def define_graph(self): 81 | # imgs, labels = self.input(name='Reader') 82 | # imgs = self.decode(imgs) 83 | # imgs = self.res(imgs) 84 | # imgs = self.bri(imgs, brightness=self.uniform()) 85 | # imgs = self.con(imgs, contrast=self.uniform()) 86 | # imgs = self.sat(imgs, saturation=self.uniform()) 87 | # imgs = self.cmnp(imgs, mirror=self.coin()) 88 | # return imgs.gpu(), labels.gpu() 89 | # 90 | # 91 | # class ValPipe(dali.pipeline.Pipeline): 92 | # def __init__(self, data_dir, batch_size, num_threads, device_id, resize, crop, use_cpu=False): 93 | # super(ValPipe, self).__init__(batch_size, num_threads, device_id) 94 | # dali_device = 'cpu' if use_cpu else 'gpu' 95 | # decoder_device = 'cpu' if use_cpu else 'mixed' 96 | # 97 | # self.input = dali.ops.FileReader(file_root=data_dir, shard_id=device_id, num_shards=1, 98 | # shuffle_after_epoch=True) 99 | # self.decode = dali.ops.ImageDecoder(device=decoder_device, output_type=dali.types.RGB) 100 | # self.res = dali.ops.Resize(device=dali_device, resize_shorter=resize, 101 | # interp_type=dali.types.INTERP_TRIANGULAR) 102 | # self.cmnp = dali.ops.CropMirrorNormalize(device=dali_device, output_dtype=dali.types.FLOAT, 103 | # output_layout=dali.types.NCHW, 104 | # crop=(crop, crop), image_type=dali.types.RGB, 105 | # mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 106 | # std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 107 | # 108 | # def define_graph(self): 109 | # imgs, labels = self.input(name='Reader') 110 | # imgs = self.decode(imgs) 111 | # imgs = self.res(imgs) 112 | # imgs = self.cmnp(imgs) 113 | # return imgs.gpu(), labels.gpu() 114 | -------------------------------------------------------------------------------- /models/resnest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = ['resnest50'] 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class SplitAttention2d(nn.Module): 9 | def __init__(self, channels, inter_channels, radix, groups): 10 | super(SplitAttention2d, self).__init__() 11 | self.radix = radix 12 | self.channels = channels 13 | self.pool = nn.AdaptiveAvgPool2d(1) 14 | self.fc1 = nn.Conv2d(channels, inter_channels, 1, groups=groups, bias=False) 15 | self.bn1 = nn.BatchNorm2d(inter_channels) 16 | self.fc2 = nn.Conv2d(inter_channels, channels * radix, 1, groups=groups) 17 | self.act = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | n, c, h, w = x.size() # c = channels * radix 21 | sp = torch.reshape(x, (n, self.radix, self.channels, h, w)) 22 | x = torch.sum(sp, dim=1) 23 | x = self.pool(x) 24 | x = self.fc1(x) 25 | x = self.bn1(x) 26 | x = self.act(x) 27 | x = self.fc2(x).reshape((n, self.radix, self.channels, 1, 1)) 28 | x = torch.softmax(x, dim=1) 29 | x = torch.sum(x * sp, dim=1) 30 | return x 31 | 32 | 33 | class SplAtConv2d(nn.Module): 34 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 35 | dilation=(1, 1), groups=1, bias=True, radix=2, reduction_factor=4, **kwargs): 36 | super(SplAtConv2d, self).__init__() 37 | inter_channels = max(in_channels * radix // reduction_factor, 32) 38 | assert radix > 1 39 | self.radix = radix 40 | self.groups = groups 41 | self.channels = channels 42 | self.conv = nn.Conv2d(in_channels, channels * radix, kernel_size, stride, padding, 43 | dilation, groups * radix, bias, **kwargs) 44 | self.bn = nn.BatchNorm2d(channels * radix) 45 | self.act = nn.ReLU(inplace=True) 46 | self.spaconv = SplitAttention2d(channels, inter_channels, radix, groups) 47 | 48 | def forward(self, x): 49 | x = self.conv(x) 50 | x = self.bn(x) 51 | x = self.act(x) 52 | x = self.spaconv(x) 53 | return x 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, radix=1, groups=1, 60 | bottleneck_width=64, dilation=1): 61 | super(Bottleneck, self).__init__() 62 | group_width = int(planes * (bottleneck_width / 64.)) * groups 63 | self.radix = radix 64 | self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(group_width) 66 | self.conv2 = SplAtConv2d(group_width, group_width, kernel_size=3, 67 | stride=stride, padding=dilation, dilation=dilation, 68 | groups=groups, bias=False, radix=radix) 69 | self.conv3 = nn.Conv2d(group_width, planes * self.expansion, 1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 71 | self.act = nn.ReLU(inplace=True) 72 | self.downsample = nn.Identity() if downsample is None else downsample 73 | 74 | def forward(self, x): 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.act(out) 78 | 79 | out = self.conv2(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.act(out) 88 | return out 89 | 90 | 91 | class ResNet(nn.Module): 92 | def __init__(self, layers, radix=1, groups=1, bottleneck_width=64, 93 | num_classes=1000): 94 | super(ResNet, self).__init__() 95 | self.groups = groups 96 | self.radix = radix 97 | self.bottleneck_width = bottleneck_width 98 | 99 | self.inplanes = 64 100 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 101 | self.bn1 = nn.BatchNorm2d(self.inplanes) 102 | self.act = nn.ReLU(inplace=True) 103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 104 | 105 | self.layer1 = self._make_layer(64, layers[0]) 106 | self.layer2 = self._make_layer(128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(512, layers[3], stride=2) 109 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 110 | self.flatten = nn.Flatten() 111 | self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes) 112 | 113 | def _make_layer(self, planes, blocks, stride=1): 114 | downsample = None 115 | if stride != 1 or self.inplanes != planes * Bottleneck.expansion: 116 | downsample = nn.Sequential( 117 | nn.Conv2d(self.inplanes, planes * Bottleneck.expansion, 118 | kernel_size=1, stride=stride, bias=False), 119 | nn.BatchNorm2d(planes * Bottleneck.expansion) 120 | ) 121 | layers = [] 122 | layers.append(Bottleneck(self.inplanes, planes, stride, downsample, 123 | self.radix, self.groups, self.bottleneck_width)) 124 | self.inplanes = planes * Bottleneck.expansion 125 | for _ in range(1, blocks): 126 | layers.append(Bottleneck(self.inplanes, planes, 127 | radix=self.radix, groups=self.groups, 128 | bottleneck_width=self.bottleneck_width)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.act(x) 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | 143 | x = self.avgpool(x) 144 | x = self.flatten(x) 145 | x = self.fc(x) 146 | return x 147 | 148 | 149 | def resnest50(**kwargs): 150 | return ResNet([3, 4, 6, 3], radix=2, groups=1, bottleneck_width=64, **kwargs) 151 | -------------------------------------------------------------------------------- /models/proxyless_nas.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | """Paper `PROXYLESSNAS: DIRECT NEURAL ARCHITECTURE 4 | SEARCH ON TARGET TASK AND HARDWARE`, 5 | `https://arxiv.org/pdf/1812.00332.pdf`""" 6 | 7 | __all__ = ['ProxylessGPU', 'ProxylessCPU', 'ProxylessMobile'] 8 | 9 | from .fairnet import InvertedResidual 10 | from torch import nn 11 | 12 | 13 | class ProxylessGPU(nn.Module): 14 | def __init__(self, num_classes=1000, small_input=False): 15 | super(ProxylessGPU, self).__init__() 16 | self.first_block = nn.Sequential( 17 | nn.Conv2d(3, 40, 3, 2 if not small_input else 1, 1, bias=False), 18 | nn.BatchNorm2d(40), 19 | nn.ReLU6(inplace=True), 20 | nn.Conv2d(40, 40, 3, 1, 1, groups=40, bias=False), 21 | nn.BatchNorm2d(40), 22 | nn.ReLU6(inplace=True), 23 | nn.Conv2d(40, 24, 1, 1, bias=False), 24 | nn.BatchNorm2d(24) 25 | ) 26 | self.mb_blocks = nn.Sequential( 27 | InvertedResidual(24, 3, 32, 5, 2), 28 | InvertedResidual(32, 3, 56, 7, 2), 29 | InvertedResidual(56, 3, 56, 3, 1), 30 | InvertedResidual(56, 6, 112, 7, 2), 31 | InvertedResidual(112, 3, 112, 5, 1), 32 | InvertedResidual(112, 6, 128, 5, 1), 33 | InvertedResidual(128, 3, 128, 3, 1), 34 | InvertedResidual(128, 3, 128, 5, 1), 35 | InvertedResidual(128, 6, 256, 7, 2), 36 | InvertedResidual(256, 6, 256, 7, 1), 37 | InvertedResidual(256, 6, 256, 7, 1), 38 | InvertedResidual(256, 6, 256, 5, 1), 39 | InvertedResidual(256, 6, 432, 7, 1), 40 | ) 41 | self.last_block = nn.Sequential( 42 | nn.Conv2d(432, 1728, 1, 1, bias=False), 43 | nn.BatchNorm2d(1728), 44 | nn.ReLU6(inplace=True), 45 | nn.AdaptiveAvgPool2d(1), 46 | nn.Flatten(), 47 | ) 48 | self.output = nn.Linear(1728, num_classes) 49 | 50 | def forward(self, x): 51 | x = self.first_block(x) 52 | x = self.mb_blocks(x) 53 | x = self.last_block(x) 54 | x = self.output(x) 55 | return x 56 | 57 | 58 | class ProxylessCPU(nn.Module): 59 | def __init__(self, num_classes=1000, small_input=False): 60 | super(ProxylessCPU, self).__init__() 61 | self.first_block = nn.Sequential( 62 | nn.Conv2d(3, 40, 3, 2 if not small_input else 1, 1, bias=False), 63 | nn.BatchNorm2d(40), 64 | nn.ReLU6(inplace=True), 65 | nn.Conv2d(40, 40, 3, 1, 1, groups=40, bias=False), 66 | nn.BatchNorm2d(40), 67 | nn.ReLU6(inplace=True), 68 | nn.Conv2d(40, 24, 1, 1, bias=False), 69 | nn.BatchNorm2d(24) 70 | ) 71 | self.mb_blocks = nn.Sequential( 72 | InvertedResidual(24, 6, 32, 3, 2), 73 | InvertedResidual(32, 3, 32, 3, 1), 74 | InvertedResidual(32, 3, 32, 3, 1), 75 | InvertedResidual(32, 3, 32, 3, 1), 76 | InvertedResidual(32, 6, 48, 3, 2), 77 | InvertedResidual(48, 3, 48, 3, 1), 78 | InvertedResidual(48, 3, 48, 3, 1), 79 | InvertedResidual(48, 3, 48, 5, 1), 80 | InvertedResidual(48, 6, 88, 3, 2), 81 | InvertedResidual(88, 3, 88, 3, 1), 82 | InvertedResidual(88, 6, 104, 5, 1), 83 | InvertedResidual(104, 3, 104, 3, 1), 84 | InvertedResidual(104, 3, 104, 3, 1), 85 | InvertedResidual(104, 3, 104, 3, 1), 86 | InvertedResidual(104, 6, 216, 5, 2), 87 | InvertedResidual(216, 3, 216, 5, 1), 88 | InvertedResidual(216, 3, 216, 5, 1), 89 | InvertedResidual(216, 3, 216, 3, 1), 90 | InvertedResidual(216, 6, 360, 5, 1), 91 | ) 92 | self.last_block = nn.Sequential( 93 | nn.Conv2d(360, 1432, 1, 1, bias=False), 94 | nn.BatchNorm2d(1432), 95 | nn.ReLU6(inplace=True), 96 | nn.AdaptiveAvgPool2d(1), 97 | nn.Flatten(), 98 | ) 99 | self.output = nn.Linear(1432, num_classes) 100 | 101 | def forward(self, x): 102 | x = self.first_block(x) 103 | x = self.mb_blocks(x) 104 | x = self.last_block(x) 105 | x = self.output(x) 106 | return x 107 | 108 | 109 | class ProxylessMobile(nn.Module): 110 | def __init__(self, num_classes=1000, small_input=False): 111 | super(ProxylessMobile, self).__init__() 112 | self.first_block = nn.Sequential( 113 | nn.Conv2d(3, 32, 3, 2 if not small_input else 1, 1, bias=False), 114 | nn.BatchNorm2d(32), 115 | nn.ReLU6(inplace=True), 116 | nn.Conv2d(32, 32, 3, 1, 1, groups=40, bias=False), 117 | nn.BatchNorm2d(32), 118 | nn.ReLU6(inplace=True), 119 | nn.Conv2d(32, 16, 1, 1, bias=False), 120 | nn.BatchNorm2d(16) 121 | ) 122 | self.mb_blocks = nn.Sequential( 123 | InvertedResidual(16, 3, 32, 5, 2), 124 | InvertedResidual(32, 3, 32, 3, 1), 125 | InvertedResidual(32, 3, 40, 7, 2), 126 | InvertedResidual(40, 3, 40, 3, 1), 127 | InvertedResidual(40, 3, 40, 5, 1), 128 | InvertedResidual(40, 3, 40, 5, 1), 129 | InvertedResidual(40, 6, 80, 7, 2), 130 | InvertedResidual(80, 3, 80, 5, 1), 131 | InvertedResidual(80, 3, 80, 5, 1), 132 | InvertedResidual(80, 3, 80, 5, 1), 133 | InvertedResidual(80, 6, 96, 5, 1), 134 | InvertedResidual(96, 3, 96, 5, 1), 135 | InvertedResidual(96, 3, 96, 5, 1), 136 | InvertedResidual(96, 3, 96, 5, 1), 137 | InvertedResidual(96, 6, 192, 7, 2), 138 | InvertedResidual(192, 6, 192, 7, 1), 139 | InvertedResidual(192, 3, 192, 7, 1), 140 | InvertedResidual(192, 3, 192, 7, 1), 141 | InvertedResidual(192, 6, 320, 7, 1), 142 | ) 143 | self.last_block = nn.Sequential( 144 | nn.Conv2d(320, 1280, 1, 1, bias=False), 145 | nn.BatchNorm2d(1280), 146 | nn.ReLU6(inplace=True), 147 | nn.AdaptiveAvgPool2d(1), 148 | nn.Flatten(), 149 | ) 150 | self.output = nn.Linear(1280, num_classes) 151 | 152 | def forward(self, x): 153 | x = self.first_block(x) 154 | x = self.mb_blocks(x) 155 | x = self.last_block(x) 156 | x = self.output(x) 157 | return x 158 | -------------------------------------------------------------------------------- /models/lamdba_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, einsum 4 | from einops import rearrange 5 | from torchtoolbox.nn import Activation 6 | 7 | 8 | # helpers functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | 18 | # lambda layer 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class LambdaLayer(nn.Module): 25 | def __init__(self, dim, *, dim_k, n=None, r=None, heads=4, dim_out=None, dim_u=1): 26 | super().__init__() 27 | dim_out = default(dim_out, dim) 28 | self.u = dim_u # intra-depth dimension 29 | self.heads = heads 30 | 31 | assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query' 32 | dim_v = dim_out // heads 33 | 34 | self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False) 35 | self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False) 36 | self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False) 37 | 38 | self.norm_q = nn.BatchNorm2d(dim_k * heads) 39 | self.norm_v = nn.BatchNorm2d(dim_v * dim_u) 40 | 41 | self.local_contexts = exists(r) 42 | if exists(r): 43 | assert (r % 2) == 1, 'Receptive kernel size should be odd' 44 | self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2)) 45 | else: 46 | assert exists(n), 'You must specify the total sequence length (h x w)' 47 | self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u)) 48 | 49 | def forward(self, x): 50 | b, c, hh, ww, u, h = *x.shape, self.u, self.heads 51 | 52 | q = self.to_q(x) 53 | k = self.to_k(x) 54 | v = self.to_v(x) 55 | 56 | q = self.norm_q(q) 57 | v = self.norm_v(v) 58 | 59 | q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h) 60 | k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u) 61 | v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u) 62 | 63 | k = k.softmax(dim=-1) 64 | 65 | λc = einsum('b u k m, b u v m -> b k v', k, v) 66 | Yc = einsum('b h k n, b k v -> b n h v', q, λc) 67 | 68 | if self.local_contexts: 69 | v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww) 70 | λp = self.pos_conv(v) 71 | Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3)) 72 | else: 73 | λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v) 74 | Yp = einsum('b h k n, b n k v -> b n h v', q, λp) 75 | 76 | Y = Yc + Yp 77 | out = rearrange(Y, 'b (hh ww) h v -> b (h v) hh ww', hh=hh, ww=ww) 78 | return out.contiguous() 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None): 85 | super(Bottleneck, self).__init__() 86 | 87 | self.conv1 = conv1x1(inplanes, planes) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | 90 | self.conv2 = LambdaLayer(planes, dim_k=16, r=15, heads=4, dim_u=1) 91 | self.pool = nn.AvgPool2d(3, 2, 1) if stride != 1 else nn.Identity() 92 | 93 | self.conv3 = conv1x1(planes, planes * self.expansion) 94 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 95 | self.act = nn.ReLU(inplace=True) 96 | self.downsample = nn.Identity() if downsample is None else downsample 97 | 98 | def forward(self, x): 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.act(out) 102 | 103 | out = self.conv2(out) 104 | out = self.pool(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | identity = self.downsample(x) 110 | 111 | out += identity 112 | out = self.act(out) 113 | 114 | return out 115 | 116 | 117 | class LambdaResnet(nn.Module): 118 | def __init__(self, layers, num_classes=1000, small_input=False): 119 | super(LambdaResnet, self).__init__() 120 | self.inplanes = 64 121 | if small_input: 122 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 123 | bias=False) 124 | else: 125 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 126 | bias=False) 127 | self.bn1 = nn.BatchNorm2d(self.inplanes) 128 | self.act = nn.ReLU(inplace=True) 129 | if small_input: 130 | self.maxpool = nn.Identity() 131 | else: 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | self.layer1 = self._make_layer(64, layers[0]) 134 | self.layer2 = self._make_layer(128, layers[1], stride=2) 135 | self.layer3 = self._make_layer(256, layers[2], stride=2) 136 | self.layer4 = self._make_layer(512, layers[3], stride=2) 137 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 138 | self.flatten = nn.Flatten() 139 | 140 | self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes) 141 | 142 | def _make_layer(self, planes, blocks, stride=1, ): 143 | 144 | downsample = None 145 | 146 | if stride != 1 or self.inplanes != planes * Bottleneck.expansion: 147 | downsample = nn.Sequential( 148 | conv1x1(self.inplanes, planes * Bottleneck.expansion, stride), 149 | nn.BatchNorm2d(planes * Bottleneck.expansion), 150 | ) 151 | 152 | layers = [] 153 | layers.append(Bottleneck(self.inplanes, planes, stride, downsample)) 154 | self.inplanes = planes * Bottleneck.expansion 155 | for _ in range(1, blocks): 156 | layers.append(Bottleneck(self.inplanes, planes)) 157 | 158 | return nn.Sequential(*layers) 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.bn1(x) 163 | x = self.act(x) 164 | x = self.maxpool(x) 165 | 166 | x = self.layer1(x) 167 | x = self.layer2(x) 168 | x = self.layer3(x) 169 | x = self.layer4(x) 170 | 171 | x = self.avgpool(x) 172 | x = self.flatten(x) 173 | x = self.fc(x) 174 | 175 | return x 176 | 177 | 178 | def LambdaResnet18(**kwargs): 179 | """Constructs a ResNet-18 model. 180 | 181 | """ 182 | return LambdaResnet([2, 2, 2, 2], **kwargs) 183 | 184 | 185 | def LambdaResnet34(**kwargs): 186 | """Constructs a ResNet-34 model. 187 | 188 | """ 189 | return LambdaResnet([3, 4, 6, 3], **kwargs) 190 | 191 | 192 | def LambdaResnet50(**kwargs): 193 | """Constructs a ResNet-50 model. 194 | 195 | """ 196 | return LambdaResnet([3, 4, 6, 3], **kwargs) 197 | 198 | 199 | def LambdaResnet101(**kwargs): 200 | """Constructs a ResNet-101 model. 201 | 202 | """ 203 | return LambdaResnet([3, 4, 23, 3], **kwargs) 204 | 205 | 206 | def LambdaResnet152(**kwargs): 207 | """Constructs a ResNet-152 model. 208 | 209 | """ 210 | return LambdaResnet([3, 8, 36, 3], **kwargs) 211 | -------------------------------------------------------------------------------- /models/efficientnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | 5 | __all__ = ['EfficientNet', 'EfficientNet_B0', 'EfficientNet_B1', 'EfficientNet_B2', 6 | 'EfficientNet_B3', 'EfficientNet_B4', 'EfficientNet_B5', 'EfficientNet_B6', 7 | 'EfficientNet_B7'] 8 | 9 | import math 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class Swish(nn.Module): 15 | def forward(self, x): 16 | return x * torch.sigmoid(x) 17 | 18 | 19 | class DropConnect(nn.Module): 20 | def __init__(self, ratio): 21 | super().__init__() 22 | self.ratio = 1.0 - ratio 23 | 24 | def forward(self, x): 25 | if not self.training: 26 | return x 27 | 28 | random_tensor = self.ratio 29 | random_tensor += torch.rand([x.shape[0], 1, 1, 1], dtype=torch.float, device=x.device) 30 | random_tensor.requires_grad_(False) 31 | return x / self.ratio * random_tensor.floor() 32 | 33 | 34 | def _conv_bn(in_c, out_c, kernel_size, stride=1, groups=1, 35 | eps=1e-5, momentum=0.1, use_act=False): 36 | layer = [] 37 | layer.append(nn.Conv2d(in_c, out_c, kernel_size, stride, kernel_size // 2, groups=groups, bias=False)) 38 | layer.append(nn.BatchNorm2d(out_c, eps, momentum)) 39 | if use_act: 40 | layer.append(Swish()) 41 | return nn.Sequential(*layer) 42 | 43 | 44 | class SEModule(nn.Module): 45 | def __init__(self, in_, squeeze_ch): 46 | super().__init__() 47 | self.se = nn.Sequential( 48 | nn.AdaptiveAvgPool2d(1), 49 | nn.Conv2d(in_, squeeze_ch, kernel_size=1, stride=1, padding=0, bias=True), 50 | Swish(), 51 | nn.Conv2d(squeeze_ch, in_, kernel_size=1, stride=1, padding=0, bias=True), 52 | nn.Sigmoid(), 53 | ) 54 | 55 | def forward(self, x): 56 | return x * self.se(x) 57 | 58 | 59 | class MBConv(nn.Module): 60 | def __init__(self, in_c, out_c, expand, 61 | kernel_size, stride, se_ratio, 62 | dc_ratio): 63 | super().__init__() 64 | exp_c = in_c * expand 65 | self.layer1 = _conv_bn(in_c, exp_c, 1, use_act=True) if expand != 1 else nn.Identity() 66 | self.layer2 = _conv_bn(exp_c, exp_c, kernel_size, stride, 67 | groups=exp_c, use_act=True) 68 | self.se_layer = SEModule(exp_c, int(in_c * se_ratio)) if se_ratio > 0 else nn.Identity() 69 | self.layer3 = _conv_bn(exp_c, out_c, 1) 70 | self.skip = True if stride == 1 and in_c == out_c else False 71 | self.dropconnect = DropConnect(dc_ratio) if self.skip and dc_ratio > 0 else nn.Identity() 72 | 73 | def forward(self, inputs): 74 | x = self.layer1(inputs) 75 | x = self.layer2(x) 76 | x = self.se_layer(x) 77 | x = self.layer3(x) 78 | if self.skip: 79 | x = self.dropconnect(x) + inputs 80 | return x 81 | 82 | 83 | class EfficientNet(nn.Module): 84 | def __init__(self, width_coeff, depth_coeff, 85 | depth_div=8, min_depth=None, 86 | dropout_rate=0., drop_connect_rate=0, 87 | num_classes=1000, small_input=False): 88 | super().__init__() 89 | min_depth = min_depth or depth_div 90 | 91 | def renew_ch(x): 92 | if not width_coeff: 93 | return x 94 | 95 | new_x = max(min_depth, int(x + depth_div / 2) // depth_div * depth_div) 96 | if new_x < 0.9 * new_x: 97 | new_x += depth_div 98 | return new_x 99 | 100 | def renew_repeat(x): 101 | return int(math.ceil(x * depth_coeff)) 102 | 103 | self.first_conv = _conv_bn(3, renew_ch(32), 3, 2 if not small_input else 1, use_act=True) 104 | self.blocks = nn.Sequential( 105 | self._make_layer(renew_ch(32), renew_ch(16), 1, 3, 1, renew_repeat(1), 0.25, drop_connect_rate), 106 | self._make_layer(renew_ch(16), renew_ch(24), 6, 3, 2, renew_repeat(2), 0.25, drop_connect_rate), 107 | self._make_layer(renew_ch(24), renew_ch(40), 6, 5, 2, renew_repeat(2), 0.25, drop_connect_rate), 108 | self._make_layer(renew_ch(40), renew_ch(80), 6, 3, 2, renew_repeat(3), 0.25, drop_connect_rate), 109 | self._make_layer(renew_ch(80), renew_ch(112), 6, 5, 1, renew_repeat(3), 0.25, drop_connect_rate), 110 | self._make_layer(renew_ch(112), renew_ch(192), 6, 5, 2, renew_repeat(4), 0.25, drop_connect_rate), 111 | self._make_layer(renew_ch(192), renew_ch(320), 6, 3, 1, renew_repeat(1), 0.25, drop_connect_rate), 112 | ) 113 | self.last_process = nn.Sequential( 114 | *_conv_bn(renew_ch(320), renew_ch(1280), 1, use_act=True), 115 | nn.AdaptiveAvgPool2d(1), 116 | nn.Dropout2d(dropout_rate, True) if dropout_rate > 0 else nn.Identity(), 117 | ) 118 | self.output = nn.Linear(renew_ch(1280), num_classes) 119 | 120 | def _make_layer(self, in_c, out_c, expand, kernel_size, stride, repeats, se_ratio, drop_connect_ratio): 121 | layers = [] 122 | layers.append(MBConv(in_c, out_c, expand, kernel_size, stride, se_ratio, drop_connect_ratio)) 123 | for _ in range(repeats - 1): 124 | layers.append(MBConv(out_c, out_c, expand, kernel_size, 1, se_ratio, drop_connect_ratio)) 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | x = self.first_conv(x) 129 | x = self.blocks(x) 130 | x = self.last_process(x) 131 | x = x.view(x.shape[0], -1) 132 | x = self.output(x) 133 | return x 134 | 135 | 136 | def EfficientNet_B0(num_classes=1000, **kwargs): 137 | model = EfficientNet(1., 1., dropout_rate=0, num_classes=num_classes, **kwargs) 138 | return model 139 | 140 | 141 | def EfficientNet_B1(num_classes=1000, **kwargs): 142 | # input size should be 240(~1.07x) 143 | model = EfficientNet(1., 1.1, dropout_rate=0.2, num_classes=num_classes, **kwargs) 144 | return model 145 | 146 | 147 | def EfficientNet_B2(num_classes=1000, **kwargs): 148 | # input size should be 260(~1.16x) 149 | model = EfficientNet(1.1, 1.2, dropout_rate=0.3, num_classes=num_classes, **kwargs) 150 | return model 151 | 152 | 153 | def EfficientNet_B3(num_classes=1000, **kwargs): 154 | # input size should be 300(~1.34x) 155 | model = EfficientNet(1.2, 1.4, dropout_rate=0.3, num_classes=num_classes, **kwargs) 156 | return model 157 | 158 | 159 | def EfficientNet_B4(num_classes=1000, **kwargs): 160 | # input size should be 380(~1.70x) 161 | model = EfficientNet(1.4, 1.8, dropout_rate=0.4, num_classes=num_classes, **kwargs) 162 | return model 163 | 164 | 165 | def EfficientNet_B5(num_classes=1000, **kwargs): 166 | # input size should be 456(~2.036x) 167 | model = EfficientNet(1.6, 2.2, dropout_rate=0.4, num_classes=num_classes, **kwargs) 168 | return model 169 | 170 | 171 | def EfficientNet_B6(num_classes=1000, **kwargs): 172 | # input size should be 528(~2.357x) 173 | model = EfficientNet(1.8, 2.6, dropout_rate=0.5, num_classes=num_classes, **kwargs) 174 | return model 175 | 176 | 177 | def EfficientNet_B7(num_classes=1000, **kwargs): 178 | # input size should be 600(~2.679x) 179 | model = EfficientNet(2., 3.1, dropout_rate=0.5, num_classes=num_classes, **kwargs) 180 | return model 181 | -------------------------------------------------------------------------------- /models/fairnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['FairNasA', 'FairNasB', 'FairNasC', 4 | 'InvertedResidual'] 5 | 6 | from torch import nn 7 | 8 | 9 | class InvertedResidual(nn.Module): 10 | def __init__(self, in_c, expansion, out_c, kernel_size, stride): 11 | super(InvertedResidual, self).__init__() 12 | hidden_c = round(in_c * expansion) 13 | self.skip = stride == 1 and in_c == out_c 14 | self.act = nn.ReLU6(inplace=True) 15 | 16 | self.conv1 = nn.Conv2d(in_c, hidden_c, 1, 1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(hidden_c) 18 | self.conv2 = nn.Conv2d(hidden_c, hidden_c, kernel_size, stride, 19 | kernel_size // 2, groups=hidden_c, bias=False) 20 | self.bn2 = nn.BatchNorm2d(hidden_c) 21 | self.conv3 = nn.Conv2d(hidden_c, out_c, 1, 1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(out_c) 23 | 24 | def forward(self, x): 25 | skip = x 26 | x = self.conv1(x) 27 | x = self.bn1(x) 28 | x = self.act(x) 29 | x = self.conv2(x) 30 | x = self.bn2(x) 31 | x = self.act(x) 32 | x = self.conv3(x) 33 | x = self.bn3(x) 34 | if self.skip: 35 | x = skip + x 36 | return x 37 | 38 | 39 | class FairNasA(nn.Module): 40 | def __init__(self, num_classes=1000, small_input=False): 41 | super(FairNasA, self).__init__() 42 | self.first_block = nn.Sequential( 43 | nn.Conv2d(3, 32, 3, 2 if not small_input else 1, 1, bias=False), 44 | nn.BatchNorm2d(32), 45 | nn.ReLU6(inplace=True), 46 | nn.Conv2d(32, 32, 3, 1, 1, groups=32, bias=False), 47 | nn.BatchNorm2d(32), 48 | nn.ReLU6(inplace=True), 49 | nn.Conv2d(32, 16, 1, 1, bias=False), 50 | nn.BatchNorm2d(16) 51 | ) 52 | self.mb_blocks = nn.Sequential( 53 | InvertedResidual(16, 3, 32, 7, 2), 54 | InvertedResidual(32, 3, 32, 3, 1), 55 | InvertedResidual(32, 3, 40, 7, 2), 56 | InvertedResidual(40, 6, 40, 3, 1), 57 | InvertedResidual(40, 6, 40, 7, 1), 58 | InvertedResidual(40, 3, 40, 3, 1), 59 | InvertedResidual(40, 3, 80, 3, 2), 60 | InvertedResidual(80, 6, 80, 7, 1), 61 | InvertedResidual(80, 6, 80, 7, 1), 62 | InvertedResidual(80, 3, 80, 5, 1), 63 | InvertedResidual(80, 6, 96, 3, 1), 64 | InvertedResidual(96, 3, 96, 5, 1), 65 | InvertedResidual(96, 3, 96, 5, 1), 66 | InvertedResidual(96, 3, 96, 3, 1), 67 | InvertedResidual(96, 6, 192, 3, 2), 68 | InvertedResidual(192, 6, 192, 7, 1), 69 | InvertedResidual(192, 6, 192, 3, 1), 70 | InvertedResidual(192, 6, 192, 7, 1), 71 | InvertedResidual(192, 6, 320, 5, 1), 72 | ) 73 | self.last_block = nn.Sequential( 74 | nn.Conv2d(320, 1280, 1, 1, bias=False), 75 | nn.BatchNorm2d(1280), 76 | nn.ReLU6(inplace=True), 77 | nn.AdaptiveAvgPool2d(1), 78 | nn.Flatten(), 79 | ) 80 | self.output = nn.Linear(1280, num_classes) 81 | 82 | def forward(self, x): 83 | x = self.first_block(x) 84 | x = self.mb_blocks(x) 85 | x = self.last_block(x) 86 | x = self.output(x) 87 | return x 88 | 89 | 90 | class FairNasB(nn.Module): 91 | def __init__(self, num_classes=1000, small_input=False): 92 | super(FairNasB, self).__init__() 93 | self.first_block = nn.Sequential( 94 | nn.Conv2d(3, 32, 3, 2 if not small_input else 1, 1, bias=False), 95 | nn.BatchNorm2d(32), 96 | nn.ReLU6(inplace=True), 97 | nn.Conv2d(32, 32, 3, 1, 1, groups=32, bias=False), 98 | nn.BatchNorm2d(32), 99 | nn.ReLU6(inplace=True), 100 | nn.Conv2d(32, 16, 1, 1, bias=False), 101 | nn.BatchNorm2d(16) 102 | ) 103 | self.mb_blocks = nn.Sequential( 104 | InvertedResidual(16, 3, 32, 5, 2), 105 | InvertedResidual(32, 3, 32, 3, 1), 106 | InvertedResidual(32, 3, 40, 5, 2), 107 | InvertedResidual(40, 3, 40, 3, 1), 108 | InvertedResidual(40, 6, 40, 3, 1), 109 | InvertedResidual(40, 3, 40, 5, 1), 110 | InvertedResidual(40, 3, 80, 7, 2), 111 | InvertedResidual(80, 3, 80, 3, 1), 112 | InvertedResidual(80, 6, 80, 3, 1), 113 | InvertedResidual(80, 3, 80, 5, 1), 114 | InvertedResidual(80, 3, 96, 3, 1), 115 | InvertedResidual(96, 6, 96, 3, 1), 116 | InvertedResidual(96, 3, 96, 7, 1), 117 | InvertedResidual(96, 3, 96, 3, 1), 118 | InvertedResidual(96, 6, 192, 7, 2), 119 | InvertedResidual(192, 6, 192, 5, 1), 120 | InvertedResidual(192, 6, 192, 7, 1), 121 | InvertedResidual(192, 6, 192, 3, 1), 122 | InvertedResidual(192, 6, 320, 5, 1), 123 | ) 124 | self.last_block = nn.Sequential( 125 | nn.Conv2d(320, 1280, 1, 1, bias=False), 126 | nn.BatchNorm2d(1280), 127 | nn.ReLU6(inplace=True), 128 | nn.AdaptiveAvgPool2d(1), 129 | nn.Flatten(), 130 | ) 131 | self.output = nn.Linear(1280, num_classes) 132 | 133 | def forward(self, x): 134 | x = self.first_block(x) 135 | x = self.mb_blocks(x) 136 | x = self.last_block(x) 137 | x = self.output(x) 138 | return x 139 | 140 | 141 | class FairNasC(nn.Module): 142 | def __init__(self, num_classes=1000, small_input=False): 143 | super(FairNasC, self).__init__() 144 | self.first_block = nn.Sequential( 145 | nn.Conv2d(3, 32, 3, 2 if not small_input else 1, 1, bias=False), 146 | nn.BatchNorm2d(32), 147 | nn.ReLU6(inplace=True), 148 | nn.Conv2d(32, 32, 3, 1, 1, groups=32, bias=False), 149 | nn.BatchNorm2d(32), 150 | nn.ReLU6(inplace=True), 151 | nn.Conv2d(32, 16, 1, 1, bias=False), 152 | nn.BatchNorm2d(16) 153 | ) 154 | self.mb_blocks = nn.Sequential( 155 | InvertedResidual(16, 3, 32, 5, 2), 156 | InvertedResidual(32, 3, 32, 3, 1), 157 | InvertedResidual(32, 3, 40, 7, 2), 158 | InvertedResidual(40, 3, 40, 3, 1), 159 | InvertedResidual(40, 3, 40, 3, 1), 160 | InvertedResidual(40, 3, 40, 3, 1), 161 | InvertedResidual(40, 3, 80, 3, 2), 162 | InvertedResidual(80, 3, 80, 3, 1), 163 | InvertedResidual(80, 3, 80, 3, 1), 164 | InvertedResidual(80, 6, 80, 3, 1), 165 | InvertedResidual(80, 3, 96, 3, 1), 166 | InvertedResidual(96, 3, 96, 3, 1), 167 | InvertedResidual(96, 3, 96, 3, 1), 168 | InvertedResidual(96, 3, 96, 3, 1), 169 | InvertedResidual(96, 6, 192, 7, 2), 170 | InvertedResidual(192, 6, 192, 7, 1), 171 | InvertedResidual(192, 6, 192, 3, 1), 172 | InvertedResidual(192, 6, 192, 3, 1), 173 | InvertedResidual(192, 6, 320, 5, 1), 174 | ) 175 | self.last_block = nn.Sequential( 176 | nn.Conv2d(320, 1280, 1, 1, bias=False), 177 | nn.BatchNorm2d(1280), 178 | nn.ReLU6(inplace=True), 179 | nn.AdaptiveAvgPool2d(1), 180 | nn.Flatten(), 181 | ) 182 | self.output = nn.Linear(1280, num_classes) 183 | 184 | def forward(self, x): 185 | x = self.first_block(x) 186 | x = self.mb_blocks(x) 187 | x = self.last_block(x) 188 | x = self.output(x) 189 | return x 190 | -------------------------------------------------------------------------------- /module/octconv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['OctaveConv', 4 | 'OctConv', 'OctConvFirst', 'OctDwConv', 'OctConvLast'] 5 | from torch import nn 6 | 7 | 8 | class OctaveConv(nn.Module): 9 | def __init__(self, in_channels, channels, alpha_in, alpha_out, 10 | kernel_size, padding=0, stride=1, groups=1, bias=False): 11 | super().__init__() 12 | assert stride in (1, 2), 'stride should be 1 or 2.' 13 | assert 0 <= alpha_in < 1 and 0 <= alpha_out < 1, 'Wrong setting with alpha' 14 | self.alpha_in, self.alpha_out = alpha_in, alpha_out 15 | self.stride = stride 16 | self.depth_wise = depth_wise = True if channels == groups else False 17 | h_in = int((1 - alpha_in) * in_channels) 18 | l_in = int(alpha_in * in_channels) 19 | h_out = int((1 - alpha_out) * channels) 20 | l_out = int(alpha_out * channels) 21 | 22 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=2) 23 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 24 | self.W_HH = nn.Conv2d(h_in, h_out, kernel_size, 1, padding, 1, min(h_in, groups), bias) 25 | if alpha_out != 0 and alpha_in != 0: 26 | self.W_LL = nn.Conv2d(l_in, l_out, kernel_size, 1, padding, 1, min(l_in, groups), bias) 27 | if alpha_out != 0 and not depth_wise: 28 | self.W_HL = nn.Conv2d(h_in, l_out, kernel_size, 1, padding, 1, min(h_in, groups), bias) 29 | if alpha_in != 0 and not depth_wise: 30 | self.W_LH = nn.Conv2d(l_in, h_out, kernel_size, 1, padding, 1, min(l_in, groups), bias) 31 | 32 | def forward(self, x_h, x_l=None): 33 | # vanilla layer 34 | if self.alpha_in == self.alpha_out == 0: 35 | y_hh = self.W_HH(x_h) if self.stride == 1 else self.W_HH(self.downsample(x_h)) 36 | return y_hh, None 37 | # first oct layer(first layer should not be depth wise layer) 38 | elif self.alpha_in == 0: 39 | x_h = x_h if self.stride == 1 else self.downsample(x_h) 40 | y_hh = self.W_HH(x_h) 41 | y_hl = self.W_HL(self.downsample(x_h)) 42 | return y_hh, y_hl 43 | # last oct layer 44 | elif self.alpha_out == 0: 45 | y_hh = self.W_HH(x_h) if self.stride == 1 else self.W_HH(self.downsample(x_h)) 46 | if not self.depth_wise: 47 | y_lh = self.upsample(self.W_LH(x_l)) if self.stride == 1 else self.W_LH(x_l) 48 | y_h_out = y_hh + y_lh 49 | else: 50 | y_h_out = y_hh 51 | return y_h_out, None 52 | # oct layer 53 | else: 54 | y_hh = self.W_HH(x_h) if self.stride == 1 else self.W_HH(self.downsample(x_h)) 55 | y_ll = self.W_LL(x_l) if self.stride == 1 else self.W_LL(self.downsample(x_l)) 56 | if not self.depth_wise: 57 | y_lh = self.upsample(self.W_LH(x_l)) if self.stride == 1 else self.W_LH(x_l) 58 | x_h = x_h if self.stride == 1 else self.downsample(x_h) 59 | y_hl = self.W_HL(self.downsample(x_h)) 60 | y_h_out = y_hh + y_lh 61 | y_l_out = y_ll + y_hl 62 | else: 63 | y_h_out = y_hh 64 | y_l_out = y_ll 65 | return y_h_out, y_l_out 66 | 67 | 68 | # Helper layer to avoid using if/else in forward 69 | class OctConvFirst(nn.Module): 70 | def __init__(self, in_channels, channels, alpha, kernel_size, 71 | padding=0, stride=1, groups=1, bias=False): 72 | assert stride in (1, 2), 'stride should be 1 or 2.' 73 | assert 0 <= alpha < 1, 'Wrong setting with alpha' 74 | assert groups < channels, 'First OctConv does not support dw_conv' 75 | super(OctConvFirst, self).__init__() 76 | h_out = int((1 - alpha) * channels) 77 | l_out = int(alpha * channels) 78 | 79 | self.stride = stride 80 | self.W_HH = nn.Conv2d(in_channels, h_out, kernel_size, 1, padding, groups=groups, bias=bias) 81 | self.W_HL = nn.Conv2d(in_channels, l_out, kernel_size, 1, padding, groups=groups, bias=bias) 82 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=2) 83 | 84 | def forward(self, x): 85 | if self.stride == 2: 86 | x = self.downsample(x) 87 | y_hh = self.W_HH(x) 88 | y_ll = self.W_HL(self.downsample(x)) 89 | return y_hh, y_ll 90 | 91 | 92 | class OctConv(nn.Module): 93 | def __init__(self, in_channels, channels, alpha, kernel_size, 94 | padding=0, stride=1, groups=1, bias=False): 95 | assert stride in (1, 2), 'stride should be 1 or 2.' 96 | assert 0 < alpha < 1, 'Wrong setting with alpha' 97 | assert groups < channels, 'Use OctDwConv for dw conv' 98 | super(OctConv, self).__init__() 99 | h_in = int((1 - alpha) * in_channels) 100 | l_in = int(alpha * in_channels) 101 | h_out = int((1 - alpha) * channels) 102 | l_out = int(alpha * channels) 103 | 104 | self.stride = stride 105 | self.W_HH = nn.Conv2d(h_in, h_out, kernel_size, 1, padding, 1, groups, bias) 106 | self.W_LL = nn.Conv2d(l_in, l_out, kernel_size, 1, padding, 1, groups, bias) 107 | self.W_HL = nn.Conv2d(h_in, l_out, kernel_size, 1, padding, 1, groups, bias) 108 | self.W_LH = nn.Conv2d(l_in, h_out, kernel_size, 1, padding, 1, groups, bias) 109 | 110 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=2) 111 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 112 | 113 | def forward(self, x_h, x_l): 114 | if self.stride == 1: 115 | y_hh = self.W_HH(x_h) 116 | y_ll = self.W_LL(x_l) 117 | 118 | y_lh = self.upsample(self.W_LH(x_l)) 119 | y_hl = self.W_HL(self.downsample(x_h)) 120 | else: 121 | y_lh = self.W_LH(x_l) 122 | 123 | x_h = self.downsample(x_h) 124 | x_l = self.downsample(x_l) 125 | 126 | y_hh = self.W_HH(x_h) 127 | y_ll = self.W_LL(x_l) 128 | y_hl = self.W_HL(self.downsample(x_h)) 129 | 130 | y_h_out = y_hh + y_lh 131 | y_l_out = y_ll + y_hl 132 | return y_h_out, y_l_out 133 | 134 | 135 | class OctDwConv(nn.Module): 136 | def __init__(self, in_channels, channels, alpha, kernel_size, 137 | padding=0, stride=1, groups=1, bias=False): 138 | assert stride in (1, 2), 'stride should be 1 or 2.' 139 | assert 0 < alpha < 1, 'Wrong setting with alpha' 140 | assert groups == channels, 'This layer is for dw conv' 141 | super(OctDwConv, self).__init__() 142 | h_in = int((1 - alpha) * in_channels) 143 | l_in = int(alpha * in_channels) 144 | h_out = int((1 - alpha) * channels) 145 | l_out = int(alpha * channels) 146 | 147 | self.stride = stride 148 | self.W_HH = nn.Conv2d(h_in, h_out, kernel_size, 1, padding, 1, groups, bias) 149 | self.W_LL = nn.Conv2d(l_in, l_out, kernel_size, 1, padding, 1, groups, bias) 150 | 151 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=2) 152 | 153 | def forward(self, x_h, x_l): 154 | if self.stride != 1: 155 | x_h = self.downsample(x_h) 156 | x_l = self.downsample(x_l) 157 | y_h_out = self.W_HH(x_h) 158 | y_l_out = self.W_LL(x_l) 159 | 160 | return y_h_out, y_l_out 161 | 162 | 163 | class OctConvLast(nn.Module): 164 | def __init__(self, in_channels, channels, alpha, kernel_size, 165 | padding=0, stride=1, groups=1, bias=False): 166 | assert stride in (1, 2), 'stride should be 1 or 2.' 167 | assert 0 < alpha < 1, 'Wrong setting with alpha' 168 | assert groups < channels, 'Use OctDwConvLast for dw conv' 169 | super(OctConvLast, self).__init__() 170 | h_in = int((1 - alpha) * in_channels) 171 | l_in = int(alpha * in_channels) 172 | 173 | self.stride = stride 174 | self.W_HH = nn.Conv2d(h_in, channels, kernel_size, 1, padding, 1, groups, bias) 175 | self.W_LH = nn.Conv2d(l_in, channels, kernel_size, 1, padding, 1, groups, bias) 176 | 177 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=2) 178 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 179 | 180 | def forward(self, x_h, x_l): 181 | if self.stride == 1: 182 | y_hh = self.W_HH(x_h) 183 | y_lh = self.upsample(self.W_LH(x_l)) 184 | else: 185 | x_h = self.downsample(x_h) 186 | y_hh = self.W_HH(x_h) 187 | y_lh = self.W_LH(x_l) 188 | y_h_out = y_hh + y_lh 189 | return y_h_out 190 | -------------------------------------------------------------------------------- /models/regnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = ['RegNetX200MF', 'RegNetX400MF', 'RegNetX600MF', 'RegNetX800MF', 3 | 'RegNetX1_6GF', 'RegNetX3_2GF', 'RegNetX4_0GF', 'RegNetX6_4GF', 'RegNetX8_0GF', 4 | 'RegNetY200MF', 'RegNetY400MF', 'RegNetY600MF', 'RegNetY800MF', 5 | 'RegNetY1_6GF', 'RegNetY3_2GF', 'RegNetY4_0GF', 'RegNetY6_4GF', 'RegNetY8_0GF' 6 | ] 7 | 8 | from torch import nn 9 | from torchtoolbox.tools import make_divisible 10 | 11 | 12 | class Stem(nn.Module): 13 | def __init__(self, in_c, out_c): 14 | super(Stem, self).__init__() 15 | self.block = nn.Sequential( 16 | nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1, bias=False), 17 | nn.BatchNorm2d(out_c), 18 | nn.ReLU(inplace=True), 19 | ) 20 | 21 | def forward(self, x): 22 | return self.block(x) 23 | 24 | 25 | class SE(nn.Module): 26 | def __init__(self, in_c, reduction_ratio=0.25): 27 | super(SE, self).__init__() 28 | reducation_c = int(in_c * reduction_ratio) 29 | self.block = nn.Sequential( 30 | nn.AdaptiveAvgPool2d(1), 31 | nn.Conv2d(in_c, reducation_c, kernel_size=1, bias=True), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(reducation_c, in_c, kernel_size=1, bias=True), 34 | nn.Sigmoid() 35 | ) 36 | 37 | def forward(self, x): 38 | return x * self.block(x) 39 | 40 | 41 | class Stage(nn.Module): 42 | def __init__(self, in_c, out_c, stride, bottleneck_ratio, group_width, reduction_ratio=0): 43 | super(Stage, self).__init__() 44 | width = make_divisible(out_c * bottleneck_ratio) 45 | groups = width // group_width 46 | 47 | self.block = nn.Sequential( 48 | nn.Conv2d(in_c, width, kernel_size=1, bias=False), 49 | nn.BatchNorm2d(width), 50 | nn.ReLU(inplace=True), 51 | 52 | nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False), 53 | nn.BatchNorm2d(width), 54 | nn.ReLU(inplace=True), 55 | SE(width, reduction_ratio) if reduction_ratio != 0 else nn.Identity(), 56 | 57 | nn.Conv2d(width, out_c, kernel_size=1, bias=False), 58 | nn.BatchNorm2d(out_c) 59 | ) 60 | 61 | if in_c != out_c or stride != 1: 62 | self.skip_connection = nn.Sequential( 63 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, bias=False), 64 | nn.BatchNorm2d(out_c) 65 | ) 66 | else: 67 | self.skip_connection = nn.Identity() 68 | 69 | self.act = nn.ReLU(inplace=True) 70 | 71 | def forward(self, x): 72 | skip = self.skip_connection(x) 73 | x = self.block(x) 74 | x = self.act(x + skip) 75 | return x 76 | 77 | 78 | class Head(nn.Module): 79 | def __init__(self, in_c, out_c): 80 | super(Head, self).__init__() 81 | self.block = nn.Sequential( 82 | nn.AdaptiveAvgPool2d(1), 83 | nn.Flatten(), 84 | nn.Linear(in_c, out_c, bias=True) 85 | ) 86 | 87 | def forward(self, x): 88 | return self.block(x) 89 | 90 | 91 | class RegNet(nn.Module): 92 | def __init__(self, d, w, g, num_classes=1000, b=1, se=False): 93 | super(RegNet, self).__init__() 94 | self.reduction_ratio = 0.25 if se else 0 95 | self.bottleneck_ratio = b 96 | self.group_width = g 97 | stem_c = 32 98 | 99 | self.stem = Stem(3, stem_c) 100 | self.stage = nn.Sequential( 101 | self._make_layer(stem_c, w[0], d[0], 2), 102 | self._make_layer(w[0], w[1], d[1], 2), 103 | self._make_layer(w[1], w[2], d[2], 2), 104 | self._make_layer(w[2], w[3], d[3], 2)) 105 | self.head = Head(w[3], num_classes) 106 | 107 | def _make_layer(self, in_c, out_c, blocks, stride=2): 108 | layers = [] 109 | layers.append(Stage(in_c, out_c, stride, self.bottleneck_ratio, 110 | self.group_width, self.reduction_ratio)) 111 | for _ in range(1, blocks): 112 | layers.append(Stage(out_c, out_c, 1, self.bottleneck_ratio, 113 | self.group_width, self.reduction_ratio)) 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.stem(x) 118 | x = self.stage(x) 119 | x = self.head(x) 120 | return x 121 | 122 | 123 | _regnetx_config = { 124 | '200MF': {'d': [1, 1, 4, 7], 'w': [24, 56, 152, 368], 'g': 8}, 125 | '400MF': {'d': [1, 2, 7, 12], 'w': [32, 64, 160, 384], 'g': 16}, 126 | '600MF': {'d': [1, 3, 5, 7], 'w': [48, 96, 240, 528], 'g': 24}, 127 | '800MF': {'d': [1, 3, 5, 7], 'w': [64, 128, 288, 672], 'g': 16}, 128 | '1.6GF': {'d': [2, 4, 10, 2], 'w': [72, 168, 408, 912], 'g': 24}, 129 | '3.2GF': {'d': [2, 6, 15, 2], 'w': [96, 192, 432, 1008], 'g': 48}, 130 | '4.0GF': {'d': [2, 5, 14, 2], 'w': [80, 240, 560, 1360], 'g': 40}, 131 | '6.4GF': {'d': [2, 4, 10, 1], 'w': [168, 392, 784, 1624], 'g': 56}, 132 | '8.0GF': {'d': [2, 5, 15, 1], 'w': [80, 240, 720, 1920], 'g': 120}, 133 | '12GF': {'d': [2, 5, 11, 1], 'w': [224, 448, 896, 2240], 'g': 112}, 134 | '16GF': {'d': [2, 6, 13, 1], 'w': [256, 512, 896, 2048], 'g': 128}, 135 | '32GF': {'d': [2, 7, 13, 1], 'w': [336, 672, 1344, 2520], 'g': 168}, 136 | } 137 | 138 | _regnety_config = { 139 | '200MF': {'d': [1, 1, 4, 7], 'w': [24, 56, 152, 368], 'g': 8}, 140 | '400MF': {'d': [1, 3, 6, 6], 'w': [48, 104, 208, 440], 'g': 8}, 141 | '600MF': {'d': [1, 3, 7, 4], 'w': [48, 112, 256, 608], 'g': 16}, 142 | '800MF': {'d': [1, 3, 8, 2], 'w': [64, 128, 320, 768], 'g': 16}, 143 | '1.6GF': {'d': [2, 6, 17, 2], 'w': [48, 120, 336, 888], 'g': 24}, 144 | '3.2GF': {'d': [2, 5, 13, 1], 'w': [72, 216, 576, 1512], 'g': 24}, 145 | '4.0GF': {'d': [2, 6, 12, 2], 'w': [128, 192, 512, 1088], 'g': 64}, 146 | '6.4GF': {'d': [2, 7, 14, 2], 'w': [144, 288, 576, 1296], 'g': 72}, 147 | '8.0GF': {'d': [2, 4, 10, 1], 'w': [168, 448, 896, 2016], 'g': 56}, 148 | '12GF': {'d': [2, 5, 11, 1], 'w': [224, 448, 896, 2240], 'g': 112}, 149 | '16GF': {'d': [2, 4, 11, 1], 'w': [224, 448, 1232, 3024], 'g': 112}, 150 | '32GF': {'d': [2, 5, 12, 1], 'w': [232, 696, 1392, 3712], 'g': 232}, 151 | } 152 | 153 | 154 | def _regnet(name, b=1, se=False, **kwargs): 155 | config = _regnetx_config[name] if not se \ 156 | else _regnety_config[name] 157 | 158 | d, w, g = config['d'], config['w'], config['g'] 159 | return RegNet(d, w, g, b=b, se=se, **kwargs) 160 | 161 | 162 | def RegNetX200MF(**kwargs): 163 | return _regnet('200MF', **kwargs) 164 | 165 | 166 | def RegNetX400MF(**kwargs): 167 | return _regnet('400MF', **kwargs) 168 | 169 | 170 | def RegNetX600MF(**kwargs): 171 | return _regnet('600MF', **kwargs) 172 | 173 | 174 | def RegNetX800MF(**kwargs): 175 | return _regnet('800MF', **kwargs) 176 | 177 | 178 | def RegNetX1_6GF(**kwargs): 179 | return _regnet('1.6GF', **kwargs) 180 | 181 | 182 | def RegNetX3_2GF(**kwargs): 183 | return _regnet('3.2GF', **kwargs) 184 | 185 | 186 | def RegNetX4_0GF(**kwargs): 187 | return _regnet('4.0GF', **kwargs) 188 | 189 | 190 | def RegNetX6_4GF(**kwargs): 191 | return _regnet('6.4GF', **kwargs) 192 | 193 | 194 | def RegNetX8_0GF(**kwargs): 195 | return _regnet('8.0GF', **kwargs) 196 | 197 | 198 | def RegNetX12GF(**kwargs): 199 | return _regnet('12GF', **kwargs) 200 | 201 | 202 | def RegNetX16GF(**kwargs): 203 | return _regnet('16GF', **kwargs) 204 | 205 | 206 | def RegNetX32GF(**kwargs): 207 | return _regnet('32GF', **kwargs) 208 | 209 | 210 | def RegNetY200MF(**kwargs): 211 | return _regnet('200MF', se=True, **kwargs) 212 | 213 | 214 | def RegNetY400MF(**kwargs): 215 | return _regnet('400MF', se=True, **kwargs) 216 | 217 | 218 | def RegNetY600MF(**kwargs): 219 | return _regnet('600MF', se=True, **kwargs) 220 | 221 | 222 | def RegNetY800MF(**kwargs): 223 | return _regnet('800MF', se=True, **kwargs) 224 | 225 | 226 | def RegNetY1_6GF(**kwargs): 227 | return _regnet('1.6GF', se=True, **kwargs) 228 | 229 | 230 | def RegNetY3_2GF(**kwargs): 231 | return _regnet('3.2GF', se=True, **kwargs) 232 | 233 | 234 | def RegNetY4_0GF(**kwargs): 235 | return _regnet('4.0GF', se=True, **kwargs) 236 | 237 | 238 | def RegNetY6_4GF(**kwargs): 239 | return _regnet('6.4GF', se=True, **kwargs) 240 | 241 | 242 | def RegNetY8_0GF(**kwargs): 243 | return _regnet('8.0GF', se=True, **kwargs) 244 | 245 | 246 | def RegNetY12GF(**kwargs): 247 | return _regnet('12GF', se=True, **kwargs) 248 | 249 | 250 | def RegNetY16GF(**kwargs): 251 | return _regnet('16GF', se=True, **kwargs) 252 | 253 | 254 | def RegNetY32GF(**kwargs): 255 | return _regnet('32GF', se=True, **kwargs) 256 | -------------------------------------------------------------------------------- /models/oct_resnet.py: -------------------------------------------------------------------------------- 1 | __all__ = ['oct_resnet50', 'oct_resnet50v2'] 2 | 3 | from module import * 4 | from torchtoolbox.nn import AdaptiveSequential 5 | from torch import nn 6 | 7 | 8 | def check_status(alpha_in, alpha_out): 9 | alpha_in = alpha_out if alpha_in == 0 else alpha_in 10 | alpha_in = 0 if alpha_out == 0 else alpha_in 11 | return alpha_in, alpha_out 12 | 13 | 14 | class OctBottleneck(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, alpha_in, alpha_out, 18 | stride=1, groups=1, base_width=64): 19 | super(OctBottleneck, self).__init__() 20 | width = int(planes * (base_width / 64.)) * groups 21 | if stride != 1 or inplanes != planes * self.expansion: 22 | self.downsample = AdaptiveSequential( 23 | OctaveConv(inplanes, planes * self.expansion, alpha_in, alpha_out, 24 | 1, stride=stride, bias=False), 25 | fs_bn(planes * self.expansion, alpha_out) 26 | ) 27 | else: 28 | self.downsample = None 29 | 30 | self.conv1 = OctaveConv(inplanes, width, alpha_in, alpha_out, 1, bias=False) 31 | self.bn1 = fs_bn(width, alpha_out) 32 | alpha_in, alpha_out = check_status(alpha_in, alpha_out) 33 | self.conv2 = OctaveConv(width, width, alpha_in, alpha_out, 3, 1, stride, 34 | groups, False) 35 | self.bn2 = fs_bn(width, alpha_out) 36 | self.conv3 = OctaveConv(width, planes * self.expansion, alpha_in, alpha_out, 37 | 1, bias=False) 38 | self.bn3 = fs_bn(planes * self.expansion, alpha_out) 39 | self.relu = fs_relu() 40 | 41 | def forward(self, x_h, x_l=None): 42 | r_h, r_l = x_h, x_l 43 | x_h, x_l = self.conv1(x_h, x_l) 44 | x_h, x_l = self.bn1(x_h, x_l) 45 | x_h, x_l = self.relu(x_h, x_l) 46 | 47 | x_h, x_l = self.conv2(x_h, x_l) 48 | x_h, x_l = self.bn2(x_h, x_l) 49 | x_h, x_l = self.relu(x_h, x_l) 50 | 51 | x_h, x_l = self.conv3(x_h, x_l) 52 | x_h, x_l = self.bn3(x_h, x_l) 53 | 54 | if self.downsample: 55 | r_h, r_l = self.downsample(r_h, r_l) 56 | y_h, y_l = self.relu(x_h + r_h, None if x_l is None and r_l is None else x_l + r_l) 57 | return y_h, y_l 58 | 59 | 60 | class OctBottleneckV2(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, alpha_in, alpha_out, stride=1, 64 | groups=1, base_width=64): 65 | super().__init__() 66 | width = int(planes * (base_width / 64.)) * groups 67 | if stride != 1 or inplanes != planes * self.expansion: 68 | self.downsample = AdaptiveSequential( 69 | OctaveConv(inplanes, planes * self.expansion, alpha_in, alpha_out, 70 | 1, stride=stride, bias=False), 71 | ) 72 | else: 73 | self.downsample = None 74 | self.bn1 = fs_bn(inplanes, alpha_in) 75 | self.conv1 = OctaveConv(inplanes, width, alpha_in, alpha_out, 1, bias=False) 76 | alpha_in, alpha_out = check_status(alpha_in, alpha_out) 77 | self.bn2 = fs_bn(width, alpha_in) 78 | self.conv2 = OctaveConv(width, width, alpha_in, alpha_out, 3, 1, 79 | stride, groups, False) 80 | self.bn3 = fs_bn(width, alpha_in) 81 | self.conv3 = OctaveConv(width, planes * self.expansion, alpha_in, alpha_in, 82 | 1, bias=False) 83 | self.relu = fs_relu() 84 | 85 | def forward(self, x_h, x_l=None): 86 | r_h, r_l = x_h, x_l 87 | x_h, x_l = self.bn1(x_h, x_l) 88 | x_h, x_l = self.relu(x_h, x_l) 89 | if self.downsample: 90 | r_h, r_l = self.downsample(x_h, x_l) 91 | x_h, x_l = self.conv1(x_h, x_l) 92 | 93 | x_h, x_l = self.bn2(x_h, x_l) 94 | x_h, x_l = self.relu(x_h, x_l) 95 | x_h, x_l = self.conv2(x_h, x_l) 96 | 97 | x_h, x_l = self.bn3(x_h, x_l) 98 | x_h, x_l = self.relu(x_h, x_l) 99 | x_h, x_l = self.conv3(x_h, x_l) 100 | 101 | y_h, y_l = x_h + r_h, None if x_l is None and r_l is None else x_l + r_l 102 | return y_h, y_l 103 | 104 | 105 | class OctResNet(nn.Module): 106 | def __init__(self, alpha, layers, num_classes=1000, groups=1, width_per_group=64): 107 | super(OctResNet, self).__init__() 108 | self.inplanes = 64 109 | self.groups = groups 110 | self.base_width = width_per_group 111 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 112 | bias=False) 113 | self.bn1 = nn.BatchNorm2d(self.inplanes) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | self.layer1 = self._make_layer(alpha, 64, layers[0], 1, 'start') 117 | self.layer2 = self._make_layer(alpha, 128, layers[1], 2) 118 | self.layer3 = self._make_layer(alpha, 256, layers[2], 2) 119 | self.layer4 = self._make_layer(alpha, 512, layers[3], 2, 'end') 120 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 121 | self.fc = nn.Linear(512 * OctBottleneck.expansion, num_classes) 122 | 123 | def _make_layer(self, alpha, planes, blocks, stride=1, status='normal'): 124 | assert status in ('start', 'normal', 'end') 125 | layers = [] 126 | layers.append(OctBottleneck(self.inplanes, planes, 127 | alpha if status != 'start' else 0, 128 | alpha if status != 'end' else 0, 129 | stride, self.groups, self.base_width)) 130 | self.inplanes = planes * OctBottleneck.expansion 131 | alpha = 0 if status == 'end' else alpha 132 | for _ in range(1, blocks): 133 | layers.append(OctBottleneck(self.inplanes, planes, alpha, alpha, 1, 134 | self.groups, self.base_width)) 135 | return AdaptiveSequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x_h, x_l = self.layer1(x) 144 | x_h, x_l = self.layer2(x_h, x_l) 145 | x_h, x_l = self.layer3(x_h, x_l) 146 | x, _ = self.layer4(x_h, x_l) 147 | 148 | x = self.avgpool(x) 149 | x = x.reshape(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | class OctResNetV2(nn.Module): 156 | def __init__(self, alpha, layers, num_classes=1000, groups=1, width_per_group=64): 157 | super().__init__() 158 | self.inplanes = 64 159 | self.groups = groups 160 | self.base_width = width_per_group 161 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 162 | bias=False) 163 | self.bn1 = nn.BatchNorm2d(self.inplanes) 164 | self.relu = nn.ReLU(inplace=True) 165 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 166 | self.layer1 = self._make_layer(alpha, 64, layers[0], 1, 'start') 167 | self.layer2 = self._make_layer(alpha, 128, layers[1], 2) 168 | self.layer3 = self._make_layer(alpha, 256, layers[2], 2) 169 | self.layer4 = self._make_layer(alpha, 512, layers[3], 2, 'end') 170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 171 | self.fc = nn.Linear(512 * OctBottleneck.expansion, num_classes) 172 | 173 | def _make_layer(self, alpha, planes, blocks, stride=1, status='normal'): 174 | assert status in ('start', 'normal', 'end') 175 | layers = [] 176 | layers.append(OctBottleneckV2(self.inplanes, planes, 177 | alpha if status != 'start' else 0, 178 | alpha if status != 'end' else 0, 179 | stride, self.groups, self.base_width)) 180 | self.inplanes = planes * OctBottleneckV2.expansion 181 | alpha = 0 if status == 'end' else alpha 182 | for _ in range(1, blocks): 183 | layers.append(OctBottleneckV2(self.inplanes, planes, alpha, alpha, 1, 184 | self.groups, self.base_width)) 185 | return AdaptiveSequential(*layers) 186 | 187 | def forward(self, x): 188 | x = self.conv1(x) 189 | x = self.bn1(x) 190 | x = self.relu(x) 191 | x = self.maxpool(x) 192 | 193 | x_h, x_l = self.layer1(x) 194 | x_h, x_l = self.layer2(x_h, x_l) 195 | x_h, x_l = self.layer3(x_h, x_l) 196 | x, _ = self.layer4(x_h, x_l) 197 | 198 | x = self.avgpool(x) 199 | x = x.reshape(x.size(0), -1) 200 | x = self.fc(x) 201 | 202 | return x 203 | 204 | 205 | def oct_resnet50(alpha, **kwargs): 206 | """Constructs a OctResNet-50 model. 207 | 208 | Args: 209 | progress (bool): If True, displays a progress bar of the download to stderr 210 | """ 211 | return OctResNet(alpha, [3, 4, 6, 3], **kwargs) 212 | 213 | 214 | def oct_resnet50v2(alpha, **kwargs): 215 | """Constructs a OctResNet-50 model. 216 | 217 | Args: 218 | progress (bool): If True, displays a progress bar of the download to stderr 219 | """ 220 | return OctResNetV2(alpha, [3, 4, 6, 3], **kwargs) 221 | -------------------------------------------------------------------------------- /models/evo_norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | __all__ = ['EvoResNet', 'evo_resnet18', 'evo_resnet34', 'evo_resnet50', 'evo_resnet101', 5 | 'evo_resnet152', 'evo_resnext101_32x8d', 'evo_resnext50_32x4d'] 6 | 7 | import torch 8 | from torch import nn 9 | 10 | try: 11 | # Not in stable release. 12 | from torchtoolbox.nn import EvoNormB0, EvoNormS0 13 | except ImportError: 14 | def instance_std(x, eps=1e-5): 15 | var = torch.var(x, dim=(2, 3), keepdim=True) 16 | std = torch.sqrt(var + eps) 17 | return std 18 | 19 | 20 | def group_std(x: torch.Tensor, groups=32, eps=1e-5): 21 | n, c, h, w = x.size() 22 | x = torch.reshape(x, (n, groups, c // groups, h, w)) 23 | var = torch.var(x, dim=(2, 3, 4), keepdim=True) 24 | std = torch.sqrt(var + eps) 25 | return torch.reshape(std, (n, c, h, w)) 26 | 27 | 28 | def evo_norm(x, prefix, running_var, v, weight, bias, 29 | training, momentum, eps=0.1, groups=32): 30 | if prefix == 'b0': 31 | if training: 32 | var = torch.var(x, dim=(0, 2, 3), keepdim=True) 33 | running_var.mul_(momentum) 34 | running_var.add_((1 - momentum) * var) 35 | else: 36 | var = running_var 37 | if v is not None: 38 | den = torch.max((var + eps).sqrt(), v * x + instance_std(x, eps)) 39 | x = x / den * weight + bias 40 | else: 41 | x = x * weight + bias 42 | elif prefix == 's0': 43 | if v is not None: 44 | x = x * torch.sigmoid(v * x) / group_std(x, groups, eps) * weight + bias 45 | else: 46 | x = x * weight + bias 47 | else: 48 | raise NotImplementedError 49 | return x 50 | 51 | 52 | class _EvoNorm(nn.Module): 53 | def __init__(self, prefix, num_features, eps=1e-5, momentum=0.9, groups=32, 54 | affine=True): 55 | super(_EvoNorm, self).__init__() 56 | assert prefix in ('s0', 'b0') 57 | self.prefix = prefix 58 | self.groups = groups 59 | self.num_features = num_features 60 | self.eps = eps 61 | self.momentum = momentum 62 | self.affine = affine 63 | if self.affine: 64 | self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 65 | self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 66 | self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_parameter('v', None) 71 | self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | if self.affine: 76 | torch.nn.init.ones_(self.weight) 77 | torch.nn.init.zeros_(self.bias) 78 | torch.nn.init.ones_(self.v) 79 | 80 | def _check_input_dim(self, x): 81 | if x.dim() != 4: 82 | raise ValueError('expected 4D input (got {}D input)' 83 | .format(x.dim())) 84 | 85 | def forward(self, x): 86 | self._check_input_dim(x) 87 | return evo_norm(x, self.prefix, self.running_var, self.v, 88 | self.weight, self.bias, self.training, 89 | self.momentum, self.eps, self.groups) 90 | 91 | 92 | class EvoNormB0(_EvoNorm): 93 | def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True): 94 | super(EvoNormB0, self).__init__('b0', num_features, eps, momentum, 95 | affine=affine) 96 | 97 | 98 | class EvoNormS0(_EvoNorm): 99 | def __init__(self, num_features, groups=32, affine=True): 100 | super(EvoNormS0, self).__init__('s0', num_features, groups=groups, 101 | affine=affine) 102 | 103 | 104 | class BN_RELU(nn.Module): 105 | def __init__(self, num_features): 106 | super(BN_RELU, self).__init__() 107 | self.block = nn.Sequential( 108 | nn.BatchNorm2d(num_features), 109 | nn.ReLU(inplace=True) 110 | ) 111 | 112 | def forward(self, x): 113 | return self.block(x) 114 | 115 | 116 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 117 | """3x3 convolution with padding""" 118 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 119 | padding=1, groups=groups, bias=False) 120 | 121 | 122 | def conv1x1(in_planes, out_planes, stride=1): 123 | """1x1 convolution""" 124 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 125 | 126 | 127 | class Bottleneck(nn.Module): 128 | expansion = 4 129 | 130 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 131 | base_width=64, evo_norm=True): 132 | super(Bottleneck, self).__init__() 133 | width = int(planes * (base_width / 64.)) * groups 134 | self.evonorm1 = EvoNormB0(inplanes) if evo_norm else BN_RELU(inplanes) 135 | self.conv1 = conv1x1(inplanes, width) 136 | self.evonorm2 = EvoNormB0(width) if evo_norm else BN_RELU(width) 137 | self.conv2 = conv3x3(width, width, stride, groups) 138 | self.evonorm3 = EvoNormB0(width) if evo_norm else BN_RELU(width) 139 | self.conv3 = conv1x1(width, planes * self.expansion) 140 | self.downsample = downsample 141 | 142 | def forward(self, x): 143 | identity = x 144 | out = self.evonorm1(x) 145 | if self.downsample is not None: 146 | identity = self.downsample(out) 147 | out = self.conv1(out) 148 | out = self.evonorm2(out) 149 | out = self.conv2(out) 150 | out = self.evonorm3(out) 151 | out = self.conv3(out) 152 | out += identity 153 | return out 154 | 155 | 156 | class EvoResNet(nn.Module): 157 | def __init__(self, layers, num_classes=1000, groups=1, width_per_group=64, 158 | dropout_rate=None, small_input=False, evo_norm=True): 159 | super(EvoResNet, self).__init__() 160 | self.inplanes = 64 161 | self.evo_norm = evo_norm 162 | self.groups = groups 163 | self.base_width = width_per_group 164 | if small_input: 165 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 166 | bias=False) 167 | else: 168 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 169 | bias=False) 170 | self.evonorm1 = EvoNormB0(self.inplanes) if evo_norm else BN_RELU(self.inplanes) 171 | if small_input: 172 | self.maxpool = nn.Identity() 173 | else: 174 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 175 | self.layer1 = self._make_layer(64, layers[0]) 176 | self.layer2 = self._make_layer(128, layers[1], stride=2) 177 | self.layer3 = self._make_layer(256, layers[2], stride=2) 178 | self.layer4 = self._make_layer(512, layers[3], stride=2) 179 | 180 | self.evonorm2 = EvoNormB0(512 * Bottleneck.expansion) if evo_norm else BN_RELU(512 * Bottleneck.expansion) 181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 182 | self.flatten = nn.Flatten() 183 | self.dropout = nn.Dropout(dropout_rate, inplace=True) if dropout_rate is not None else nn.Identity() 184 | self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes) 185 | 186 | def _make_layer(self, planes, blocks, stride=1): 187 | downsample = None 188 | 189 | if stride != 1 or self.inplanes != planes * Bottleneck.expansion: 190 | downsample = nn.Sequential( 191 | conv1x1(self.inplanes, planes * Bottleneck.expansion, stride), 192 | ) 193 | 194 | layers = [] 195 | layers.append(Bottleneck(self.inplanes, planes, stride, downsample, self.groups, 196 | self.base_width, self.evo_norm)) 197 | self.inplanes = planes * Bottleneck.expansion 198 | for _ in range(1, blocks): 199 | layers.append(Bottleneck(self.inplanes, planes, groups=self.groups, 200 | base_width=self.base_width, evo_norm=self.evo_norm)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.evonorm1(x) 207 | x = self.maxpool(x) 208 | 209 | x = self.layer1(x) 210 | x = self.layer2(x) 211 | x = self.layer3(x) 212 | x = self.layer4(x) 213 | x = self.evonorm2(x) 214 | 215 | x = self.avgpool(x) 216 | x = self.flatten(x) 217 | x = self.dropout(x) 218 | x = self.fc(x) 219 | 220 | return x 221 | 222 | 223 | def evo_resnet18(**kwargs): 224 | """Constructs a ResNet-18 model. 225 | 226 | """ 227 | return EvoResNet([2, 2, 2, 2], **kwargs) 228 | 229 | 230 | def evo_resnet34(**kwargs): 231 | """Constructs a ResNet-34 model. 232 | 233 | """ 234 | return EvoResNet([3, 4, 6, 3], **kwargs) 235 | 236 | 237 | def evo_resnet50(**kwargs): 238 | """Constructs a ResNet-50 model. 239 | 240 | """ 241 | return EvoResNet([3, 4, 6, 3], **kwargs) 242 | 243 | 244 | def evo_resnet101(**kwargs): 245 | """Constructs a ResNet-101 model. 246 | 247 | """ 248 | return EvoResNet([3, 4, 23, 3], **kwargs) 249 | 250 | 251 | def evo_resnet152(**kwargs): 252 | """Constructs a ResNet-152 model. 253 | 254 | """ 255 | return EvoResNet([3, 8, 36, 3], **kwargs) 256 | 257 | 258 | def evo_resnext50_32x4d(**kwargs): 259 | """Constructs a ResNeXt-50 32x4d model. 260 | 261 | """ 262 | kwargs['groups'] = 32 263 | kwargs['width_per_group'] = 4 264 | return EvoResNet([3, 4, 6, 3], **kwargs) 265 | 266 | 267 | def evo_resnext101_32x8d(**kwargs): 268 | """Constructs a ResNeXt-101 32x8d model. 269 | 270 | """ 271 | kwargs['groups'] = 32 272 | kwargs['width_per_group'] = 8 273 | return EvoResNet([3, 4, 23, 3], **kwargs) 274 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | __all__ = ['MobileNetV1', 'MobileNetV2', 'MobileNetV3_Large', 'MobileNetV3_Small'] 4 | 5 | from torchtoolbox.nn import Activation 6 | from functools import partial 7 | from torch import nn 8 | import math 9 | import torch.nn.functional as F 10 | 11 | 12 | def make_divisible(x, divisible_by=8): 13 | return int(math.ceil(x * 1. / divisible_by) * divisible_by) 14 | 15 | 16 | class HardSwish(nn.Module): 17 | def __init__(self, inplace=False): 18 | super(HardSwish, self).__init__() 19 | self.inplace = inplace 20 | 21 | def forward(self, x): 22 | return x * F.relu6(x + 3., inplace=self.inplace) / 6. 23 | 24 | 25 | class HardSigmoid(nn.Module): 26 | def __init__(self, inplace=False): 27 | super(HardSigmoid, self).__init__() 28 | self.inplace = inplace 29 | 30 | def forward(self, x): 31 | return F.relu6(x + 3., inplace=self.inplace) / 6. 32 | 33 | 34 | class SE_Module(nn.Module): 35 | def __init__(self, channels, reduction=4): 36 | super(SE_Module, self).__init__() 37 | reduction_c = make_divisible(channels // reduction) 38 | self.out = nn.Sequential( 39 | nn.Conv2d(channels, reduction_c, 1, bias=True), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(reduction_c, channels, 1, bias=True), 42 | HardSigmoid() 43 | ) 44 | 45 | def forward(self, x): 46 | y = F.adaptive_avg_pool2d(x, 1) 47 | y = self.out(y) 48 | return x * y 49 | 50 | 51 | class MobileNetBottleneck(nn.Module): 52 | def __init__(self, in_c, expansion, out_c, kernel_size, stride, se=False, 53 | activation='relu6', first_conv=True, skip=True, linear=True): 54 | super(MobileNetBottleneck, self).__init__() 55 | 56 | self.act = Activation(activation, auto_optimize=True) # [bug]no use when linear=True 57 | hidden_c = round(in_c * expansion) 58 | self.linear = linear 59 | self.skip = stride == 1 and in_c == out_c and skip 60 | 61 | seq = [] 62 | if first_conv and in_c != hidden_c: 63 | seq.append(nn.Conv2d(in_c, hidden_c, 1, 1, bias=False)) 64 | seq.append(nn.BatchNorm2d(hidden_c)) 65 | seq.append(Activation(activation, auto_optimize=True)) 66 | seq.append(nn.Conv2d(hidden_c, hidden_c, kernel_size, stride, 67 | kernel_size // 2, groups=hidden_c, bias=False)) 68 | seq.append(nn.BatchNorm2d(hidden_c)) 69 | seq.append(Activation(activation, auto_optimize=True)) 70 | if se: 71 | seq.append(SE_Module(hidden_c)) 72 | seq.append(nn.Conv2d(hidden_c, out_c, 1, 1, bias=False)) 73 | seq.append(nn.BatchNorm2d(out_c)) 74 | 75 | self.seq = nn.Sequential(*seq) 76 | 77 | def forward(self, x): 78 | skip = x 79 | x = self.seq(x) 80 | if self.skip: 81 | x = skip + x 82 | if not self.linear: 83 | x = self.act(x) 84 | return x 85 | 86 | 87 | class MobileNetV1(nn.Module): 88 | def __init__(self, num_classes=1000, small_input=False): 89 | super(MobileNetV1, self).__init__() 90 | self.first_block = nn.Sequential( 91 | nn.Conv2d(3, 32, 3, 2 if not small_input else 1, 1, bias=False), 92 | nn.BatchNorm2d(32), 93 | nn.ReLU(inplace=True), 94 | ) 95 | MB1_Bottleneck = partial(MobileNetBottleneck, first_conv=False, 96 | activation='relu', skip=False, linear=False) 97 | self.mb_block = nn.Sequential( 98 | MB1_Bottleneck(32, 1, 64, 3, 1), 99 | MB1_Bottleneck(64, 1, 128, 3, 2), 100 | MB1_Bottleneck(128, 1, 128, 3, 1), 101 | MB1_Bottleneck(128, 1, 256, 3, 2), 102 | MB1_Bottleneck(256, 1, 256, 3, 1), 103 | MB1_Bottleneck(256, 1, 512, 3, 2), 104 | MB1_Bottleneck(512, 1, 512, 3, 1), 105 | MB1_Bottleneck(512, 1, 512, 3, 1), 106 | MB1_Bottleneck(512, 1, 512, 3, 1), 107 | MB1_Bottleneck(512, 1, 512, 3, 1), 108 | MB1_Bottleneck(512, 1, 512, 3, 1), 109 | MB1_Bottleneck(512, 1, 1024, 3, 2), 110 | MB1_Bottleneck(1024, 1, 1024, 3, 1), 111 | ) 112 | self.last_block = nn.Sequential( 113 | nn.AdaptiveAvgPool2d(1), 114 | nn.Flatten(), 115 | ) 116 | self.output = nn.Linear(1024, num_classes) 117 | 118 | def forward(self, x): 119 | x = self.first_block(x) 120 | x = self.mb_block(x) 121 | x = self.last_block(x) 122 | x = self.output(x) 123 | return x 124 | 125 | 126 | class MobileNetV2(nn.Module): 127 | def __init__(self, num_classes=1000, small_input=False): 128 | super(MobileNetV2, self).__init__() 129 | self.first_block = nn.Sequential( 130 | nn.Conv2d(3, 32, 3, 2 if not small_input else 1, 1, bias=False), 131 | nn.BatchNorm2d(32), 132 | nn.ReLU6(inplace=True), 133 | nn.Conv2d(32, 32, 3, 1, 1, groups=32, bias=False), 134 | nn.BatchNorm2d(32), 135 | nn.ReLU6(inplace=True), 136 | nn.Conv2d(32, 16, 1, 1, bias=False), 137 | nn.BatchNorm2d(16), 138 | ) 139 | self.mb_block = nn.Sequential( 140 | MobileNetBottleneck(16, 6, 24, 3, 2), 141 | MobileNetBottleneck(24, 6, 24, 3, 1), 142 | MobileNetBottleneck(24, 6, 32, 3, 2), 143 | MobileNetBottleneck(32, 6, 32, 3, 1), 144 | MobileNetBottleneck(32, 6, 32, 3, 1), 145 | MobileNetBottleneck(32, 6, 64, 3, 2), 146 | MobileNetBottleneck(64, 6, 64, 3, 1), 147 | MobileNetBottleneck(64, 6, 64, 3, 1), 148 | MobileNetBottleneck(64, 6, 64, 3, 1), 149 | MobileNetBottleneck(64, 6, 96, 3, 1), 150 | MobileNetBottleneck(96, 6, 96, 3, 1), 151 | MobileNetBottleneck(96, 6, 96, 3, 1), 152 | MobileNetBottleneck(96, 6, 160, 3, 2), 153 | MobileNetBottleneck(160, 6, 160, 3, 1), 154 | MobileNetBottleneck(160, 6, 160, 3, 1), 155 | MobileNetBottleneck(160, 6, 320, 3, 1), 156 | ) 157 | self.last_block = nn.Sequential( 158 | nn.Conv2d(320, 1280, 1, 1, bias=False), 159 | nn.BatchNorm2d(1280), 160 | nn.ReLU6(inplace=True), 161 | nn.AdaptiveAvgPool2d(1), 162 | nn.Flatten() 163 | ) 164 | self.output = nn.Linear(1280, num_classes) 165 | 166 | def forward(self, x): 167 | x = self.first_block(x) 168 | x = self.mb_block(x) 169 | x = self.last_block(x) 170 | x = self.output(x) 171 | return x 172 | 173 | 174 | class MobileNetV3_Large(nn.Module): 175 | def __init__(self, num_classes=1000, small_input=False, dropout_rate=0.2): 176 | super(MobileNetV3_Large, self).__init__() 177 | self.first_block = nn.Sequential( 178 | nn.Conv2d(3, 16, 3, 2 if not small_input else 1, 1, bias=False), 179 | nn.BatchNorm2d(16), 180 | HardSwish(inplace=True), 181 | ) 182 | self.mb_block = nn.Sequential( 183 | MobileNetBottleneck(16, 1, 16, 3, 1, False, 'relu'), 184 | MobileNetBottleneck(16, 4, 24, 3, 2, False, 'relu'), 185 | MobileNetBottleneck(24, 3, 24, 3, 1, False, 'relu'), 186 | MobileNetBottleneck(24, 3, 40, 5, 2, True, 'relu'), 187 | MobileNetBottleneck(40, 3, 40, 5, 1, True, 'relu'), 188 | MobileNetBottleneck(40, 3, 40, 5, 1, True, 'relu'), 189 | MobileNetBottleneck(40, 6, 80, 3, 2, False, 'h_swish'), 190 | MobileNetBottleneck(80, 2.5, 80, 3, 1, False, 'h_swish'), 191 | MobileNetBottleneck(80, 2.3, 80, 3, 1, False, 'h_swish'), 192 | MobileNetBottleneck(80, 2.3, 80, 3, 1, False, 'h_swish'), 193 | MobileNetBottleneck(80, 6, 112, 3, 1, True, 'h_swish'), 194 | MobileNetBottleneck(112, 6, 112, 3, 1, True, 'h_swish'), 195 | MobileNetBottleneck(112, 6, 160, 5, 2, True, 'h_swish'), 196 | MobileNetBottleneck(160, 6, 160, 5, 1, True, 'h_swish'), 197 | MobileNetBottleneck(160, 6, 160, 5, 1, True, 'h_swish'), 198 | ) 199 | self.last_block = nn.Sequential( 200 | nn.Conv2d(160, 960, 1, bias=False), 201 | nn.BatchNorm2d(960), 202 | HardSwish(inplace=True), 203 | nn.AdaptiveAvgPool2d(1), 204 | nn.Conv2d(960, 1280, 1, bias=False), 205 | HardSwish(), 206 | nn.Dropout2d(p=dropout_rate, inplace=True), 207 | nn.Flatten(), 208 | ) 209 | self.output = nn.Linear(1280, num_classes) 210 | 211 | def forward(self, x): 212 | x = self.first_block(x) 213 | x = self.mb_block(x) 214 | x = self.last_block(x) 215 | x = self.output(x) 216 | return x 217 | 218 | 219 | class MobileNetV3_Small(nn.Module): 220 | def __init__(self, num_classes=1000, small_input=False, dropout_rate=0.2): 221 | super(MobileNetV3_Small, self).__init__() 222 | self.first_block = nn.Sequential( 223 | nn.Conv2d(3, 16, 3, 2 if not small_input else 1, 1, bias=False), 224 | nn.BatchNorm2d(16), 225 | HardSwish(inplace=True), 226 | ) 227 | self.mb_block = nn.Sequential( 228 | MobileNetBottleneck(16, 1, 16, 3, 2, True, 'relu'), 229 | MobileNetBottleneck(16, 4.5, 24, 3, 2, False, 'relu'), 230 | MobileNetBottleneck(24, 88 / 24, 24, 3, 1, False, 'relu'), 231 | MobileNetBottleneck(24, 4, 40, 5, 2, True, 'h_swish'), 232 | MobileNetBottleneck(40, 6, 40, 5, 1, True, 'h_swish'), 233 | MobileNetBottleneck(40, 6, 40, 5, 1, True, 'h_swish'), 234 | MobileNetBottleneck(40, 3, 48, 5, 1, True, 'h_swish'), 235 | MobileNetBottleneck(48, 3, 48, 5, 1, True, 'h_swish'), 236 | MobileNetBottleneck(48, 6, 96, 5, 2, True, 'h_swish'), 237 | MobileNetBottleneck(96, 6, 96, 5, 1, True, 'h_swish'), 238 | MobileNetBottleneck(96, 6, 96, 5, 1, True, 'h_swish'), 239 | ) 240 | self.last_block = nn.Sequential( 241 | nn.Conv2d(96, 576, 1, bias=False), 242 | nn.BatchNorm2d(576), 243 | HardSwish(inplace=True), 244 | nn.AdaptiveAvgPool2d(1), 245 | nn.Conv2d(576, 1280, 1, bias=False), 246 | HardSwish(), 247 | nn.Dropout2d(p=dropout_rate, inplace=True), 248 | nn.Flatten(), 249 | ) 250 | self.output = nn.Linear(1280, num_classes) 251 | 252 | def forward(self, x): 253 | x = self.first_block(x) 254 | x = self.mb_block(x) 255 | x = self.last_block(x) 256 | x = self.output(x) 257 | return x 258 | -------------------------------------------------------------------------------- /models/oct_resnet_re.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | """This file is totally same as oct_resnet, I just want to avoid if/else in forward.""" 4 | 5 | __all__ = ['OctResnet', 'oct_resnet50', 'oct_resnet101', 'oct_resnet152', 6 | 'oct_resnet50_32x4d', 'oct_resnet101_32x8d'] 7 | 8 | from torch import nn 9 | from torchtoolbox.nn import AdaptiveSequential 10 | from module import OctConv, OctConvFirst, OctConvLast 11 | 12 | 13 | class fs_bn(nn.Module): 14 | def __init__(self, channels, alpha): 15 | super().__init__() 16 | h_out = int((1 - alpha) * channels) 17 | l_out = int(alpha * channels) 18 | 19 | self.h_bn = nn.BatchNorm2d(h_out) 20 | self.l_bn = nn.BatchNorm2d(l_out) 21 | 22 | def forward(self, x_h, x_l=None): 23 | y_h = self.h_bn(x_h) 24 | y_l = self.l_bn(x_l) 25 | return y_h, y_l 26 | 27 | 28 | class fs_relu(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | def forward(self, x_h, x_l=None): 34 | y_h = self.relu(x_h) 35 | y_l = self.relu(x_l) 36 | return y_h, y_l 37 | 38 | 39 | class OctBottleneck_First(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, inplanes, planes, alpha, stride=1, groups=1, base_width=64): 43 | super(OctBottleneck_First, self).__init__() 44 | width = int(planes * (base_width / 64.)) * groups 45 | 46 | self.conv1 = OctConvFirst(inplanes, width, alpha, 1, bias=False) 47 | self.bn1 = fs_bn(width, alpha) 48 | self.conv2 = OctConv(width, width, alpha, 3, 1, 1, groups, False) 49 | self.bn2 = fs_bn(width, alpha) 50 | self.conv3 = OctConv(width, planes * self.expansion, alpha, 1, bias=False) 51 | self.bn3 = fs_bn(planes * self.expansion, alpha) 52 | 53 | self.relu = fs_relu() 54 | if stride != 1 or inplanes != planes * self.expansion: 55 | self.downsample = AdaptiveSequential( 56 | OctConvFirst(inplanes, planes * self.expansion, alpha, 57 | 1, stride=stride, bias=False), 58 | fs_bn(planes * self.expansion, alpha) 59 | ) 60 | else: 61 | self.downsample = nn.Identity() 62 | 63 | def forward(self, x): 64 | r_h, r_l = self.downsample(x) 65 | x_h, x_l = self.conv1(x) 66 | x_h, x_l = self.bn1(x_h, x_l) 67 | x_h, x_l = self.relu(x_h, x_l) 68 | 69 | x_h, x_l = self.conv2(x_h, x_l) 70 | x_h, x_l = self.bn2(x_h, x_l) 71 | x_h, x_l = self.relu(x_h, x_l) 72 | 73 | x_h, x_l = self.conv3(x_h, x_l) 74 | x_h, x_l = self.bn3(x_h, x_l) 75 | 76 | y_h, y_l = self.relu(x_h + r_h, x_l + r_l) 77 | 78 | return y_h, y_l 79 | 80 | 81 | class OctBottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, alpha, stride=1, groups=1, base_width=64): 85 | super(OctBottleneck, self).__init__() 86 | width = int(planes * (base_width / 64.)) * groups 87 | 88 | self.conv1 = OctConv(inplanes, width, alpha, 1, bias=False) 89 | self.bn1 = fs_bn(width, alpha) 90 | self.conv2 = OctConv(width, width, alpha, 3, 1, stride, groups, False) 91 | self.bn2 = fs_bn(width, alpha) 92 | self.conv3 = OctConv(width, planes * self.expansion, alpha, 1, bias=False) 93 | self.bn3 = fs_bn(planes * self.expansion, alpha) 94 | 95 | self.relu = fs_relu() 96 | if stride != 1 or inplanes != planes * self.expansion: 97 | self.downsample = AdaptiveSequential( 98 | OctConv(inplanes, planes * self.expansion, alpha, 99 | 1, stride=stride, bias=False), 100 | fs_bn(planes * self.expansion, alpha) 101 | ) 102 | else: 103 | self.downsample = AdaptiveSequential() 104 | 105 | def forward(self, x_h, x_l): 106 | r_h, r_l = self.downsample(x_h, x_l) 107 | x_h, x_l = self.conv1(x_h, x_l) 108 | x_h, x_l = self.bn1(x_h, x_l) 109 | x_h, x_l = self.relu(x_h, x_l) 110 | 111 | x_h, x_l = self.conv2(x_h, x_l) 112 | x_h, x_l = self.bn2(x_h, x_l) 113 | x_h, x_l = self.relu(x_h, x_l) 114 | 115 | x_h, x_l = self.conv3(x_h, x_l) 116 | x_h, x_l = self.bn3(x_h, x_l) 117 | 118 | y_h, y_l = self.relu(x_h + r_h, x_l + r_l) 119 | return y_h, y_l 120 | 121 | 122 | class Bottleneck(nn.Module): 123 | expansion = 4 124 | 125 | def __init__(self, inplanes, planes, stride=1, groups=1, base_width=64): 126 | super(Bottleneck, self).__init__() 127 | width = int(planes * (base_width / 64.)) * groups 128 | 129 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 130 | self.bn1 = nn.BatchNorm2d(width) 131 | self.conv2 = nn.Conv2d(width, width, 3, stride, 1, groups=groups, bias=False) 132 | self.bn2 = nn.BatchNorm2d(width) 133 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 134 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 135 | 136 | self.relu = nn.ReLU() 137 | if stride != 1 or inplanes != planes * self.expansion: 138 | self.downsample = AdaptiveSequential( 139 | nn.Conv2d(inplanes, planes * self.expansion, 140 | 1, stride=stride, bias=False), 141 | nn.BatchNorm2d(planes * self.expansion) 142 | ) 143 | else: 144 | self.downsample = nn.Identity() 145 | 146 | def forward(self, x): 147 | r = self.downsample(x) 148 | 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) 152 | 153 | x = self.conv2(x) 154 | x = self.bn2(x) 155 | x = self.relu(x) 156 | 157 | x = self.conv3(x) 158 | x = self.bn3(x) 159 | 160 | y = self.relu(x + r) 161 | return y 162 | 163 | 164 | class OctBottleneck_Last(nn.Module): 165 | expansion = 4 166 | 167 | def __init__(self, inplanes, planes, alpha, stride=1, groups=1, base_width=64): 168 | super(OctBottleneck_Last, self).__init__() 169 | width = int(planes * (base_width / 64.)) * groups 170 | 171 | self.conv1 = OctConvLast(inplanes, width, alpha, 1, bias=False) 172 | self.bn1 = nn.BatchNorm2d(width) 173 | self.conv2 = nn.Conv2d(width, width, 3, stride, 1, groups=groups, bias=False) 174 | self.bn2 = nn.BatchNorm2d(width) 175 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 176 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, alpha) 177 | 178 | self.relu = nn.ReLU() 179 | if stride != 1 or inplanes != planes * self.expansion: 180 | self.downsample = AdaptiveSequential( 181 | OctConvLast(inplanes, planes * self.expansion, alpha, 182 | 1, stride=stride, bias=False), 183 | nn.BatchNorm2d(planes * self.expansion) 184 | ) 185 | else: 186 | self.downsample = nn.Identity() 187 | 188 | def forward(self, x_h, x_l): 189 | r = self.downsample(x_h, x_l) 190 | x = self.conv1(x_h, x_l) 191 | x = self.bn1(x) 192 | x = self.relu(x) 193 | 194 | x = self.conv2(x) 195 | x = self.bn2(x) 196 | x = self.relu(x) 197 | 198 | x = self.conv3(x) 199 | x = self.bn3(x) 200 | 201 | x = self.relu(x + r) 202 | return x 203 | 204 | 205 | class OctResnet(nn.Module): 206 | def __init__(self, alpha, layers, num_classes=1000, groups=1, width_per_group=64): 207 | super(OctResnet, self).__init__() 208 | self.inplanes = 64 209 | self.groups = groups 210 | self.alpha = alpha 211 | self.base_width = width_per_group 212 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 213 | bias=False) 214 | self.bn1 = nn.BatchNorm2d(self.inplanes) 215 | self.relu = nn.ReLU(inplace=True) 216 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 217 | 218 | self.layer1 = self._make_layer(64, layers[0], 1, 'start') 219 | self.layer2 = self._make_layer(128, layers[1], 2) 220 | self.layer3 = self._make_layer(256, layers[2], 2) 221 | self.layer4 = self._make_layer(512, layers[3], 2, 'end') 222 | 223 | self.avgpool = nn.AdaptiveAvgPool2d(1) 224 | self.fc = nn.Sequential( 225 | nn.Flatten(), 226 | nn.Linear(512 * OctBottleneck.expansion, num_classes) 227 | ) 228 | 229 | def _make_layer(self, planes, blocks, stride=1, status='normal'): 230 | assert status in ('start', 'normal', 'end') 231 | layers = [] 232 | if status == 'start': 233 | layers.append(OctBottleneck_First(self.inplanes, planes, self.alpha, stride, 234 | self.groups, self.base_width)) 235 | elif status == 'normal': 236 | layers.append(OctBottleneck(self.inplanes, planes, self.alpha, stride, 237 | self.groups, self.base_width)) 238 | else: 239 | layers.append(OctBottleneck_Last(self.inplanes, planes, self.alpha, stride, 240 | self.groups, self.base_width)) 241 | self.inplanes = planes * OctBottleneck.expansion 242 | for _ in range(1, blocks): 243 | if status != 'end': 244 | layers.append(OctBottleneck(self.inplanes, planes, self.alpha, 1, 245 | self.groups, self.base_width)) 246 | else: 247 | layers.append(Bottleneck(self.inplanes, planes, 1, 248 | self.groups, self.base_width)) 249 | return AdaptiveSequential(*layers) 250 | 251 | def forward(self, x): 252 | x = self.conv1(x) 253 | x = self.bn1(x) 254 | x = self.relu(x) 255 | x = self.maxpool(x) 256 | 257 | x_h, x_l = self.layer1(x) 258 | x_h, x_l = self.layer2(x_h, x_l) 259 | x_h, x_l = self.layer3(x_h, x_l) 260 | x = self.layer4(x_h, x_l) 261 | 262 | x = self.avgpool(x) 263 | x = self.fc(x) 264 | return x 265 | 266 | 267 | def oct_resnet50(alpha, **kwargs): 268 | return OctResnet(alpha, [3, 4, 6, 3], **kwargs) 269 | 270 | 271 | def oct_resnet101(alpha, **kwargs): 272 | return OctResnet(alpha, [3, 4, 23, 3], **kwargs) 273 | 274 | 275 | def oct_resnet152(alpha, **kwargs): 276 | return OctResnet(alpha, [3, 8, 36, 3], **kwargs) 277 | 278 | 279 | def oct_resnet50_32x4d(alpha, **kwargs): 280 | kwargs['groups'] = 32 281 | kwargs['width_per_group'] = 4 282 | return OctResnet(alpha, [3, 4, 6, 3], **kwargs) 283 | 284 | 285 | def oct_resnet101_32x8d(alpha, **kwargs): 286 | kwargs['groups'] = 32 287 | kwargs['width_per_group'] = 8 288 | return OctResnet(alpha, [3, 4, 32, 3], **kwargs) 289 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ModelZoo for Pytorch 2 | 3 | This is a model zoo project under Pytorch. In this repo I will implement some of basic classification 4 | models which have good performance on ImageNet. Then I will train them in most fair way as possible and 5 | try my best to get SOTA model on ImageNet. In this repo I'll only consider FP16. 6 | 7 | 8 | ## Usage 9 | ### Environment 10 | - OS: Ubuntu 18.04 11 | - CUDA: 10.1, CuDNN: 7.6 12 | - Devices: I use 8 * RTX 2080ti(8 * V100 should be much better /cry). This project is in FP16 precision, it's recommend to use FP16 friendly devices like 13 | RTX series, V100. If you want to totally reproduce my research, you'd better use same batch size with me. 14 | 15 | ### Requirement 16 | - Pytorch: >= 1.6.0 (Need torch.cuda.amp in version 1.6) 17 | - [TorchToolbox](https://github.com/deeplearningforfun/torch-toolbox): stable version. 18 | Helper functions to make your code simpler and more readable, it's a optional tools 19 | if you don't want to use it just write them yourself. 20 | 21 | ### LMDB Dataset 22 | - No necessary. 23 | 24 | If you found any IO bottleneck please use LMDB format dataset. A good way is try both and find out 25 | which is more faster. 26 | 27 | I provide conversion script [here](scripts/generate_LMDB_dataset.py). 28 | 29 | ### Train script 30 | 31 | ```shell 32 | python distribute_train_script --params 33 | ``` 34 | Here is a example 35 | ```shell 36 | python distribute_train_script.py --data-path /s4/piston/ImageNet --batch-size 256 --dtype float16 \ 37 | -j 48 --epochs 360 --lr 2.6 --warmup-epochs 5 --label-smoothing \ 38 | --no-wd --wd 0.00003 --model GhostNet --log-interval 150 --model-info \ 39 | --dist-url tcp://127.0.0.1:26548 --world-size 1 --rank 0 40 | ``` 41 | 42 | ## ToDo 43 | - [x] Resume training 44 | - ~~Try Nvidia-DALI~~ 45 | - [x] Multi-node(distributed) training by ~~Apex or BytePS~~ Pytorch 46 | - [x] I may try AutoAugment.This project aims to train models by ourselves to observe and learn, 47 | it's impossible for me to train this, just copy feels meaningless. 48 | 49 | ## Baseline models 50 | 51 | |model | epochs| dtype |batch size*|gpus | lr | tricks|Params(M)/FLOPs |top1/top5 |params/logs| 52 | |:----:|:-----:|:-----:|:---------:|:----:|:---:|:------:|:---------------:|:---------:|:---------:| 53 | |resnet50|120 |FP16 |128 | 8 |0.4 | - | 25.6/4.1G |77.36/- |[Google Drive](https://drive.google.com/drive/folders/1orshUNj-4LroO2q-vyd45c_Iz7alQ50M?usp=sharing)| 54 | |resnet101|120 |FP16 |128 | 8 |0.4 | - | 44.7/7.8G |79.13/94.38|[Google Drive](https://drive.google.com/drive/folders/1nmdpX39_9KidxxUXuL0uDYpDGjavQS0M?usp=sharing)| 55 | |resnet50v2|120|FP16 |128 | 8 |0.4 | - | 25.6/4.1G |77.06/93.44|[Google Drive](https://drive.google.com/drive/folders/1W_GBANCv0eOQaTmDFZ-NrNJlUay5NP-C?usp=sharing)| 56 | |resnet101v2|120|FP16 |128 | 8 |0.4 | - | 44.6/7.8G |78.90/94.39|[Google Drive](https://drive.google.com/drive/folders/1L4r5S9MciLUkBzzjZwZ-vlC2xH1O1Csj?usp=sharing)| 57 | |ResNext50_32x4d|120|FP16 |128 | 8 |0.4 | - | 25.1/4.2G |79.00/94.39|| 58 | |RegNetX4_0GF|120|FP16 |128 | 8 |0.4 | - | 22.2/4.0G |78.40/94.04|| 59 | |RegNetY4_0GF|120|FP16 |128 | 8 |0.4 | - | 22.1/4.0G |79.22/94.57|| 60 | |RegNetY6_4GF|120|FP16 |128 | 8 |0.4 | - | 31.2/6.4G |79.69/94.82|| 61 | |ResNeST50 |120|FP16 |128 | 8 |0.4 | - | 27.5/4.1G |78.62/94.28|| 62 | |mobilenetv1|150|FP16 |256 | 8 |0.4 | - | 4.3/572.2M |72.17/90.70|[Google Drive](https://drive.google.com/drive/folders/1n_4WTnh-anrszm1VCo35etmUsG7O4j9e?usp=sharing)| 63 | |mobilenetv2|150|FP16 |256 | 8 |0.4 | - | 3.5/305.3M |71.94/90.59|[Google Drive](https://drive.google.com/drive/folders/1PqqyZ02L4h42KOVPSO6e9A0a_gVCir_b?usp=sharing)| 64 | |mobilenetv3 Large|360|FP16 |256 | 8 |2.6 |Label smoothing No decay bias Dropout| 5.5/219M |75.64/92.61 |[Google Drive](https://drive.google.com/drive/folders/1pZSDhNuSxSIyKq4Leyam9m5iQr1Xcpf6?usp=sharing)| 65 | |mobilenetv3 Small|360|FP16 |256 | 8 |2.6 |Label smoothing No decay bias Dropout| 3.0/57.8M |67.83/87.78 || 66 | |GhostNet1.3 |360|FP16 |400 | 8 |2.6 |Label smoothing No decay bias Dropout| 7.4/230.4M |75.78/92.77 |[Google Drive](https://drive.google.com/drive/folders/1-s0VujC1BAC-H2NGTtQ1-QHOlMNerieD?usp=sharing)| 67 | 68 | 69 | 70 | - I use nesterov SGD and cosine lr decay with 5 warmup epochs by default[2][3] (to save time), it's more common and effective. 71 | - *Batch size is pre GPU holds. Total batch size should be (batch size * gpus). 72 | 73 | 74 | ## Optimized Models(with tricks) 75 | - In progress. 76 | 77 | ## Ablation Study on Tricks 78 | 79 | Here are lots of tricks to improve accuracy during this years.(If you have another idea please open an issue.) 80 | I want to verify them in a fair way. 81 | 82 | 83 | Tricks: RandomRotation, OctConv[14], Drop out, Label Smoothing[4], Sync BN, ~~SwitchNorm[6]~~, Mixup[17], no decay bias[7], 84 | Cutout[5], Relu6[18], ~~swish activation[10]~~, Stochastic Depth[9], Lookahead Optimizer[11], Pre-active(ResnetV2)[12], 85 | ~~DCNv2[13]~~, LIP[16]. 86 | 87 | - Delete line means make me out of memory. 88 | 89 | Special: Zero-initialize the last BN, just call it 'Zero γ', only for post-active model. 90 | 91 | I'll only use 120 epochs and 128*8 batch size to train them. 92 | I know some tricks may need train more time or larger batch size but it's not fair for others. 93 | You can think of it as a performance in the current situation. 94 | 95 | 96 | |model | epochs| dtype |batch size*|gpus | lr | tricks|degree|top1/top5 |improve |params/logs| 97 | |:----:|:-----:|:-----:|:---------:|:----:|:---:|:------:|:----:|:---------:|:------:|:----:| 98 | |resnet50|120 |FP16 |128 | 8 |0.4 | - | - |77.36/- |baseline|[Google Drive](https://drive.google.com/drive/folders/1orshUNj-4LroO2q-vyd45c_Iz7alQ50M?usp=sharing)| 99 | |resnet50|120 |FP16 |128 | 8 |0.4 |Label smoothing|smoothing=0.1|77.78/93.80 |**+0.42** |[Google Drive](https://drive.google.com/drive/folders/1CO8Fmbiy1TgEvdpU-KKV7AHIa7EanaqG?usp=sharing)| 100 | |resnet50|120 |FP16 |128 | 8 |0.4 |No decay bias |- |77.28/93.61 |-0.08 |[Google Drive](https://drive.google.com/drive/folders/1oYC3EjLn-2nnWrS_UrhaP_3YY3uhWzhz?usp=sharing)| 101 | |resnet50|120 |FP16 |128 | 8 |0.4 |Sync BN |- |77.31/93.49 |-0.05 |[Google Drive](https://drive.google.com/drive/folders/1QW2LSl7JsTcnCGM289N9wA-xkjkuhBvg?usp=sharing)| 102 | |resnet50|120 |FP16 |128 | 8 |0.4 |Mixup |alpha=0.2 |77.49/93.73 |**+0.13** |missing| 103 | |resnet50|120 |FP16 |128 | 8 |0.4 |RandomRotation |degree=15 |76.64/93.28 |-1.15 |[Google Drive](https://drive.google.com/drive/folders/1FYmTVStop4VT5LA9RCPUbWPnzGsEJoCy?usp=sharing)| 104 | |resnet50|120 |FP16 |128 | 8 |0.4 |Cutout |read code |77.44/93.62 |**+0.08** |[Google Drive](https://drive.google.com/drive/folders/1HhDTDkj6Zg_oJT-5TQZu1RP-CYs1fr3U?usp=sharing)| 105 | |resnet50|120 |FP16 |128 | 8 |0.4 |Dropout |rate=0.3 |77.11/93.58 |-0.25 |[Google Drive](https://drive.google.com/drive/folders/1sA6e8sewz-Za6ySUUJcLpiTjV9V1Fk8f?usp=sharing)| 106 | |resnet50|120 |FP16 |128 | 8 |0.4 |Lookahead-SGD | - |77.23/93.39 |-0.13 |[Google Drive](https://drive.google.com/drive/folders/1gC8pD7CDDQ7haBKhNBNqj8i9Xsk3cNla?usp=sharing)| 107 | |resnet50v2|120 |FP16 |128 | 8 |0.4 |pre-active | - |77.06/93.44 |-0.30 |[Google Drive](https://drive.google.com/drive/folders/1W_GBANCv0eOQaTmDFZ-NrNJlUay5NP-C?usp=sharing)| 108 | |oct_resnet50|120 |FP16 |128 | 8 |0.4 |OctConv |alpha=0.125 |-|-|| 109 | |resnet50|120 |FP16 |128 | 8 |0.4 |Relu6 | |77.28/93.5 |-0.08 |[Google Drive](https://drive.google.com/drive/folders/1en9SQq2ZeswaZoTiYDAR_vQS3YAJU5gq?usp=sharing)| 110 | |resnet50|120 |FP16 |128 | 8 |0.4 | | - |77.00/- |DDP baseline|| 111 | |resnet50|120 |FP16 |128 | 8 |0.4 |Gradient Centralization|Conv only|77.40/93.57 |**+0.40**|| 112 | |resnet50|120 |FP16 |128 | 8 |0.4 |Zero γ | |77.24/- |**+0.24**|| 113 | |resnet50|120 |FP16 |128 | 8 |0.4 |No decay bias | |77.74/93.77 |**+0.74**|| 114 | |resnet50|120 |FP16 |128 | 8 |0.4 |RandAugment |n=2,m=9 |76.44/93.18 |-0.96|| 115 | |resnet50|120 |FP16 |128 | 8 |0.4 |AutoAugment | |76.50/93.23 |-0.50|| 116 | 117 | 118 | - More epochs for `Mixup`, `Cutout`, `Dropout` may get better results. 119 | - Auto/Rand Augment may train 180 epochs better. 120 | 121 | ## Citation 122 | ``` 123 | @misc{ModelZoo.pytorch, 124 | title = {Basic deep conv neural network reproduce and explore}, 125 | author = {X.Yang}, 126 | URL = {https://github.com/PistonY/ModelZoo.pytorch}, 127 | year = {2019} 128 | } 129 | ``` 130 | 131 | ## Reference 132 | - [1] [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) 133 | - [2] [Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour](https://arxiv.org/pdf/1706.02677.pdf) 134 | - [3] [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/pdf/1812.01187.pdf) 135 | - [4] [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) 136 | - [5] [Improved Regularization of Convolutional Neural Networks with Cutout](https://arxiv.org/pdf/1708.04552.pdf) 137 | - [6] [Differentiable Learning-to-Normalize via Switchable Normalization](https://arxiv.org/pdf/1806.10779.pdf) [OpenSourse](https://github.com/switchablenorms/Switchable-Normalization) 138 | - [7] [Highly Scalable Deep Learning Training System with Mixed-Precision: Training ImageNet in Four Minutes](https://arxiv.org/pdf/1807.11205.pdf) 139 | - [8] [MIXED PRECISION TRAINING](https://arxiv.org/pdf/1710.03740.pdf) 140 | - [9] [Deep Networks with Stochastic Depth](https://arxiv.org/pdf/1603.09382.pdf) 141 | - [10] [SEARCHING FOR ACTIVATION FUNCTIONS](https://arxiv.org/pdf/1710.05941.pdf) 142 | - [11] [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610) 143 | - [12] [Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf) 144 | - [13] [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/pdf/1811.11168.pdf) 145 | - [14] [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution](https://export.arxiv.org/pdf/1904.05049) 146 | - [15] [Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) 147 | - [16] [LIP: Local Importance-based Pooling](https://arxiv.org/pdf/1908.04156v1.pdf) 148 | - [17] [mixup: BEYOND EMPIRICAL RISK MINIMIZATION](https://arxiv.org/pdf/1710.09412.pdf) 149 | - [18] [Gradient Centralization: A New Optimization Technique for Deep Neural Networks](https://arxiv.org/pdf/2004.01461.pdf) -------------------------------------------------------------------------------- /models/ghostnet.py: -------------------------------------------------------------------------------- 1 | __all__ = ['GhostNet'] 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torchtoolbox.nn import Activation 7 | 8 | 9 | def make_divisible(v, divisible_by, min_value=None): 10 | """ 11 | This function is taken from the original tf repo. 12 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 13 | """ 14 | if min_value is None: 15 | min_value = divisible_by 16 | new_v = max(min_value, int(v + divisible_by / 2) // divisible_by * divisible_by) 17 | # Make sure that round down does not go down by more than 10%. 18 | if new_v < 0.9 * v: 19 | new_v += divisible_by 20 | return new_v 21 | 22 | 23 | class SE(nn.Module): 24 | def __init__(self, in_c, reduction_ratio=0.25): 25 | super(SE, self).__init__() 26 | reducation_c = make_divisible(in_c * reduction_ratio, 4) 27 | self.block = nn.Sequential( 28 | nn.AdaptiveAvgPool2d(1), 29 | nn.Conv2d(in_c, reducation_c, kernel_size=1, bias=True), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(reducation_c, in_c, kernel_size=1, bias=True), 32 | nn.Hardsigmoid() 33 | ) 34 | 35 | def forward(self, x): 36 | return x * self.block(x) 37 | 38 | 39 | class GhostModule(nn.Module): 40 | def __init__(self, in_c, out_c, kernel_size=1, ratio=2, dw_size=3, stride=1, act=True, act_type='relu'): 41 | super(GhostModule, self).__init__() 42 | if ratio != 2: 43 | print("Please change output channels manually.") 44 | init_c = math.ceil(out_c / ratio) 45 | new_c = init_c * (ratio - 1) 46 | 47 | self.primary_conv = nn.Sequential( 48 | nn.Conv2d(in_c, init_c, kernel_size, stride, kernel_size // 2, bias=False), 49 | nn.BatchNorm2d(init_c), 50 | Activation(act_type) if act else nn.Identity() 51 | ) 52 | 53 | self.cheap_operation = nn.Sequential( 54 | nn.Conv2d(init_c, new_c, dw_size, 1, dw_size // 2, groups=init_c, bias=False), 55 | nn.BatchNorm2d(new_c), 56 | Activation(act_type) if act else nn.Identity() 57 | ) 58 | 59 | def forward(self, x): 60 | x1 = self.primary_conv(x) 61 | x2 = self.cheap_operation(x1) 62 | out = torch.cat([x1, x2], dim=1) 63 | # if ratio != 2, you need return out[:,:out_c,:,:] 64 | return out 65 | 66 | 67 | class GhostBottleneck(nn.Module): 68 | def __init__(self, in_c, mid_c, out_c, dw_kernel_size=3, stride=1, se_ratio=None, act_type='relu'): 69 | super(GhostBottleneck, self).__init__() 70 | self.ghost1 = GhostModule(in_c, mid_c, act=True, act_type=act_type) 71 | if stride > 1: 72 | self.dw_conv = nn.Sequential( 73 | nn.Conv2d(mid_c, mid_c, dw_kernel_size, stride, 74 | dw_kernel_size // 2, groups=mid_c, bias=False), 75 | nn.BatchNorm2d(mid_c) 76 | ) 77 | else: 78 | self.dw_conv = nn.Identity() 79 | self.se = SE(mid_c, reduction_ratio=se_ratio) if se_ratio is not None else nn.Identity() 80 | self.ghost2 = GhostModule(mid_c, out_c, act=False, act_type=act_type) 81 | 82 | if in_c == out_c and stride == 1: 83 | self.shortcut = nn.Identity() 84 | else: 85 | self.shortcut = nn.Sequential( 86 | nn.Conv2d(in_c, in_c, dw_kernel_size, stride, 87 | dw_kernel_size // 2, groups=in_c, bias=False), 88 | nn.BatchNorm2d(in_c), 89 | nn.Conv2d(in_c, out_c, 1, bias=False), 90 | nn.BatchNorm2d(out_c) 91 | ) 92 | 93 | def forward(self, x): 94 | residual = self.shortcut(x) 95 | x = self.ghost1(x) 96 | x = self.dw_conv(x) 97 | x = self.se(x) 98 | x = self.ghost2(x) 99 | return x + residual 100 | 101 | 102 | class Stem(nn.Module): 103 | def __init__(self, out_c, act_type='relu'): 104 | super(Stem, self).__init__() 105 | self.stem = nn.Sequential( 106 | nn.Conv2d(3, out_c, 3, 2, 1, bias=False), 107 | nn.BatchNorm2d(out_c), 108 | Activation(act_type) 109 | ) 110 | 111 | def forward(self, x): 112 | return self.stem(x) 113 | 114 | 115 | class Head(nn.Module): 116 | def __init__(self, in_c, mid_c, out_c, dropout, act_type='relu'): 117 | super(Head, self).__init__() 118 | self.head = nn.Sequential( 119 | nn.AdaptiveAvgPool2d(1), 120 | nn.Conv2d(in_c, mid_c, 1, bias=True), 121 | Activation(act_type), 122 | nn.Flatten(), 123 | nn.Dropout(dropout) if dropout > 0 else nn.Identity(), 124 | nn.Linear(mid_c, out_c) 125 | ) 126 | 127 | def forward(self, x): 128 | return self.head(x) 129 | 130 | 131 | class GhostNet(nn.Module): 132 | def __init__(self, num_classes=1000, width=1.3, dropout=0): 133 | super(GhostNet, self).__init__() 134 | assert dropout >= 0, "Use = 0 to disable or > 0 to enable." 135 | self.width = width 136 | stem_c = make_divisible(16 * width, 4) 137 | self.stem = Stem(stem_c) 138 | self.stage = nn.Sequential( 139 | # stage1 140 | GhostBottleneck(stem_c, self.get_c(16), self.get_c(16), 3, 1), 141 | # stage2 142 | GhostBottleneck(self.get_c(16), self.get_c(48), self.get_c(24), 3, 2), 143 | GhostBottleneck(self.get_c(24), self.get_c(72), self.get_c(24), 3, 1), 144 | # stage3 145 | GhostBottleneck(self.get_c(24), self.get_c(72), self.get_c(40), 5, 2, 0.25), 146 | GhostBottleneck(self.get_c(40), self.get_c(120), self.get_c(40), 5, 1, 0.25), 147 | # stage4 148 | GhostBottleneck(self.get_c(40), self.get_c(240), self.get_c(80), 3, 2), 149 | GhostBottleneck(self.get_c(80), self.get_c(200), self.get_c(80), 3, 1), 150 | GhostBottleneck(self.get_c(80), self.get_c(184), self.get_c(80), 3, 1), 151 | GhostBottleneck(self.get_c(80), self.get_c(184), self.get_c(80), 3, 1), 152 | GhostBottleneck(self.get_c(80), self.get_c(480), self.get_c(112), 3, 1, 0.25), 153 | GhostBottleneck(self.get_c(112), self.get_c(672), self.get_c(112), 3, 1, 0.25), 154 | # stage5 155 | GhostBottleneck(self.get_c(112), self.get_c(672), self.get_c(160), 5, 2, 0.25), 156 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 5, 1), 157 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 5, 1, 0.25), 158 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 5, 1), 159 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 5, 1, 0.25), 160 | # conv-bn-act 161 | nn.Conv2d(self.get_c(160), self.get_c(960), 1, bias=False), 162 | nn.BatchNorm2d(self.get_c(960)), 163 | nn.ReLU(inplace=True), 164 | ) 165 | 166 | self.head = Head(self.get_c(960), 1280, num_classes, dropout) 167 | 168 | def get_c(self, c): 169 | return make_divisible(c * self.width, 4) 170 | 171 | def forward(self, x): 172 | x = self.stem(x) 173 | x = self.stage(x) 174 | x = self.head(x) 175 | return x 176 | 177 | 178 | class GhostNet600(nn.Module): 179 | def __init__(self, num_classes=1000, width=1.75, dropout=0.8): 180 | super(GhostNet600, self).__init__() 181 | assert dropout >= 0, "Use = 0 to disable or > 0 to enable." 182 | self.width = width 183 | stem_c = make_divisible(16 * width, 4) 184 | self.stem = Stem(stem_c, 'h_swish') 185 | self.stage = nn.Sequential( 186 | # stage1 187 | GhostBottleneck(stem_c, self.get_c(16), self.get_c(16), 3, 1, 0.1, 'h_swish'), 188 | # stage2 189 | GhostBottleneck(self.get_c(16), self.get_c(48), self.get_c(24), 3, 2, 0.1, 'h_swish'), 190 | GhostBottleneck(self.get_c(24), self.get_c(72), self.get_c(24), 3, 1, 0.1, 'h_swish'), 191 | # stage3 192 | GhostBottleneck(self.get_c(24), self.get_c(72), self.get_c(40), 5, 2, 0.1, 'h_swish'), 193 | GhostBottleneck(self.get_c(40), self.get_c(120), self.get_c(40), 3, 1, 0.1, 'h_swish'), 194 | GhostBottleneck(self.get_c(40), self.get_c(120), self.get_c(40), 3, 1, 0.1, 'h_swish'), 195 | # stage4 196 | GhostBottleneck(self.get_c(40), self.get_c(240), self.get_c(80), 3, 2, 0.1, 'h_swish'), 197 | GhostBottleneck(self.get_c(80), self.get_c(200), self.get_c(80), 3, 1, 0.1, 'h_swish'), 198 | GhostBottleneck(self.get_c(80), self.get_c(200), self.get_c(80), 3, 1, 0.1, 'h_swish'), 199 | GhostBottleneck(self.get_c(80), self.get_c(200), self.get_c(80), 3, 1, 0.1, 'h_swish'), 200 | GhostBottleneck(self.get_c(80), self.get_c(480), self.get_c(112), 3, 1, 0.1, 'h_swish'), 201 | GhostBottleneck(self.get_c(112), self.get_c(672), self.get_c(112), 3, 1, 0.1, 'h_swish'), 202 | GhostBottleneck(self.get_c(112), self.get_c(672), self.get_c(112), 3, 1, 0.1, 'h_swish'), 203 | # stage5 204 | GhostBottleneck(self.get_c(112), self.get_c(672), self.get_c(160), 5, 2, 0.1, 'h_swish'), 205 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 3, 1, 0.1, 'h_swish'), 206 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 3, 1, 0.1, 'h_swish'), 207 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 3, 1, 0.1, 'h_swish'), 208 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 3, 1, 0.1, 'h_swish'), 209 | GhostBottleneck(self.get_c(160), self.get_c(960), self.get_c(160), 3, 1, 0.1, 'h_swish'), 210 | # conv-bn-act 211 | nn.Conv2d(self.get_c(160), self.get_c(960), 1, bias=False), 212 | nn.BatchNorm2d(self.get_c(960)), 213 | Activation('h_swish'), 214 | ) 215 | 216 | self.head = Head(self.get_c(960), 1400, num_classes, dropout, 'h_swish') 217 | 218 | def get_c(self, c): 219 | return make_divisible(c * self.width, 4) 220 | 221 | def forward(self, x): 222 | x = self.stem(x) 223 | x = self.stage(x) 224 | x = self.head(x) 225 | return x 226 | 227 | 228 | class TinyGhostNet(nn.Module): 229 | def __init__(self, width_coeff, depth_coeff, depth_div=4, 230 | min_depth=None, num_classes=1000, dropout=0): 231 | super().__init__() 232 | assert dropout >= 0, "Use = 0 to disable or > 0 to enable." 233 | min_depth = min_depth or depth_div 234 | 235 | def renew_ch(x): 236 | if not width_coeff: 237 | return x 238 | 239 | new_x = max(min_depth, int(x + depth_div / 2) // depth_div * depth_div) 240 | if new_x < 0.9 * new_x: 241 | new_x += depth_div 242 | return new_x 243 | 244 | def renew_repeat(x): 245 | return int(math.ceil(x * depth_coeff)) 246 | 247 | self.stem = Stem(renew_ch(28)) 248 | self.stage = nn.Sequential( 249 | # stage1 250 | self._make_layer(renew_ch(28), renew_ch(28), 1, 3, 1, renew_repeat(1), 0.1), 251 | # stage2 252 | self._make_layer(renew_ch(28), renew_ch(44), 3, 3, 2, renew_repeat(1), 0.1), 253 | self._make_layer(renew_ch(44), renew_ch(44), 3, 3, 1, renew_repeat(1), 0.1), 254 | # stage3 255 | self._make_layer(renew_ch(44), renew_ch(72), 3, 3, 2, renew_repeat(1), 0.1), 256 | self._make_layer(renew_ch(72), renew_ch(72), 3, 3, 1, renew_repeat(2), 0.1), 257 | # stage4 258 | self._make_layer(renew_ch(72), renew_ch(140), 6, 3, 2, renew_repeat(1), 0.1), 259 | self._make_layer(renew_ch(140), renew_ch(140), 2.5, 3, 1, renew_repeat(3), 0.1), 260 | self._make_layer(renew_ch(140), renew_ch(196), 6, 3, 1, renew_repeat(3), 0.1), 261 | # stage5 262 | self._make_layer(renew_ch(196), renew_ch(280), 6, 3, 2, renew_repeat(1), 0.1), 263 | self._make_layer(renew_ch(280), renew_ch(280), 6, 3, 1, renew_repeat(5), 0.1), 264 | nn.Conv2d(renew_ch(280), renew_ch(1680), 1, bias=False), 265 | nn.BatchNorm2d(renew_ch(1680)), 266 | nn.ReLU(inplace=True) 267 | ) 268 | self.head = Head(renew_ch(1680), renew_ch(1400), num_classes, dropout) 269 | 270 | def _make_layer(self, in_c, out_c, expand, kernel_size, stride, repeats, se_ratio): 271 | layers = [] 272 | 273 | layers.append(GhostBottleneck(in_c, int(in_c * expand), out_c, kernel_size, stride, se_ratio)) 274 | for _ in range(repeats - 1): 275 | layers.append(GhostBottleneck(out_c, int(out_c * expand), out_c, kernel_size, 1, se_ratio)) 276 | return nn.Sequential(*layers) 277 | 278 | def forward(self, x): 279 | x = self.stem(x) 280 | x = self.stage(x) 281 | x = self.head(x) 282 | return x 283 | 284 | 285 | def GhostNetA(num_classes=1000, dropout=0, **kwargs): 286 | return TinyGhostNet(1., 1., num_classes=num_classes, dropout=dropout, **kwargs) 287 | 288 | 289 | if __name__ == '__main__': 290 | from torchtoolbox.tools import summary 291 | 292 | model = GhostNet600() 293 | x = torch.rand(size=(1, 3, 224, 224)) 294 | summary(model, x) 295 | -------------------------------------------------------------------------------- /scripts/train_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import argparse, time, logging, os 5 | import models 6 | import torch 7 | import warnings 8 | import apex 9 | from torch.utils.data import DistributedSampler 10 | from scripts.utils import get_model, set_model 11 | 12 | from torchtoolbox import metric 13 | from torchtoolbox.nn import LabelSmoothingLoss 14 | from torchtoolbox.optimizer import CosineWarmupLr, Lookahead 15 | from torchtoolbox.nn.init import KaimingInitializer 16 | from torchtoolbox.tools import no_decay_bias, \ 17 | mixup_data, mixup_criterion, check_dir 18 | from torchtoolbox.data import ImageLMDB 19 | 20 | from torchvision import transforms 21 | from torchvision.datasets import ImageNet 22 | from torch.utils.data import DataLoader 23 | from torch import nn 24 | from torch import optim 25 | from apex import amp 26 | 27 | parser = argparse.ArgumentParser(description='Train a model on ImageNet.') 28 | parser.add_argument('--data-path', type=str, required=True, 29 | help='training and validation dataset.') 30 | parser.add_argument('--use-lmdb', action='store_true', 31 | help='use LMDB dataset/format') 32 | parser.add_argument('--batch-size', type=int, default=32, 33 | help='training batch size per device (CPU/GPU).') 34 | parser.add_argument('--dtype', type=str, default='float32', 35 | help='data type for training. default is float32') 36 | parser.add_argument('--devices', type=str, default='0', 37 | help='gpus to use.') 38 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 39 | help='number of preprocessing workers') 40 | parser.add_argument('--epochs', type=int, default=1, 41 | help='number of training epochs.') 42 | parser.add_argument('--lr', type=float, default=0, 43 | help='learning rate. default is 0.') 44 | parser.add_argument('--momentum', type=float, default=0.9, 45 | help='momentum value for optimizer, default is 0.9.') 46 | parser.add_argument('--wd', type=float, default=0.0001, 47 | help='weight decay rate. default is 0.0001.') 48 | parser.add_argument('--dropout', type=float, default=0., 49 | help='model dropout rate.') 50 | parser.add_argument('--sync-bn', action='store_true', 51 | help='use Apex Sync-BN.') 52 | parser.add_argument('--lookahead', action='store_true', 53 | help='use lookahead optimizer.') 54 | parser.add_argument('--warmup-lr', type=float, default=0.0, 55 | help='starting warmup learning rate. default is 0.0.') 56 | parser.add_argument('--warmup-epochs', type=int, default=0, 57 | help='number of warmup epochs.') 58 | parser.add_argument('--model', type=str, required=True, 59 | help='type of model to use. see vision_model for options.') 60 | parser.add_argument('--alpha', type=float, default=0, 61 | help='model param.') 62 | parser.add_argument('--input-size', type=int, default=224, 63 | help='size of the input image size. default is 224') 64 | parser.add_argument('--crop-ratio', type=float, default=0.875, 65 | help='Crop ratio during validation. default is 0.875') 66 | parser.add_argument('--norm-layer', type=str, default='', 67 | help='Norm layer to use.') 68 | parser.add_argument('--activation', type=str, default='', 69 | help='activation to use.') 70 | parser.add_argument('--mixup', action='store_true', 71 | help='whether train the model with mix-up. default is false.') 72 | parser.add_argument('--mixup-alpha', type=float, default=0.2, 73 | help='beta distribution parameter for mixup sampling, default is 0.2.') 74 | parser.add_argument('--mixup-off-epoch', type=int, default=0, 75 | help='how many epochs to train without mixup, default is 0.') 76 | parser.add_argument('--label-smoothing', action='store_true', 77 | help='use label smoothing or not in training. default is false.') 78 | parser.add_argument('--no-wd', action='store_true', 79 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') 80 | parser.add_argument('--save-dir', type=str, default='params', 81 | help='directory of saved models') 82 | parser.add_argument('--log-interval', type=int, default=50, 83 | help='Number of batches to wait before logging.') 84 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log', 85 | help='name of training log file') 86 | parser.add_argument('--resume-epoch', type=int, default=0, 87 | help='epoch to resume training from.') 88 | parser.add_argument('--resume-param', type=str, default='', 89 | help='resume training param path.') 90 | parser.add_argument("--local_rank", default=0, type=int) 91 | args = parser.parse_args() 92 | 93 | filehandler = logging.FileHandler(args.logging_file) 94 | streamhandler = logging.StreamHandler() 95 | 96 | logger = logging.getLogger('') 97 | logger.setLevel(logging.INFO) 98 | logger.addHandler(filehandler) 99 | logger.addHandler(streamhandler) 100 | 101 | logger.info(args) 102 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 103 | 104 | torch.backends.cudnn.benchmark = True 105 | 106 | classes = 1000 107 | num_training_samples = 1281167 108 | 109 | check_dir(args.save_dir) 110 | assert torch.cuda.is_available(), \ 111 | "Please don't waste of your time,it's impossible to train on CPU." 112 | 113 | device = torch.device("cuda:0") 114 | device_ids = args.devices.strip().split(',') 115 | device_ids = [int(device) for device in device_ids] 116 | 117 | dtype = args.dtype 118 | epochs = args.epochs 119 | resume_epoch = args.resume_epoch 120 | num_workers = args.num_workers 121 | initializer = KaimingInitializer() 122 | batch_size = args.batch_size * len(device_ids) 123 | batches_pre_epoch = num_training_samples // batch_size 124 | lr = 0.1 * (args.batch_size // 32) if args.lr == 0 else args.lr 125 | 126 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 127 | std=[0.229, 0.224, 0.225]) 128 | 129 | train_transform = transforms.Compose([ 130 | transforms.RandomResizedCrop(224), 131 | # Cutout(), 132 | # transforms.RandomRotation(15), 133 | transforms.RandomHorizontalFlip(), 134 | transforms.ColorJitter(0.4, 0.4, 0.4), 135 | transforms.ToTensor(), 136 | normalize, 137 | ]) 138 | 139 | val_transform = transforms.Compose([ 140 | transforms.Resize(256), 141 | transforms.CenterCrop(224), 142 | transforms.ToTensor(), 143 | normalize, 144 | ]) 145 | 146 | torch.distributed.init_process_group(backend="nccl") 147 | if not args.use_lmdb: 148 | train_set = ImageNet(args.data_path, split='train', transform=train_transform) 149 | val_set = ImageNet(args.data_path, split='val', transform=val_transform) 150 | else: 151 | train_set = ImageLMDB(os.path.join(args.data_path, 'train.lmdb'), transform=train_transform) 152 | val_set = ImageLMDB(os.path.join(args.data_path, 'val.lmdb'), transform=val_transform) 153 | 154 | train_sampler = DistributedSampler(train_set) 155 | train_data = DataLoader(train_set, batch_size, False, pin_memory=True, num_workers=num_workers, drop_last=True, 156 | sampler=train_sampler) 157 | val_data = DataLoader(val_set, batch_size, False, pin_memory=True, num_workers=num_workers, drop_last=False) 158 | 159 | model_setting = set_model(args.dropout, args.norm_layer, args.activation) 160 | 161 | try: 162 | model = get_model(models, args.model, alpha=args.alpha, **model_setting) 163 | except TypeError: 164 | model = get_model(models, args.model, **model_setting) 165 | 166 | model.apply(initializer) 167 | model.to(device) 168 | parameters = model.parameters() if not args.no_wd else no_decay_bias(model) 169 | optimizer = optim.SGD(parameters, lr=lr, momentum=args.momentum, 170 | weight_decay=args.wd, nesterov=True) 171 | 172 | if args.sync_bn: 173 | logger.info('Use Apex Synced BN.') 174 | model = apex.parallel.convert_syncbn_model(model) 175 | 176 | if dtype == 'float16': 177 | logger.info('Train with FP16.') 178 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 179 | 180 | if args.lookahead: 181 | logger.info('Use lookahead optimizer.') 182 | optimizer = Lookahead(optimizer) 183 | 184 | model = nn.parallel.DistributedDataParallel(model) 185 | # model = nn.DataParallel(model) 186 | lr_scheduler = CosineWarmupLr(optimizer, batches_pre_epoch, epochs, 187 | base_lr=args.lr, warmup_epochs=args.warmup_epochs) 188 | if resume_epoch > 0: 189 | checkpoint = torch.load(args.resume_param) 190 | model.load_state_dict(checkpoint['model']) 191 | optimizer.load_state_dict(checkpoint['optimizer']) 192 | amp.load_state_dict(checkpoint['amp']) 193 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 194 | print("Finish loading resume param.") 195 | 196 | top1_acc = metric.Accuracy(name='Top1 Accuracy') 197 | top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy') 198 | loss_record = metric.NumericalCost(name='Loss') 199 | 200 | Loss = nn.CrossEntropyLoss() if not args.label_smoothing else \ 201 | LabelSmoothingLoss(classes, smoothing=0.1) 202 | 203 | 204 | @torch.no_grad() 205 | def test(epoch=0, save_status=True): 206 | top1_acc.reset() 207 | top5_acc.reset() 208 | loss_record.reset() 209 | model.eval() 210 | for data, labels in val_data: 211 | data = data.to(device, non_blocking=True) 212 | labels = labels.to(device, non_blocking=True) 213 | 214 | outputs = model(data) 215 | losses = Loss(outputs, labels) 216 | 217 | top1_acc.update(outputs, labels) 218 | top5_acc.update(outputs, labels) 219 | loss_record.update(losses) 220 | 221 | test_msg = 'Test Epoch {}: {}:{:.5}, {}:{:.5}, {}:{:.5}\n'.format( 222 | epoch, top1_acc.name, top1_acc.get(), top5_acc.name, top5_acc.get(), 223 | loss_record.name, loss_record.get()) 224 | logger.info(test_msg) 225 | if save_status: 226 | checkpoint = { 227 | 'model': model.state_dict(), 228 | 'optimizer': optimizer.state_dict(), 229 | 'amp': amp.state_dict(), 230 | 'lr_scheduler': lr_scheduler.state_dict(), 231 | } 232 | torch.save(checkpoint, '{}/{}_{}_{:.5}.pt'.format( 233 | args.save_dir, args.model, epoch, top1_acc.get())) 234 | 235 | 236 | def train(): 237 | for epoch in range(resume_epoch, epochs): 238 | train_sampler.set_epoch(epoch) 239 | top1_acc.reset() 240 | loss_record.reset() 241 | tic = time.time() 242 | 243 | model.train() 244 | for i, (data, labels) in enumerate(train_data): 245 | data = data.to(device, non_blocking=True) 246 | labels = labels.to(device, non_blocking=True) 247 | 248 | optimizer.zero_grad() 249 | outputs = model(data) 250 | loss = Loss(outputs, labels) 251 | 252 | with amp.scale_loss(loss, optimizer) as scaled_loss: 253 | scaled_loss.backward() 254 | optimizer.step() 255 | 256 | lr_scheduler.step() 257 | top1_acc.update(outputs, labels) 258 | loss_record.update(loss) 259 | 260 | if i % args.log_interval == 0 and i != 0: 261 | logger.info('Epoch {}, Iter {}, {}:{:.5}, {}:{:.5}, {} samples/s. lr: {:.5}.'.format( 262 | epoch, i, top1_acc.name, top1_acc.get(), 263 | loss_record.name, loss_record.get(), 264 | int((i * batch_size) // (time.time() - tic)), 265 | lr_scheduler.learning_rate 266 | )) 267 | 268 | train_speed = int(num_training_samples // (time.time() - tic)) 269 | epoch_msg = 'Train Epoch {}: {}:{:.5}, {}:{:.5}, {} samples/s.'.format( 270 | epoch, top1_acc.name, top1_acc.get(), loss_record.name, loss_record.get(), train_speed) 271 | logger.info(epoch_msg) 272 | test(epoch) 273 | 274 | 275 | def train_mixup(): 276 | mixup_off_epoch = epochs if args.mixup_off_epoch == 0 else args.mixup_off_epoch 277 | for epoch in range(resume_epoch, epochs): 278 | train_sampler.set_epoch(epoch) 279 | loss_record.reset() 280 | alpha = args.mixup_alpha if epoch < mixup_off_epoch else 0 281 | tic = time.time() 282 | 283 | model.train() 284 | for i, (data, labels) in enumerate(train_data): 285 | data = data.to(device, non_blocking=True) 286 | labels = labels.to(device, non_blocking=True) 287 | 288 | data, labels_a, labels_b, lam = mixup_data(data, labels, alpha) 289 | optimizer.zero_grad() 290 | outputs = model(data) 291 | loss = mixup_criterion(Loss, outputs, labels_a, labels_b, lam) 292 | 293 | with amp.scale_loss(loss, optimizer) as scaled_loss: 294 | scaled_loss.backward() 295 | optimizer.step() 296 | 297 | loss_record.update(loss) 298 | lr_scheduler.step() 299 | 300 | if i % args.log_interval == 0 and i != 0: 301 | logger.info('Epoch {}, Iter {}, {}:{:.5}, {} samples/s.'.format( 302 | epoch, i, loss_record.name, loss_record.get(), 303 | int((i * batch_size) // (time.time() - tic)) 304 | )) 305 | 306 | train_speed = int(num_training_samples // (time.time() - tic)) 307 | train_msg = 'Train Epoch {}: {}:{:.5}, {} samples/s, lr:{:.5}'.format( 308 | epoch, loss_record.name, loss_record.get(), 309 | train_speed, lr_scheduler.learning_rate) 310 | logger.info(train_msg) 311 | test(epoch) 312 | 313 | 314 | if __name__ == '__main__': 315 | if args.mixup: 316 | logger.info('Train using Mixup.') 317 | train_mixup() 318 | else: 319 | train() 320 | -------------------------------------------------------------------------------- /scripts/train_sample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import argparse, time, logging, os 5 | import models 6 | import torch 7 | import warnings 8 | import apex 9 | from scripts.utils import get_model, set_model 10 | 11 | from torchtoolbox import metric 12 | from torchtoolbox.nn import LabelSmoothingLoss 13 | from torchtoolbox.optimizer import CosineWarmupLr, Lookahead 14 | from torchtoolbox.nn.init import KaimingInitializer 15 | from torchtoolbox.tools import no_decay_bias, \ 16 | mixup_data, mixup_criterion, check_dir, summary 17 | 18 | from torchvision import transforms 19 | from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST, ImageFolder 20 | from torch.utils.data import DataLoader 21 | from torch import nn 22 | from torch.nn import functional as F 23 | from torch import optim 24 | from apex import amp 25 | 26 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 27 | 28 | parser = argparse.ArgumentParser(description='Train a model on ImageNet.') 29 | parser.add_argument('--data-path', type=str, required=True, 30 | help='training and validation dataset.') 31 | parser.add_argument('--dataset', type=str, default='mnist', 32 | help='Which dataset to train, default is mnist') 33 | parser.add_argument('--batch-size', type=int, default=32, 34 | help='training batch size per device (CPU/GPU).') 35 | parser.add_argument('--dtype', type=str, default='float32', 36 | help='data type for training. default is float32') 37 | parser.add_argument('--devices', type=str, default='0', 38 | help='gpus to use.') 39 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 40 | help='number of preprocessing workers') 41 | parser.add_argument('--epochs', type=int, default=1, 42 | help='number of training epochs.') 43 | parser.add_argument('--lr', type=float, default=0, 44 | help='learning rate. default is 0.') 45 | parser.add_argument('--momentum', type=float, default=0.9, 46 | help='momentum value for optimizer, default is 0.9.') 47 | parser.add_argument('--wd', type=float, default=0.0001, 48 | help='weight decay rate. default is 0.0001.') 49 | parser.add_argument('--dropout', type=float, default=0., 50 | help='model dropout rate.') 51 | parser.add_argument('--sync-bn', action='store_true', 52 | help='use Apex Sync-BN.') 53 | parser.add_argument('--lookahead', action='store_true', 54 | help='use lookahead optimizer.') 55 | parser.add_argument('--warmup-lr', type=float, default=0.0, 56 | help='starting warmup learning rate. default is 0.0.') 57 | parser.add_argument('--warmup-epochs', type=int, default=0, 58 | help='number of warmup epochs.') 59 | parser.add_argument('--model', type=str, required=True, 60 | help='type of model to use. see vision_model for options.') 61 | parser.add_argument('--alpha', type=float, default=0, 62 | help='model param.') 63 | parser.add_argument('--input-size', type=int, default=32, 64 | help='size of the input image size. default is 224') 65 | parser.add_argument('--padding', type=int, default=4, 66 | help='pad input image') 67 | parser.add_argument('--norm-layer', type=str, default='', 68 | help='Norm layer to use.') 69 | parser.add_argument('--activation', type=str, default='', 70 | help='activation to use.') 71 | parser.add_argument('--mixup', action='store_true', 72 | help='whether train the model with mix-up. default is false.') 73 | parser.add_argument('--mixup-alpha', type=float, default=0.2, 74 | help='beta distribution parameter for mixup sampling, default is 0.2.') 75 | parser.add_argument('--mixup-off-epoch', type=int, default=0, 76 | help='how many epochs to train without mixup, default is 0.') 77 | parser.add_argument('--label-smoothing', action='store_true', 78 | help='use label smoothing or not in training. default is false.') 79 | parser.add_argument('--no-wd', action='store_true', 80 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') 81 | parser.add_argument('--save-dir', type=str, default='params', 82 | help='directory of saved models') 83 | parser.add_argument('--log-interval', type=int, default=50, 84 | help='Number of batches to wait before logging.') 85 | parser.add_argument('--logging-file', type=str, default='train_sample.log', 86 | help='name of training log file') 87 | parser.add_argument('--resume-epoch', type=int, default=0, 88 | help='epoch to resume training from.') 89 | parser.add_argument('--resume-param', type=str, default='', 90 | help='resume training param path.') 91 | parser.add_argument("--local_rank", default=0, type=int) 92 | args = parser.parse_args() 93 | 94 | 95 | def get_dataset(name): 96 | path = args.data_path 97 | download = not os.path.exists(path) or not os.listdir(path) 98 | if name == 'cifar10': 99 | train_ = CIFAR10(path, train=True, transform=train_transform, download=download) 100 | val_ = CIFAR10(path, train=False, transform=val_transform, download=download) 101 | elif name == 'cifar100': 102 | train_ = CIFAR100(path, train=True, transform=train_transform, download=download) 103 | val_ = CIFAR100(path, train=False, transform=val_transform, download=download) 104 | elif name == 'mnist': 105 | train_ = MNIST(path, train=True, transform=train_transform, download=download) 106 | val_ = MNIST(path, train=False, transform=val_transform, download=download) 107 | elif name == 'fashion_mnist': 108 | train_ = FashionMNIST(path, train=True, transform=train_transform, download=download) 109 | val_ = FashionMNIST(path, train=False, transform=val_transform, download=download) 110 | else: 111 | train_ = ImageFolder(os.path.join(path, 'train'), transform=train_transform) 112 | val_ = ImageFolder(os.path.join(path, 'val'), transform=val_transform) 113 | return train_, val_ 114 | 115 | 116 | filehandler = logging.FileHandler(args.logging_file) 117 | streamhandler = logging.StreamHandler() 118 | 119 | logger = logging.getLogger('') 120 | logger.setLevel(logging.INFO) 121 | logger.addHandler(filehandler) 122 | logger.addHandler(streamhandler) 123 | 124 | logger.info(args) 125 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 126 | 127 | torch.backends.cudnn.benchmark = True 128 | 129 | train_transform = transforms.Compose([ 130 | transforms.Pad(args.padding), 131 | transforms.RandomCrop(args.input_size), 132 | # Cutout(), 133 | transforms.RandomHorizontalFlip(), 134 | transforms.ToTensor(), 135 | ]) 136 | 137 | val_transform = transforms.ToTensor() 138 | train_set, val_set = get_dataset(args.dataset) 139 | 140 | classes = len(train_set.classes) 141 | num_training_samples = len(train_set) 142 | 143 | check_dir(args.save_dir) 144 | assert torch.cuda.is_available(), \ 145 | "Please don't waste of your time,it's impossible to train on CPU." 146 | 147 | device = torch.device("cuda:0") 148 | device_ids = args.devices.strip().split(',') 149 | device_ids = [int(device) for device in device_ids] 150 | 151 | dtype = args.dtype 152 | epochs = args.epochs 153 | resume_epoch = args.resume_epoch 154 | num_workers = args.num_workers 155 | initializer = KaimingInitializer() 156 | batch_size = args.batch_size * len(device_ids) 157 | batches_pre_epoch = num_training_samples // batch_size 158 | lr = 0.1 * (args.batch_size // 32) if args.lr == 0 else args.lr 159 | 160 | train_data = DataLoader(train_set, batch_size, False, pin_memory=True, num_workers=num_workers, drop_last=True) 161 | val_data = DataLoader(val_set, batch_size, False, pin_memory=True, num_workers=num_workers, drop_last=False) 162 | 163 | model_setting = set_model(args.dropout, args.norm_layer, args.activation) 164 | 165 | try: 166 | model = get_model(models, args.model, alpha=args.alpha, small_input=True, 167 | return_feature=True, norm_feature=True, **model_setting) 168 | except TypeError: 169 | model = get_model(models, args.model, small_input=True, 170 | return_feature=True, norm_feature=True, **model_setting) 171 | 172 | summary(model, torch.rand((1, 3, args.input_size, args.input_size))) 173 | model.apply(initializer) 174 | model.to(device) 175 | parameters = model.parameters() if not args.no_wd else no_decay_bias(model) 176 | optimizer = optim.SGD(parameters, lr=lr, momentum=args.momentum, 177 | weight_decay=args.wd, nesterov=True) 178 | 179 | if args.sync_bn: 180 | logger.info('Use Apex Synced BN.') 181 | model = apex.parallel.convert_syncbn_model(model) 182 | 183 | if dtype == 'float16': 184 | logger.info('Train with FP16.') 185 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 186 | 187 | if args.lookahead: 188 | logger.info('Use lookahead optimizer.') 189 | optimizer = Lookahead(optimizer) 190 | 191 | model = nn.DataParallel(model) 192 | lr_scheduler = CosineWarmupLr(optimizer, batches_pre_epoch, epochs, 193 | base_lr=args.lr, warmup_epochs=args.warmup_epochs) 194 | if resume_epoch > 0: 195 | checkpoint = torch.load(args.resume_param) 196 | model.load_state_dict(checkpoint['model']) 197 | optimizer.load_state_dict(checkpoint['optimizer']) 198 | amp.load_state_dict(checkpoint['amp']) 199 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 200 | print("Finish loading resume param.") 201 | 202 | top1_acc = metric.Accuracy(name='Top1 Accuracy') 203 | top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy') 204 | loss_record = metric.NumericalCost(name='Loss') 205 | 206 | Loss = nn.CrossEntropyLoss() if not args.label_smoothing else \ 207 | LabelSmoothingLoss(classes, smoothing=0.1) 208 | 209 | 210 | @torch.no_grad() 211 | def test(epoch=0, save_status=False): 212 | top1_acc.reset() 213 | top5_acc.reset() 214 | loss_record.reset() 215 | model.eval() 216 | for data, labels in val_data: 217 | data = data.to(device, non_blocking=True) 218 | labels = labels.to(device, non_blocking=True) 219 | 220 | outputs = model(data) 221 | losses = Loss(outputs, labels) 222 | 223 | top1_acc.update(outputs, labels) 224 | top5_acc.update(outputs, labels) 225 | loss_record.update(losses) 226 | 227 | test_msg = 'Test Epoch {}: {}:{:.5}, {}:{:.5}, {}:{:.5}\n'.format( 228 | epoch, top1_acc.name, top1_acc.get(), top5_acc.name, top5_acc.get(), 229 | loss_record.name, loss_record.get()) 230 | logger.info(test_msg) 231 | if save_status: 232 | checkpoint = { 233 | 'model': model.state_dict(), 234 | 'optimizer': optimizer.state_dict(), 235 | 'amp': amp.state_dict(), 236 | 'lr_scheduler': lr_scheduler.state_dict(), 237 | } 238 | torch.save(checkpoint, '{}/{}_{}_{:.5}.pt'.format( 239 | args.save_dir, args.model, epoch, top1_acc.get())) 240 | 241 | 242 | def train(): 243 | for epoch in range(resume_epoch, epochs): 244 | top1_acc.reset() 245 | loss_record.reset() 246 | tic = time.time() 247 | 248 | model.train() 249 | for i, (data, labels) in enumerate(train_data): 250 | data = data.to(device, non_blocking=True) 251 | labels = labels.to(device, non_blocking=True) 252 | 253 | optimizer.zero_grad() 254 | outputs = model(data) 255 | loss = Loss(outputs, labels) 256 | 257 | with amp.scale_loss(loss, optimizer) as scaled_loss: 258 | scaled_loss.backward() 259 | optimizer.step() 260 | 261 | lr_scheduler.step() 262 | top1_acc.update(outputs, labels) 263 | loss_record.update(loss) 264 | 265 | if i % args.log_interval == 0 and i != 0: 266 | logger.info('Epoch {}, Iter {}, {}:{:.5}, {}:{:.5}, {} samples/s. lr: {:.5}.'.format( 267 | epoch, i, top1_acc.name, top1_acc.get(), 268 | loss_record.name, loss_record.get(), 269 | int((i * batch_size) // (time.time() - tic)), 270 | lr_scheduler.learning_rate 271 | )) 272 | 273 | train_speed = int(num_training_samples // (time.time() - tic)) 274 | epoch_msg = 'Train Epoch {}: {}:{:.5}, {}:{:.5}, {} samples/s.'.format( 275 | epoch, top1_acc.name, top1_acc.get(), loss_record.name, loss_record.get(), train_speed) 276 | logger.info(epoch_msg) 277 | test(epoch) 278 | 279 | 280 | def train_mixup(): 281 | mixup_off_epoch = epochs if args.mixup_off_epoch == 0 else args.mixup_off_epoch 282 | for epoch in range(resume_epoch, epochs): 283 | loss_record.reset() 284 | alpha = args.mixup_alpha if epoch < mixup_off_epoch else 0 285 | tic = time.time() 286 | 287 | model.train() 288 | for i, (data, labels) in enumerate(train_data): 289 | data = data.to(device, non_blocking=True) 290 | labels = labels.to(device, non_blocking=True) 291 | 292 | data, labels_a, labels_b, lam = mixup_data(data, labels, alpha) 293 | optimizer.zero_grad() 294 | outputs = model(data) 295 | loss = mixup_criterion(Loss, outputs, labels_a, labels_b, lam) 296 | 297 | with amp.scale_loss(loss, optimizer) as scaled_loss: 298 | scaled_loss.backward() 299 | optimizer.step() 300 | 301 | loss_record.update(loss) 302 | lr_scheduler.step() 303 | 304 | if i % args.log_interval == 0 and i != 0: 305 | logger.info('Epoch {}, Iter {}, {}:{:.5}, {} samples/s.'.format( 306 | epoch, i, loss_record.name, loss_record.get(), 307 | int((i * batch_size) // (time.time() - tic)) 308 | )) 309 | 310 | train_speed = int(num_training_samples // (time.time() - tic)) 311 | train_msg = 'Train Epoch {}: {}:{:.5}, {} samples/s, lr:{:.5}'.format( 312 | epoch, loss_record.name, loss_record.get(), 313 | train_speed, lr_scheduler.learning_rate) 314 | logger.info(train_msg) 315 | test(epoch) 316 | 317 | 318 | if __name__ == '__main__': 319 | if args.mixup: 320 | logger.info('Train using Mixup.') 321 | train_mixup() 322 | else: 323 | train() 324 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import torch.nn as nn 5 | from torchtoolbox.nn import Activation 6 | from module.dropblock import DropBlock2d 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'ResNetV2', 'resnet18v2', 'resnet34v2', 'resnet50v2', 'resnet101v2', 11 | 'resnet152v2', 'resnext50v2_32x4d', 'resnext101v2_32x8d'] 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, groups=groups, bias=False) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, norm_layer=None, activation=None): 30 | super(BasicBlock, self).__init__() 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = norm_layer(planes) 38 | self.act = Activation(activation, auto_optimize=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = norm_layer(planes) 41 | self.downsample = nn.Identity() if downsample is None else downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.act(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.act(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 64 | base_width=64, norm_layer=None, activation=None, drop_block=False, 65 | drop_prob=0.1, block_size=7): 66 | super(Bottleneck, self).__init__() 67 | if norm_layer is None: 68 | norm_layer = nn.BatchNorm2d 69 | width = int(planes * (base_width / 64.)) * groups 70 | 71 | self.conv1 = conv1x1(inplanes, width) 72 | self.bn1 = norm_layer(width) 73 | self.conv2 = conv3x3(width, width, stride, groups) 74 | self.bn2 = norm_layer(width) 75 | self.conv3 = conv1x1(width, planes * self.expansion) 76 | self.bn3 = norm_layer(planes * self.expansion) 77 | self.act = Activation(activation, auto_optimize=True) 78 | self.downsample = nn.Identity() if downsample is None else downsample 79 | self.dropblock = DropBlock2d(drop_prob, block_size) if drop_block else nn.Identity() 80 | 81 | def forward(self, x): 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.dropblock(out) 85 | out = self.act(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.dropblock(out) 90 | out = self.act(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | out = self.dropblock(out) 99 | out = self.act(out) 100 | 101 | return out 102 | 103 | 104 | class BasicBlockV2(nn.Module): 105 | expansion = 1 106 | 107 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 108 | base_width=64, norm_layer=None, activation=None): 109 | super(BasicBlockV2, self).__init__() 110 | if norm_layer is None: 111 | norm_layer = nn.BatchNorm2d 112 | if groups != 1 or base_width != 64: 113 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 114 | 115 | self.bn1 = norm_layer(inplanes) 116 | self.conv1 = conv3x3(inplanes, planes, stride) 117 | self.bn2 = norm_layer(planes) 118 | self.conv2 = conv3x3(planes, planes) 119 | self.downsample = downsample 120 | self.act = Activation(activation, auto_optimize=True) 121 | 122 | def forward(self, x): 123 | identity = x 124 | out = self.bn1(x) 125 | out = self.act(out) 126 | if self.downsample is not None: 127 | identity = self.downsample(out) 128 | out = self.conv1(out) 129 | 130 | out = self.bn2(out) 131 | out = self.act(out) 132 | out = self.conv2(out) 133 | 134 | out += identity 135 | 136 | return out 137 | 138 | 139 | class BottleneckV2(nn.Module): 140 | expansion = 4 141 | 142 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 143 | base_width=64, norm_layer=None, activation=None): 144 | super(BottleneckV2, self).__init__() 145 | if norm_layer is None: 146 | norm_layer = nn.BatchNorm2d 147 | width = int(planes * (base_width / 64.)) * groups 148 | 149 | self.bn1 = norm_layer(inplanes) 150 | self.conv1 = conv1x1(inplanes, width) 151 | self.bn2 = norm_layer(width) 152 | self.conv2 = conv3x3(width, width, stride, groups) 153 | self.bn3 = norm_layer(width) 154 | self.conv3 = conv1x1(width, planes * self.expansion) 155 | self.act = Activation(activation, auto_optimize=True) 156 | 157 | self.downsample = downsample 158 | 159 | def forward(self, x): 160 | identity = x 161 | out = self.bn1(x) 162 | out = self.act(out) 163 | if self.downsample is not None: 164 | identity = self.downsample(out) 165 | out = self.conv1(out) 166 | 167 | out = self.bn2(out) 168 | out = self.act(out) 169 | out = self.conv2(out) 170 | 171 | out = self.bn3(out) 172 | out = self.act(out) 173 | out = self.conv3(out) 174 | 175 | out += identity 176 | return out 177 | 178 | 179 | class ResNet(nn.Module): 180 | def __init__(self, block, layers, num_classes=1000, groups=1, width_per_group=64, 181 | norm_layer=None, activation='relu', dropout_rate=None, small_input=False, 182 | drop_block=True): 183 | super(ResNet, self).__init__() 184 | if norm_layer is None: 185 | norm_layer = nn.BatchNorm2d 186 | self._norm_layer = norm_layer 187 | self._activation = activation 188 | 189 | self.inplanes = 64 190 | 191 | self.groups = groups 192 | self.base_width = width_per_group 193 | if small_input: 194 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 195 | bias=False) 196 | else: 197 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 198 | bias=False) 199 | self.bn1 = norm_layer(self.inplanes) 200 | self.act = Activation(activation, auto_optimize=True) 201 | if small_input: 202 | self.maxpool = nn.Identity() 203 | else: 204 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 205 | self.layer1 = self._make_layer(block, 64, layers[0]) 206 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 207 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_block=drop_block) 208 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_block=drop_block) 209 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 210 | self.flatten = nn.Flatten() 211 | self.dropout = nn.Dropout(dropout_rate, inplace=True) if dropout_rate is not None else nn.Identity() 212 | self.fc = nn.Linear(512 * block.expansion, num_classes) 213 | 214 | def _make_layer(self, block, planes, blocks, stride=1, drop_block=False): 215 | norm_layer = self._norm_layer 216 | downsample = None 217 | 218 | if stride != 1 or self.inplanes != planes * block.expansion: 219 | downsample = nn.Sequential( 220 | conv1x1(self.inplanes, planes * block.expansion, stride), 221 | norm_layer(planes * block.expansion), 222 | ) 223 | 224 | layers = [] 225 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 226 | self.base_width, norm_layer, self._activation, drop_block)) 227 | self.inplanes = planes * block.expansion 228 | for _ in range(1, blocks): 229 | layers.append(block(self.inplanes, planes, groups=self.groups, 230 | base_width=self.base_width, norm_layer=norm_layer, 231 | activation=self._activation, drop_block=drop_block)) 232 | 233 | return nn.Sequential(*layers) 234 | 235 | def forward(self, x): 236 | x = self.conv1(x) 237 | x = self.bn1(x) 238 | x = self.act(x) 239 | x = self.maxpool(x) 240 | 241 | x = self.layer1(x) 242 | x = self.layer2(x) 243 | x = self.layer3(x) 244 | x = self.layer4(x) 245 | 246 | x = self.avgpool(x) 247 | x = self.flatten(x) 248 | x = self.dropout(x) 249 | x = self.fc(x) 250 | 251 | return x 252 | 253 | 254 | class ResNetV2(nn.Module): 255 | def __init__(self, block, layers, num_classes=1000, groups=1, width_per_group=64, 256 | norm_layer=None, activation='relu', dropout_rate=None, small_input=False): 257 | super(ResNetV2, self).__init__() 258 | if norm_layer is None: 259 | norm_layer = nn.BatchNorm2d 260 | 261 | self._norm_layer = norm_layer 262 | self._activation = activation 263 | 264 | self.inplanes = 64 265 | 266 | self.groups = groups 267 | self.base_width = width_per_group 268 | if small_input: 269 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 270 | bias=False) 271 | else: 272 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 273 | bias=False) 274 | self.bn1 = norm_layer(self.inplanes) 275 | self.act = Activation(activation, auto_optimize=True) 276 | 277 | if small_input: 278 | self.maxpool = nn.Identity() 279 | else: 280 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 281 | self.layer1 = self._make_layer(block, 64, layers[0]) 282 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 283 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 284 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 285 | self.bn_last = norm_layer(512 * block.expansion) 286 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 287 | self.flatten = nn.Flatten() 288 | self.dropout = nn.Dropout(dropout_rate, inplace=True) if dropout_rate is not None else nn.Identity() 289 | self.fc = nn.Linear(512 * block.expansion, num_classes) 290 | 291 | def _make_layer(self, block, planes, blocks, stride=1): 292 | norm_layer = self._norm_layer 293 | downsample = None 294 | 295 | if stride != 1 or self.inplanes != planes * block.expansion: 296 | downsample = nn.Sequential( 297 | conv1x1(self.inplanes, planes * block.expansion, stride), 298 | ) 299 | 300 | layers = [] 301 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 302 | self.base_width, norm_layer, self._activation)) 303 | self.inplanes = planes * block.expansion 304 | for _ in range(1, blocks): 305 | layers.append(block(self.inplanes, planes, groups=self.groups, 306 | base_width=self.base_width, norm_layer=norm_layer, 307 | activation=self._activation)) 308 | 309 | return nn.Sequential(*layers) 310 | 311 | def forward(self, x): 312 | x = self.conv1(x) 313 | x = self.bn1(x) 314 | x = self.act(x) 315 | x = self.maxpool(x) 316 | 317 | x = self.layer1(x) 318 | x = self.layer2(x) 319 | x = self.layer3(x) 320 | x = self.layer4(x) 321 | x = self.bn_last(x) 322 | x = self.act(x) 323 | 324 | x = self.avgpool(x) 325 | x = self.flatten(x) 326 | x = self.dropout(x) 327 | x = self.fc(x) 328 | 329 | return x 330 | 331 | 332 | def _resnet(block, layers, version=1, **kwargs): 333 | assert version in (1, 2) 334 | 335 | if version == 1: 336 | model = ResNet(block, layers, **kwargs) 337 | else: 338 | model = ResNetV2(block, layers, **kwargs) 339 | return model 340 | 341 | 342 | def resnet18(**kwargs): 343 | """Constructs a ResNet-18 model. 344 | 345 | """ 346 | return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs) 347 | 348 | 349 | def resnet18v2(**kwargs): 350 | """Constructs a ResNet-18v2 model. 351 | 352 | """ 353 | return _resnet(BasicBlockV2, [2, 2, 2, 2], 2, **kwargs) 354 | 355 | 356 | def resnet34(**kwargs): 357 | """Constructs a ResNet-34 model. 358 | 359 | """ 360 | return _resnet(BasicBlock, [3, 4, 6, 3], **kwargs) 361 | 362 | 363 | def resnet34v2(**kwargs): 364 | """Constructs a ResNet-34v2 model. 365 | 366 | """ 367 | return _resnet(BasicBlockV2, [3, 4, 6, 3], 2, **kwargs) 368 | 369 | 370 | def resnet50(**kwargs): 371 | """Constructs a ResNet-50 model. 372 | 373 | """ 374 | return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 375 | 376 | 377 | def resnet50v2(**kwargs): 378 | """Constructs a ResNet-50 model. 379 | 380 | """ 381 | return _resnet(BottleneckV2, [3, 4, 6, 3], 2, **kwargs) 382 | 383 | 384 | def resnet101(**kwargs): 385 | """Constructs a ResNet-101 model. 386 | 387 | """ 388 | return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 389 | 390 | 391 | def resnet101v2(**kwargs): 392 | """Constructs a ResNet-101v2 model. 393 | 394 | """ 395 | return _resnet(Bottleneck, [3, 4, 23, 3], 2, **kwargs) 396 | 397 | 398 | def resnet152(**kwargs): 399 | """Constructs a ResNet-152 model. 400 | 401 | """ 402 | return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs) 403 | 404 | 405 | def resnet152v2(**kwargs): 406 | """Constructs a ResNet-152v2 model. 407 | 408 | """ 409 | return _resnet(Bottleneck, [3, 8, 36, 3], 2, **kwargs) 410 | 411 | 412 | def resnext50_32x4d(**kwargs): 413 | """Constructs a ResNeXt-50 32x4d model. 414 | 415 | """ 416 | kwargs['groups'] = 32 417 | kwargs['width_per_group'] = 4 418 | return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 419 | 420 | 421 | def resnext50v2_32x4d(**kwargs): 422 | """Constructs a ResNeXt-50v2 32x4d model. 423 | 424 | """ 425 | kwargs['groups'] = 32 426 | kwargs['width_per_group'] = 4 427 | return _resnet(Bottleneck, [3, 4, 6, 3], 2, **kwargs) 428 | 429 | 430 | def resnext101_32x8d(**kwargs): 431 | """Constructs a ResNeXt-101 32x8d model. 432 | 433 | """ 434 | kwargs['groups'] = 32 435 | kwargs['width_per_group'] = 8 436 | return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 437 | 438 | 439 | def resnext101v2_32x8d(**kwargs): 440 | """Constructs a ResNeXt-101v2 32x8d model. 441 | 442 | """ 443 | kwargs['groups'] = 32 444 | kwargs['width_per_group'] = 8 445 | return _resnet(Bottleneck, [3, 4, 23, 3], 2, **kwargs) 446 | -------------------------------------------------------------------------------- /scripts/distribute_train_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : DevinYang(pistonyang@gmail.com) 3 | 4 | import argparse 5 | import time 6 | import os 7 | import models 8 | import torch 9 | import warnings 10 | 11 | from scripts.utils import get_logger, get_model 12 | from torchtoolbox import metric 13 | from torchtoolbox.nn import LabelSmoothingLoss 14 | from torchtoolbox.optimizer import CosineWarmupLr, Lookahead 15 | from torchtoolbox.optimizer.sgd_gc import SGD_GC 16 | from torchtoolbox.nn.init import KaimingInitializer 17 | from torchtoolbox.tools import no_decay_bias, \ 18 | mixup_data, mixup_criterion, check_dir, summary 19 | from torchtoolbox.transform import Cutout, ImageNetPolicy, \ 20 | RandAugment 21 | 22 | from torchvision.datasets import ImageNet 23 | from torch import multiprocessing as mp 24 | from torch.utils.data import DataLoader 25 | from torch.utils.data import DistributedSampler 26 | from torchvision import transforms 27 | from torch import nn 28 | from torch import optim 29 | from torch.cuda.amp import autocast, GradScaler 30 | 31 | # from module.aa import ImageNetPolicy 32 | 33 | # from module.dropblock import DropBlockScheduler 34 | 35 | parser = argparse.ArgumentParser(description='Train a model on ImageNet.') 36 | parser.add_argument('--data-path', type=str, required=True, 37 | help='training and validation dataset.') 38 | parser.add_argument('--batch-size', type=int, default=32, 39 | help='training batch size per device (CPU/GPU).') 40 | parser.add_argument('--dtype', type=str, default='float32', 41 | help='data type for training. default is float32') 42 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 43 | help='number of preprocessing workers') 44 | parser.add_argument('--epochs', type=int, default=1, 45 | help='number of training epochs.') 46 | parser.add_argument('--lr', type=float, default=0, 47 | help='learning rate. default is 0.') 48 | parser.add_argument('--momentum', type=float, default=0.9, 49 | help='momentum value for optimizer, default is 0.9.') 50 | parser.add_argument('--wd', type=float, default=0.0001, 51 | help='weight decay rate. default is 0.0001.') 52 | parser.add_argument('--dropout', type=float, default=0., 53 | help='model dropout rate.') 54 | parser.add_argument('--lookahead', action='store_true', 55 | help='use lookahead optimizer.') 56 | parser.add_argument('--warmup-lr', type=float, default=0.0, 57 | help='starting warmup learning rate. default is 0.0.') 58 | parser.add_argument('--warmup-epochs', type=int, default=0, 59 | help='number of warmup epochs.') 60 | parser.add_argument('--model', type=str, required=True, 61 | help='type of model to use. see vision_model for options.') 62 | parser.add_argument('--alpha', type=float, default=0, 63 | help='model param.') 64 | parser.add_argument('--input-size', type=int, default=224, 65 | help='size of the input image size. default is 224') 66 | parser.add_argument('--crop-ratio', type=float, default=0.875, 67 | help='Crop ratio during validation. default is 0.875') 68 | parser.add_argument('--mixup', action='store_true', 69 | help='whether train the model with mix-up. default is false.') 70 | parser.add_argument('--mixup-alpha', type=float, default=0.2, 71 | help='beta distribution parameter for mixup sampling, default is 0.2.') 72 | parser.add_argument('--label-smoothing', action='store_true', 73 | help='use label smoothing or not in training. default is false.') 74 | parser.add_argument('--no-wd', action='store_true', 75 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') 76 | parser.add_argument('--last-gamma', action='store_true', 77 | help='apply zero last bn weight in Bottleneck') 78 | parser.add_argument('--sgd-gc', action='store_true', 79 | help='using sgd Gradient Centralization') 80 | parser.add_argument('--transform', type=str, default='normal', 81 | help='use normal, aa or ra.') 82 | parser.add_argument('--drop-block', action='store_true', 83 | help='use DropBlock') 84 | parser.add_argument('--save-dir', type=str, default='params', 85 | help='directory of saved models') 86 | parser.add_argument('--model-info', action='store_true', 87 | help='show model information.') 88 | parser.add_argument('--log-interval', type=int, default=50, 89 | help='Number of batches to wait before logging.') 90 | parser.add_argument('--logging-file', type=str, default='distribute_train_imagenet.log', 91 | help='name of training log file') 92 | parser.add_argument('--resume-epoch', type=int, default=0, 93 | help='epoch to resume training from.') 94 | parser.add_argument('--resume-param', type=str, default='', 95 | help='resume training param path.') 96 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:26548', type=str, 97 | help='url used to set up distributed training') 98 | parser.add_argument("--rank", required=True, type=int, 99 | help='node rank for distributed training') 100 | parser.add_argument('--world-size', required=True, type=int, 101 | help='number of nodes for distributed training') 102 | # Default enable 103 | # parser.add_argument('--multiprocessing-distributed', action='store_true', 104 | # help='Use multi-processing distributed training to launch ' 105 | # 'N processes per node, which has N GPUs. This is the ' 106 | # 'fastest way to use PyTorch for either single node or ' 107 | # 'multi node data parallel training') 108 | 109 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 110 | assert torch.cuda.is_available(), \ 111 | "Please don't waste of your time,it's impossible to train on CPU." 112 | 113 | 114 | class ZeroLastGamma(object): 115 | def __init__(self, block_name='Bottleneck', bn_name='bn3'): 116 | self.block_name = block_name 117 | self.bn_name = bn_name 118 | 119 | def __call__(self, module): 120 | if module.__class__.__name__ == self.block_name: 121 | target_bn = module.__getattr__(self.bn_name) 122 | nn.init.zeros_(target_bn.weight) 123 | 124 | 125 | def main(): 126 | args = parser.parse_args() 127 | logger = get_logger(args.logging_file) 128 | logger.info(args) 129 | args.save_dir = os.path.join(os.getcwd(), args.save_dir) 130 | check_dir(args.save_dir) 131 | 132 | assert args.world_size >= 1 133 | 134 | args.classes = 1000 135 | args.num_training_samples = 1281167 136 | args.world = args.rank 137 | ngpus_per_node = torch.cuda.device_count() 138 | args.world_size = ngpus_per_node * args.world_size 139 | args.mix_precision_training = True if args.dtype == 'float16' else False 140 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 141 | 142 | 143 | def main_worker(gpu, ngpus_per_node, args): 144 | args.gpu = gpu 145 | logger = get_logger(args.logging_file) 146 | logger.info("Use GPU: {} for training".format(args.gpu)) 147 | 148 | args.rank = args.rank * ngpus_per_node + gpu 149 | torch.distributed.init_process_group(backend="nccl", init_method=args.dist_url, 150 | world_size=args.world_size, rank=args.rank) 151 | 152 | epochs = args.epochs 153 | input_size = args.input_size 154 | resume_epoch = args.resume_epoch 155 | initializer = KaimingInitializer() 156 | zero_gamma = ZeroLastGamma() 157 | is_first_rank = True if args.rank % ngpus_per_node == 0 else False 158 | 159 | batches_pre_epoch = args.num_training_samples // (args.batch_size * ngpus_per_node) 160 | lr = 0.1 * (args.batch_size * ngpus_per_node // 32) if args.lr == 0 else args.lr 161 | 162 | model = get_model(models, args.model) 163 | 164 | model.apply(initializer) 165 | if args.last_gamma: 166 | model.apply(zero_gamma) 167 | logger.info('Apply zero last gamma init.') 168 | 169 | if is_first_rank and args.model_info: 170 | summary(model, torch.rand((1, 3, input_size, input_size))) 171 | 172 | parameters = model.parameters() if not args.no_wd else no_decay_bias(model) 173 | if args.sgd_gc: 174 | logger.info('Use SGD_GC optimizer.') 175 | optimizer = SGD_GC(parameters, lr=lr, momentum=args.momentum, 176 | weight_decay=args.wd, nesterov=True) 177 | else: 178 | optimizer = optim.SGD(parameters, lr=lr, momentum=args.momentum, 179 | weight_decay=args.wd, nesterov=True) 180 | 181 | lr_scheduler = CosineWarmupLr(optimizer, batches_pre_epoch, epochs, 182 | base_lr=args.lr, warmup_epochs=args.warmup_epochs) 183 | 184 | # dropblock_scheduler = DropBlockScheduler(model, batches_pre_epoch, epochs) 185 | 186 | if args.lookahead: 187 | optimizer = Lookahead(optimizer) 188 | logger.info('Use lookahead optimizer.') 189 | 190 | torch.cuda.set_device(args.gpu) 191 | model.cuda(args.gpu) 192 | args.num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node) 193 | 194 | if args.mix_precision_training and is_first_rank: 195 | logger.info('Train with FP16.') 196 | 197 | scaler = GradScaler(enabled=args.mix_precision_training) 198 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 199 | 200 | Loss = nn.CrossEntropyLoss().cuda(args.gpu) if not args.label_smoothing else \ 201 | LabelSmoothingLoss(args.classes, smoothing=0.1).cuda(args.gpu) 202 | 203 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 204 | std=[0.229, 0.224, 0.225]) 205 | 206 | if args.transform == 'aa': 207 | if is_first_rank: 208 | print('Using AutoAugment transform.') 209 | train_transform = transforms.Compose([ 210 | transforms.RandomResizedCrop(input_size), 211 | transforms.RandomHorizontalFlip(), 212 | ImageNetPolicy, 213 | transforms.ToTensor(), 214 | normalize, 215 | ]) 216 | elif args.transform == 'ra': 217 | if is_first_rank: 218 | print('Using RandAugment transform.') 219 | train_transform = transforms.Compose([ 220 | transforms.RandomResizedCrop(input_size), 221 | transforms.RandomHorizontalFlip(), 222 | RandAugment(n=2, m=9), 223 | transforms.ToTensor(), 224 | normalize, 225 | ]) 226 | else: 227 | train_transform = transforms.Compose([ 228 | transforms.RandomResizedCrop(input_size), 229 | # Cutout(), 230 | transforms.RandomHorizontalFlip(), 231 | transforms.ColorJitter(0.4, 0.4, 0.4), 232 | transforms.ToTensor(), 233 | normalize, 234 | ]) 235 | 236 | val_transform = transforms.Compose([ 237 | transforms.Resize(int(input_size / 0.875)), 238 | transforms.CenterCrop(input_size), 239 | transforms.ToTensor(), 240 | normalize, 241 | ]) 242 | 243 | train_set = ImageNet(args.data_path, split='train', transform=train_transform) 244 | val_set = ImageNet(args.data_path, split='val', transform=val_transform) 245 | 246 | train_sampler = DistributedSampler(train_set) 247 | train_loader = DataLoader(train_set, args.batch_size, False, pin_memory=True, 248 | num_workers=args.num_workers, drop_last=True, sampler=train_sampler) 249 | val_loader = DataLoader(val_set, args.batch_size, False, pin_memory=True, num_workers=args.num_workers, 250 | drop_last=False) 251 | 252 | if resume_epoch > 0: 253 | loc = 'cuda:{}'.format(args.gpu) 254 | checkpoint = torch.load(args.resume_param, map_location=loc) 255 | model.load_state_dict(checkpoint['model']) 256 | optimizer.load_state_dict(checkpoint['optimizer']) 257 | scaler.load_state_dict(checkpoint['scaler']) 258 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 259 | print("Finish loading resume param.") 260 | 261 | torch.backends.cudnn.benchmark = True 262 | 263 | top1_acc = metric.Accuracy(name='Top1 Accuracy') 264 | top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy') 265 | loss_record = metric.NumericalCost(name='Loss') 266 | 267 | for epoch in range(resume_epoch, epochs): 268 | tic = time.time() 269 | train_sampler.set_epoch(epoch) 270 | if not args.mixup: 271 | train_one_epoch(model, train_loader, Loss, optimizer, epoch, lr_scheduler, 272 | logger, top1_acc, loss_record, scaler, args) 273 | else: 274 | train_one_epoch_mixup(model, train_loader, Loss, optimizer, epoch, lr_scheduler, 275 | logger, loss_record, scaler, args) 276 | train_speed = int(args.num_training_samples // (time.time() - tic)) 277 | if is_first_rank: 278 | logger.info('Finish one epoch speed: {} samples/s'.format(train_speed)) 279 | test(model, val_loader, Loss, epoch, logger, top1_acc, top5_acc, loss_record, args) 280 | 281 | if args.rank % ngpus_per_node == 0: 282 | checkpoint = { 283 | 'model': model.state_dict(), 284 | 'optimizer': optimizer.state_dict(), 285 | 'scaler': scaler.state_dict(), 286 | 'lr_scheduler': lr_scheduler.state_dict(), 287 | } 288 | torch.save(checkpoint, '{}/{}_{}_{:.5}.pt'.format( 289 | args.save_dir, args.model, epoch, top1_acc.get())) 290 | 291 | 292 | @torch.no_grad() 293 | def test(model, val_loader, criterion, epoch, logger, top1_acc, top5_acc, loss_record, args): 294 | top1_acc.reset() 295 | top5_acc.reset() 296 | loss_record.reset() 297 | 298 | model.eval() 299 | for data, labels in val_loader: 300 | data = data.cuda(args.gpu, non_blocking=True) 301 | labels = labels.cuda(args.gpu, non_blocking=True) 302 | 303 | outputs = model(data) 304 | losses = criterion(outputs, labels) 305 | 306 | top1_acc.update(outputs, labels) 307 | top5_acc.update(outputs, labels) 308 | loss_record.update(losses) 309 | 310 | test_msg = 'Test Epoch {}, Node {}, GPU {}: {}:{:.5}, {}:{:.5}, {}:{:.5}'.format( 311 | epoch, args.world, args.gpu, top1_acc.name, top1_acc.get(), top5_acc.name, 312 | top5_acc.get(), loss_record.name, loss_record.get()) 313 | logger.info(test_msg) 314 | 315 | 316 | def train_one_epoch(model, train_loader, criterion, optimizer, epoch, lr_scheduler, 317 | logger, top1_acc, loss_record, scaler, args): 318 | top1_acc.reset() 319 | loss_record.reset() 320 | tic = time.time() 321 | 322 | model.train() 323 | for i, (data, labels) in enumerate(train_loader): 324 | data = data.cuda(args.gpu, non_blocking=True) 325 | labels = labels.cuda(args.gpu, non_blocking=True) 326 | 327 | optimizer.zero_grad() 328 | with autocast(enabled=args.mix_precision_training): 329 | outputs = model(data) 330 | loss = criterion(outputs, labels) 331 | scaler.scale(loss).backward() 332 | scaler.step(optimizer) 333 | scaler.update() 334 | lr_scheduler.step() 335 | 336 | top1_acc.update(outputs, labels) 337 | loss_record.update(loss) 338 | 339 | if i % args.log_interval == 0 and i != 0: 340 | logger.info('Epoch {}, Node {}, GPU {}, Iter {}, {}:{:.5}, {}:{:.5}, {} samples/s. lr: {:.5}.'.format( 341 | epoch, args.world, args.gpu, i, top1_acc.name, top1_acc.get(), 342 | loss_record.name, loss_record.get(), 343 | int((i * args.batch_size) // (time.time() - tic)), 344 | lr_scheduler.learning_rate 345 | )) 346 | 347 | 348 | def train_one_epoch_mixup(model, train_loader, criterion, optimizer, epoch, lr_scheduler, 349 | logger, loss_record, scaler, args): 350 | loss_record.reset() 351 | tic = time.time() 352 | 353 | model.train() 354 | for i, (data, labels) in enumerate(train_loader): 355 | data = data.cuda(args.gpu, non_blocking=True) 356 | labels = labels.cuda(args.gpu, non_blocking=True) 357 | 358 | data, labels_a, labels_b, lam = mixup_data(data, labels, args.mixup_alpha) 359 | optimizer.zero_grad() 360 | with autocast(args.mix_precision_training): 361 | outputs = model(data) 362 | loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam) 363 | scaler.scale(loss).backward() 364 | scaler.step(optimizer) 365 | scaler.update() 366 | 367 | loss_record.update(loss) 368 | lr_scheduler.step() 369 | 370 | if i % args.log_interval == 0 and i != 0: 371 | logger.info('Epoch {}, Node {}, GPU {}, Iter {}, {}:{:.5}, {} samples/s, lr: {:.5}.'.format( 372 | epoch, args.world, args.gpu, i, loss_record.name, loss_record.get(), 373 | int((i * args.batch_size) // (time.time() - tic)), 374 | lr_scheduler.learning_rate)) 375 | 376 | 377 | if __name__ == '__main__': 378 | main() 379 | --------------------------------------------------------------------------------