├── Knowledge_Distillation.md ├── LICENSE ├── README.md ├── config ├── __pycache__ │ ├── config.cpython-38.pyc │ └── sgd_config.cpython-38.pyc ├── config.py └── sgd_config.py ├── export.py ├── figure ├── base_lr.png ├── custom_lr.png ├── warmup_custom_lr.png └── warmup_lr.png ├── main.py ├── metrice.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── convnext.cpython-38.pyc │ ├── cspnet.cpython-38.pyc │ ├── densenet.cpython-38.pyc │ ├── dpn.cpython-38.pyc │ ├── efficientnet.cpython-38.pyc │ ├── efficientnetv2.cpython-38.pyc │ ├── ghostnet.cpython-38.pyc │ ├── mnasnet.cpython-38.pyc │ ├── mobilenetv2.cpython-38.pyc │ ├── mobilenetv3.cpython-38.pyc │ ├── repghost.cpython-38.pyc │ ├── repvgg.cpython-38.pyc │ ├── resnest.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── sequencer.cpython-38.pyc │ ├── shufflenetv2.cpython-38.pyc │ ├── shufflenetv2.cpython-39.pyc │ ├── vgg.cpython-38.pyc │ └── vovnet.cpython-38.pyc ├── convnext.py ├── cspnet.py ├── densenet.py ├── dpn.py ├── efficientnetv2.py ├── ghostnet.py ├── mnasnet.py ├── mobilenetv2.py ├── mobilenetv3.py ├── repghost.py ├── repvgg.py ├── resnest.py ├── resnet.py ├── sequencer.py ├── shufflenetv2.py ├── vgg.py └── vovnet.py ├── predict.py ├── processing.py ├── requirements.txt └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── utils.cpython-38.pyc ├── utils.cpython-39.pyc ├── utils_aug.cpython-38.pyc ├── utils_aug.cpython-39.pyc ├── utils_distill.cpython-38.pyc ├── utils_fit.cpython-38.pyc ├── utils_fit.cpython-39.pyc ├── utils_loss.cpython-38.pyc ├── utils_model.cpython-38.pyc └── utils_model.cpython-39.pyc ├── utils.py ├── utils_aug.py ├── utils_distill.py ├── utils_fit.py ├── utils_loss.py └── utils_model.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 He JiaJie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /config/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/config/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /config/__pycache__/sgd_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/config/__pycache__/sgd_config.cpython-38.pyc -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from argparse import Namespace 4 | from utils.utils_aug import CutOut, Create_Albumentations_From_Name 5 | 6 | class Config: 7 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR 8 | lr_scheduler_params = { 9 | 'T_max': 10, 10 | 'eta_min': 1e-6 11 | } 12 | random_seed = 0 13 | plot_train_batch_count = 5 14 | custom_augment = transforms.Compose([ 15 | # transforms.RandomChoice([ 16 | # transforms.RandomHorizontalFlip(p=0.5), 17 | # transforms.RandomVerticalFlip(p=0.5), 18 | # ]), 19 | # transforms.RandomRotation(45), 20 | # Create_Albumentations_From_Name('PixelDropout', p=1.0), 21 | # Create_Albumentations_From_Name('RandomGridShuffle', grid=(16, 16)) 22 | ]) 23 | 24 | def _get_opt(self): 25 | config_dict = {name:getattr(self, name) for name in dir(self) if name[0] != '_'} 26 | return Namespace(**config_dict) 27 | 28 | if __name__ == '__main__': 29 | config = Config() 30 | print(config._get_opt()) -------------------------------------------------------------------------------- /config/sgd_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from argparse import Namespace 4 | from utils.utils_aug import CutOut, Create_Albumentations_From_Name 5 | 6 | class Config: 7 | lr_scheduler = torch.optim.lr_scheduler.StepLR 8 | lr_scheduler_params = { 9 | 'gamma': 0.8, 10 | 'step_size': 5 11 | } 12 | random_seed = 0 13 | plot_train_batch_count = 5 14 | custom_augment = transforms.Compose([ 15 | # transforms.RandomChoice([ 16 | # transforms.RandomHorizontalFlip(p=0.5), 17 | # transforms.RandomVerticalFlip(p=0.5), 18 | # ]), 19 | # transforms.RandomRotation(45), 20 | ]) 21 | 22 | def _get_opt(self): 23 | config_dict = {name:getattr(self, name) for name in dir(self) if name[0] != '_'} 24 | return Namespace(**config_dict) 25 | 26 | if __name__ == '__main__': 27 | config = Config() 28 | print(config._get_opt()) -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 4 | import torch 5 | import torch.nn as nn 6 | from utils.utils import select_device, model_fuse 7 | 8 | def export_torchscript(opt, model, img, prefix='TorchScript'): 9 | print('Starting TorchScript export with pytorch %s...' % torch.__version__) 10 | f = os.path.join(opt.save_path, 'best.ts') 11 | ts = torch.jit.trace(model, img, strict=False) 12 | ts.save(f) 13 | print(f'Export TorchScript Model Successfully.\nSave sa {f}') 14 | 15 | def export_onnx(opt, model, img, prefix='ONNX'): 16 | import onnx 17 | f = os.path.join(opt.save_path, 'best.onnx') 18 | print('Starting ONNX export with onnx %s...' % onnx.__version__) 19 | if opt.dynamic: 20 | dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, 'output':{0: 'batch'}} 21 | else: 22 | dynamic_axes = None 23 | 24 | torch.onnx.export( 25 | (model.to('cpu') if opt.dynamic else model), 26 | (img.to('cpu') if opt.dynamic else img), 27 | f, verbose=False, opset_version=13, input_names=['images'], output_names=['output'], dynamic_axes=dynamic_axes) 28 | 29 | onnx_model = onnx.load(f) # load onnx model 30 | onnx.checker.check_model(onnx_model) # check onnx model 31 | 32 | if opt.simplify: 33 | try: 34 | import onnxsim 35 | print('\nStarting to simplify ONNX...') 36 | onnx_model, check = onnxsim.simplify(onnx_model) 37 | assert check, 'assert check failed' 38 | except Exception as e: 39 | print(f'Simplifier failure: {e}') 40 | onnx.save(onnx_model, f) 41 | 42 | print(f'Export Onnx Model Successfully.\nSave sa {f}') 43 | 44 | def export_engine(opt, model, img, workspace=4, prefix='TensorRT'): 45 | export_onnx(opt, model, img) 46 | onnx_file = os.path.join(opt.save_path, 'best.onnx') 47 | assert img.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' 48 | import tensorrt as trt 49 | print('Starting TensorRT export with TensorRT %s...' % trt.__version__) 50 | f = os.path.join(opt.save_path, 'best.engine') 51 | 52 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if opt.verbose else trt.Logger() 53 | builder = trt.Builder(TRT_LOGGER) 54 | config = builder.create_builder_config() 55 | config.max_workspace_size = workspace * 1 << 30 56 | 57 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 58 | network = builder.create_network(flag) 59 | parser = trt.OnnxParser(network, TRT_LOGGER) 60 | if not parser.parse_from_file(str(onnx_file)): 61 | raise RuntimeError(f'failed to load ONNX file: {onnx_file}') 62 | 63 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 64 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 65 | for inp in inputs: 66 | print(f'input {inp.name} with shape {inp.shape} and dtype {inp.dtype}') 67 | for out in outputs: 68 | print(f'output {out.name} with shape {out.shape} and dtype {out.dtype}') 69 | 70 | if opt.dynamic: 71 | if img.shape[0] <= 1: 72 | print(f"{prefix} WARNING: --dynamic model requires maximum --batch-size argument") 73 | profile = builder.create_optimization_profile() 74 | for inp in inputs: 75 | profile.set_shape(inp.name, (1, *img.shape[1:]), (max(1, img.shape[0] // 2), *img.shape[1:]), img.shape) 76 | config.add_optimization_profile(profile) 77 | 78 | print(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and opt.half else 32} engine in {f}') 79 | if builder.platform_has_fast_fp16 and opt.half: 80 | config.set_flag(trt.BuilderFlag.FP16) 81 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 82 | t.write(engine.serialize()) 83 | print(f'Export TensorRT Model Successfully.\nSave sa {f}') 84 | 85 | def parse_opt(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log') 88 | parser.add_argument('--image_size', type=int, default=224, help='image size') 89 | parser.add_argument('--image_channel', type=int, default=3, help='image channel') 90 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 91 | parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX batchsize') 92 | parser.add_argument('--simplify', action='store_true', help='simplify onnx model') 93 | parser.add_argument('--half', action="store_true", help='FP32 to FP16') 94 | parser.add_argument('--verbose', action="store_true", help='TensorRT:verbose export log') 95 | parser.add_argument('--export', default='torchscript', type=str, choices=['onnx', 'torchscript', 'tensorrt'], help='export type') 96 | parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 97 | 98 | opt = parser.parse_known_args()[0] 99 | if not os.path.exists(os.path.join(opt.save_path, 'best.pt')): 100 | raise Exception('best.pt not found. please check your --save_path folder') 101 | DEVICE = select_device(opt.device) 102 | if opt.half: 103 | assert DEVICE.type != 'cpu', '--half only supported with GPU export' 104 | assert not opt.dynamic, '--half not compatible with --dynamic' 105 | ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) 106 | model = ckpt['model'].float().to(DEVICE) 107 | model_fuse(model) 108 | img = torch.rand((opt.batch_size, opt.image_channel, opt.image_size, opt.image_size)).to(DEVICE) 109 | 110 | return opt, (model.half() if opt.half else model), (img.half() if opt.half else img), DEVICE 111 | 112 | if __name__ == '__main__': 113 | opt, model, img, DEVICE = parse_opt() 114 | 115 | if opt.export == 'onnx': 116 | export_onnx(opt, model, img) 117 | elif opt.export == 'torchscript': 118 | export_torchscript(opt, model, img) 119 | elif opt.export == 'tensorrt': 120 | export_engine(opt, model, img) -------------------------------------------------------------------------------- /figure/base_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/figure/base_lr.png -------------------------------------------------------------------------------- /figure/custom_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/figure/custom_lr.png -------------------------------------------------------------------------------- /figure/warmup_custom_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/figure/warmup_custom_lr.png -------------------------------------------------------------------------------- /figure/warmup_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/figure/warmup_lr.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | from PIL import ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | import os, argparse, shutil, random, imp 6 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 7 | import numpy as np 8 | import torch, torchvision, time, datetime, copy 9 | from sklearn.utils.class_weight import compute_class_weight 10 | from copy import deepcopy 11 | from utils.utils_fit import fitting, fitting_distill 12 | from utils.utils_model import select_model 13 | from utils import utils_aug 14 | from utils.utils import save_model, plot_train_batch, WarmUpLR, show_config, setting_optimizer, check_batch_size, \ 15 | plot_log, update_opt, load_weights, get_channels, dict_to_PrettyTable, ModelEMA, select_device 16 | from utils.utils_distill import * 17 | from utils.utils_loss import * 18 | 19 | torch.backends.cudnn.deterministic = True 20 | def set_seed(seed): 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | def parse_opt(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--model_name', type=str, default='resnet18', help='model name') 29 | parser.add_argument('--pretrained', action="store_true", help='using pretrain weight') 30 | parser.add_argument('--weight', type=str, default='', help='loading weight path') 31 | parser.add_argument('--config', type=str, default='config/config.py', help='config path') 32 | parser.add_argument('--device', type=str, default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 33 | 34 | parser.add_argument('--train_path', type=str, default=r'dataset/train', help='train data path') 35 | parser.add_argument('--val_path', type=str, default=r'dataset/val', help='val data path') 36 | parser.add_argument('--test_path', type=str, default=r'dataset/test', help='test data path') 37 | parser.add_argument('--label_path', type=str, default=r'dataset/label.txt', help='label path') 38 | parser.add_argument('--image_size', type=int, default=224, help='image size') 39 | parser.add_argument('--image_channel', type=int, default=3, help='image channel') 40 | parser.add_argument('--workers', type=int, default=4, help='dataloader workers') 41 | parser.add_argument('--batch_size', type=int, default=64, help='batch size (-1 for autobatch)') 42 | parser.add_argument('--epoch', type=int, default=100, help='epoch') 43 | parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log') 44 | parser.add_argument('--resume', action="store_true", help='resume from save_path traning') 45 | 46 | # optimizer parameters 47 | parser.add_argument('--loss', type=str, choices=['PolyLoss', 'CrossEntropyLoss', 'FocalLoss'], 48 | default='CrossEntropyLoss', help='loss function') 49 | parser.add_argument('--optimizer', type=str, choices=['SGD', 'AdamW', 'RMSProp'], default='AdamW', help='optimizer') 50 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 51 | parser.add_argument('--label_smoothing', type=float, default=0.1, help='label smoothing') 52 | parser.add_argument('--class_balance', action="store_true", help='using class balance in loss') 53 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight_decay') 54 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum in optimizer') 55 | parser.add_argument('--amp', action="store_true", help='using AMP(Automatic Mixed Precision)') 56 | parser.add_argument('--warmup', action="store_true", help='using WarmUp LR') 57 | parser.add_argument('--warmup_ratios', type=float, default=0.05, 58 | help='warmup_epochs = int(warmup_ratios * epoch) if warmup=True') 59 | parser.add_argument('--warmup_minlr', type=float, default=1e-6, 60 | help='minimum lr in warmup(also as minimum lr in training)') 61 | parser.add_argument('--metrice', type=str, choices=['loss', 'acc', 'mean_acc'], default='acc', help='best.pt save relu') 62 | parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (--metrice without improvement)') 63 | 64 | # Data Processing parameters 65 | parser.add_argument('--imagenet_meanstd', action="store_true", help='using ImageNet Mean and Std') 66 | parser.add_argument('--mixup', type=str, choices=['mixup', 'cutmix', 'none'], default='none', help='MixUp Methods') 67 | parser.add_argument('--Augment', type=str, 68 | choices=['RandAugment', 'AutoAugment', 'TrivialAugmentWide', 'AugMix', 'none'], default='none', 69 | help='Data Augment') 70 | parser.add_argument('--test_tta', action="store_true", help='using TTA') 71 | 72 | # Knowledge Distillation parameters 73 | parser.add_argument('--kd', action="store_true", help='Knowledge Distillation') 74 | parser.add_argument('--kd_ratio', type=float, default=0.7, help='Knowledge Distillation Loss ratio') 75 | parser.add_argument('--kd_method', type=str, choices=['SoftTarget', 'MGD', 'SP', 'AT'], default='SoftTarget', help='Knowledge Distillation Method') 76 | parser.add_argument('--teacher_path', type=str, default='', help='teacher model path') 77 | 78 | # Tricks parameters 79 | parser.add_argument('--rdrop', action="store_true", help='using R-Drop') 80 | parser.add_argument('--ema', action="store_true", help='using EMA(Exponential Moving Average) Reference to YOLOV5') 81 | 82 | opt = parser.parse_known_args()[0] 83 | if opt.resume: 84 | opt.resume = True 85 | if not os.path.exists(os.path.join(opt.save_path, 'last.pt')): 86 | raise Exception('last.pt not found. please check your --save_path folder and --resume parameters') 87 | ckpt = torch.load(os.path.join(opt.save_path, 'last.pt')) 88 | opt = ckpt['opt'] 89 | opt.resume = True 90 | print('found checkpoint from {}, model type:{}\n{}'.format(opt.save_path, ckpt['model'].name, dict_to_PrettyTable(ckpt['best_metrice'], 'Best Metrice'))) 91 | else: 92 | if os.path.exists(opt.save_path): 93 | shutil.rmtree(opt.save_path) 94 | os.makedirs(opt.save_path) 95 | config = imp.load_source('config', opt.config).Config() 96 | shutil.copy(__file__, os.path.join(opt.save_path, 'main.py')) 97 | shutil.copy(opt.config, os.path.join(opt.save_path, 'config.py')) 98 | opt = update_opt(opt, config._get_opt()) 99 | 100 | set_seed(opt.random_seed) 101 | show_config(deepcopy(opt)) 102 | 103 | CLASS_NUM = len(os.listdir(opt.train_path)) 104 | DEVICE = select_device(opt.device, opt.batch_size) 105 | 106 | train_transform, test_transform = utils_aug.get_dataprocessing(torchvision.datasets.ImageFolder(opt.train_path), 107 | opt) 108 | train_dataset = torchvision.datasets.ImageFolder(opt.train_path, transform=train_transform) 109 | test_dataset = torchvision.datasets.ImageFolder(opt.val_path, transform=test_transform) 110 | if opt.resume: 111 | model = ckpt['model'].to(DEVICE).float() 112 | else: 113 | model = select_model(opt.model_name, CLASS_NUM, (opt.image_size, opt.image_size), opt.image_channel, 114 | opt.pretrained) 115 | model = load_weights(model, opt).to(DEVICE) 116 | plot_train_batch(copy.deepcopy(train_dataset), opt) 117 | 118 | batch_size = opt.batch_size if opt.batch_size != -1 else check_batch_size(model, opt.image_size, amp=opt.amp) 119 | 120 | if opt.class_balance: 121 | class_weight = np.sqrt(compute_class_weight('balanced', classes=np.unique(train_dataset.targets), y=train_dataset.targets)) 122 | else: 123 | class_weight = np.ones_like(np.unique(train_dataset.targets)) 124 | print('class weight: {}'.format(class_weight)) 125 | 126 | train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=opt.workers) 127 | test_dataset = torch.utils.data.DataLoader(test_dataset, max(batch_size // (10 if opt.test_tta else 1), 1), 128 | shuffle=False, num_workers=(0 if opt.test_tta else opt.workers)) 129 | scaler = torch.cuda.amp.GradScaler(enabled=(opt.amp if torch.cuda.is_available() else False)) 130 | ema = ModelEMA(model) if opt.ema else None 131 | optimizer = setting_optimizer(opt, model) 132 | lr_scheduler = WarmUpLR(optimizer, opt) 133 | if opt.resume: 134 | optimizer.load_state_dict(ckpt['optimizer']) 135 | lr_scheduler.load_state_dict(ckpt['lr_scheduler']) 136 | loss = ckpt['loss'].to(DEVICE) 137 | scaler.load_state_dict(ckpt['scaler']) 138 | if opt.ema: 139 | ema.ema = ckpt['ema'].to(DEVICE).float() 140 | ema.updates = ckpt['updates'] 141 | else: 142 | loss = eval(opt.loss)(label_smoothing=opt.label_smoothing, 143 | weight=torch.from_numpy(class_weight).to(DEVICE).float()) 144 | if opt.rdrop: 145 | loss = RDropLoss(loss) 146 | return opt, model, ema, train_dataset, test_dataset, optimizer, scaler, lr_scheduler, loss, DEVICE, CLASS_NUM, ( 147 | ckpt['epoch'] if opt.resume else 0), (ckpt['best_metrice'] if opt.resume else None) 148 | 149 | 150 | if __name__ == '__main__': 151 | opt, model, ema, train_dataset, test_dataset, optimizer, scaler, lr_scheduler, loss, DEVICE, CLASS_NUM, begin_epoch, best_metrice = parse_opt() 152 | 153 | if not opt.resume: 154 | save_epoch = 0 155 | with open(os.path.join(opt.save_path, 'train.log'), 'w+') as f: 156 | if opt.kd: 157 | f.write('epoch,lr,loss,kd_loss,acc,mean_acc,test_loss,test_acc,test_mean_acc') 158 | else: 159 | f.write('epoch,lr,loss,acc,mean_acc,test_loss,test_acc,test_mean_acc') 160 | else: 161 | save_epoch = torch.load(os.path.join(opt.save_path, 'last.pt'))['best_epoch'] 162 | 163 | if opt.kd: 164 | if not os.path.exists(os.path.join(opt.teacher_path, 'best.pt')): 165 | raise Exception('teacher best.pt not found. please check your --teacher_path folder') 166 | teacher_ckpt = torch.load(os.path.join(opt.teacher_path, 'best.pt')) 167 | teacher_model = teacher_ckpt['model'].float().to(DEVICE).eval() 168 | print('found teacher checkpoint from {}, model type:{}\n{}'.format(opt.teacher_path, teacher_model.name, dict_to_PrettyTable(teacher_ckpt['best_metrice'], 'Best Metrice'))) 169 | 170 | if opt.resume: 171 | kd_loss = torch.load(os.path.join(opt.save_path, 'last.pt'))['kd_loss'].to(DEVICE) 172 | else: 173 | if opt.kd_method == 'SoftTarget': 174 | kd_loss = SoftTarget().to(DEVICE) 175 | elif opt.kd_method == 'MGD': 176 | kd_loss = MGD(get_channels(model, opt), get_channels(teacher_model, opt)).to(DEVICE) 177 | optimizer.add_param_group({'params': kd_loss.parameters(), 'weight_decay': opt.weight_decay}) 178 | elif opt.kd_method == 'SP': 179 | kd_loss = SP().to(DEVICE) 180 | elif opt.kd_method == 'AT': 181 | kd_loss = AT().to(DEVICE) 182 | 183 | print('{} begin train!'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) 184 | for epoch in range(begin_epoch, opt.epoch): 185 | if epoch > (save_epoch + opt.patience) and opt.patience != 0: 186 | print('No Improve from {} to {}, EarlyStopping.'.format(save_epoch + 1, epoch)) 187 | break 188 | 189 | begin = time.time() 190 | if opt.kd: 191 | metrice = fitting_distill(teacher_model, model, ema, loss, kd_loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, '{}/{}'.format(epoch + 1,opt.epoch), opt) 192 | else: 193 | metrice = fitting(model, ema, loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler,'{}/{}'.format(epoch + 1, opt.epoch), opt) 194 | 195 | with open(os.path.join(opt.save_path, 'train.log'), 'a+') as f: 196 | f.write( 197 | '\n{},{:.10f},{}'.format(epoch + 1, optimizer.param_groups[2]['lr'], metrice[1])) 198 | 199 | n_lr = optimizer.param_groups[2]['lr'] 200 | lr_scheduler.step() 201 | 202 | if best_metrice is None: 203 | best_metrice = metrice[0] 204 | else: 205 | if eval('{} {} {}'.format(metrice[0]['test_{}'.format(opt.metrice)], '<' if opt.metrice == 'loss' else '>', best_metrice['test_{}'.format(opt.metrice)])): 206 | best_metrice = metrice[0] 207 | save_model( 208 | os.path.join(opt.save_path, 'best.pt'), 209 | **{ 210 | 'model': (deepcopy(ema.ema).to('cpu').half() if opt.ema else deepcopy(model).to('cpu').half()), 211 | 'opt': opt, 212 | 'best_metrice': best_metrice, 213 | } 214 | ) 215 | save_epoch = epoch 216 | 217 | save_model( 218 | os.path.join(opt.save_path, 'last.pt'), 219 | **{ 220 | 'model': deepcopy(model).to('cpu').half(), 221 | 'ema': (deepcopy(ema.ema).to('cpu').half() if opt.ema else None), 222 | 'updates': (ema.updates if opt.ema else None), 223 | 'opt': opt, 224 | 'epoch': epoch + 1, 225 | 'optimizer' : optimizer.state_dict(), 226 | 'lr_scheduler': lr_scheduler.state_dict(), 227 | 'best_metrice': best_metrice, 228 | 'loss': deepcopy(loss).to('cpu'), 229 | 'kd_loss': (deepcopy(kd_loss).to('cpu') if opt.kd else None), 230 | 'scaler': scaler.state_dict(), 231 | 'best_epoch': save_epoch, 232 | } 233 | ) 234 | 235 | print(dict_to_PrettyTable(metrice[0], '{} epoch:{}/{}, best_epoch:{}, time:{:.2f}s, lr:{:.8f}'.format( 236 | datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 237 | epoch + 1, opt.epoch, save_epoch + 1, time.time() - begin, n_lr, 238 | ))) 239 | 240 | plot_log(opt) 241 | -------------------------------------------------------------------------------- /metrice.py: -------------------------------------------------------------------------------- 1 | import warnings, sys, datetime, random 2 | warnings.filterwarnings("ignore") 3 | from PIL import ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | import os, torch, argparse, time, torchvision, tqdm 6 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 7 | import numpy as np 8 | from utils import utils_aug 9 | from utils.utils import classification_metrice, Metrice_Dataset, visual_predictions, visual_tsne, dict_to_PrettyTable, Model_Inference, select_device, model_fuse 10 | 11 | torch.backends.cudnn.deterministic = True 12 | def set_seed(seed): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | def parse_opt(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--train_path', type=str, default=r'dataset/train', help='train data path') 21 | parser.add_argument('--val_path', type=str, default=r'dataset/val', help='val data path') 22 | parser.add_argument('--test_path', type=str, default=r'dataset/test', help='test data path') 23 | parser.add_argument('--label_path', type=str, default=r'dataset/label.txt', help='label path') 24 | parser.add_argument('--device', type=str, default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 25 | parser.add_argument('--task', type=str, choices=['train', 'val', 'test', 'fps'], default='test', help='train, val, test, fps') 26 | parser.add_argument('--workers', type=int, default=4, help='dataloader workers') 27 | parser.add_argument('--batch_size', type=int, default=64, help='batch size') 28 | parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log') 29 | parser.add_argument('--test_tta', action="store_true", help='using TTA Tricks') 30 | parser.add_argument('--visual', action="store_true", help='visual dataset identification') 31 | parser.add_argument('--tsne', action="store_true", help='visual tsne') 32 | parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference') 33 | parser.add_argument('--model_type', type=str, choices=['torch', 'torchscript', 'onnx', 'tensorrt'], default='torch', help='model type(default: torch)') 34 | 35 | opt = parser.parse_known_args()[0] 36 | 37 | DEVICE = select_device(opt.device, opt.batch_size) 38 | if opt.half and DEVICE.type == 'cpu': 39 | raise Exception('half inference only supported GPU.') 40 | if not os.path.exists(os.path.join(opt.save_path, 'best.pt')): 41 | raise Exception('best.pt not found. please check your --save_path folder') 42 | ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) 43 | train_opt = ckpt['opt'] 44 | set_seed(train_opt.random_seed) 45 | model = Model_Inference(DEVICE, opt) 46 | 47 | print('found checkpoint from {}, model type:{}\n{}'.format(opt.save_path, ckpt['model'].name, dict_to_PrettyTable(ckpt['best_metrice'], 'Best Metrice'))) 48 | 49 | test_transform = utils_aug.get_dataprocessing_teststage(train_opt, opt, torch.load(os.path.join(opt.save_path, 'preprocess.transforms'))) 50 | 51 | if opt.task == 'fps': 52 | inputs = torch.rand((opt.batch_size, train_opt.image_channel, train_opt.image_size, train_opt.image_size)).to(DEVICE) 53 | if opt.half and torch.cuda.is_available(): 54 | inputs = inputs.half() 55 | warm_up, test_time = 100, 300 56 | fps_arr = [] 57 | for i in tqdm.tqdm(range(test_time + warm_up)): 58 | since = time.time() 59 | with torch.inference_mode(): 60 | model(inputs) 61 | if i > warm_up: 62 | fps_arr.append(time.time() - since) 63 | fps = np.mean(fps_arr) 64 | print('{:.6f} seconds, {:.2f} fps, @batch_size {}'.format(fps, 1 / fps, opt.batch_size)) 65 | sys.exit(0) 66 | else: 67 | save_path = os.path.join(opt.save_path, opt.task, datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')) 68 | if not os.path.exists(save_path): 69 | os.makedirs(save_path) 70 | 71 | CLASS_NUM = len(os.listdir(eval('opt.{}_path'.format(opt.task)))) 72 | test_dataset = Metrice_Dataset(torchvision.datasets.ImageFolder(eval('opt.{}_path'.format(opt.task)), transform=test_transform)) 73 | test_dataset = torch.utils.data.DataLoader(test_dataset, opt.batch_size, shuffle=False, 74 | num_workers=(0 if opt.test_tta else opt.workers)) 75 | 76 | try: 77 | with open(opt.label_path, encoding='utf-8') as f: 78 | label = list(map(lambda x: x.strip(), f.readlines())) 79 | except Exception as e: 80 | with open(opt.label_path, encoding='gbk') as f: 81 | label = list(map(lambda x: x.strip(), f.readlines())) 82 | 83 | return opt, model, test_dataset, DEVICE, CLASS_NUM, label, save_path 84 | 85 | 86 | if __name__ == '__main__': 87 | opt, model, test_dataset, DEVICE, CLASS_NUM, label, save_path = parse_opt() 88 | y_true, y_pred, y_score, y_feature, img_path = [], [], [], [], [] 89 | with torch.inference_mode(): 90 | for x, y, path in tqdm.tqdm(test_dataset, desc='Test Stage'): 91 | x = (x.half().to(DEVICE) if opt.half else x.to(DEVICE)) 92 | if opt.test_tta: 93 | bs, ncrops, c, h, w = x.size() 94 | pred = model(x.view(-1, c, h, w)) 95 | pred = pred.view(bs, ncrops, -1).mean(1) 96 | 97 | if opt.tsne: 98 | pred_feature = model.forward_features(x.view(-1, c, h, w)) 99 | pred_feature = pred_feature.view(bs, ncrops, -1).mean(1) 100 | else: 101 | pred = model(x) 102 | 103 | if opt.tsne: 104 | pred_feature = model.forward_features(x) 105 | try: 106 | pred = torch.softmax(pred, 1) 107 | except: 108 | pred = torch.softmax(torch.from_numpy(pred), 1) # using torch.softmax will faster than numpy 109 | 110 | y_true.extend(list(y.cpu().detach().numpy())) 111 | y_pred.extend(list(pred.argmax(-1).cpu().detach().numpy())) 112 | y_score.extend(list(pred.max(-1)[0].cpu().detach().numpy())) 113 | img_path.extend(list(path)) 114 | 115 | if opt.tsne: 116 | y_feature.extend(list(pred_feature.cpu().detach().numpy())) 117 | 118 | classification_metrice(np.array(y_true), np.array(y_pred), CLASS_NUM, label, save_path) 119 | if opt.visual: 120 | visual_predictions(np.array(y_true), np.array(y_pred), np.array(y_score), np.array(img_path), label, save_path) 121 | if opt.tsne: 122 | visual_tsne(np.array(y_feature), np.array(y_pred), np.array(img_path), label, save_path) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .shufflenetv2 import * 2 | from .mobilenetv2 import * 3 | from .mobilenetv3 import * 4 | from .resnet import * 5 | from .densenet import * 6 | from .vgg import * 7 | from .efficientnetv2 import * 8 | from .mnasnet import * 9 | from .vovnet import * 10 | from .convnext import * 11 | from .resnest import * 12 | from .ghostnet import * 13 | from .repvgg import * 14 | from .sequencer import * 15 | from .cspnet import * 16 | from .dpn import * 17 | from .repghost import * -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/convnext.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/convnext.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/cspnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/cspnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/densenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/densenet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/dpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/dpn.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/efficientnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/efficientnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/efficientnetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/efficientnetv2.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/ghostnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/ghostnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/mnasnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/mnasnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/mobilenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/mobilenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/mobilenetv3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/mobilenetv3.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/repghost.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/repghost.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/repvgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/repvgg.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/resnest.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/resnest.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/sequencer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/sequencer.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/shufflenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/shufflenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/shufflenetv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/shufflenetv2.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/vovnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/model/__pycache__/vovnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from timm.models.layers import trunc_normal_, DropPath 6 | from utils.utils import load_weights_from_state_dict 7 | 8 | __all__ = ['convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'] 9 | 10 | class Block(nn.Module): 11 | r""" ConvNeXt Block. There are two equivalent implementations: 12 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 13 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 14 | We use (2) as we find it slightly faster in PyTorch 15 | 16 | Args: 17 | dim (int): Number of input channels. 18 | drop_path (float): Stochastic depth rate. Default: 0.0 19 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 20 | """ 21 | 22 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 23 | super().__init__() 24 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 25 | self.norm = LayerNorm(dim, eps=1e-6) 26 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 27 | self.act = nn.GELU() 28 | self.pwconv2 = nn.Linear(4 * dim, dim) 29 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 30 | requires_grad=True) if layer_scale_init_value > 0 else None 31 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 32 | 33 | def forward(self, x): 34 | input = x 35 | x = self.dwconv(x) 36 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 37 | x = self.norm(x) 38 | x = self.pwconv1(x) 39 | x = self.act(x) 40 | x = self.pwconv2(x) 41 | if self.gamma is not None: 42 | x = self.gamma * x 43 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 44 | 45 | x = input + self.drop_path(x) 46 | return x 47 | 48 | 49 | class ConvNeXt(nn.Module): 50 | r""" ConvNeXt 51 | A PyTorch impl of : `A ConvNet for the 2020s` - 52 | https://arxiv.org/pdf/2201.03545.pdf 53 | Args: 54 | in_chans (int): Number of input image channels. Default: 3 55 | num_classes (int): Number of classes for classification head. Default: 1000 56 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 57 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 58 | drop_path_rate (float): Stochastic depth rate. Default: 0. 59 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 60 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 61 | """ 62 | 63 | def __init__(self, in_chans=3, num_classes=1000, 64 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 65 | layer_scale_init_value=1e-6, head_init_scale=1., 66 | ): 67 | super().__init__() 68 | 69 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 70 | stem = nn.Sequential( 71 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 72 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 73 | ) 74 | self.downsample_layers.append(stem) 75 | for i in range(3): 76 | downsample_layer = nn.Sequential( 77 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 78 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), 79 | ) 80 | self.downsample_layers.append(downsample_layer) 81 | 82 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 83 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 84 | cur = 0 85 | for i in range(4): 86 | stage = nn.Sequential( 87 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 88 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 89 | ) 90 | self.stages.append(stage) 91 | cur += depths[i] 92 | 93 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 94 | self.head = nn.Linear(dims[-1], num_classes) 95 | 96 | self.apply(self._init_weights) 97 | self.head.weight.data.mul_(head_init_scale) 98 | self.head.bias.data.mul_(head_init_scale) 99 | 100 | def _init_weights(self, m): 101 | if isinstance(m, (nn.Conv2d, nn.Linear)): 102 | trunc_normal_(m.weight, std=.02) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | def forward_features(self, x, need_fea=False): 106 | if need_fea: 107 | features = [] 108 | for i in range(4): 109 | x = self.downsample_layers[i](x) 110 | x = self.stages[i](x) 111 | features.append(x) 112 | return features, self.norm(features[-1].mean([-2, -1])) 113 | else: 114 | for i in range(4): 115 | x = self.downsample_layers[i](x) 116 | x = self.stages[i](x) 117 | return self.norm(x.mean([-2, -1])) 118 | 119 | def forward(self, x, need_fea=False): 120 | if need_fea: 121 | features, features_fc = self.forward_features(x, need_fea=need_fea) 122 | x = self.head(features_fc) 123 | return features, features_fc, x 124 | else: 125 | x = self.forward_features(x) 126 | x = self.head(x) 127 | return x 128 | 129 | def cam_layer(self): 130 | return self.stages[-1] 131 | 132 | 133 | class LayerNorm(nn.Module): 134 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 135 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 136 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 137 | with shape (batch_size, channels, height, width). 138 | """ 139 | 140 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 141 | super().__init__() 142 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 143 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 144 | self.eps = eps 145 | self.data_format = data_format 146 | if self.data_format not in ["channels_last", "channels_first"]: 147 | raise NotImplementedError 148 | self.normalized_shape = (normalized_shape,) 149 | 150 | def forward(self, x): 151 | if self.data_format == "channels_last": 152 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 153 | elif self.data_format == "channels_first": 154 | u = x.mean(1, keepdim=True) 155 | s = (x - u).pow(2).mean(1, keepdim=True) 156 | x = (x - u) / torch.sqrt(s + self.eps) 157 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 158 | return x 159 | 160 | 161 | model_urls = { 162 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 163 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 164 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 165 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 166 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 167 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 168 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 169 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 170 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 171 | } 172 | 173 | def convnext_tiny(pretrained=False, in_22k=False, **kwargs): 174 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 175 | if pretrained: 176 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 177 | state_dict = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)['model'] 178 | model = load_weights_from_state_dict(model, state_dict) 179 | return model 180 | 181 | def convnext_small(pretrained=False, in_22k=False, **kwargs): 182 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 183 | if pretrained: 184 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 185 | state_dict = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)['model'] 186 | model = load_weights_from_state_dict(model, state_dict) 187 | return model 188 | 189 | 190 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 191 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 192 | if pretrained: 193 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 194 | state_dict = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)['model'] 195 | model = load_weights_from_state_dict(model, state_dict) 196 | return model 197 | 198 | 199 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 200 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 201 | if pretrained: 202 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 203 | state_dict = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)['model'] 204 | model = load_weights_from_state_dict(model, state_dict) 205 | return model 206 | 207 | 208 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 209 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 210 | if pretrained: 211 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 212 | url = model_urls['convnext_xlarge_22k'] 213 | state_dict = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)['model'] 214 | model = load_weights_from_state_dict(model, state_dict) 215 | return model 216 | 217 | if __name__ == '__main__': 218 | inputs = torch.rand((1, 3, 224, 224)) 219 | model = convnext_small(pretrained=True) 220 | model.eval() 221 | out = model(inputs) 222 | print('out shape:{}'.format(out.size())) 223 | feas, fea_fc, out = model(inputs, True) 224 | for idx, fea in enumerate(feas): 225 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 226 | print('fc shape:{}'.format(fea_fc.size())) 227 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.checkpoint as cp 7 | from collections import OrderedDict 8 | from torchvision._internally_replaced_utils import load_state_dict_from_url 9 | from torch import Tensor 10 | from typing import Any, List, Tuple 11 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 12 | 13 | __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] 14 | 15 | model_urls = { 16 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 17 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 18 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 19 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 20 | } 21 | 22 | 23 | class _DenseLayer(nn.Module): 24 | def __init__( 25 | self, 26 | num_input_features: int, 27 | growth_rate: int, 28 | bn_size: int, 29 | drop_rate: float, 30 | memory_efficient: bool = False 31 | ) -> None: 32 | super(_DenseLayer, self).__init__() 33 | self.norm1: nn.BatchNorm2d 34 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)) 35 | self.relu1: nn.ReLU 36 | self.add_module('relu1', nn.ReLU(inplace=True)) 37 | self.conv1: nn.Conv2d 38 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 39 | growth_rate, kernel_size=1, stride=1, 40 | bias=False)) 41 | self.norm2: nn.BatchNorm2d 42 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)) 43 | self.relu2: nn.ReLU 44 | self.add_module('relu2', nn.ReLU(inplace=True)) 45 | self.conv2: nn.Conv2d 46 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 47 | kernel_size=3, stride=1, padding=1, 48 | bias=False)) 49 | self.drop_rate = float(drop_rate) 50 | self.memory_efficient = memory_efficient 51 | 52 | def bn_function(self, inputs: List[Tensor]) -> Tensor: 53 | concated_features = torch.cat(inputs, 1) 54 | bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 55 | return bottleneck_output 56 | 57 | # todo: rewrite when torchscript supports any 58 | def any_requires_grad(self, input: List[Tensor]) -> bool: 59 | for tensor in input: 60 | if tensor.requires_grad: 61 | return True 62 | return False 63 | 64 | @torch.jit.unused # noqa: T484 65 | def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor: 66 | def closure(*inputs): 67 | return self.bn_function(inputs) 68 | 69 | return cp.checkpoint(closure, *input) 70 | 71 | @torch.jit._overload_method # noqa: F811 72 | def forward(self, input: List[Tensor]) -> Tensor: 73 | pass 74 | 75 | @torch.jit._overload_method # noqa: F811 76 | def forward(self, input: Tensor) -> Tensor: 77 | pass 78 | 79 | # torchscript does not yet support *args, so we overload method 80 | # allowing it to take either a List[Tensor] or single Tensor 81 | def forward(self, input: Tensor) -> Tensor: # noqa: F811 82 | if isinstance(input, Tensor): 83 | prev_features = [input] 84 | else: 85 | prev_features = input 86 | 87 | if self.memory_efficient and self.any_requires_grad(prev_features): 88 | if torch.jit.is_scripting(): 89 | raise Exception("Memory Efficient not supported in JIT") 90 | 91 | bottleneck_output = self.call_checkpoint_bottleneck(prev_features) 92 | else: 93 | bottleneck_output = self.bn_function(prev_features) 94 | 95 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 96 | if self.drop_rate > 0: 97 | new_features = F.dropout(new_features, p=self.drop_rate, 98 | training=self.training) 99 | return new_features 100 | 101 | 102 | class _DenseBlock(nn.ModuleDict): 103 | _version = 2 104 | 105 | def __init__( 106 | self, 107 | num_layers: int, 108 | num_input_features: int, 109 | bn_size: int, 110 | growth_rate: int, 111 | drop_rate: float, 112 | memory_efficient: bool = False 113 | ) -> None: 114 | super(_DenseBlock, self).__init__() 115 | for i in range(num_layers): 116 | layer = _DenseLayer( 117 | num_input_features + i * growth_rate, 118 | growth_rate=growth_rate, 119 | bn_size=bn_size, 120 | drop_rate=drop_rate, 121 | memory_efficient=memory_efficient, 122 | ) 123 | self.add_module('denselayer%d' % (i + 1), layer) 124 | 125 | def forward(self, init_features: Tensor) -> Tensor: 126 | features = [init_features] 127 | for name, layer in self.items(): 128 | new_features = layer(features) 129 | features.append(new_features) 130 | return torch.cat(features, 1) 131 | 132 | 133 | class _Transition(nn.Sequential): 134 | def __init__(self, num_input_features: int, num_output_features: int) -> None: 135 | super(_Transition, self).__init__() 136 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 137 | self.add_module('relu', nn.ReLU(inplace=True)) 138 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 139 | kernel_size=1, stride=1, bias=False)) 140 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 141 | 142 | 143 | class DenseNet(nn.Module): 144 | r"""Densenet-BC model class, based on 145 | `"Densely Connected Convolutional Networks" `_. 146 | 147 | Args: 148 | growth_rate (int) - how many filters to add each layer (`k` in paper) 149 | block_config (list of 4 ints) - how many layers in each pooling block 150 | num_init_features (int) - the number of filters to learn in the first convolution layer 151 | bn_size (int) - multiplicative factor for number of bottle neck layers 152 | (i.e. bn_size * k features in the bottleneck layer) 153 | drop_rate (float) - dropout rate after each dense layer 154 | num_classes (int) - number of classification classes 155 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 156 | but slower. Default: *False*. See `"paper" `_. 157 | """ 158 | 159 | def __init__( 160 | self, 161 | growth_rate: int = 32, 162 | block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), 163 | num_init_features: int = 64, 164 | bn_size: int = 4, 165 | drop_rate: float = 0, 166 | num_classes: int = 1000, 167 | memory_efficient: bool = False 168 | ) -> None: 169 | 170 | super(DenseNet, self).__init__() 171 | 172 | # First convolution 173 | self.features = nn.Sequential(OrderedDict([ 174 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 175 | padding=3, bias=False)), 176 | ('norm0', nn.BatchNorm2d(num_init_features)), 177 | ('relu0', nn.ReLU(inplace=True)), 178 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 179 | ])) 180 | 181 | # Each denseblock 182 | num_features = num_init_features 183 | for i, num_layers in enumerate(block_config): 184 | block = _DenseBlock( 185 | num_layers=num_layers, 186 | num_input_features=num_features, 187 | bn_size=bn_size, 188 | growth_rate=growth_rate, 189 | drop_rate=drop_rate, 190 | memory_efficient=memory_efficient 191 | ) 192 | self.features.add_module('denseblock%d' % (i + 1), block) 193 | num_features = num_features + num_layers * growth_rate 194 | if i != len(block_config) - 1: 195 | trans = _Transition(num_input_features=num_features, 196 | num_output_features=num_features // 2) 197 | self.features.add_module('transition%d' % (i + 1), trans) 198 | num_features = num_features // 2 199 | 200 | # Final batch norm 201 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 202 | 203 | # Linear layer 204 | self.classifier = nn.Linear(num_features, num_classes) 205 | 206 | # Official init from torch repo. 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | nn.init.kaiming_normal_(m.weight) 210 | elif isinstance(m, nn.BatchNorm2d): 211 | nn.init.constant_(m.weight, 1) 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.Linear): 214 | nn.init.constant_(m.bias, 0) 215 | 216 | def forward(self, x: Tensor, need_fea=False) -> Tensor: 217 | if need_fea: 218 | features, features_fc = self.forward_features(x, need_fea=True) 219 | out = self.classifier(features_fc) 220 | return features, features_fc, out 221 | else: 222 | features = self.forward_features(x) 223 | out = self.classifier(features) 224 | return out 225 | 226 | def forward_features(self, x, need_fea=False): 227 | if need_fea: 228 | input_size = x.size(2) 229 | scale = [4, 8, 16, 32] 230 | features = [None, None, None, None] 231 | for idx, layer in enumerate(self.features): 232 | x = layer(x) 233 | if input_size // x.size(2) in scale: 234 | features[scale.index(input_size // x.size(2))] = x 235 | else: 236 | features[-1] = F.relu(features[-1], inplace=True) 237 | out = F.adaptive_avg_pool2d(features[-1], (1, 1)) 238 | out = torch.flatten(out, 1) 239 | return features, out 240 | else: 241 | features = self.features(x) 242 | out = F.relu(features, inplace=True) 243 | out = F.adaptive_avg_pool2d(out, (1, 1)) 244 | out = torch.flatten(out, 1) 245 | return out 246 | 247 | def cam_layer(self): 248 | return self.features[-1] 249 | 250 | def switch_to_deploy(self): 251 | self.features.conv0 = fuse_conv_bn(self.features.conv0, self.features.norm0) 252 | del self.features.norm0 253 | 254 | def load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: 255 | # '.'s are no longer allowed in module names, but previous _DenseLayer 256 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 257 | # They are also in the checkpoints in model_urls. This pattern is used 258 | # to find such keys. 259 | pattern = re.compile( 260 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 261 | state_dict = load_state_dict_from_url(model_url, progress=progress) 262 | for key in list(state_dict.keys()): 263 | res = pattern.match(key) 264 | if res: 265 | new_key = res.group(1) + res.group(2) 266 | state_dict[new_key] = state_dict[key] 267 | del state_dict[key] 268 | load_weights_from_state_dict(model, state_dict) 269 | # model_dict = model.state_dict() 270 | # weight_dict = {} 271 | # for k, v in state_dict.items(): 272 | # if k in model_dict: 273 | # if np.shape(model_dict[k]) == np.shape(v): 274 | # weight_dict[k] = v 275 | # pretrained_dict = weight_dict 276 | # model_dict.update(pretrained_dict) 277 | # model.load_state_dict(model_dict) 278 | 279 | def _densenet( 280 | arch: str, 281 | growth_rate: int, 282 | block_config: Tuple[int, int, int, int], 283 | num_init_features: int, 284 | pretrained: bool, 285 | progress: bool, 286 | **kwargs: Any 287 | ) -> DenseNet: 288 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 289 | if pretrained: 290 | load_state_dict(model, model_urls[arch], progress) 291 | return model 292 | 293 | 294 | def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 295 | r"""Densenet-121 model from 296 | `"Densely Connected Convolutional Networks" `_. 297 | The required minimum input size of the model is 29x29. 298 | 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 303 | but slower. Default: *False*. See `"paper" `_. 304 | """ 305 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 306 | **kwargs) 307 | 308 | 309 | def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 310 | r"""Densenet-161 model from 311 | `"Densely Connected Convolutional Networks" `_. 312 | The required minimum input size of the model is 29x29. 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 318 | but slower. Default: *False*. See `"paper" `_. 319 | """ 320 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 321 | **kwargs) 322 | 323 | 324 | def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 325 | r"""Densenet-169 model from 326 | `"Densely Connected Convolutional Networks" `_. 327 | The required minimum input size of the model is 29x29. 328 | 329 | Args: 330 | pretrained (bool): If True, returns a model pre-trained on ImageNet 331 | progress (bool): If True, displays a progress bar of the download to stderr 332 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 333 | but slower. Default: *False*. See `"paper" `_. 334 | """ 335 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 336 | **kwargs) 337 | 338 | 339 | def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 340 | r"""Densenet-201 model from 341 | `"Densely Connected Convolutional Networks" `_. 342 | The required minimum input size of the model is 29x29. 343 | 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | progress (bool): If True, displays a progress bar of the download to stderr 347 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 348 | but slower. Default: *False*. See `"paper" `_. 349 | """ 350 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 351 | **kwargs) 352 | 353 | if __name__ == '__main__': 354 | inputs = torch.rand((1, 3, 224, 224)) 355 | model = densenet121(pretrained=True) 356 | model.eval() 357 | out = model(inputs) 358 | print('out shape:{}'.format(out.size())) 359 | feas, fea_fc, out = model(inputs, True) 360 | for idx, fea in enumerate(feas): 361 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 362 | print('fc shape:{}'.format(fea_fc.size())) 363 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/dpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import Tuple 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from timm.models.layers import BatchNormAct2d, ConvNormAct, create_conv2d 11 | from utils.utils import load_weights_from_state_dict 12 | 13 | urls_dict = { 14 | 'dpn68': 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth', 15 | 'dpn68b': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth', 16 | 'dpn92': 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth', 17 | 'dpn98': 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth', 18 | 'dpn131': 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth', 19 | 'dpn107': 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth' 20 | } 21 | 22 | __all__ = list(urls_dict.keys()) 23 | 24 | class CatBnAct(nn.Module): 25 | def __init__(self, in_chs, norm_layer=BatchNormAct2d): 26 | super(CatBnAct, self).__init__() 27 | self.bn = norm_layer(in_chs, eps=0.001) 28 | 29 | @torch.jit._overload_method # noqa: F811 30 | def forward(self, x): 31 | # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor) 32 | pass 33 | 34 | @torch.jit._overload_method # noqa: F811 35 | def forward(self, x): 36 | # type: (torch.Tensor) -> (torch.Tensor) 37 | pass 38 | 39 | def forward(self, x): 40 | if isinstance(x, tuple): 41 | x = torch.cat(x, dim=1) 42 | return self.bn(x) 43 | 44 | 45 | class BnActConv2d(nn.Module): 46 | def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=BatchNormAct2d): 47 | super(BnActConv2d, self).__init__() 48 | self.bn = norm_layer(in_chs, eps=0.001) 49 | self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups) 50 | 51 | def forward(self, x): 52 | return self.conv(self.bn(x)) 53 | 54 | class DualPathBlock(nn.Module): 55 | def __init__( 56 | self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): 57 | super(DualPathBlock, self).__init__() 58 | self.num_1x1_c = num_1x1_c 59 | self.inc = inc 60 | self.b = b 61 | if block_type == 'proj': 62 | self.key_stride = 1 63 | self.has_proj = True 64 | elif block_type == 'down': 65 | self.key_stride = 2 66 | self.has_proj = True 67 | else: 68 | assert block_type == 'normal' 69 | self.key_stride = 1 70 | self.has_proj = False 71 | 72 | self.c1x1_w_s1 = None 73 | self.c1x1_w_s2 = None 74 | if self.has_proj: 75 | # Using different member names here to allow easier parameter key matching for conversion 76 | if self.key_stride == 2: 77 | self.c1x1_w_s2 = BnActConv2d( 78 | in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2) 79 | else: 80 | self.c1x1_w_s1 = BnActConv2d( 81 | in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) 82 | 83 | self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) 84 | self.c3x3_b = BnActConv2d( 85 | in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups) 86 | if b: 87 | self.c1x1_c = CatBnAct(in_chs=num_3x3_b) 88 | self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1) 89 | self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1) 90 | else: 91 | self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) 92 | self.c1x1_c1 = None 93 | self.c1x1_c2 = None 94 | 95 | @torch.jit._overload_method # noqa: F811 96 | def forward(self, x): 97 | # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] 98 | pass 99 | 100 | @torch.jit._overload_method # noqa: F811 101 | def forward(self, x): 102 | # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 103 | pass 104 | 105 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: 106 | if isinstance(x, tuple): 107 | x_in = torch.cat(x, dim=1) 108 | else: 109 | x_in = x 110 | if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None: 111 | # self.has_proj == False, torchscript requires condition on module == None 112 | x_s1 = x[0] 113 | x_s2 = x[1] 114 | else: 115 | # self.has_proj == True 116 | if self.c1x1_w_s1 is not None: 117 | # self.key_stride = 1 118 | x_s = self.c1x1_w_s1(x_in) 119 | else: 120 | # self.key_stride = 2 121 | x_s = self.c1x1_w_s2(x_in) 122 | x_s1 = x_s[:, :self.num_1x1_c, :, :] 123 | x_s2 = x_s[:, self.num_1x1_c:, :, :] 124 | x_in = self.c1x1_a(x_in) 125 | x_in = self.c3x3_b(x_in) 126 | x_in = self.c1x1_c(x_in) 127 | if self.c1x1_c1 is not None: 128 | # self.b == True, using None check for torchscript compat 129 | out1 = self.c1x1_c1(x_in) 130 | out2 = self.c1x1_c2(x_in) 131 | else: 132 | out1 = x_in[:, :self.num_1x1_c, :, :] 133 | out2 = x_in[:, self.num_1x1_c:, :, :] 134 | resid = x_s1 + out1 135 | dense = torch.cat([x_s2, out2], dim=1) 136 | return resid, dense 137 | 138 | 139 | class DPN(nn.Module): 140 | def __init__( 141 | self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg', 142 | b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, 143 | num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU): 144 | super(DPN, self).__init__() 145 | self.num_classes = num_classes 146 | self.drop_rate = drop_rate 147 | self.b = b 148 | assert output_stride == 32 # FIXME look into dilation support 149 | norm_layer = partial(BatchNormAct2d, eps=.001) 150 | fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False) 151 | bw_factor = 1 if small else 4 152 | blocks = OrderedDict() 153 | 154 | # conv1 155 | blocks['conv1_1'] = ConvNormAct( 156 | in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) 157 | blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 158 | self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] 159 | 160 | # conv2 161 | bw = 64 * bw_factor 162 | inc = inc_sec[0] 163 | r = (k_r * bw) // (64 * bw_factor) 164 | blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b) 165 | in_chs = bw + 3 * inc 166 | for i in range(2, k_sec[0] + 1): 167 | blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) 168 | in_chs += inc 169 | self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')] 170 | 171 | # conv3 172 | bw = 128 * bw_factor 173 | inc = inc_sec[1] 174 | r = (k_r * bw) // (64 * bw_factor) 175 | blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) 176 | in_chs = bw + 3 * inc 177 | for i in range(2, k_sec[1] + 1): 178 | blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) 179 | in_chs += inc 180 | self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')] 181 | 182 | # conv4 183 | bw = 256 * bw_factor 184 | inc = inc_sec[2] 185 | r = (k_r * bw) // (64 * bw_factor) 186 | blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) 187 | in_chs = bw + 3 * inc 188 | for i in range(2, k_sec[2] + 1): 189 | blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) 190 | in_chs += inc 191 | self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')] 192 | 193 | # conv5 194 | bw = 512 * bw_factor 195 | inc = inc_sec[3] 196 | r = (k_r * bw) // (64 * bw_factor) 197 | blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) 198 | in_chs = bw + 3 * inc 199 | for i in range(2, k_sec[3] + 1): 200 | blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) 201 | in_chs += inc 202 | self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] 203 | 204 | blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer) 205 | 206 | self.num_features = in_chs 207 | self.features = nn.Sequential(blocks) 208 | 209 | # Using 1x1 conv for the FC layer to allow the extra pooling scheme 210 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 211 | self.classifier = nn.Linear(in_features=self.num_features, out_features=num_classes) 212 | 213 | def forward_features(self, x, need_fea=False): 214 | if need_fea: 215 | input_size = x.size(2) 216 | scale = [4, 8, 16, 32] 217 | features = [None, None, None, None] 218 | for idx, layer in enumerate(self.features): 219 | x = layer(x) 220 | temp_x = torch.cat(x, dim=1) if type(x) is tuple else x 221 | if input_size // temp_x.size(2) in scale: 222 | features[scale.index(input_size // temp_x.size(2))] = temp_x 223 | out = self.global_pool(x) 224 | return features, out.flatten(1) 225 | else: 226 | x = self.features(x) 227 | x = self.global_pool(x) 228 | return x.flatten(1) 229 | 230 | def forward_head(self, x): 231 | return self.classifier(x) 232 | 233 | def forward(self, x, need_fea=False): 234 | if need_fea: 235 | features, features_fc = self.forward_features(x, need_fea=need_fea) 236 | x = self.forward_head(features_fc) 237 | return features, features_fc, x 238 | else: 239 | x = self.forward_features(x) 240 | x = self.forward_head(x) 241 | return x 242 | 243 | def cam_layer(self): 244 | return self.features[-1] 245 | 246 | 247 | def _create_dpn(variant, pretrained=False, **kwargs): 248 | model = DPN(**kwargs) 249 | if pretrained: 250 | state_dict = torch.hub.load_state_dict_from_url(urls_dict[variant], progress=True, check_hash=True) 251 | model = load_weights_from_state_dict(model, state_dict) 252 | return model 253 | 254 | 255 | def dpn68(pretrained=False, **kwargs): 256 | model_kwargs = dict( 257 | small=True, num_init_features=10, k_r=128, groups=32, 258 | k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) 259 | return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) 260 | 261 | 262 | def dpn68b(pretrained=False, **kwargs): 263 | model_kwargs = dict( 264 | small=True, num_init_features=10, k_r=128, groups=32, 265 | b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) 266 | return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) 267 | 268 | 269 | def dpn92(pretrained=False, **kwargs): 270 | model_kwargs = dict( 271 | num_init_features=64, k_r=96, groups=32, 272 | k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) 273 | return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) 274 | 275 | 276 | def dpn98(pretrained=False, **kwargs): 277 | model_kwargs = dict( 278 | num_init_features=96, k_r=160, groups=40, 279 | k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) 280 | return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) 281 | 282 | 283 | def dpn131(pretrained=False, **kwargs): 284 | model_kwargs = dict( 285 | num_init_features=128, k_r=160, groups=40, 286 | k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) 287 | return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) 288 | 289 | 290 | def dpn107(pretrained=False, **kwargs): 291 | model_kwargs = dict( 292 | num_init_features=128, k_r=200, groups=50, 293 | k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) 294 | return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) 295 | 296 | if __name__ == '__main__': 297 | inputs = torch.rand((1, 3, 224, 224)) 298 | model = dpn68(pretrained=False) 299 | model.eval() 300 | out = model(inputs) 301 | print('out shape:{}'.format(out.size())) 302 | feas, fea_fc, out = model(inputs, True) 303 | for idx, fea in enumerate(feas): 304 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 305 | print('fc shape:{}'.format(fea_fc.size())) 306 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/ghostnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from torch.hub import load_state_dict_from_url 6 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 7 | 8 | __all__ = ['ghostnet'] 9 | 10 | def _make_divisible(v, divisor, min_value=None): 11 | """ 12 | This function is taken from the original tf repo. 13 | It ensures that all layers have a channel number that is divisible by 8 14 | It can be seen here: 15 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 16 | """ 17 | if min_value is None: 18 | min_value = divisor 19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 20 | # Make sure that round down does not go down by more than 10%. 21 | if new_v < 0.9 * v: 22 | new_v += divisor 23 | return new_v 24 | 25 | 26 | class SELayer(nn.Module): 27 | def __init__(self, channel, reduction=4): 28 | super(SELayer, self).__init__() 29 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 30 | self.fc = nn.Sequential( 31 | nn.Linear(channel, channel // reduction), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(channel // reduction, channel), ) 34 | 35 | def forward(self, x): 36 | b, c, _, _ = x.size() 37 | y = self.avg_pool(x).view(b, c) 38 | y = self.fc(y).view(b, c, 1, 1) 39 | y = torch.clamp(y, 0, 1) 40 | return x * y 41 | 42 | 43 | def depthwise_conv(inp, oup, kernel_size=3, stride=1, relu=False): 44 | return nn.Sequential( 45 | nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, groups=inp, bias=False), 46 | nn.BatchNorm2d(oup), 47 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 48 | ) 49 | 50 | class GhostModule(nn.Module): 51 | def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 52 | super(GhostModule, self).__init__() 53 | self.oup = oup 54 | init_channels = math.ceil(oup / ratio) 55 | new_channels = init_channels*(ratio-1) 56 | 57 | self.primary_conv = nn.Sequential( 58 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 59 | nn.BatchNorm2d(init_channels), 60 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 61 | ) 62 | 63 | self.cheap_operation = nn.Sequential( 64 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 65 | nn.BatchNorm2d(new_channels), 66 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 67 | ) 68 | 69 | def forward(self, x): 70 | x1 = self.primary_conv(x) 71 | x2 = self.cheap_operation(x1) 72 | out = torch.cat([x1,x2], dim=1) 73 | return out[:,:self.oup,:,:] 74 | 75 | def switch_to_deploy(self): 76 | self.primary_conv = nn.Sequential( 77 | fuse_conv_bn(self.primary_conv[0], self.primary_conv[1]), 78 | self.primary_conv[2] 79 | ) 80 | self.cheap_operation = nn.Sequential( 81 | fuse_conv_bn(self.cheap_operation[0], self.cheap_operation[1]), 82 | self.cheap_operation[2] 83 | ) 84 | 85 | class GhostBottleneck(nn.Module): 86 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se): 87 | super(GhostBottleneck, self).__init__() 88 | assert stride in [1, 2] 89 | 90 | self.conv = nn.Sequential( 91 | # pw 92 | GhostModule(inp, hidden_dim, kernel_size=1, relu=True), 93 | # dw 94 | depthwise_conv(hidden_dim, hidden_dim, kernel_size, stride, relu=False) if stride==2 else nn.Sequential(), 95 | # Squeeze-and-Excite 96 | SELayer(hidden_dim) if use_se else nn.Sequential(), 97 | # pw-linear 98 | GhostModule(hidden_dim, oup, kernel_size=1, relu=False), 99 | ) 100 | 101 | if stride == 1 and inp == oup: 102 | self.shortcut = nn.Sequential() 103 | else: 104 | self.shortcut = nn.Sequential( 105 | depthwise_conv(inp, inp, 3, stride, relu=True), 106 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 107 | nn.BatchNorm2d(oup), 108 | ) 109 | 110 | def forward(self, x): 111 | return self.conv(x) + self.shortcut(x) 112 | 113 | def switch_to_deploy(self): 114 | if len(self.conv[1]) > 0: 115 | self.conv = nn.Sequential( 116 | self.conv[0], 117 | fuse_conv_bn(self.conv[1][0], self.conv[1][1]), 118 | self.conv[1][2], 119 | self.conv[2], 120 | self.conv[3], 121 | ) 122 | else: 123 | self.conv = nn.Sequential( 124 | self.conv[0], 125 | self.conv[2], 126 | self.conv[3], 127 | ) 128 | if len(self.shortcut) != 0: 129 | self.shortcut = nn.Sequential( 130 | fuse_conv_bn(self.shortcut[0][0], self.shortcut[0][1]), 131 | self.shortcut[0][2], 132 | fuse_conv_bn(self.shortcut[1], self.shortcut[2]) 133 | ) 134 | 135 | class GhostNet(nn.Module): 136 | def __init__(self, cfgs, num_classes=1000, width_mult=1.): 137 | super(GhostNet, self).__init__() 138 | # setting of inverted residual blocks 139 | self.cfgs = cfgs 140 | 141 | # building first layer 142 | output_channel = _make_divisible(16 * width_mult, 4) 143 | layers = [nn.Sequential( 144 | nn.Conv2d(3, output_channel, 3, 2, 1, bias=False), 145 | nn.BatchNorm2d(output_channel), 146 | nn.ReLU(inplace=True) 147 | )] 148 | input_channel = output_channel 149 | 150 | # building inverted residual blocks 151 | block = GhostBottleneck 152 | for k, exp_size, c, use_se, s in self.cfgs: 153 | output_channel = _make_divisible(c * width_mult, 4) 154 | hidden_channel = _make_divisible(exp_size * width_mult, 4) 155 | layers.append(block(input_channel, hidden_channel, output_channel, k, s, use_se)) 156 | input_channel = output_channel 157 | self.features = nn.Sequential(*layers) 158 | 159 | # building last several layers 160 | output_channel = _make_divisible(exp_size * width_mult, 4) 161 | self.squeeze = nn.Sequential( 162 | nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False), 163 | nn.BatchNorm2d(output_channel), 164 | nn.ReLU(inplace=True), 165 | nn.AdaptiveAvgPool2d((1, 1)), 166 | ) 167 | input_channel = output_channel 168 | 169 | output_channel = 1280 170 | self.classifier = nn.Sequential( 171 | nn.Linear(input_channel, output_channel, bias=False), 172 | nn.BatchNorm1d(output_channel), 173 | nn.ReLU(inplace=True), 174 | nn.Dropout(0.2), 175 | nn.Linear(output_channel, num_classes), 176 | ) 177 | 178 | self._initialize_weights() 179 | 180 | def forward(self, x, need_fea=False): 181 | if need_fea: 182 | features, features_fc = self.forward_features(x, need_fea) 183 | x = self.classifier(features_fc) 184 | return features, features_fc, x 185 | else: 186 | x = self.forward_features(x) 187 | x = self.classifier(x) 188 | return x 189 | 190 | def forward_features(self, x, need_fea=False): 191 | if need_fea: 192 | input_size = x.size(2) 193 | scale = [4, 8, 16, 32] 194 | features = [None, None, None, None] 195 | for idx, layer in enumerate(self.features): 196 | x = layer(x) 197 | if input_size // x.size(2) in scale: 198 | features[scale.index(input_size // x.size(2))] = x 199 | x = self.squeeze(x) 200 | return features, x.view(x.size(0), -1) 201 | else: 202 | x = self.features(x) 203 | x = self.squeeze(x) 204 | return x.view(x.size(0), -1) 205 | 206 | def _initialize_weights(self): 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 210 | elif isinstance(m, nn.BatchNorm2d): 211 | m.weight.data.fill_(1) 212 | m.bias.data.zero_() 213 | 214 | def cam_layer(self): 215 | return self.features[-1] 216 | 217 | def switch_to_deploy(self): 218 | self.features[0] = nn.Sequential( 219 | fuse_conv_bn(self.features[0][0], self.features[0][1]), 220 | self.features[0][2] 221 | ) 222 | self.squeeze = nn.Sequential( 223 | fuse_conv_bn(self.squeeze[0], self.squeeze[1]), 224 | self.squeeze[2], 225 | self.squeeze[3] 226 | ) 227 | 228 | def ghostnet(pretrained=False, **kwargs): 229 | """ 230 | Constructs a GhostNet model 231 | """ 232 | cfgs = [ 233 | # k, t, c, SE, s 234 | [3, 16, 16, 0, 1], 235 | [3, 48, 24, 0, 2], 236 | [3, 72, 24, 0, 1], 237 | [5, 72, 40, 1, 2], 238 | [5, 120, 40, 1, 1], 239 | [3, 240, 80, 0, 2], 240 | [3, 200, 80, 0, 1], 241 | [3, 184, 80, 0, 1], 242 | [3, 184, 80, 0, 1], 243 | [3, 480, 112, 1, 1], 244 | [3, 672, 112, 1, 1], 245 | [5, 672, 160, 1, 2], 246 | [5, 960, 160, 0, 1], 247 | [5, 960, 160, 1, 1], 248 | [5, 960, 160, 0, 1], 249 | [5, 960, 160, 1, 1] 250 | ] 251 | model = GhostNet(cfgs, **kwargs) 252 | if pretrained: 253 | state_dict = load_state_dict_from_url('https://github.com/z1069614715/pretrained-weights/releases/download/ghostnet_1x_v1.0/ghostnet_1x-f97d70db.pth', 254 | progress=True) 255 | model = load_weights_from_state_dict(model, state_dict) 256 | return model 257 | 258 | if __name__ == '__main__': 259 | inputs = torch.rand((1, 3, 224, 224)) 260 | model = ghostnet(pretrained=True) 261 | model.eval() 262 | out = model(inputs) 263 | print('out shape:{}'.format(out.size())) 264 | feas, fea_fc, out = model(inputs, True) 265 | for idx, fea in enumerate(feas): 266 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 267 | print('fc shape:{}'.format(fea_fc.size())) 268 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/mnasnet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import Tensor 5 | import torch.nn as nn 6 | import numpy as np 7 | from torchvision._internally_replaced_utils import load_state_dict_from_url 8 | from typing import Any, Dict, List 9 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 10 | 11 | __all__ = ['mnasnet1_0'] 12 | 13 | _MODEL_URLS = { 14 | "mnasnet0_5": 15 | "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", 16 | "mnasnet0_75": None, 17 | "mnasnet1_0": 18 | "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", 19 | "mnasnet1_3": None 20 | } 21 | 22 | # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is 23 | # 1.0 - tensorflow. 24 | _BN_MOMENTUM = 1 - 0.9997 25 | 26 | 27 | class _InvertedResidual(nn.Module): 28 | 29 | def __init__( 30 | self, 31 | in_ch: int, 32 | out_ch: int, 33 | kernel_size: int, 34 | stride: int, 35 | expansion_factor: int, 36 | bn_momentum: float = 0.1 37 | ) -> None: 38 | super(_InvertedResidual, self).__init__() 39 | assert stride in [1, 2] 40 | assert kernel_size in [3, 5] 41 | mid_ch = in_ch * expansion_factor 42 | self.apply_residual = (in_ch == out_ch and stride == 1) 43 | self.layers = nn.Sequential( 44 | # Pointwise 45 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 46 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 47 | nn.ReLU(inplace=True), 48 | # Depthwise 49 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 50 | stride=stride, groups=mid_ch, bias=False), 51 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 52 | nn.ReLU(inplace=True), 53 | # Linear pointwise. Note that there's no activation. 54 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 55 | nn.BatchNorm2d(out_ch, momentum=bn_momentum)) 56 | 57 | def forward(self, input: Tensor) -> Tensor: 58 | if self.apply_residual: 59 | return self.layers(input) + input 60 | else: 61 | return self.layers(input) 62 | 63 | def switch_to_deploy(self): 64 | self.layers = nn.Sequential( 65 | fuse_conv_bn(self.layers[0], self.layers[1]), 66 | self.layers[2], 67 | fuse_conv_bn(self.layers[3], self.layers[4]), 68 | self.layers[5], 69 | fuse_conv_bn(self.layers[6], self.layers[7]) 70 | ) 71 | 72 | 73 | def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, 74 | bn_momentum: float) -> nn.Sequential: 75 | """ Creates a stack of inverted residuals. """ 76 | assert repeats >= 1 77 | # First one has no skip, because feature map size changes. 78 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, 79 | bn_momentum=bn_momentum) 80 | remaining = [] 81 | for _ in range(1, repeats): 82 | remaining.append( 83 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, 84 | bn_momentum=bn_momentum)) 85 | return nn.Sequential(first, *remaining) 86 | 87 | 88 | def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: 89 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 90 | bias, will round up, unless the number is no more than 10% greater than the 91 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 92 | assert 0.0 < round_up_bias < 1.0 93 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 94 | return new_val if new_val >= round_up_bias * val else new_val + divisor 95 | 96 | 97 | def _get_depths(alpha: float) -> List[int]: 98 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 99 | rather than down. """ 100 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 101 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] 102 | 103 | 104 | class MNASNet(torch.nn.Module): 105 | """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This 106 | implements the B1 variant of the model. 107 | >>> model = MNASNet(1.0, num_classes=1000) 108 | >>> x = torch.rand(1, 3, 224, 224) 109 | >>> y = model(x) 110 | >>> y.dim() 111 | 2 112 | >>> y.nelement() 113 | 1000 114 | """ 115 | # Version 2 adds depth scaling in the initial stages of the network. 116 | _version = 2 117 | 118 | def __init__( 119 | self, 120 | alpha: float, 121 | num_classes: int = 1000, 122 | dropout: float = 0.2 123 | ) -> None: 124 | super(MNASNet, self).__init__() 125 | assert alpha > 0.0 126 | self.alpha = alpha 127 | self.num_classes = num_classes 128 | depths = _get_depths(alpha) 129 | layers = [ 130 | # First layer: regular conv. 131 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), 132 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), 133 | nn.ReLU(inplace=True), 134 | # Depthwise separable, no skip. 135 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, 136 | groups=depths[0], bias=False), 137 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), 138 | nn.ReLU(inplace=True), 139 | nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), 140 | nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM), 141 | # MNASNet blocks: stacks of inverted residuals. 142 | _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM), 143 | _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM), 144 | _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM), 145 | _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM), 146 | _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM), 147 | _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM), 148 | # Final mapping to classifier input. 149 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), 150 | nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), 151 | nn.ReLU(inplace=True), 152 | ] 153 | self.layers = nn.Sequential(*layers) 154 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 155 | nn.Linear(1280, num_classes)) 156 | self._initialize_weights() 157 | 158 | def switch_to_deploy(self): 159 | self.layers = nn.Sequential( 160 | fuse_conv_bn(self.layers[0], self.layers[1]), 161 | self.layers[2], 162 | fuse_conv_bn(self.layers[3], self.layers[4]), 163 | self.layers[5], 164 | fuse_conv_bn(self.layers[6], self.layers[7]), 165 | self.layers[8:14], 166 | fuse_conv_bn(self.layers[14], self.layers[15]), 167 | self.layers[16] 168 | ) 169 | 170 | def forward(self, x: Tensor, need_fea=False) -> Tensor: 171 | if need_fea: 172 | features, features_fc = self.forward_features(x, need_fea) 173 | # Equivalent to global avgpool and removing H and W dimensions. 174 | x = self.classifier(features_fc) 175 | return features, features_fc, x 176 | else: 177 | x = self.forward_features(x) 178 | # Equivalent to global avgpool and removing H and W dimensions. 179 | x = self.classifier(x) 180 | return x 181 | 182 | def forward_features(self, x, need_fea=False): 183 | if need_fea: 184 | input_size = x.size(2) 185 | scale = [4, 8, 16, 32] 186 | features = [None, None, None, None] 187 | for idx, layer in enumerate(self.layers): 188 | x = layer(x) 189 | if input_size // x.size(2) in scale: 190 | features[scale.index(input_size // x.size(2))] = x 191 | return features, x.mean([2, 3]) 192 | else: 193 | x = self.layers(x) 194 | # Equivalent to global avgpool and removing H and W dimensions. 195 | x = x.mean([2, 3]) 196 | return x 197 | 198 | def _initialize_weights(self) -> None: 199 | for m in self.modules(): 200 | if isinstance(m, nn.Conv2d): 201 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 202 | nonlinearity="relu") 203 | if m.bias is not None: 204 | nn.init.zeros_(m.bias) 205 | elif isinstance(m, nn.BatchNorm2d): 206 | nn.init.ones_(m.weight) 207 | nn.init.zeros_(m.bias) 208 | elif isinstance(m, nn.Linear): 209 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", 210 | nonlinearity="sigmoid") 211 | nn.init.zeros_(m.bias) 212 | 213 | def cam_layer(self): 214 | return self.layers[-1] 215 | 216 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, 217 | missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: 218 | version = local_metadata.get("version", None) 219 | assert version in [1, 2] 220 | 221 | if version == 1 and not self.alpha == 1.0: 222 | # In the initial version of the model (v1), stem was fixed-size. 223 | # All other layer configurations were the same. This will patch 224 | # the model so that it's identical to v1. Model with alpha 1.0 is 225 | # unaffected. 226 | depths = _get_depths(self.alpha) 227 | v1_stem = [ 228 | nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), 229 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 230 | nn.ReLU(inplace=True), 231 | nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, 232 | bias=False), 233 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 234 | nn.ReLU(inplace=True), 235 | nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), 236 | nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), 237 | _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM), 238 | ] 239 | for idx, layer in enumerate(v1_stem): 240 | self.layers[idx] = layer 241 | 242 | # The model is now identical to v1, and must be saved as such. 243 | self._version = 1 244 | warnings.warn( 245 | "A new version of MNASNet model has been implemented. " 246 | "Your checkpoint was saved using the previous version. " 247 | "This checkpoint will load and work as before, but " 248 | "you may want to upgrade by training a newer model or " 249 | "transfer learning from an updated ImageNet checkpoint.", 250 | UserWarning) 251 | 252 | super(MNASNet, self)._load_from_state_dict( 253 | state_dict, prefix, local_metadata, strict, missing_keys, 254 | unexpected_keys, error_msgs) 255 | 256 | 257 | def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: 258 | if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: 259 | raise ValueError( 260 | "No checkpoint is available for model type {}".format(model_name)) 261 | checkpoint_url = _MODEL_URLS[model_name] 262 | state_dict = load_state_dict_from_url(checkpoint_url, progress=progress) 263 | model = load_weights_from_state_dict(model, state_dict) 264 | 265 | def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: 266 | r"""MNASNet with depth multiplier of 0.5 from 267 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 268 | `_. 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | model = MNASNet(0.5, **kwargs) 275 | if pretrained: 276 | _load_pretrained("mnasnet0_5", model, progress) 277 | return model 278 | 279 | 280 | def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: 281 | r"""MNASNet with depth multiplier of 0.75 from 282 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 283 | `_. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | """ 289 | model = MNASNet(0.75, **kwargs) 290 | if pretrained: 291 | _load_pretrained("mnasnet0_75", model, progress) 292 | return model 293 | 294 | 295 | def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: 296 | r"""MNASNet with depth multiplier of 1.0 from 297 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 298 | `_. 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | model = MNASNet(1.0, **kwargs) 305 | if pretrained: 306 | _load_pretrained("mnasnet1_0", model, progress) 307 | return model 308 | 309 | 310 | def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: 311 | r"""MNASNet with depth multiplier of 1.3 from 312 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 313 | `_. 314 | 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | progress (bool): If True, displays a progress bar of the download to stderr 318 | """ 319 | model = MNASNet(1.3, **kwargs) 320 | if pretrained: 321 | _load_pretrained("mnasnet1_3", model, progress) 322 | return model 323 | 324 | if __name__ == '__main__': 325 | inputs = torch.rand((1, 3, 224, 224)) 326 | model = mnasnet0_5(pretrained=True) 327 | model.eval() 328 | out = model(inputs) 329 | print('out shape:{}'.format(out.size())) 330 | feas, fea_fc, out = model(inputs, True) 331 | for idx, fea in enumerate(feas): 332 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 333 | print('fc shape:{}'.format(fea_fc.size())) 334 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import numpy as np 4 | from torch import nn 5 | from torch import Tensor 6 | from torchvision.ops.misc import ConvNormActivation 7 | from torchvision._internally_replaced_utils import load_state_dict_from_url 8 | from torchvision.models._utils import _make_divisible 9 | from typing import Callable, Any, Optional, List 10 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 11 | 12 | __all__ = ['mobilenet_v2'] 13 | 14 | 15 | model_urls = { 16 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 17 | } 18 | 19 | 20 | # necessary for backwards compatibility 21 | class _DeprecatedConvBNAct(ConvNormActivation): 22 | def __init__(self, *args, **kwargs): 23 | warnings.warn( 24 | "The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. " 25 | "Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning) 26 | if kwargs.get("norm_layer", None) is None: 27 | kwargs["norm_layer"] = nn.BatchNorm2d 28 | if kwargs.get("activation_layer", None) is None: 29 | kwargs["activation_layer"] = nn.ReLU6 30 | super().__init__(*args, **kwargs) 31 | 32 | 33 | ConvBNReLU = _DeprecatedConvBNAct 34 | ConvBNActivation = _DeprecatedConvBNAct 35 | 36 | 37 | class InvertedResidual(nn.Module): 38 | def __init__( 39 | self, 40 | inp: int, 41 | oup: int, 42 | stride: int, 43 | expand_ratio: int, 44 | norm_layer: Optional[Callable[..., nn.Module]] = None 45 | ) -> None: 46 | super(InvertedResidual, self).__init__() 47 | self.stride = stride 48 | assert stride in [1, 2] 49 | 50 | if norm_layer is None: 51 | norm_layer = nn.BatchNorm2d 52 | 53 | hidden_dim = int(round(inp * expand_ratio)) 54 | self.use_res_connect = self.stride == 1 and inp == oup 55 | 56 | layers: List[nn.Module] = [] 57 | if expand_ratio != 1: 58 | # pw 59 | layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, 60 | activation_layer=nn.ReLU6)) 61 | layers.extend([ 62 | # dw 63 | ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, 64 | activation_layer=nn.ReLU6), 65 | # pw-linear 66 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 67 | norm_layer(oup), 68 | ]) 69 | self.conv = nn.Sequential(*layers) 70 | self.out_channels = oup 71 | self._is_cn = stride > 1 72 | 73 | def forward(self, x: Tensor) -> Tensor: 74 | if self.use_res_connect: 75 | return x + self.conv(x) 76 | else: 77 | return self.conv(x) 78 | 79 | def switch_to_deploy(self): 80 | if len(self.conv) == 4: 81 | self.conv = nn.Sequential( 82 | fuse_conv_bn(self.conv[0][0], self.conv[0][1]), 83 | self.conv[0][2], 84 | fuse_conv_bn(self.conv[1][0], self.conv[1][1]), 85 | self.conv[1][2], 86 | fuse_conv_bn(self.conv[2], self.conv[3]), 87 | ) 88 | else: 89 | self.conv = nn.Sequential( 90 | fuse_conv_bn(self.conv[0][0], self.conv[0][1]), 91 | self.conv[0][2], 92 | fuse_conv_bn(self.conv[1], self.conv[2]), 93 | ) 94 | 95 | class MobileNetV2(nn.Module): 96 | def __init__( 97 | self, 98 | num_classes: int = 1000, 99 | width_mult: float = 1.0, 100 | inverted_residual_setting: Optional[List[List[int]]] = None, 101 | round_nearest: int = 8, 102 | block: Optional[Callable[..., nn.Module]] = None, 103 | norm_layer: Optional[Callable[..., nn.Module]] = None 104 | ) -> None: 105 | """ 106 | MobileNet V2 main class 107 | 108 | Args: 109 | num_classes (int): Number of classes 110 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 111 | inverted_residual_setting: Network structure 112 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 113 | Set to 1 to turn off rounding 114 | block: Module specifying inverted residual building block for mobilenet 115 | norm_layer: Module specifying the normalization layer to use 116 | 117 | """ 118 | super(MobileNetV2, self).__init__() 119 | 120 | if block is None: 121 | block = InvertedResidual 122 | 123 | if norm_layer is None: 124 | norm_layer = nn.BatchNorm2d 125 | 126 | input_channel = 32 127 | last_channel = 1280 128 | 129 | if inverted_residual_setting is None: 130 | inverted_residual_setting = [ 131 | # t, c, n, s 132 | [1, 16, 1, 1], 133 | [6, 24, 2, 2], 134 | [6, 32, 3, 2], 135 | [6, 64, 4, 2], 136 | [6, 96, 3, 1], 137 | [6, 160, 3, 2], 138 | [6, 320, 1, 1], 139 | ] 140 | 141 | # only check the first element, assuming user knows t,c,n,s are required 142 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 143 | raise ValueError("inverted_residual_setting should be non-empty " 144 | "or a 4-element list, got {}".format(inverted_residual_setting)) 145 | 146 | # building first layer 147 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 148 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 149 | features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, 150 | activation_layer=nn.ReLU6)] 151 | # building inverted residual blocks 152 | for t, c, n, s in inverted_residual_setting: 153 | output_channel = _make_divisible(c * width_mult, round_nearest) 154 | for i in range(n): 155 | stride = s if i == 0 else 1 156 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 157 | input_channel = output_channel 158 | # building last several layers 159 | features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, 160 | activation_layer=nn.ReLU6)) 161 | # make it nn.Sequential 162 | self.features = nn.Sequential(*features) 163 | 164 | # building classifier 165 | self.classifier = nn.Sequential( 166 | nn.Dropout(0.2), 167 | nn.Linear(self.last_channel, num_classes), 168 | ) 169 | 170 | # weight initialization 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 174 | if m.bias is not None: 175 | nn.init.zeros_(m.bias) 176 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 177 | nn.init.ones_(m.weight) 178 | nn.init.zeros_(m.bias) 179 | elif isinstance(m, nn.Linear): 180 | nn.init.normal_(m.weight, 0, 0.01) 181 | nn.init.zeros_(m.bias) 182 | 183 | def _forward_impl(self, x: Tensor, need_fea=False) -> Tensor: 184 | if need_fea: 185 | features, features_fc = self.forward_features(x, need_fea) 186 | x = self.classifier(features_fc) 187 | return features, features_fc, x 188 | else: 189 | x = self.forward_features(x) 190 | x = self.classifier(x) 191 | return x 192 | 193 | def forward(self, x: Tensor, need_fea=False) -> Tensor: 194 | return self._forward_impl(x, need_fea) 195 | 196 | def forward_features(self, x, need_fea=False): 197 | if need_fea: 198 | input_size = x.size(2) 199 | scale = [4, 8, 16, 32] 200 | features = [None, None, None, None] 201 | for idx, layer in enumerate(self.features): 202 | x = layer(x) 203 | if input_size // x.size(2) in scale: 204 | features[scale.index(input_size // x.size(2))] = x 205 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) 206 | x = torch.flatten(x, 1) 207 | return features, x 208 | else: 209 | x = self.features(x) 210 | # Cannot use "squeeze" as batch-size can be 1 211 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) 212 | x = torch.flatten(x, 1) 213 | return x 214 | 215 | def cam_layer(self): 216 | return self.features[-1] 217 | 218 | def switch_to_deploy(self): 219 | self.features[0] = nn.Sequential( 220 | fuse_conv_bn(self.features[0][0], self.features[0][1]), 221 | self.features[0][2] 222 | ) 223 | self.features[-1] = nn.Sequential( 224 | fuse_conv_bn(self.features[-1][0], self.features[-1][1]), 225 | self.features[-1][2] 226 | ) 227 | 228 | 229 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: 230 | """ 231 | Constructs a MobileNetV2 architecture from 232 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | progress (bool): If True, displays a progress bar of the download to stderr 237 | """ 238 | model = MobileNetV2(**kwargs) 239 | if pretrained: 240 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 241 | progress=progress) 242 | model = load_weights_from_state_dict(model, state_dict) 243 | return model 244 | 245 | if __name__ == '__main__': 246 | inputs = torch.rand((1, 3, 224, 224)) 247 | model = mobilenet_v2(pretrained=True) 248 | model.eval() 249 | out = model(inputs) 250 | print('out shape:{}'.format(out.size())) 251 | feas, fea_fc, out = model(inputs, True) 252 | for idx, fea in enumerate(feas): 253 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 254 | print('fc shape:{}'.format(fea_fc.size())) 255 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | from torch import nn, Tensor 6 | from typing import Any, Callable, List, Optional, Sequence 7 | from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer 8 | from torchvision._internally_replaced_utils import load_state_dict_from_url 9 | from torchvision.models._utils import _make_divisible 10 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 11 | 12 | __all__ = ["mobilenetv3_large", "mobilenetv3_small"] 13 | 14 | 15 | model_urls = { 16 | "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", 17 | "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", 18 | } 19 | 20 | 21 | class SqueezeExcitation(SElayer): 22 | """DEPRECATED 23 | """ 24 | def __init__(self, input_channels: int, squeeze_factor: int = 4): 25 | squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) 26 | super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) 27 | self.relu = self.activation 28 | delattr(self, 'activation') 29 | warnings.warn( 30 | "This SqueezeExcitation class is deprecated and will be removed in future versions. " 31 | "Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning) 32 | 33 | 34 | class InvertedResidualConfig: 35 | # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper 36 | def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, 37 | activation: str, stride: int, dilation: int, width_mult: float): 38 | self.input_channels = self.adjust_channels(input_channels, width_mult) 39 | self.kernel = kernel 40 | self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) 41 | self.out_channels = self.adjust_channels(out_channels, width_mult) 42 | self.use_se = use_se 43 | self.use_hs = activation == "HS" 44 | self.stride = stride 45 | self.dilation = dilation 46 | 47 | @staticmethod 48 | def adjust_channels(channels: int, width_mult: float): 49 | return _make_divisible(channels * width_mult, 8) 50 | 51 | 52 | class InvertedResidual(nn.Module): 53 | # Implemented as described at section 5 of MobileNetV3 paper 54 | def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], 55 | se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)): 56 | super().__init__() 57 | if not (1 <= cnf.stride <= 2): 58 | raise ValueError('illegal stride value') 59 | 60 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 61 | 62 | layers: List[nn.Module] = [] 63 | activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU 64 | 65 | # expand 66 | if cnf.expanded_channels != cnf.input_channels: 67 | layers.append(Conv2dNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, 68 | norm_layer=norm_layer, activation_layer=activation_layer)) 69 | 70 | # depthwise 71 | stride = 1 if cnf.dilation > 1 else cnf.stride 72 | layers.append(Conv2dNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, 73 | stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, 74 | norm_layer=norm_layer, activation_layer=activation_layer)) 75 | if cnf.use_se: 76 | squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) 77 | layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) 78 | 79 | # project 80 | layers.append(Conv2dNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, 81 | activation_layer=None)) 82 | 83 | self.block = nn.Sequential(*layers) 84 | self.out_channels = cnf.out_channels 85 | self._is_cn = cnf.stride > 1 86 | 87 | def forward(self, input: Tensor) -> Tensor: 88 | result = self.block(input) 89 | if self.use_res_connect: 90 | result += input 91 | return result 92 | 93 | def switch_to_deploy(self): 94 | new_layers = [] 95 | for i in range(len(self.block)): 96 | if type(self.block[i]) is Conv2dNormActivation: 97 | new_layers.append(fuse_conv_bn(self.block[i][0], self.block[i][1])) 98 | if len(self.block[i]) == 3: 99 | new_layers.append(self.block[i][2]) 100 | else: 101 | new_layers.append(self.block[i]) 102 | self.block = nn.Sequential(*new_layers) 103 | 104 | class MobileNetV3(nn.Module): 105 | 106 | def __init__( 107 | self, 108 | inverted_residual_setting: List[InvertedResidualConfig], 109 | last_channel: int, 110 | num_classes: int = 1000, 111 | block: Optional[Callable[..., nn.Module]] = None, 112 | norm_layer: Optional[Callable[..., nn.Module]] = None, 113 | **kwargs: Any 114 | ) -> None: 115 | """ 116 | MobileNet V3 main class 117 | 118 | Args: 119 | inverted_residual_setting (List[InvertedResidualConfig]): Network structure 120 | last_channel (int): The number of channels on the penultimate layer 121 | num_classes (int): Number of classes 122 | block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet 123 | norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use 124 | """ 125 | super().__init__() 126 | 127 | if not inverted_residual_setting: 128 | raise ValueError("The inverted_residual_setting should not be empty") 129 | elif not (isinstance(inverted_residual_setting, Sequence) and 130 | all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): 131 | raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") 132 | 133 | if block is None: 134 | block = InvertedResidual 135 | 136 | if norm_layer is None: 137 | norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) 138 | 139 | layers: List[nn.Module] = [] 140 | 141 | # building first layer 142 | firstconv_output_channels = inverted_residual_setting[0].input_channels 143 | layers.append(Conv2dNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, 144 | activation_layer=nn.Hardswish)) 145 | 146 | # building inverted residual blocks 147 | for cnf in inverted_residual_setting: 148 | layers.append(block(cnf, norm_layer)) 149 | 150 | # building last several layers 151 | lastconv_input_channels = inverted_residual_setting[-1].out_channels 152 | lastconv_output_channels = 6 * lastconv_input_channels 153 | layers.append(Conv2dNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, 154 | norm_layer=norm_layer, activation_layer=nn.Hardswish)) 155 | 156 | self.features = nn.Sequential(*layers) 157 | self.avgpool = nn.AdaptiveAvgPool2d(1) 158 | self.classifier = nn.Sequential( 159 | nn.Linear(lastconv_output_channels, last_channel), 160 | nn.Hardswish(inplace=True), 161 | nn.Dropout(p=0.2, inplace=True), 162 | nn.Linear(last_channel, num_classes), 163 | ) 164 | 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 168 | if m.bias is not None: 169 | nn.init.zeros_(m.bias) 170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 171 | nn.init.ones_(m.weight) 172 | nn.init.zeros_(m.bias) 173 | elif isinstance(m, nn.Linear): 174 | nn.init.normal_(m.weight, 0, 0.01) 175 | nn.init.zeros_(m.bias) 176 | 177 | def switch_to_deploy(self): 178 | new_layers = [] 179 | for i in range(len(self.features)): 180 | if type(self.features[i]) is Conv2dNormActivation: 181 | new_layers.append(fuse_conv_bn(self.features[i][0], self.features[i][1])) 182 | if len(self.features[i]) == 3: 183 | new_layers.append(self.features[i][2]) 184 | else: 185 | new_layers.append(self.features[i]) 186 | self.features = nn.Sequential(*new_layers) 187 | 188 | def _forward_impl(self, x: Tensor, need_fea=False) -> Tensor: 189 | if need_fea: 190 | features, features_fc = self.forward_features(x, need_fea) 191 | x = self.classifier(features_fc) 192 | return features, features_fc, x 193 | else: 194 | x = self.forward_features(x) 195 | x = self.classifier(x) 196 | return x 197 | 198 | def forward(self, x: Tensor, need_fea=False) -> Tensor: 199 | return self._forward_impl(x, need_fea) 200 | 201 | def forward_features(self, x, need_fea=False): 202 | if need_fea: 203 | input_size = x.size(2) 204 | scale = [4, 8, 16, 32] 205 | features = [None, None, None, None] 206 | for idx, layer in enumerate(self.features): 207 | x = layer(x) 208 | if input_size // x.size(2) in scale: 209 | features[scale.index(input_size // x.size(2))] = x 210 | x = self.avgpool(x) 211 | x = torch.flatten(x, 1) 212 | return features, x 213 | else: 214 | x = self.features(x) 215 | x = self.avgpool(x) 216 | x = torch.flatten(x, 1) 217 | return x 218 | 219 | def cam_layer(self): 220 | return self.features[-1] 221 | 222 | def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, 223 | **kwargs: Any): 224 | reduce_divider = 2 if reduced_tail else 1 225 | dilation = 2 if dilated else 1 226 | 227 | bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) 228 | adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) 229 | 230 | if arch == "mobilenet_v3_large": 231 | inverted_residual_setting = [ 232 | bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), 233 | bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 234 | bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), 235 | bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 236 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), 237 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), 238 | bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 239 | bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), 240 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), 241 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), 242 | bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), 243 | bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), 244 | bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4 245 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), 246 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), 247 | ] 248 | last_channel = adjust_channels(1280 // reduce_divider) # C5 249 | elif arch == "mobilenet_v3_small": 250 | inverted_residual_setting = [ 251 | bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 252 | bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 253 | bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), 254 | bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 255 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), 256 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), 257 | bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), 258 | bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), 259 | bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 260 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), 261 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), 262 | ] 263 | last_channel = adjust_channels(1024 // reduce_divider) # C5 264 | else: 265 | raise ValueError("Unsupported model type {}".format(arch)) 266 | 267 | return inverted_residual_setting, last_channel 268 | 269 | 270 | def _mobilenet_v3_model( 271 | arch: str, 272 | inverted_residual_setting: List[InvertedResidualConfig], 273 | last_channel: int, 274 | pretrained: bool, 275 | progress: bool, 276 | **kwargs: Any 277 | ): 278 | model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) 279 | if pretrained: 280 | if model_urls.get(arch, None) is None: 281 | raise ValueError("No checkpoint is available for model type {}".format(arch)) 282 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 283 | model = load_weights_from_state_dict(model, state_dict) 284 | return model 285 | 286 | 287 | def mobilenetv3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 288 | """ 289 | Constructs a large MobileNetV3 architecture from 290 | `"Searching for MobileNetV3" `_. 291 | 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | arch = "mobilenet_v3_large" 297 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) 298 | return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 299 | 300 | 301 | def mobilenetv3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 302 | """ 303 | Constructs a small MobileNetV3 architecture from 304 | `"Searching for MobileNetV3" `_. 305 | 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | progress (bool): If True, displays a progress bar of the download to stderr 309 | """ 310 | arch = "mobilenet_v3_small" 311 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) 312 | return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 313 | 314 | if __name__ == '__main__': 315 | inputs = torch.rand((1, 3, 224, 224)) 316 | model = mobilenetv3_small(pretrained=True) 317 | model.eval() 318 | out = model(inputs) 319 | print('out shape:{}'.format(out.size())) 320 | feas, fea_fc, out = model(inputs, True) 321 | for idx, fea in enumerate(feas): 322 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 323 | print('fc shape:{}'.format(fea_fc.size())) 324 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import numpy as np 5 | from torchvision._internally_replaced_utils import load_state_dict_from_url 6 | from typing import Type, Any, Callable, Union, List, Optional 7 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 8 | 9 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion: int = 1 40 | 41 | def __init__( 42 | self, 43 | inplanes: int, 44 | planes: int, 45 | stride: int = 1, 46 | downsample: Optional[nn.Module] = None, 47 | groups: int = 1, 48 | base_width: int = 64, 49 | dilation: int = 1, 50 | norm_layer: Optional[Callable[..., nn.Module]] = None 51 | ) -> None: 52 | super(BasicBlock, self).__init__() 53 | if norm_layer is None: 54 | norm_layer = nn.BatchNorm2d 55 | if groups != 1 or base_width != 64: 56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 57 | if dilation > 1: 58 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = norm_layer(planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = norm_layer(planes) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | if hasattr(self, 'bn1'): 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | if hasattr(self, 'bn2'): 78 | out = self.bn2(out) 79 | 80 | if self.downsample is not None: 81 | identity = self.downsample(x) 82 | 83 | out += identity 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | def switch_to_deploy(self): 89 | self.conv1 = fuse_conv_bn(self.conv1, self.bn1) 90 | del self.bn1 91 | self.conv2 = fuse_conv_bn(self.conv2, self.bn2) 92 | del self.bn2 93 | 94 | class Bottleneck(nn.Module): 95 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 96 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 97 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 98 | # This variant is also known as ResNet V1.5 and improves accuracy according to 99 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 100 | 101 | expansion: int = 4 102 | 103 | def __init__( 104 | self, 105 | inplanes: int, 106 | planes: int, 107 | stride: int = 1, 108 | downsample: Optional[nn.Module] = None, 109 | groups: int = 1, 110 | base_width: int = 64, 111 | dilation: int = 1, 112 | norm_layer: Optional[Callable[..., nn.Module]] = None 113 | ) -> None: 114 | super(Bottleneck, self).__init__() 115 | if norm_layer is None: 116 | norm_layer = nn.BatchNorm2d 117 | width = int(planes * (base_width / 64.)) * groups 118 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 119 | self.conv1 = conv1x1(inplanes, width) 120 | self.bn1 = norm_layer(width) 121 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 122 | self.bn2 = norm_layer(width) 123 | self.conv3 = conv1x1(width, planes * self.expansion) 124 | self.bn3 = norm_layer(planes * self.expansion) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.downsample = downsample 127 | self.stride = stride 128 | 129 | def forward(self, x: Tensor) -> Tensor: 130 | identity = x 131 | 132 | out = self.conv1(x) 133 | if hasattr(self, 'bn1'): 134 | out = self.bn1(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv2(out) 138 | if hasattr(self, 'bn2'): 139 | out = self.bn2(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv3(out) 143 | if hasattr(self, 'bn3'): 144 | out = self.bn3(out) 145 | 146 | if self.downsample is not None: 147 | identity = self.downsample(x) 148 | 149 | out += identity 150 | out = self.relu(out) 151 | 152 | return out 153 | 154 | def switch_to_deploy(self): 155 | self.conv1 = fuse_conv_bn(self.conv1, self.bn1) 156 | del self.bn1 157 | self.conv2 = fuse_conv_bn(self.conv2, self.bn2) 158 | del self.bn2 159 | self.conv3 = fuse_conv_bn(self.conv3, self.bn3) 160 | del self.bn3 161 | 162 | class ResNet(nn.Module): 163 | 164 | def __init__( 165 | self, 166 | block: Type[Union[BasicBlock, Bottleneck]], 167 | layers: List[int], 168 | num_classes: int = 1000, 169 | zero_init_residual: bool = False, 170 | groups: int = 1, 171 | width_per_group: int = 64, 172 | replace_stride_with_dilation: Optional[List[bool]] = None, 173 | norm_layer: Optional[Callable[..., nn.Module]] = None 174 | ) -> None: 175 | super(ResNet, self).__init__() 176 | if norm_layer is None: 177 | norm_layer = nn.BatchNorm2d 178 | self._norm_layer = norm_layer 179 | 180 | self.inplanes = 64 181 | self.dilation = 1 182 | if replace_stride_with_dilation is None: 183 | # each element in the tuple indicates if we should replace 184 | # the 2x2 stride with a dilated convolution instead 185 | replace_stride_with_dilation = [False, False, False] 186 | if len(replace_stride_with_dilation) != 3: 187 | raise ValueError("replace_stride_with_dilation should be None " 188 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 189 | self.groups = groups 190 | self.base_width = width_per_group 191 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 192 | bias=False) 193 | self.bn1 = norm_layer(self.inplanes) 194 | self.relu = nn.ReLU(inplace=True) 195 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 196 | self.layer1 = self._make_layer(block, 64, layers[0]) 197 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 198 | dilate=replace_stride_with_dilation[0]) 199 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 200 | dilate=replace_stride_with_dilation[1]) 201 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 202 | dilate=replace_stride_with_dilation[2]) 203 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 204 | self.fc = nn.Linear(512 * block.expansion, num_classes) 205 | 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 210 | nn.init.constant_(m.weight, 1) 211 | nn.init.constant_(m.bias, 0) 212 | 213 | # Zero-initialize the last BN in each residual branch, 214 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 215 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 216 | if zero_init_residual: 217 | for m in self.modules(): 218 | if isinstance(m, Bottleneck): 219 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 220 | elif isinstance(m, BasicBlock): 221 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 222 | 223 | def switch_to_deploy(self): 224 | self.conv1 = fuse_conv_bn(self.conv1, self.bn1) 225 | del self.bn1 226 | 227 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 228 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 229 | norm_layer = self._norm_layer 230 | downsample = None 231 | previous_dilation = self.dilation 232 | if dilate: 233 | self.dilation *= stride 234 | stride = 1 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | downsample = nn.Sequential( 237 | conv1x1(self.inplanes, planes * block.expansion, stride), 238 | norm_layer(planes * block.expansion), 239 | ) 240 | 241 | layers = [] 242 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 243 | self.base_width, previous_dilation, norm_layer)) 244 | self.inplanes = planes * block.expansion 245 | for _ in range(1, blocks): 246 | layers.append(block(self.inplanes, planes, groups=self.groups, 247 | base_width=self.base_width, dilation=self.dilation, 248 | norm_layer=norm_layer)) 249 | 250 | return nn.Sequential(*layers) 251 | 252 | def _forward_impl(self, x: Tensor, need_fea=False) -> Tensor: 253 | # See note [TorchScript super()] 254 | if need_fea: 255 | features, features_fc = self.forward_features(x, need_fea) 256 | x = self.fc(features_fc) 257 | return features, features_fc, x 258 | else: 259 | x = self.forward_features(x) 260 | x = self.fc(x) 261 | return x 262 | 263 | def forward(self, x: Tensor, need_fea=False) -> Tensor: 264 | return self._forward_impl(x, need_fea) 265 | 266 | def forward_features(self, x, need_fea=False): 267 | x = self.conv1(x) 268 | if hasattr(self, 'bn1'): 269 | x = self.bn1(x) 270 | x = self.relu(x) 271 | x = self.maxpool(x) 272 | if need_fea: 273 | x1 = self.layer1(x) 274 | x2 = self.layer2(x1) 275 | x3 = self.layer3(x2) 276 | x4 = self.layer4(x3) 277 | 278 | x = self.avgpool(x4) 279 | x = torch.flatten(x, 1) 280 | return [x1, x2, x3, x4], x 281 | else: 282 | x = self.layer1(x) 283 | x = self.layer2(x) 284 | x = self.layer3(x) 285 | x = self.layer4(x) 286 | 287 | x = self.avgpool(x) 288 | x = torch.flatten(x, 1) 289 | return x 290 | 291 | def cam_layer(self): 292 | return self.layer4 293 | 294 | def _resnet( 295 | arch: str, 296 | block: Type[Union[BasicBlock, Bottleneck]], 297 | layers: List[int], 298 | pretrained: bool, 299 | progress: bool, 300 | **kwargs: Any 301 | ) -> ResNet: 302 | model = ResNet(block, layers, **kwargs) 303 | if pretrained: 304 | state_dict = load_state_dict_from_url(model_urls[arch], 305 | progress=progress) 306 | load_weights_from_state_dict(model, state_dict) 307 | return model 308 | 309 | 310 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 311 | r"""ResNet-18 model from 312 | `"Deep Residual Learning for Image Recognition" `_. 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | """ 318 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 319 | **kwargs) 320 | 321 | 322 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 323 | r"""ResNet-34 model from 324 | `"Deep Residual Learning for Image Recognition" `_. 325 | 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | progress (bool): If True, displays a progress bar of the download to stderr 329 | """ 330 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 331 | **kwargs) 332 | 333 | 334 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 335 | r"""ResNet-50 model from 336 | `"Deep Residual Learning for Image Recognition" `_. 337 | 338 | Args: 339 | pretrained (bool): If True, returns a model pre-trained on ImageNet 340 | progress (bool): If True, displays a progress bar of the download to stderr 341 | """ 342 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 343 | **kwargs) 344 | 345 | 346 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 347 | r"""ResNet-101 model from 348 | `"Deep Residual Learning for Image Recognition" `_. 349 | 350 | Args: 351 | pretrained (bool): If True, returns a model pre-trained on ImageNet 352 | progress (bool): If True, displays a progress bar of the download to stderr 353 | """ 354 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 355 | **kwargs) 356 | 357 | 358 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 359 | r"""ResNet-152 model from 360 | `"Deep Residual Learning for Image Recognition" `_. 361 | 362 | Args: 363 | pretrained (bool): If True, returns a model pre-trained on ImageNet 364 | progress (bool): If True, displays a progress bar of the download to stderr 365 | """ 366 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 367 | **kwargs) 368 | 369 | 370 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 371 | r"""ResNeXt-50 32x4d model from 372 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 373 | 374 | Args: 375 | pretrained (bool): If True, returns a model pre-trained on ImageNet 376 | progress (bool): If True, displays a progress bar of the download to stderr 377 | """ 378 | kwargs['groups'] = 32 379 | kwargs['width_per_group'] = 4 380 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 381 | pretrained, progress, **kwargs) 382 | 383 | 384 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 385 | r"""ResNeXt-101 32x8d model from 386 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 387 | 388 | Args: 389 | pretrained (bool): If True, returns a model pre-trained on ImageNet 390 | progress (bool): If True, displays a progress bar of the download to stderr 391 | """ 392 | kwargs['groups'] = 32 393 | kwargs['width_per_group'] = 8 394 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 395 | pretrained, progress, **kwargs) 396 | 397 | 398 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 399 | r"""Wide ResNet-50-2 model from 400 | `"Wide Residual Networks" `_. 401 | 402 | The model is the same as ResNet except for the bottleneck number of channels 403 | which is twice larger in every block. The number of channels in outer 1x1 404 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 405 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 406 | 407 | Args: 408 | pretrained (bool): If True, returns a model pre-trained on ImageNet 409 | progress (bool): If True, displays a progress bar of the download to stderr 410 | """ 411 | kwargs['width_per_group'] = 64 * 2 412 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 413 | pretrained, progress, **kwargs) 414 | 415 | 416 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 417 | r"""Wide ResNet-101-2 model from 418 | `"Wide Residual Networks" `_. 419 | 420 | The model is the same as ResNet except for the bottleneck number of channels 421 | which is twice larger in every block. The number of channels in outer 1x1 422 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 423 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 424 | 425 | Args: 426 | pretrained (bool): If True, returns a model pre-trained on ImageNet 427 | progress (bool): If True, displays a progress bar of the download to stderr 428 | """ 429 | kwargs['width_per_group'] = 64 * 2 430 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 431 | pretrained, progress, **kwargs) 432 | 433 | if __name__ == '__main__': 434 | inputs = torch.rand((1, 3, 224, 224)) 435 | model = resnet18(pretrained=True) 436 | model.eval() 437 | out = model(inputs) 438 | print('out shape:{}'.format(out.size())) 439 | feas, fea_fc, out = model(inputs, True) 440 | for idx, fea in enumerate(feas): 441 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 442 | print('fc shape:{}'.format(fea_fc.size())) 443 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | import torch.nn as nn 5 | from torchvision._internally_replaced_utils import load_state_dict_from_url 6 | from typing import Callable, Any, List 7 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 8 | 9 | __all__ = [ 10 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0' 11 | ] 12 | 13 | model_urls = { 14 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 15 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 16 | 'shufflenetv2_x1.5': None, 17 | 'shufflenetv2_x2.0': None, 18 | } 19 | 20 | 21 | def channel_shuffle(x: Tensor, groups: int) -> Tensor: 22 | batchsize, num_channels, height, width = x.size() 23 | channels_per_group = num_channels // groups 24 | 25 | # reshape 26 | x = x.view(batchsize, groups, 27 | channels_per_group, height, width) 28 | 29 | x = torch.transpose(x, 1, 2).contiguous() 30 | 31 | # flatten 32 | x = x.view(batchsize, -1, height, width) 33 | 34 | return x 35 | 36 | 37 | class InvertedResidual(nn.Module): 38 | def __init__( 39 | self, 40 | inp: int, 41 | oup: int, 42 | stride: int 43 | ) -> None: 44 | super(InvertedResidual, self).__init__() 45 | 46 | if not (1 <= stride <= 3): 47 | raise ValueError('illegal stride value') 48 | self.stride = stride 49 | 50 | branch_features = oup // 2 51 | assert (self.stride != 1) or (inp == branch_features << 1) 52 | 53 | if self.stride > 1: 54 | self.branch1 = nn.Sequential( 55 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 56 | nn.BatchNorm2d(inp), 57 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 58 | nn.BatchNorm2d(branch_features), 59 | nn.ReLU(inplace=True), 60 | ) 61 | else: 62 | self.branch1 = nn.Sequential() 63 | 64 | self.branch2 = nn.Sequential( 65 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 66 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 67 | nn.BatchNorm2d(branch_features), 68 | nn.ReLU(inplace=True), 69 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 70 | nn.BatchNorm2d(branch_features), 71 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 72 | nn.BatchNorm2d(branch_features), 73 | nn.ReLU(inplace=True), 74 | ) 75 | 76 | @staticmethod 77 | def depthwise_conv( 78 | i: int, 79 | o: int, 80 | kernel_size: int, 81 | stride: int = 1, 82 | padding: int = 0, 83 | bias: bool = False 84 | ) -> nn.Conv2d: 85 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 86 | 87 | def forward(self, x: Tensor) -> Tensor: 88 | if self.stride == 1: 89 | x1, x2 = x.chunk(2, dim=1) 90 | out = torch.cat((x1, self.branch2(x2)), dim=1) 91 | else: 92 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 93 | 94 | out = channel_shuffle(out, 2) 95 | 96 | return out 97 | 98 | def switch_to_deploy(self): 99 | if len(self.branch1) > 0: 100 | self.branch1 = nn.Sequential( 101 | fuse_conv_bn(self.branch1[0], self.branch1[1]), 102 | fuse_conv_bn(self.branch1[2], self.branch1[3]), 103 | self.branch1[4] 104 | ) 105 | self.branch2 = nn.Sequential( 106 | fuse_conv_bn(self.branch2[0], self.branch2[1]), 107 | self.branch2[2], 108 | fuse_conv_bn(self.branch2[3], self.branch2[4]), 109 | fuse_conv_bn(self.branch2[5], self.branch2[6]), 110 | self.branch2[7] 111 | ) 112 | 113 | class ShuffleNetV2(nn.Module): 114 | def __init__( 115 | self, 116 | stages_repeats: List[int], 117 | stages_out_channels: List[int], 118 | num_classes: int = 1000, 119 | inverted_residual: Callable[..., nn.Module] = InvertedResidual 120 | ) -> None: 121 | super(ShuffleNetV2, self).__init__() 122 | 123 | if len(stages_repeats) != 3: 124 | raise ValueError('expected stages_repeats as list of 3 positive ints') 125 | if len(stages_out_channels) != 5: 126 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 127 | self._stage_out_channels = stages_out_channels 128 | 129 | input_channels = 3 130 | output_channels = self._stage_out_channels[0] 131 | self.conv1 = nn.Sequential( 132 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 133 | nn.BatchNorm2d(output_channels), 134 | nn.ReLU(inplace=True), 135 | ) 136 | input_channels = output_channels 137 | 138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | 140 | # Static annotations for mypy 141 | self.stage2: nn.Sequential 142 | self.stage3: nn.Sequential 143 | self.stage4: nn.Sequential 144 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 145 | for name, repeats, output_channels in zip( 146 | stage_names, stages_repeats, self._stage_out_channels[1:]): 147 | seq = [inverted_residual(input_channels, output_channels, 2)] 148 | for i in range(repeats - 1): 149 | seq.append(inverted_residual(output_channels, output_channels, 1)) 150 | setattr(self, name, nn.Sequential(*seq)) 151 | input_channels = output_channels 152 | 153 | output_channels = self._stage_out_channels[-1] 154 | self.conv5 = nn.Sequential( 155 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 156 | nn.BatchNorm2d(output_channels), 157 | nn.ReLU(inplace=True), 158 | ) 159 | 160 | self.fc = nn.Linear(output_channels, num_classes) 161 | 162 | def switch_to_deploy(self): 163 | self.conv1 = nn.Sequential( 164 | fuse_conv_bn(self.conv1[0], self.conv1[1]), 165 | self.conv1[2] 166 | ) 167 | self.conv5 = nn.Sequential( 168 | fuse_conv_bn(self.conv5[0], self.conv5[1]), 169 | self.conv5[2] 170 | ) 171 | 172 | def _forward_impl(self, x: Tensor, need_fea=False) -> Tensor: 173 | if need_fea: 174 | features, features_fc = self.forward_features(x, need_fea) 175 | return features, features_fc, self.fc(features_fc) 176 | else: 177 | # See note [TorchScript super()] 178 | x = self.forward_features(x) 179 | x = self.fc(x) 180 | return x 181 | 182 | def forward(self, x: Tensor, need_fea=False) -> Tensor: 183 | return self._forward_impl(x, need_fea) 184 | 185 | def forward_features(self, x, need_fea=False): 186 | x = self.conv1(x) 187 | x = self.maxpool(x) 188 | if need_fea: 189 | x2 = self.stage2(x) 190 | x3 = self.stage3(x2) 191 | x4 = self.stage4(x3) 192 | x4 = self.conv5(x4) 193 | return [x, x2, x3, x4], x4.mean([2, 3]) 194 | else: 195 | # See note [TorchScript super()] 196 | x = self.stage2(x) 197 | x = self.stage3(x) 198 | x = self.stage4(x) 199 | x = self.conv5(x) 200 | x = x.mean([2, 3]) # globalpool 201 | return x 202 | 203 | def cam_layer(self): 204 | return self.stage4 205 | 206 | 207 | def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: 208 | model = ShuffleNetV2(*args, **kwargs) 209 | 210 | if pretrained: 211 | model_url = model_urls[arch] 212 | if model_url is None: 213 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 214 | else: 215 | state_dict = load_state_dict_from_url(model_url, progress=progress) 216 | model = load_weights_from_state_dict(model, state_dict) 217 | return model 218 | 219 | 220 | def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 221 | """ 222 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 223 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 224 | `_. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 231 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 232 | 233 | 234 | def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 235 | """ 236 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 237 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 238 | `_. 239 | 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 245 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 246 | 247 | 248 | def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 249 | """ 250 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 251 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 252 | `_. 253 | 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | progress (bool): If True, displays a progress bar of the download to stderr 257 | """ 258 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 259 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 260 | 261 | 262 | def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 263 | """ 264 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 265 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 266 | `_. 267 | 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | progress (bool): If True, displays a progress bar of the download to stderr 271 | """ 272 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 273 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 274 | 275 | if __name__ == '__main__': 276 | inputs = torch.rand((1, 3, 224, 224)) 277 | model = shufflenet_v2_x1_0(pretrained=True) 278 | model.eval() 279 | out = model(inputs) 280 | print('out shape:{}'.format(out.size())) 281 | feas, fea_fc, out = model(inputs, True) 282 | for idx, fea in enumerate(feas): 283 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 284 | print('fc shape:{}'.format(fea_fc.size())) 285 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torchvision._internally_replaced_utils import load_state_dict_from_url 5 | from typing import Union, List, Dict, Any, cast 6 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 7 | 8 | __all__ = [ 9 | 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 10 | 'vgg19_bn', 'vgg19', 11 | ] 12 | 13 | 14 | model_urls = { 15 | 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth', 16 | 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth', 17 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 18 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 19 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 20 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 21 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 22 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 23 | } 24 | 25 | 26 | class VGG(nn.Module): 27 | 28 | def __init__( 29 | self, 30 | features: nn.Module, 31 | num_classes: int = 1000, 32 | init_weights: bool = True 33 | ) -> None: 34 | super(VGG, self).__init__() 35 | self.features = features 36 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 37 | self.classifier = nn.Sequential( 38 | nn.Linear(512 * 7 * 7, 4096), 39 | nn.ReLU(True), 40 | nn.Dropout(), 41 | nn.Linear(4096, 4096), 42 | nn.ReLU(True), 43 | nn.Dropout(), 44 | nn.Linear(4096, num_classes), 45 | ) 46 | if init_weights: 47 | self._initialize_weights() 48 | 49 | def _initialize_weights(self) -> None: 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | nn.init.normal_(m.weight, 0, 0.01) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | def forward(self, x: torch.Tensor, need_fea=False) -> torch.Tensor: 63 | if need_fea: 64 | features, features_fc = self.forward_features(x, need_fea) 65 | return features, features_fc, self.classifier(features_fc) 66 | else: 67 | x = self.forward_features(x) 68 | x = self.classifier(x) 69 | return x 70 | 71 | def forward_features(self, x, need_fea=False): 72 | if need_fea: 73 | input_size = x.size(2) 74 | scale = [4, 8, 16, 32] 75 | features = [None, None, None, None] 76 | for idx, layer in enumerate(self.features): 77 | x = layer(x) 78 | if input_size // x.size(2) in scale: 79 | features[scale.index(input_size // x.size(2))] = x 80 | x = self.avgpool(x) 81 | x = torch.flatten(x, 1) 82 | return features, x 83 | else: 84 | x = self.features(x) 85 | x = self.avgpool(x) 86 | x = torch.flatten(x, 1) 87 | return x 88 | 89 | def cam_layer(self): 90 | return self.features[-1] 91 | 92 | def switch_to_deploy(self): 93 | new_features = [] 94 | for i in range(len(self.features)): 95 | if type(self.features[i]) is nn.BatchNorm2d: 96 | new_features[-1] = fuse_conv_bn(new_features[-1], self.features[i]) 97 | else: 98 | new_features.append(self.features[i]) 99 | self.features = nn.Sequential(*new_features) 100 | 101 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 102 | layers: List[nn.Module] = [] 103 | in_channels = 3 104 | for v in cfg: 105 | if v == 'M': 106 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 107 | else: 108 | v = cast(int, v) 109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 110 | if batch_norm: 111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 112 | else: 113 | layers += [conv2d, nn.ReLU(inplace=True)] 114 | in_channels = v 115 | return nn.Sequential(*layers) 116 | 117 | 118 | cfgs: Dict[str, List[Union[str, int]]] = { 119 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 120 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 121 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 122 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 123 | } 124 | 125 | 126 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: 127 | if pretrained: 128 | kwargs['init_weights'] = False 129 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 130 | if pretrained: 131 | state_dict = load_state_dict_from_url(model_urls[arch], 132 | progress=progress) 133 | model = load_weights_from_state_dict(model, state_dict) 134 | return model 135 | 136 | 137 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 138 | r"""VGG 11-layer model (configuration "A") from 139 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 140 | The required minimum input size of the model is 32x32. 141 | 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | progress (bool): If True, displays a progress bar of the download to stderr 145 | """ 146 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 147 | 148 | 149 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 150 | r"""VGG 11-layer model (configuration "A") with batch normalization 151 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 152 | The required minimum input size of the model is 32x32. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | progress (bool): If True, displays a progress bar of the download to stderr 157 | """ 158 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 159 | 160 | 161 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 162 | r"""VGG 13-layer model (configuration "B") 163 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 164 | The required minimum input size of the model is 32x32. 165 | 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | progress (bool): If True, displays a progress bar of the download to stderr 169 | """ 170 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 171 | 172 | 173 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 174 | r"""VGG 13-layer model (configuration "B") with batch normalization 175 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 176 | The required minimum input size of the model is 32x32. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | progress (bool): If True, displays a progress bar of the download to stderr 181 | """ 182 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 183 | 184 | 185 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 186 | r"""VGG 16-layer model (configuration "D") 187 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 188 | The required minimum input size of the model is 32x32. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | progress (bool): If True, displays a progress bar of the download to stderr 193 | """ 194 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 195 | 196 | 197 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 198 | r"""VGG 16-layer model (configuration "D") with batch normalization 199 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 200 | The required minimum input size of the model is 32x32. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | progress (bool): If True, displays a progress bar of the download to stderr 205 | """ 206 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 207 | 208 | 209 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 210 | r"""VGG 19-layer model (configuration "E") 211 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 212 | The required minimum input size of the model is 32x32. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | progress (bool): If True, displays a progress bar of the download to stderr 217 | """ 218 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 219 | 220 | 221 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 222 | r"""VGG 19-layer model (configuration 'E') with batch normalization 223 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 224 | The required minimum input size of the model is 32x32. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 231 | 232 | if __name__ == '__main__': 233 | inputs = torch.rand((1, 3, 224, 224)) 234 | model = vgg11(pretrained=True) 235 | model.eval() 236 | out = model(inputs) 237 | print('out shape:{}'.format(out.size())) 238 | feas, fea_fc, out = model(inputs, True) 239 | for idx, fea in enumerate(feas): 240 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 241 | print('fc shape:{}'.format(fea_fc.size())) 242 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /model/vovnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.hub import load_state_dict_from_url 6 | from collections import OrderedDict 7 | from utils.utils import load_weights_from_state_dict, fuse_conv_bn 8 | 9 | __all__ = ['vovnet39', 'vovnet57'] 10 | 11 | 12 | model_urls = { 13 | 'vovnet39': 'https://github.com/z1069614715/pretrained-weights/releases/download/vovnet_v1.0/vovnet39_torchvision.pth', 14 | 'vovnet57': 'https://github.com/z1069614715/pretrained-weights/releases/download/vovnet_v1.0/vovnet57_torchvision.pth' 15 | } 16 | 17 | 18 | def conv3x3(in_channels, out_channels, module_name, postfix, 19 | stride=1, groups=1, kernel_size=3, padding=1): 20 | """3x3 convolution with padding""" 21 | return [ 22 | ('{}_{}/conv'.format(module_name, postfix), 23 | nn.Conv2d(in_channels, out_channels, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | groups=groups, 28 | bias=False)), 29 | ('{}_{}/norm'.format(module_name, postfix), 30 | nn.BatchNorm2d(out_channels)), 31 | ('{}_{}/relu'.format(module_name, postfix), 32 | nn.ReLU(inplace=True)), 33 | ] 34 | 35 | 36 | def conv1x1(in_channels, out_channels, module_name, postfix, 37 | stride=1, groups=1, kernel_size=1, padding=0): 38 | """1x1 convolution""" 39 | return [ 40 | ('{}_{}/conv'.format(module_name, postfix), 41 | nn.Conv2d(in_channels, out_channels, 42 | kernel_size=kernel_size, 43 | stride=stride, 44 | padding=padding, 45 | groups=groups, 46 | bias=False)), 47 | ('{}_{}/norm'.format(module_name, postfix), 48 | nn.BatchNorm2d(out_channels)), 49 | ('{}_{}/relu'.format(module_name, postfix), 50 | nn.ReLU(inplace=True)), 51 | ] 52 | 53 | 54 | class _OSA_module(nn.Module): 55 | def __init__(self, 56 | in_ch, 57 | stage_ch, 58 | concat_ch, 59 | layer_per_block, 60 | module_name, 61 | identity=False): 62 | super(_OSA_module, self).__init__() 63 | 64 | self.identity = identity 65 | self.layers = nn.ModuleList() 66 | in_channel = in_ch 67 | for i in range(layer_per_block): 68 | self.layers.append(nn.Sequential( 69 | OrderedDict(conv3x3(in_channel, stage_ch, module_name, i)))) 70 | in_channel = stage_ch 71 | 72 | # feature aggregation 73 | in_channel = in_ch + layer_per_block * stage_ch 74 | self.concat = nn.Sequential( 75 | OrderedDict(conv1x1(in_channel, concat_ch, module_name, 'concat'))) 76 | 77 | def forward(self, x): 78 | identity_feat = x 79 | output = [] 80 | output.append(x) 81 | for layer in self.layers: 82 | x = layer(x) 83 | output.append(x) 84 | 85 | x = torch.cat(output, dim=1) 86 | xt = self.concat(x) 87 | 88 | if self.identity: 89 | xt = xt + identity_feat 90 | 91 | return xt 92 | 93 | def switch_to_deploy(self): 94 | new_features = [] 95 | for i in range(len(self.layers)): 96 | if type(self.layers[i]) is nn.Sequential: 97 | new_features.append(nn.Sequential( 98 | fuse_conv_bn(self.layers[i][0], self.layers[i][1]), 99 | self.layers[i][2] 100 | )) 101 | elif type(self.layers[i]) is nn.BatchNorm2d: 102 | new_features[-1] = fuse_conv_bn(new_features[-1], self.layers[i]) 103 | print(1) 104 | else: 105 | new_features.append(self.layers[i]) 106 | self.layers = nn.Sequential(*new_features) 107 | 108 | self.concat = nn.Sequential( 109 | fuse_conv_bn(self.concat[0], self.concat[1]), 110 | self.concat[2] 111 | ) 112 | 113 | class _OSA_stage(nn.Sequential): 114 | def __init__(self, 115 | in_ch, 116 | stage_ch, 117 | concat_ch, 118 | block_per_stage, 119 | layer_per_block, 120 | stage_num): 121 | super(_OSA_stage, self).__init__() 122 | 123 | if not stage_num == 2: 124 | self.add_module('Pooling', 125 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)) 126 | 127 | module_name = f'OSA{stage_num}_1' 128 | self.add_module(module_name, 129 | _OSA_module(in_ch, 130 | stage_ch, 131 | concat_ch, 132 | layer_per_block, 133 | module_name)) 134 | for i in range(block_per_stage-1): 135 | module_name = f'OSA{stage_num}_{i+2}' 136 | self.add_module(module_name, 137 | _OSA_module(concat_ch, 138 | stage_ch, 139 | concat_ch, 140 | layer_per_block, 141 | module_name, 142 | identity=True)) 143 | 144 | 145 | class VoVNet(nn.Module): 146 | def __init__(self, 147 | config_stage_ch, 148 | config_concat_ch, 149 | block_per_stage, 150 | layer_per_block, 151 | num_classes=1000): 152 | super(VoVNet, self).__init__() 153 | 154 | # Stem module 155 | stem = conv3x3(3, 64, 'stem', '1', 2) 156 | stem += conv3x3(64, 64, 'stem', '2', 1) 157 | stem += conv3x3(64, 128, 'stem', '3', 2) 158 | self.add_module('stem', nn.Sequential(OrderedDict(stem))) 159 | 160 | stem_out_ch = [128] 161 | in_ch_list = stem_out_ch + config_concat_ch[:-1] 162 | self.stage_names = [] 163 | for i in range(4): #num_stages 164 | name = 'stage%d' % (i+2) 165 | self.stage_names.append(name) 166 | self.add_module(name, 167 | _OSA_stage(in_ch_list[i], 168 | config_stage_ch[i], 169 | config_concat_ch[i], 170 | block_per_stage[i], 171 | layer_per_block, 172 | i+2)) 173 | 174 | self.classifier = nn.Linear(config_concat_ch[-1], num_classes) 175 | 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | nn.init.kaiming_normal_(m.weight) 179 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 180 | nn.init.constant_(m.weight, 1) 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.Linear): 183 | nn.init.constant_(m.bias, 0) 184 | 185 | def switch_to_deploy(self): 186 | self.stem = nn.Sequential( 187 | fuse_conv_bn(self.stem[0], self.stem[1]), 188 | self.stem[2], 189 | fuse_conv_bn(self.stem[3], self.stem[4]), 190 | self.stem[5], 191 | fuse_conv_bn(self.stem[6], self.stem[7]), 192 | self.stem[8], 193 | ) 194 | 195 | def forward(self, x, need_fea=False): 196 | if need_fea: 197 | features, features_fc = self.forward_features(x, need_fea) 198 | return features, features_fc, self.classifier(features_fc) 199 | else: 200 | x = self.forward_features(x) 201 | x = self.classifier(x) 202 | return x 203 | 204 | def forward_features(self, x, need_fea=False): 205 | if need_fea: 206 | features = [] 207 | x = self.stem(x) 208 | for name in self.stage_names: 209 | x = getattr(self, name)(x) 210 | features.append(x) 211 | x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1) 212 | return features, x 213 | else: 214 | x = self.stem(x) 215 | for name in self.stage_names: 216 | x = getattr(self, name)(x) 217 | x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1) 218 | return x 219 | 220 | def cam_layer(self): 221 | return getattr(self, self.stage_names[-1]) 222 | 223 | def _vovnet(arch, 224 | config_stage_ch, 225 | config_concat_ch, 226 | block_per_stage, 227 | layer_per_block, 228 | pretrained, 229 | progress, 230 | **kwargs): 231 | model = VoVNet(config_stage_ch, config_concat_ch, 232 | block_per_stage, layer_per_block, 233 | **kwargs) 234 | if pretrained: 235 | state_dict = load_state_dict_from_url(model_urls[arch], 236 | progress=progress) 237 | for keys in list(state_dict.keys()): 238 | state_dict[f'{keys.replace("module.", "")}'] = state_dict[keys] 239 | del state_dict[keys] 240 | model = load_weights_from_state_dict(model, state_dict) 241 | return model 242 | 243 | 244 | def vovnet57(pretrained=False, progress=True, **kwargs): 245 | r"""Constructs a VoVNet-57 model as described in 246 | `"An Energy and GPU-Computation Efficient Backbone Networks" 247 | `_. 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _vovnet('vovnet57', [128, 160, 192, 224], [256, 512, 768, 1024], 253 | [1,1,4,3], 5, pretrained, progress, **kwargs) 254 | 255 | 256 | def vovnet39(pretrained=False, progress=True, **kwargs): 257 | r"""Constructs a VoVNet-39 model as described in 258 | `"An Energy and GPU-Computation Efficient Backbone Networks" 259 | `_. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _vovnet('vovnet39', [128, 160, 192, 224], [256, 512, 768, 1024], 265 | [1,1,2,2], 5, pretrained, progress, **kwargs) 266 | 267 | 268 | def vovnet27_slim(pretrained=False, progress=True, **kwargs): 269 | r"""Constructs a VoVNet-39 model as described in 270 | `"An Energy and GPU-Computation Efficient Backbone Networks" 271 | `_. 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _vovnet('vovnet27_slim', [64, 80, 96, 112], [128, 256, 384, 512], 277 | [1,1,1,1], 5, pretrained, progress, **kwargs) 278 | 279 | if __name__ == '__main__': 280 | inputs = torch.rand((1, 3, 224, 224)) 281 | model = vovnet39(pretrained=False) 282 | model.eval() 283 | out = model(inputs) 284 | print('out shape:{}'.format(out.size())) 285 | feas, fea_fc, out = model(inputs, True) 286 | for idx, fea in enumerate(feas): 287 | print('feature {} shape:{}'.format(idx + 1, fea.size())) 288 | print('fc shape:{}'.format(fea_fc.size())) 289 | print('out shape:{}'.format(out.size())) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | from PIL import ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | import os, torch, argparse, datetime, tqdm, random 6 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | from copy import deepcopy 10 | from utils import utils_aug 11 | from utils.utils import predict_single_image, cam_visual, dict_to_PrettyTable, select_device, model_fuse 12 | from utils.utils_model import select_model 13 | 14 | def set_seed(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | def parse_opt(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--source', type=str, default=r'', help='source data path(file, folder)') 23 | parser.add_argument('--label_path', type=str, default=r'dataset/label.txt', help='label path') 24 | parser.add_argument('--save_path', type=str, default=r'runs/exp', help='save path for model and log') 25 | parser.add_argument('--test_tta', action="store_true", help='using TTA Tricks') 26 | parser.add_argument('--cam_visual', action="store_true", help='visual cam') 27 | parser.add_argument('--cam_type', type=str, choices=['GradCAM', 'HiResCAM', 'ScoreCAM', 'GradCAMPlusPlus', 'AblationCAM', 'XGradCAM', 'EigenCAM', 'FullGrad'], default='FullGrad', help='cam type') 28 | parser.add_argument('--half', action="store_true", help='use FP16 half-precision inference') 29 | parser.add_argument('--device', type=str, default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 30 | 31 | opt = parser.parse_known_args()[0] 32 | 33 | if not os.path.exists(os.path.join(opt.save_path, 'best.pt')): 34 | raise Exception('best.pt not found. please check your --save_path folder') 35 | ckpt = torch.load(os.path.join(opt.save_path, 'best.pt')) 36 | DEVICE = select_device(opt.device) 37 | if opt.half and DEVICE.type == 'cpu': 38 | raise Exception('half inference only supported GPU.') 39 | if opt.half and opt.cam_visual: 40 | raise Exception('cam visual only supported cpu. please set device=cpu.') 41 | if (opt.device != 'cpu') and opt.cam_visual: 42 | raise Exception('cam visual only supported FP32.') 43 | with open(opt.label_path) as f: 44 | CLASS_NUM = len(f.readlines()) 45 | model = select_model(ckpt['model'].name, CLASS_NUM) 46 | model.load_state_dict(ckpt['model'].float().state_dict(), strict=False) 47 | model_fuse(model) 48 | model = (model.half() if opt.half else model) 49 | model.to(DEVICE) 50 | model.eval() 51 | train_opt = ckpt['opt'] 52 | set_seed(train_opt.random_seed) 53 | 54 | print('found checkpoint from {}, model type:{}\n{}'.format(opt.save_path, ckpt['model'].name, dict_to_PrettyTable(ckpt['best_metrice'], 'Best Metrice'))) 55 | test_transform = utils_aug.get_dataprocessing_teststage(train_opt, opt, torch.load(os.path.join(opt.save_path, 'preprocess.transforms'))) 56 | 57 | try: 58 | with open(opt.label_path, encoding='utf-8') as f: 59 | label = list(map(lambda x: x.strip(), f.readlines())) 60 | except Exception as e: 61 | with open(opt.label_path, encoding='gbk') as f: 62 | label = list(map(lambda x: x.strip(), f.readlines())) 63 | 64 | return opt, DEVICE, model, test_transform, label 65 | 66 | if __name__ == '__main__': 67 | opt, DEVICE, model, test_transform, label = parse_opt() 68 | 69 | if opt.cam_visual: 70 | cam_model = cam_visual(model, test_transform, DEVICE, opt) 71 | 72 | if os.path.isdir(opt.source): 73 | save_path = os.path.join(opt.save_path, 'predict', datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')) 74 | os.makedirs(os.path.join(save_path)) 75 | result = [] 76 | for file in tqdm.tqdm(os.listdir(opt.source)): 77 | pred, pred_result = predict_single_image(os.path.join(opt.source, file), model, test_transform, DEVICE, half=opt.half) 78 | result.append('{},{},{}'.format(os.path.join(opt.source, file), label[pred], pred_result[pred])) 79 | 80 | plt.figure(figsize=(6, 6)) 81 | if opt.cam_visual: 82 | cam_output = cam_model(os.path.join(opt.source, file)) 83 | plt.imshow(cam_output) 84 | else: 85 | plt.imshow(plt.imread(os.path.join(opt.source, file))) 86 | plt.axis('off') 87 | plt.title('predict label:{}\npredict probability:{:.4f}'.format(label[pred], float(pred_result[pred]))) 88 | plt.tight_layout() 89 | plt.savefig(os.path.join(save_path, file)) 90 | plt.clf() 91 | plt.close() 92 | 93 | with open(os.path.join(save_path, 'result.csv'), 'w+') as f: 94 | f.write('img_path,pred_class,pred_class_probability\n') 95 | f.write('\n'.join(result)) 96 | elif os.path.isfile(opt.source): 97 | pred, pred_result = predict_single_image(opt.source, model, test_transform, DEVICE, half=opt.half) 98 | 99 | plt.figure(figsize=(6, 6)) 100 | if opt.cam_visual: 101 | cam_output = cam_model(opt.source, pred) 102 | plt.imshow(cam_output) 103 | else: 104 | plt.imshow(plt.imread(opt.source)) 105 | plt.axis('off') 106 | plt.title('predict label:{}\npredict probability:{:.4f}'.format(label[pred], float(pred_result[pred]))) 107 | plt.tight_layout() 108 | plt.show() -------------------------------------------------------------------------------- /processing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import os, shutil, argparse 4 | import numpy as np 5 | 6 | # set random seed 7 | np.random.seed(0) 8 | ''' 9 | This file help us to split the dataset. 10 | It's going to be a training set, a validation set, a test set. 11 | We need to get all the image data into --data_path 12 | Example: 13 | dataset/train/dog/*.(jpg, png, bmp, ...) 14 | dataset/train/cat/*.(jpg, png, bmp, ...) 15 | dataset/train/person/*.(jpg, png, bmp, ...) 16 | and so on... 17 | 18 | program flow: 19 | 1. generate label.txt. 20 | 2. rename --data_path. 21 | 3. split dataset. 22 | ''' 23 | 24 | def parse_opt(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--data_path', type=str, default=r'dataset/train', help='all data path') 27 | parser.add_argument('--label_path', type=str, default=r'dataset/label.txt', help='label txt save path') 28 | parser.add_argument('--val_size', type=float, default=0.1, help='size of val set') 29 | parser.add_argument('--test_size', type=float, default=0.2, help='size of test set') 30 | opt = parser.parse_known_args()[0] 31 | return opt 32 | 33 | if __name__ == '__main__': 34 | opt = parse_opt() 35 | with open(opt.label_path, 'w+', encoding='utf-8') as f: 36 | f.write('\n'.join(os.listdir(opt.data_path))) 37 | 38 | str_len = len(str(len(os.listdir(opt.data_path)))) 39 | 40 | for idx, i in enumerate(os.listdir(opt.data_path)): 41 | os.rename(r'{}/{}'.format(opt.data_path, i), r'{}/{}'.format(opt.data_path, str(idx).zfill(str_len))) 42 | 43 | os.chdir(opt.data_path) 44 | 45 | for i in os.listdir(os.getcwd()): 46 | base_path = os.path.join(os.getcwd(), i) 47 | base_arr = os.listdir(base_path) 48 | np.random.shuffle(base_arr) 49 | 50 | val_path = base_path.replace('train', 'val') 51 | if not os.path.exists(val_path): 52 | os.makedirs(val_path) 53 | val_need_copy = base_arr[int(len(base_arr) * (1 - opt.val_size - opt.test_size)):int(len(base_arr) * (1 - opt.test_size))] 54 | for j in val_need_copy: 55 | shutil.copy(os.path.join(base_path, j), os.path.join(val_path, j)) 56 | 57 | test_path = base_path.replace('train', 'test') 58 | if not os.path.exists(test_path): 59 | os.makedirs(test_path) 60 | test_need_copy = base_arr[int(len(base_arr) * (1 - opt.test_size)):] 61 | for j in test_need_copy: 62 | shutil.move(os.path.join(base_path, j), os.path.join(test_path, j)) 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Pytorch-Classifier requirements 2 | # Usage: pip install -r requirements.txt 3 | 4 | # Base ------------------------------------------------------------------------ 5 | opencv-python 6 | grad-cam 7 | timm 8 | scikit-learn 9 | matplotlib 10 | prettytable 11 | pillow 12 | thop 13 | rfconv 14 | albumentations 15 | pycm 16 | 17 | # Export ---------------------------------------------------------------------- 18 | # onnx # ONNX export 19 | # onnx-simplifier # ONNX simplifier 20 | # nvidia-pyindex # TensorRT export 21 | # nvidia-tensorrt # TensorRT export 22 | 23 | # Export Inference ---------------------------------------------------------------- 24 | # onnxruntime # ONNX CPU Inference 25 | # onnxruntime-gpu # ONNX GPU Inference -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_aug.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_aug.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_aug.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_distill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_distill.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_fit.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_fit.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_model.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z1069614715/pytorch-classifier/ab9e3392855651b3ff622fb0efe034c840f58c44/utils/__pycache__/utils_model.cpython-39.pyc -------------------------------------------------------------------------------- /utils/utils_aug.py: -------------------------------------------------------------------------------- 1 | import torch, tqdm 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | from PIL import Image 5 | from copy import deepcopy 6 | import albumentations as A 7 | 8 | def get_mean_and_std(dataset, opt): 9 | '''Compute the mean and std value of dataset.''' 10 | if opt.imagenet_meanstd: 11 | print('using ImageNet Mean and Std. Mean:[0.485, 0.456, 0.406] Std:[0.229, 0.224, 0.225].') 12 | return [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 13 | else: 14 | print('Calculate the mean and variance of the dataset...') 15 | mean = torch.zeros(3) 16 | std = torch.zeros(3) 17 | for inputs, targets in tqdm.tqdm(dataset): 18 | inputs = transforms.ToTensor()(inputs) 19 | for i in range(3): 20 | mean[i] += inputs[i, :, :].mean() 21 | std[i] += inputs[i, :, :].std() 22 | mean.div_(len(dataset)) 23 | std.div_(len(dataset)) 24 | print('Calculate complete. Mean:[{:.3f}, {:.3f}, {:.3f}] Std:[{:.3f}, {:.3f}, {:.3f}].'.format(*list(mean.detach().numpy()), *list(std.detach().numpy()))) 25 | return mean, std 26 | 27 | def get_processing(dataset, opt): 28 | return transforms.Compose( 29 | [transforms.ToTensor(), 30 | transforms.Normalize(*get_mean_and_std(dataset, opt))]) 31 | 32 | def rand_bbox(size, lam): 33 | W = size[2] 34 | H = size[3] 35 | cut_rat = np.sqrt(1. - lam) 36 | cut_w = np.int(W * cut_rat) 37 | cut_h = np.int(H * cut_rat) 38 | 39 | # uniform 40 | cx = np.random.randint(W) 41 | cy = np.random.randint(H) 42 | 43 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 44 | bby1 = np.clip(cy - cut_h // 2, 0, H) 45 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 46 | bby2 = np.clip(cy + cut_h // 2, 0, H) 47 | 48 | return bbx1, bby1, bbx2, bby2 49 | 50 | def mixup_data(x, y, opt, alpha=1.0): 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | '''Returns mixed inputs, pairs of targets, and lambda''' 53 | if alpha > 0: 54 | lam = np.random.beta(alpha, alpha) 55 | else: 56 | lam = 1 57 | 58 | batch_size = x.size()[0] 59 | index = torch.randperm(batch_size).to(device) 60 | 61 | if opt.mixup == 'mixup': 62 | mixed_x = lam * x + (1 - lam) * x[index, :] 63 | elif opt.mixup == 'cutmix': 64 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) 65 | mixed_x = deepcopy(x) 66 | mixed_x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] 67 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 68 | else: 69 | raise 'Unsupported MixUp Methods.' 70 | return mixed_x, y, y[index], lam 71 | 72 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 73 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 74 | 75 | def select_Augment(opt): 76 | if opt.Augment == 'RandAugment': 77 | return transforms.RandAugment() 78 | elif opt.Augment == 'AutoAugment': 79 | return transforms.AutoAugment() 80 | elif opt.Augment == 'TrivialAugmentWide': 81 | return transforms.TrivialAugmentWide() 82 | elif opt.Augment == 'AugMix': 83 | return transforms.AugMix() 84 | else: 85 | return None 86 | 87 | def get_dataprocessing(dataset, opt, preprocess=None): 88 | if not preprocess: 89 | preprocess = get_processing(dataset, opt) 90 | torch.save(preprocess, r'{}/preprocess.transforms'.format(opt.save_path)) 91 | 92 | if len(opt.custom_augment.transforms) == 0: 93 | augment = select_Augment(opt) 94 | else: 95 | augment = opt.custom_augment 96 | 97 | if augment is None: 98 | train_transform = transforms.Compose( 99 | [transforms.Resize((int(opt.image_size + opt.image_size * 0.1))), 100 | transforms.RandomCrop((opt.image_size, opt.image_size)), 101 | preprocess 102 | ]) 103 | else: 104 | train_transform = transforms.Compose( 105 | [transforms.Resize((int(opt.image_size + opt.image_size * 0.1))), 106 | transforms.RandomCrop((opt.image_size, opt.image_size)), 107 | augment, 108 | preprocess 109 | ]) 110 | 111 | if opt.test_tta: 112 | test_transform = transforms.Compose([ 113 | transforms.Resize((int(opt.image_size + opt.image_size * 0.1))), 114 | transforms.TenCrop((opt.image_size, opt.image_size)), 115 | transforms.Lambda(lambda crops: torch.stack([preprocess(crop) for crop in crops])) 116 | ]) 117 | else: 118 | test_transform = transforms.Compose([ 119 | transforms.Resize((opt.image_size)), 120 | transforms.CenterCrop((opt.image_size, opt.image_size)), 121 | preprocess 122 | ]) 123 | 124 | return train_transform, test_transform 125 | 126 | def get_dataprocessing_teststage(train_opt, opt, preprocess): 127 | if opt.test_tta: 128 | test_transform = transforms.Compose([ 129 | transforms.Resize((int(train_opt.image_size + train_opt.image_size * 0.1))), 130 | transforms.TenCrop((train_opt.image_size, train_opt.image_size)), 131 | transforms.Lambda(lambda crops: torch.stack([preprocess(crop) for crop in crops])) 132 | ]) 133 | else: 134 | test_transform = transforms.Compose([ 135 | transforms.Resize((train_opt.image_size)), 136 | transforms.CenterCrop((train_opt.image_size, train_opt.image_size)), 137 | preprocess 138 | ]) 139 | return test_transform 140 | 141 | class CutOut(object): 142 | def __init__(self, n_holes=4, length=16): 143 | self.n_holes = n_holes 144 | self.length = length 145 | 146 | def __call__(self, img): 147 | img = np.array(img) 148 | h, w = img.shape[:2] 149 | mask = np.ones_like(img, np.float32) 150 | 151 | for n in range(self.n_holes): 152 | y = np.random.randint(h) 153 | x = np.random.randint(w) 154 | 155 | y1 = np.clip(y - self.length // 2, 0, h) 156 | y2 = np.clip(y + self.length // 2, 0, h) 157 | x1 = np.clip(x - self.length // 2, 0, w) 158 | x2 = np.clip(x + self.length // 2, 0, w) 159 | 160 | mask[y1:y2, x1:x2] = 0.0 161 | return Image.fromarray(np.array(img * mask, dtype=np.uint8)) 162 | 163 | def __str__(self): 164 | return 'CutOut' 165 | 166 | class Create_Albumentations_From_Name(object): 167 | # https://albumentations.ai/docs/api_reference/augmentations/transforms/ 168 | def __init__(self, name, **kwargs): 169 | self.name = name 170 | self.transform = eval('A.{}'.format(name))(**kwargs) 171 | 172 | def __call__(self, img): 173 | img = np.array(img) 174 | return Image.fromarray(np.array(self.transform(image=img)['image'], dtype=np.uint8)) 175 | 176 | def __str__(self): 177 | return self.name -------------------------------------------------------------------------------- /utils/utils_distill.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | __all__ = ['SoftTarget', 'MGD', 'SP', 'AT'] 6 | 7 | class SoftTarget(nn.Module): 8 | def __init__(self, T=4): 9 | super(SoftTarget, self).__init__() 10 | self.kl_loss = nn.KLDivLoss(reduction='batchmean') 11 | self.T = T 12 | 13 | def forward(self, student_pred, teacher_pred): 14 | student_pred_logsoftmax = torch.log_softmax(student_pred / self.T, dim=1) 15 | teacher_pred_softmax = torch.softmax(teacher_pred / self.T, dim=1) 16 | kd_loss = self.kl_loss(student_pred_logsoftmax, teacher_pred_softmax) * self.T * self.T 17 | return kd_loss 18 | 19 | def __str__(self): 20 | return 'SoftTarget' 21 | 22 | class MGD(nn.Module): 23 | """PyTorch version of `Masked Generative Distillation` 24 | Args: 25 | student_channels(int): Number of channels in the student's feature map. 26 | teacher_channels(int): Number of channels in the teacher's feature map. 27 | name (str): the loss name of the layer 28 | alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00007 29 | lambda_mgd (float, optional): masked ratio. Defaults to 0.15 30 | """ 31 | def __init__(self, 32 | student_channels, 33 | teacher_channels, 34 | alpha_mgd=0.00007, 35 | lambda_mgd=0.15, 36 | ): 37 | super(MGD, self).__init__() 38 | self.alpha_mgd = alpha_mgd 39 | self.lambda_mgd = lambda_mgd 40 | 41 | if student_channels != teacher_channels: 42 | self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) 43 | else: 44 | self.align = None 45 | 46 | self.generation = nn.Sequential( 47 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1)) 50 | 51 | 52 | def forward(self, 53 | preds_S, 54 | preds_T): 55 | """Forward function. 56 | Args: 57 | preds_S(Tensor): Bs*C*H*W, student's feature map 58 | preds_T(Tensor): Bs*C*H*W, teacher's feature map 59 | """ 60 | assert preds_S.shape[-2:] == preds_T.shape[-2:] 61 | 62 | if self.align is not None: 63 | preds_S = self.align(preds_S) 64 | 65 | loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgd 66 | 67 | return loss 68 | 69 | def get_dis_loss(self, preds_S, preds_T): 70 | loss_mse = nn.MSELoss(reduction='sum') 71 | N, C, H, W = preds_T.shape 72 | 73 | device = preds_S.device 74 | mat = torch.rand((N,C,1,1)).to(device) 75 | mat = torch.where(mat < self.lambda_mgd, 0, 1).to(device) 76 | 77 | masked_fea = torch.mul(preds_S, mat) 78 | new_fea = self.generation(masked_fea) 79 | 80 | dis_loss = loss_mse(new_fea, preds_T)/N 81 | 82 | return dis_loss 83 | 84 | def __str__(self): 85 | return 'MGD' 86 | 87 | class SP(nn.Module): 88 | ''' 89 | Similarity-Preserving Knowledge Distillation 90 | https://arxiv.org/pdf/1907.09682.pdf 91 | ''' 92 | 93 | def __init__(self): 94 | super(SP, self).__init__() 95 | 96 | def matmul_and_normalize(self, z): 97 | z = torch.flatten(z, 1) 98 | return F.normalize(torch.matmul(z, torch.t(z)), 1) 99 | 100 | def forward(self, fm_s, fm_t): 101 | g_t = self.matmul_and_normalize(fm_t) 102 | g_s = self.matmul_and_normalize(fm_s) 103 | 104 | sp_loss = torch.norm(g_t - g_s) ** 2 105 | sp_loss = sp_loss.sum() 106 | return sp_loss / (fm_s.size(0) ** 2) 107 | 108 | def __str__(self): 109 | return 'SP' 110 | 111 | class AT(nn.Module): 112 | ''' 113 | Paying More Attention to Attention: Improving the Performance of Convolutional 114 | Neural Netkworks wia Attention Transfer 115 | https://arxiv.org/pdf/1612.03928.pdf 116 | ''' 117 | def __init__(self): 118 | super(AT, self).__init__() 119 | 120 | def forward(self, fm_s, fm_t): 121 | fm_s_att = self.attention_map(fm_s) 122 | fm_t_att = self.attention_map(fm_t) 123 | 124 | return (fm_s_att - fm_t_att).pow(2).mean() 125 | 126 | def attention_map(self, x): 127 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 128 | 129 | def __str__(self): 130 | return 'AT' -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import torch, tqdm 2 | import numpy as np 3 | from copy import deepcopy 4 | from .utils_aug import mixup_data, mixup_criterion 5 | from .utils import Train_Metrice 6 | import time 7 | 8 | def fitting(model, ema, loss, optimizer, train_dataset, test_dataset, CLASS_NUM, DEVICE, scaler, show_thing, opt): 9 | model.train() 10 | metrice = Train_Metrice(CLASS_NUM) 11 | for x, y in tqdm.tqdm(train_dataset, desc='{} Train Stage'.format(show_thing)): 12 | x, y = x.to(DEVICE).float(), y.to(DEVICE).long() 13 | 14 | with torch.cuda.amp.autocast(opt.amp): 15 | if opt.rdrop: 16 | if opt.mixup != 'none' and np.random.rand() > 0.5: 17 | x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) 18 | pred = model(x_mixup) 19 | pred2 = model(x_mixup) 20 | l = mixup_criterion(loss, [pred, pred2], y_a, y_b, lam) 21 | pred = model(x) 22 | else: 23 | pred = model(x) 24 | pred2 = model(x) 25 | l = loss([pred, pred2], y) 26 | else: 27 | if opt.mixup != 'none' and np.random.rand() > 0.5: 28 | x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) 29 | pred = model(x_mixup) 30 | l = mixup_criterion(loss, pred, y_a, y_b, lam) 31 | pred = model(x) 32 | else: 33 | 34 | pred = model(x) 35 | l = loss(pred, y) 36 | 37 | 38 | metrice.update_loss(float(l.data)) 39 | metrice.update_y(y, pred) 40 | 41 | scaler.scale(l).backward() 42 | 43 | scaler.step(optimizer) 44 | scaler.update() 45 | optimizer.zero_grad() 46 | if ema: 47 | ema.update(model) 48 | 49 | if ema: 50 | model_eval = ema.ema 51 | else: 52 | model_eval = model.eval() 53 | with torch.inference_mode(): 54 | for x, y in tqdm.tqdm(test_dataset, desc='{} Test Stage'.format(show_thing)): 55 | x, y = x.to(DEVICE).float(), y.to(DEVICE).long() 56 | 57 | with torch.cuda.amp.autocast(opt.amp): 58 | if opt.test_tta: 59 | bs, ncrops, c, h, w = x.size() 60 | pred = model_eval(x.view(-1, c, h, w)) 61 | pred = pred.view(bs, ncrops, -1).mean(1) 62 | l = loss(pred, y) 63 | else: 64 | pred = model_eval(x) 65 | l = loss(pred, y) 66 | 67 | metrice.update_loss(float(l.data), isTest=True) 68 | metrice.update_y(y, pred, isTest=True) 69 | 70 | return metrice.get() 71 | 72 | 73 | def fitting_distill(teacher_model, student_model, ema, loss, kd_loss, optimizer, train_dataset, test_dataset, CLASS_NUM, 74 | DEVICE, scaler, show_thing, opt): 75 | student_model.train() 76 | metrice = Train_Metrice(CLASS_NUM) 77 | for x, y in tqdm.tqdm(train_dataset, desc='{} Train Stage'.format(show_thing)): 78 | x, y = x.to(DEVICE).float(), y.to(DEVICE).long() 79 | 80 | with torch.cuda.amp.autocast(opt.amp): 81 | if opt.mixup != 'none' and np.random.rand() > 0.5: 82 | x_mixup, y_a, y_b, lam = mixup_data(x, y, opt) 83 | s_features, s_features_fc, s_pred = student_model(x_mixup, need_fea=True) 84 | t_features, t_features_fc, t_pred = teacher_model(x_mixup, need_fea=True) 85 | l = mixup_criterion(loss, s_pred, y_a, y_b, lam) 86 | pred = student_model(x) 87 | else: 88 | s_features, s_features_fc, s_pred = student_model(x, need_fea=True) 89 | t_features, t_features_fc, t_pred = teacher_model(x, need_fea=True) 90 | l = loss(s_pred, y) 91 | if str(kd_loss) in ['SoftTarget']: 92 | kd_l = kd_loss(s_pred, t_pred) 93 | elif str(kd_loss) in ['MGD']: 94 | kd_l = kd_loss(s_features[-1], t_features[-1]) 95 | elif str(kd_loss) in ['SP']: 96 | kd_l = kd_loss(s_features[2], t_features[2]) + kd_loss(s_features[3], t_features[3]) 97 | elif str(kd_loss) in ['AT']: 98 | kd_l = kd_loss(s_features[2], t_features[2]) + kd_loss(s_features[3], t_features[3]) 99 | 100 | if str(kd_loss) in ['SoftTarget', 'SP', 'MGD']: 101 | kd_l *= (opt.kd_ratio / (1 - opt.kd_ratio)) if opt.kd_ratio < 1 else opt.kd_ratio 102 | elif str(kd_loss) in ['AT']: 103 | kd_l *= opt.kd_ratio 104 | 105 | metrice.update_loss(float(l.data)) 106 | metrice.update_loss(float(kd_l.data), isKd=True) 107 | if opt.mixup != 'none': 108 | metrice.update_y(y, pred) 109 | else: 110 | metrice.update_y(y, s_pred) 111 | 112 | scaler.scale(l + kd_l).backward() 113 | 114 | scaler.step(optimizer) 115 | scaler.update() 116 | optimizer.zero_grad() 117 | if ema: 118 | ema.update(student_model) 119 | 120 | if ema: 121 | model_eval = ema.ema 122 | else: 123 | model_eval = student_model.eval() 124 | with torch.inference_mode(): 125 | for x, y in tqdm.tqdm(test_dataset, desc='{} Test Stage'.format(show_thing)): 126 | x, y = x.to(DEVICE).float(), y.to(DEVICE).long() 127 | 128 | with torch.cuda.amp.autocast(opt.amp): 129 | if opt.test_tta: 130 | bs, ncrops, c, h, w = x.size() 131 | pred = model_eval(x.view(-1, c, h, w)) 132 | pred = pred.view(bs, ncrops, -1).mean(1) 133 | l = loss(pred, y) 134 | else: 135 | pred = model_eval(x) 136 | l = loss(pred, y) 137 | 138 | metrice.update_loss(float(l.data), isTest=True) 139 | metrice.update_y(y, pred, isTest=True) 140 | 141 | return metrice.get() -------------------------------------------------------------------------------- /utils/utils_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | __all__ = ['PolyLoss', 'CrossEntropyLoss', 'FocalLoss', 'RDropLoss'] 7 | 8 | class PolyLoss(torch.nn.Module): 9 | """ 10 | PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions 11 | 12 | """ 13 | def __init__(self, label_smoothing: float = 0.0, weight: torch.Tensor = None, epsilon=2.0): 14 | super().__init__() 15 | self.epsilon = epsilon 16 | self.label_smoothing = label_smoothing 17 | self.weight = weight 18 | 19 | def forward(self, outputs, targets): 20 | ce = F.cross_entropy(outputs, targets, label_smoothing=self.label_smoothing, weight=self.weight) 21 | pt = F.one_hot(targets, outputs.size()[1]) * F.softmax(outputs, 1) 22 | 23 | return (ce + self.epsilon * (1.0 - pt.sum(dim=1))).mean() 24 | 25 | class CrossEntropyLoss(nn.Module): 26 | def __init__(self, label_smoothing: float = 0.0, weight: torch.Tensor = None): 27 | super(CrossEntropyLoss, self).__init__() 28 | self.cross_entropy = nn.CrossEntropyLoss(weight=weight, label_smoothing=label_smoothing) 29 | 30 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 31 | return self.cross_entropy(input, target) 32 | 33 | class FocalLoss(nn.Module): 34 | def __init__(self, label_smoothing:float = 0.0, weight: torch.Tensor = None, gamma:float = 2.0): 35 | super(FocalLoss, self).__init__() 36 | self.label_smoothing = label_smoothing 37 | self.weight = weight 38 | self.gamma = gamma 39 | 40 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 41 | target_onehot = F.one_hot(target, num_classes=input.size(1)) 42 | target_onehot_labelsmoothing = torch.clamp(target_onehot.float(), min=self.label_smoothing/(input.size(1)-1), max=1.0-self.label_smoothing) 43 | input_softmax = F.softmax(input, dim=1) + 1e-7 44 | input_logsoftmax = torch.log(input_softmax) 45 | ce = -1 * input_logsoftmax * target_onehot_labelsmoothing 46 | fl = torch.pow((1 - input_softmax), self.gamma) * ce 47 | fl = fl.sum(1) * self.weight[target.long()] 48 | return fl.mean() 49 | 50 | class RDropLoss(nn.Module): 51 | def __init__(self, loss, a=0.3): 52 | super(RDropLoss, self).__init__() 53 | self.loss = loss 54 | self.a = a 55 | 56 | def forward(self, input, target: torch.Tensor) -> torch.Tensor: 57 | if type(input) is list: 58 | input1, input2 = input 59 | main_loss = (self.loss(input1, target) + self.loss(input2, target)) * 0.5 60 | kl_loss = self.compute_kl_loss(input1, input2) 61 | return main_loss + self.a * kl_loss 62 | else: 63 | return self.loss(input, target) 64 | 65 | def compute_kl_loss(self, p, q, pad_mask=None): 66 | p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') 67 | q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') 68 | 69 | # pad_mask is for seq-level tasks 70 | if pad_mask is not None: 71 | p_loss.masked_fill_(pad_mask, 0.) 72 | q_loss.masked_fill_(pad_mask, 0.) 73 | 74 | # You can choose whether to use function "sum" and "mean" depending on your task 75 | p_loss = p_loss.sum() 76 | q_loss = q_loss.sum() 77 | 78 | loss = (p_loss + q_loss) / 2 79 | return loss -------------------------------------------------------------------------------- /utils/utils_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import model as models 4 | from thop import clever_format, profile 5 | 6 | def select_model(name, num_classes, input_shape=None, channels=None, pretrained=False): 7 | if 'shufflenet_v2' in name: 8 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 9 | model.fc = nn.Sequential( 10 | nn.Dropout(p=0.2), 11 | nn.Linear(model.fc.in_features, num_classes) 12 | ) 13 | elif name == 'mobilenetv2': 14 | model = models.mobilenet_v2(pretrained=pretrained) 15 | model.classifier = nn.Sequential( 16 | nn.Dropout(0.2), 17 | nn.Linear(model.last_channel, num_classes), 18 | ) 19 | elif 'mobilenetv3' in name: 20 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 21 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) 22 | elif name.startswith('resnet'): 23 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 24 | model.fc = nn.Sequential( 25 | nn.Dropout(0.2), 26 | nn.Linear(model.fc.in_features, num_classes) 27 | ) 28 | elif name == 'wide_resnet50': 29 | model = models.wide_resnet50_2(pretrained=pretrained) 30 | model.fc = nn.Sequential( 31 | nn.Dropout(0.2), 32 | nn.Linear(model.fc.in_features, num_classes) 33 | ) 34 | elif name == 'wide_resnet101': 35 | model = models.wide_resnet101_2(pretrained=pretrained) 36 | model.fc = nn.Sequential( 37 | nn.Dropout(0.2), 38 | nn.Linear(model.fc.in_features, num_classes) 39 | ) 40 | elif name == 'resnext50': 41 | model = models.resnext50_32x4d(pretrained=pretrained) 42 | model.fc = nn.Sequential( 43 | nn.Dropout(0.2), 44 | nn.Linear(model.fc.in_features, num_classes) 45 | ) 46 | elif name == 'resnext101': 47 | model = models.resnext101_32x8d(pretrained=pretrained) 48 | model.fc = nn.Sequential( 49 | nn.Dropout(0.2), 50 | nn.Linear(model.fc.in_features, num_classes) 51 | ) 52 | elif 'resnest' in name: 53 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 54 | model.fc = nn.Sequential( 55 | nn.Dropout(0.2), 56 | nn.Linear(model.fc.in_features, num_classes) 57 | ) 58 | elif 'densenet' in name: 59 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 60 | model.classifier = nn.Sequential( 61 | nn.Dropout(0.2), 62 | nn.Linear(model.classifier.in_features, num_classes) 63 | ) 64 | elif 'vgg' in name: 65 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 66 | model.classifier[-1] = nn.Linear(in_features=model.classifier[-1].in_features, out_features=num_classes) 67 | elif 'efficientnet' in name: 68 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 69 | model.classifier[-1] = nn.Linear(in_features=model.classifier[-1].in_features, out_features=num_classes) 70 | elif name.startswith('mnasnet'): 71 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 72 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) 73 | elif 'vovnet' in name: 74 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 75 | model.classifier = nn.Sequential( 76 | nn.Dropout(0.2), 77 | nn.Linear(in_features=model.classifier.in_features, out_features=num_classes) 78 | ) 79 | elif 'convnext' in name: 80 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 81 | model.head = nn.Sequential( 82 | nn.Dropout(0.2), 83 | nn.Linear(in_features=model.head.in_features, out_features=num_classes) 84 | ) 85 | elif name == 'ghostnet': 86 | model = models.ghostnet(pretrained=pretrained) 87 | model.classifier[-1] = nn.Linear(in_features=model.classifier[-1].in_features, out_features=num_classes) 88 | elif 'RepVGG' in name: 89 | model = models.get_RepVGG_func_by_name(name, pretrained) 90 | model.linear = nn.Sequential( 91 | nn.Dropout(0.2), 92 | nn.Linear(model.linear.in_features, num_classes) 93 | ) 94 | elif 'sequencer2d' in name: 95 | model = eval('models.{}(pretrained={}, in_chans={}, img_size={})'.format(name, pretrained, channels, input_shape[0])) 96 | model.head = nn.Sequential( 97 | nn.Dropout(0.2), 98 | nn.Linear(model.head.in_features, num_classes) 99 | ) 100 | elif name.startswith('csp') or name.startswith('darknet') or name.startswith('cs3'): 101 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 102 | model.head = nn.Sequential( 103 | nn.Dropout(0.2), 104 | nn.Linear(model.head.in_features, num_classes) 105 | ) 106 | elif name.startswith('dpn'): 107 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 108 | model.classifier = nn.Sequential( 109 | nn.Dropout(0.2), 110 | nn.Linear(in_features=model.classifier.in_features, out_features=num_classes) 111 | ) 112 | elif name.startswith('repghostnet'): 113 | model = eval('models.{}(pretrained={})'.format(name, pretrained)) 114 | model.classifier = nn.Sequential( 115 | nn.Dropout(0.2), 116 | nn.Linear(in_features=model.classifier.in_features, out_features=num_classes) 117 | ) 118 | else: 119 | raise 'Unsupported Model Name.' 120 | 121 | if input_shape and channels: 122 | # 计算参数量和flops 123 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 124 | dummy_input = torch.randn(1, channels, input_shape[0], input_shape[1]).to(device) 125 | flops, params = profile(model.to(device), (dummy_input,), verbose=False) 126 | #--------------------------------------------------------# 127 | # flops * 2是因为profile没有将卷积作为两个operations 128 | # 有些论文将卷积算乘法、加法两个operations。此时乘2 129 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 130 | # --------------------------------------------------------# 131 | # flops = flops * 2 132 | flops, params = clever_format([flops, params], "%.3f") 133 | print('Select Model: {}'.format(name)) 134 | print('Total FLOPS: %s' % (flops)) 135 | print('Total params: %s' % (params)) 136 | model.name = name 137 | return model 138 | 139 | if __name__ == '__main__': 140 | model = select_model(name='shufflenet_v2_x0_5', num_classes=5, channels=3, input_shape=(224, 224)) 141 | model = select_model(name='shufflenet_v2_x1_0', num_classes=5, channels=3, input_shape=(224, 224)) 142 | # model = select_model(name='mobilenetv2', num_classes=5, channels=3, input_shape=(224, 224)) 143 | # model = select_model(name='mobilenetv3_large', num_classes=5, channels=3, input_shape=(224, 224)) 144 | # model = select_model(name='mobilenetv3_small', num_classes=5, channels=3, input_shape=(224, 224)) 145 | # model = select_model(name='resnet18', num_classes=5, channels=3, input_shape=(224, 224)) 146 | # model = select_model(name='resnet34', num_classes=5, channels=3, input_shape=(224, 224)) 147 | # model = select_model(name='resnet50', num_classes=5, channels=3, input_shape=(224, 224)) 148 | # model = select_model(name='resnet101', num_classes=5, channels=3, input_shape=(224, 224)) 149 | # model = select_model(name='wide_resnet50', num_classes=5, channels=3, input_shape=(224, 224)) 150 | # model = select_model(name='wide_resnet101', num_classes=5, channels=3, input_shape=(224, 224)) 151 | # model = select_model(name='resnext50', num_classes=5, channels=3, input_shape=(224, 224)) 152 | # model = select_model(name='resnext101', num_classes=5, channels=3, input_shape=(224, 224)) 153 | # model = select_model(name='resnest50', num_classes=5, channels=3, input_shape=(224, 224)) 154 | # model = select_model(name='resnest101', num_classes=5, channels=3, input_shape=(224, 224)) 155 | # model = select_model(name='resnest200', num_classes=5, channels=3, input_shape=(224, 224)) 156 | # model = select_model(name='resnest269', num_classes=5, channels=3, input_shape=(224, 224)) 157 | # model = select_model(name='densenet121', num_classes=5, channels=3, input_shape=(224, 224)) 158 | # model = select_model(name='densenet161', num_classes=5, channels=3, input_shape=(224, 224)) 159 | # model = select_model(name='densenet169', num_classes=5, channels=3, input_shape=(224, 224)) 160 | # model = select_model(name='densenet201', num_classes=5, channels=3, input_shape=(224, 224)) 161 | # model = select_model(name='vgg11', num_classes=5, channels=3, input_shape=(224, 224)) 162 | # model = select_model(name='vgg11_bn', num_classes=5, channels=3, input_shape=(224, 224)) 163 | # model = select_model(name='vgg13', num_classes=5, channels=3, input_shape=(224, 224)) 164 | # model = select_model(name='vgg13_bn', num_classes=5, channels=3, input_shape=(224, 224)) 165 | # model = select_model(name='vgg16', num_classes=5, channels=3, input_shape=(224, 224)) 166 | # model = select_model(name='vgg16_bn', num_classes=5, channels=3, input_shape=(224, 224)) 167 | # model = select_model(name='vgg19', num_classes=5, channels=3, input_shape=(224, 224)) 168 | # model = select_model(name='vgg19_bn', num_classes=5, channels=3, input_shape=(224, 224)) 169 | # model = select_model(name='efficientnet_b0', num_classes=5, channels=3, input_shape=(224, 224)) 170 | # model = select_model(name='efficientnet_b1', num_classes=5, channels=3, input_shape=(224, 224)) 171 | # model = select_model(name='efficientnet_b2', num_classes=5, channels=3, input_shape=(224, 224)) 172 | # model = select_model(name='efficientnet_b3', num_classes=5, channels=3, input_shape=(224, 224)) 173 | # model = select_model(name='efficientnet_b4', num_classes=5, channels=3, input_shape=(224, 224)) 174 | # model = select_model(name='efficientnet_b5', num_classes=5, channels=3, input_shape=(224, 224)) 175 | # model = select_model(name='efficientnet_b6', num_classes=5, channels=3, input_shape=(224, 224)) 176 | # model = select_model(name='efficientnet_b7', num_classes=5, channels=3, input_shape=(224, 224)) 177 | # model = select_model(name='efficientnet_v2_s', num_classes=5, channels=3, input_shape=(224, 224)) 178 | # model = select_model(name='efficientnet_v2_m', num_classes=5, channels=3, input_shape=(224, 224)) 179 | # model = select_model(name='efficientnet_v2_l', num_classes=5, channels=3, input_shape=(224, 224)) 180 | # model = select_model(name='mnasnet', num_classes=5, channels=3, input_shape=(224, 224)) 181 | # model = select_model(name='vovnet39', num_classes=5, channels=3, input_shape=(224, 224)) 182 | # model = select_model(name='vovnet57', num_classes=5, channels=3, input_shape=(224, 224)) 183 | # model = select_model(name='convnext_tiny', num_classes=5, channels=3, input_shape=(224, 224)) 184 | # model = select_model(name='convnext_small', num_classes=5, channels=3, input_shape=(224, 224)) 185 | # model = select_model(name='convnext_base', num_classes=5, channels=3, input_shape=(224, 224)) 186 | # model = select_model(name='convnext_large', num_classes=5, channels=3, input_shape=(224, 224)) 187 | # model = select_model(name='convnext_xlarge', num_classes=5, channels=3, input_shape=(224, 224)) 188 | # model = select_model(name='ghostnet', num_classes=5, channels=3, input_shape=(224, 224)) 189 | # model = select_model(name='RepVGG-A0', num_classes=5, channels=3, input_shape=(224, 224)) 190 | # model = select_model(name='sequencer2d_s', num_classes=5, channels=3, input_shape=(224, 224)) 191 | # model = select_model(name='cspresnet50', num_classes=5, channels=3, input_shape=(224, 224)) 192 | # model = select_model(name='dpn98', num_classes=5, channels=3, input_shape=(224, 224), pretrained=True) 193 | pass --------------------------------------------------------------------------------