├── README.md ├── RM+AMC ├── README.md ├── amc_search.py ├── amc_train.py ├── benchmark.py ├── data.py ├── models │ ├── __pycache__ │ │ └── rmnet.cpython-36.pyc │ └── rmnet.py └── utils.py ├── models ├── __init__.py ├── easy_mb2_mb1.py ├── easy_resnet2vgg.py ├── easy_resnet50_to_vgg.py ├── preactresnet2vgg.py ├── resnet2vgg.py ├── rmnet_from_scratch.py ├── rmnet_pruning.py ├── rmnext.py ├── rmobilenet.py ├── rmrepse.py └── rmrepvgg.py ├── ppt_souce_files ├── RMNet.pptx ├── downsample_prelu.pptx ├── downsample_relu.pptx ├── improving_repvgg.pptx ├── mobilenetv2_1.pptx ├── prune.pptx ├── readme.md ├── repvgg_vs_resnet.pptx ├── resnet.pptx └── resnet_vs_repvgg.pptx ├── train.py ├── train_pruning.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RMNet: Equivalently Removing Residual Connection from Networks 2 | 3 | This repository is the official implementation of "[RMNet: Equivalently Removing Residual Connection from Networks](https://arxiv.org/abs/2111.00687)". Welcome to discuss this paper with me on [知乎](https://zhuanlan.zhihu.com/p/453479354) 4 | 5 | ## Updates 6 | 7 | Feb 18,2022, For better understanding, we implement a simpilify version RM Operation on [ResNet](models/easy_resnet2vgg.py) and [MobileNetV2](models/easy_mb2_mb1.py). 8 | 9 | Jan 25,2022, RM+AMC purning: 10 | 11 | https://github.com/fxmeng/RMNet/blob/aec110b528c2646a19a20777bd5b93500e9b74a3/RM+AMC/README.md 12 | 13 | 14 | Dec 24, 2021, RMNet Pruning: 15 | 16 | `python train_pruning.py --sr xxx --threshold xxx` 17 | 18 | `python train_pruning.py --eval xxx/ckpt.pth` 19 | 20 | `python train_pruning.py --finetune xxx/ckpt.pth` 21 | 22 | Nov 15, 2021, RM Opeartion now supports PreActResNet. 23 | 24 | Nov 13, 2021, RM Opeartion now supports SEBlock. 25 | 26 | 27 | ## Requirements 28 | 29 | To install requirements: 30 | 31 | ```setup 32 | pip install torch 33 | pip install torchvision 34 | ``` 35 | 36 | ## Training 37 | 38 | To train the models in the paper, run this command: 39 | 40 | ```train 41 | python train.py -a rmrep_69 --dist-url 'tcp://127.0.0.1:23333' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 --workers 32 [imagenet-folder with train and val folders] 42 | ``` 43 | 44 | ## Evaluation 45 | 46 | To evaluate our pre-trained models trained on ImageNet, run: 47 | 48 | ```eval 49 | python train.py -a rmrep_69 -e checkpoint/rmrep_69.pth.tar [imagenet-folder with train and val folders] 50 | ``` 51 | 52 | ## Results 53 | 54 | Our model achieves the following performance on : 55 | 56 | ### Help pruning achieve better performance [drive.google](https://drive.google.com/drive/folders/1Mu3fXmZPm2EB9Bv17e41H3EfBOLlJYcw?usp=share_link) 57 | | Method | Speed(Imgs/Sec) | Acc(%)| 58 | | ----------------- | ----------------- | ---------- | 59 | |Baseline|3752|71.79| 60 | |AMC(0.75)|4873|70.94| 61 | |AMC(0.7)|4949|70.84| 62 | |AMC(0.5)|5483|68.89| 63 | |RM+AMC(0.75)|5120|**73.21**| 64 | |RM+AMC(0.7)|5238|72.63| 65 | |RM+AMC(0.6)|5675|71.88| 66 | |RM+AMC(0.5)|**6250**|71.01| 67 | 68 | ### Help RepVGG achieve better performance even when the depth is large 69 | | Arch | Top-1 Accuracy(%) | Top-5 Accuracy(%) | Train FLOPs(G) | Test FLOPs(M) | 70 | | ----------------------- | ----------------- | ----------------- | ----------- | ---------- | 71 | | RepVGG-21 | 72.508 | 90.840 | 2.4 | 2.1 | 72 | | **RepVGG-21(RM 0.25)** | **72.590** | **90.924** | **2.1** | **2.1** | 73 | | RepVGG-37 | 74.408 | 91.900 | 4.4 | 4.0 | 74 | | **RepVGG-37(RM 0.25)** | **74.478** | **91.892** | **3.9** | **4.0** | 75 | | RepVGG-69 | 74.526 | 92.182 | 8.6 | 7.7 | 76 | | **RepVGG-69(RM 0.5)** | **75.088** | **92.144** | **6.5** | **7.7** | 77 | | RepVGG-133 | 70.912 | 89.788 | 16.8 | 15.1 | 78 | | **RepVGG-133(RM 0.75)** | **74.560** | **92.000** | **10.6** | **15.1** | 79 | 80 | 81 | ### Image Classification on ImageNet [drive.google](https://drive.google.com/drive/folders/1Mu3fXmZPm2EB9Bv17e41H3EfBOLlJYcw?usp=share_link). 82 | | Model name | Top 1 Accuracy(%) | Top 5 Accuracy(%) | 83 | | ------------------ |---------------- | -------------- | 84 | | RMNeXt 41x5\_16 | 78.498 | 94.086 | 85 | | RMNeXt 50x5\_32 | 79.076 | 94.444 | 86 | | RMNeXt 50x6\_32 | 79.57 | 94.644 | 87 | | RMNeXt 101x6\_16 | 80.07 | 94.918 | 88 | | RMNeXt 152x6\_32 | 80.356 | 80.356 | 89 | 90 | ## Citation 91 | 92 | If you find this code useful, please cite the following paper: 93 | 94 | ``` 95 | @misc{meng2021rmnet, 96 | title={RMNet: Equivalently Removing Residual Connection from Networks}, 97 | author={Fanxu Meng and Hao Cheng and Jiaxin Zhuang and Ke Li and Xing Sun}, 98 | year={2021}, 99 | eprint={2111.00687}, 100 | archivePrefix={arXiv}, 101 | primaryClass={cs.CV} 102 | } 103 | ``` 104 | 105 | ## Contributing 106 | 107 | Our code is based on [RepVGG](https://github.com/DingXiaoH/RepVGG) and [nni/amc pruning](https://github.com/microsoft/nni/tree/master/examples/model_compress/pruning/amc) 108 | -------------------------------------------------------------------------------- /RM+AMC/README.md: -------------------------------------------------------------------------------- 1 | # AMC Pruning 2 | This example shows us how to use AMCPruner example. 3 | 4 | ## Step 1: train a model for pruning 5 | Run following command to train a mobilenetv2 model: 6 | ```bash 7 | python3 amc_train.py --model_type mobilenetv2 --dataset cifar10 8 | ``` 9 | Once finished, saved checkpoint file can be found at: 10 | ``` 11 | logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth 12 | ``` 13 | 14 | ## Step 2: Pruning with AMCPruner 15 | Run following command to prune the trained model: 16 | ```bash 17 | python3 amc_search.py --model_type mobilenetv2 --dataset cifar10 --ckpt logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth --flops_ratio 0.5 18 | ``` 19 | Once finished, pruned model and mask can be found at: 20 | ``` 21 | logs/mobilenetv2_cifar10_r0.5_search-run2 22 | ``` 23 | 24 | ## Step 3: Finetune pruned model 25 | Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model: 26 | ```bash 27 | python3 amc_train.py --model_type mobilenetv2 --dataset cifar10 --ckpt logs/mobilenetv2_cifar10_r0.5_search-run2/best_model.pth --mask logs/mobilenetv2_cifar10_r0.5_search-run2/best_mask.pth 28 | ``` 29 | Once finished, saved checkpoint file can be found at: 30 | ``` 31 | logs/mobilenetv2_cifar10_finetune-run4/ckpt.best.pth 32 | ``` 33 | 34 | # RM + AMC pruning 35 | 36 | ## Step 1: train a model for pruning 37 | Run following command to train a mobilenetv2 model: 38 | ```bash 39 | python3 amc_train.py --model_type mobilenetv2 40 | ``` 41 | Once finished, saved checkpoint file can be found at: 42 | ``` 43 | logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth 44 | ``` 45 | 46 | ## Step 2: Converting mobilenetv2 to mobilenetv1 and finetune this model 47 | Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model: 48 | ```bash 49 | python3 amc_train.py --model_type mobilenetv1 --dataset cifar10 --ckpt logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth 50 | ``` 51 | Once finished, saved checkpoint file can be found at: 52 | ``` 53 | logs/mobilenetv1_cifar10_finetune-run2/ckpt.best.pth 54 | ``` 55 | 56 | ## Step 3: Pruning with AMCPruner 57 | Run following command to prune the trained model: 58 | ```bash 59 | python3 amc_search.py --model_type mobilenetv1 --dataset cifar10 --ckpt logs/mobilenetv1_cifar10_finetune-run2/ckpt.best.pth --flops_ratio 0.5 60 | ``` 61 | Once finished, pruned model and mask can be found at: 62 | ``` 63 | logs/mobilenetv1_cifar10_r0.5_search-run3 64 | ``` 65 | 66 | ## Step 4: Finetune pruned model 67 | Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model: 68 | ```bash 69 | python3 amc_train.py --model_type mobilenetv1 --dataset cifar10 --ckpt logs/mobilenetv1_cifar10_r0.5_search-run3/best_model.pth --mask logs/mobilenetv1_cifar10_r0.5_search-run3/best_mask.pth 70 | ``` 71 | Once finished, saved checkpoint file can be found at: 72 | ``` 73 | logs/mobilenetv1_cifar10_finetune-run4/ckpt.best.pth 74 | ``` 75 | | Method | Speed(Imgs/Sec) | Acc(%)| 76 | | ----------------- | ----------------- | ---------- | 77 | |Baseline|3752|71.79| 78 | |AMC(0.75)|4873|70.94| 79 | |AMC(0.7)|4949|70.84| 80 | |AMC(0.5)|5483|68.89| 81 | |RM+AMC(0.75)|5120|**73.21**| 82 | |RM+AMC(0.7)|5238|72.63| 83 | |RM+AMC(0.6)|5675|71.88| 84 | |RM+AMC(0.5)|**6250**|71.01| -------------------------------------------------------------------------------- /RM+AMC/amc_search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import argparse 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | from nni.algorithms.compression.pytorch.pruning import AMCPruner 11 | from data import get_split_dataset 12 | from utils import AverageMeter, accuracy 13 | from models.rmnet import * 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='AMC search script') 17 | parser.add_argument('--model_type', default='mobilenetv2', type=str, 18 | choices=['mobilenetv1', 'mobilenetv2'], 19 | help='model to prune') 20 | parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)') 21 | parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size') 22 | parser.add_argument('--data_root', default='/dev/shm', type=str, help='dataset path') 23 | parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model') 24 | parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity') 25 | parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity') 26 | parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint') 27 | 28 | parser.add_argument('--train_episode', default=800, type=int, help='number of training episode') 29 | parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use') 30 | parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker') 31 | parser.add_argument('--suffix', default=None, type=str, help='suffix of auto-generated log directory') 32 | 33 | return parser.parse_args() 34 | 35 | 36 | def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1): 37 | if dataset == 'imagenet': 38 | n_class = 1000 39 | elif dataset == 'cifar10': 40 | n_class = 10 41 | else: 42 | raise ValueError('unsupported dataset') 43 | 44 | print('=> Loading checkpoint {} ..'.format(args.ckpt_path)) 45 | net = ckpt_to_model(args.ckpt_path,args.model_type,n_class) 46 | 47 | if torch.cuda.is_available() and n_gpu > 0: 48 | net = net.cuda() 49 | if n_gpu > 1: 50 | net = torch.nn.DataParallel(net, range(n_gpu)) 51 | 52 | return net 53 | 54 | def init_data(args): 55 | # split the train set into train + val 56 | # for CIFAR, split 5k for val 57 | # for ImageNet, split 3k for val 58 | val_size = 5000 if 'cifar' in args.dataset else 3000 59 | train_loader, val_loader, _ = get_split_dataset( 60 | args.dataset, args.batch_size, 61 | args.n_worker, val_size, 62 | data_root=args.data_root, 63 | shuffle=False 64 | ) # same sampling 65 | return train_loader, val_loader 66 | 67 | def validate(val_loader, model, verbose=False): 68 | batch_time = AverageMeter() 69 | losses = AverageMeter() 70 | top1 = AverageMeter() 71 | top5 = AverageMeter() 72 | 73 | criterion = nn.CrossEntropyLoss().cuda() 74 | # switch to evaluate mode 75 | model.eval() 76 | end = time.time() 77 | 78 | t1 = time.time() 79 | with torch.no_grad(): 80 | for i, (input, target) in enumerate(val_loader): 81 | target = target.to(device) 82 | input_var = torch.autograd.Variable(input).to(device) 83 | target_var = torch.autograd.Variable(target).to(device) 84 | 85 | # compute output 86 | output = model(input_var) 87 | loss = criterion(output, target_var) 88 | 89 | # measure accuracy and record loss 90 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 91 | losses.update(loss.item(), input.size(0)) 92 | top1.update(prec1.item(), input.size(0)) 93 | top5.update(prec5.item(), input.size(0)) 94 | 95 | # measure elapsed time 96 | batch_time.update(time.time() - end) 97 | end = time.time() 98 | t2 = time.time() 99 | if verbose: 100 | print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' % 101 | (losses.avg, top1.avg, top5.avg, t2 - t1)) 102 | return top5.avg 103 | 104 | 105 | if __name__ == "__main__": 106 | args = parse_args() 107 | 108 | device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu') 109 | 110 | model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu) 111 | _, val_loader = init_data(args) 112 | 113 | config_list = [{ 114 | 'op_types': ['Conv2d', 'Linear'] 115 | }] 116 | pruner = AMCPruner( 117 | model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset, 118 | train_episode=args.train_episode, flops_ratio=args.flops_ratio, lbound=args.lbound, 119 | rbound=args.rbound, suffix=args.suffix) 120 | pruner.compress() 121 | -------------------------------------------------------------------------------- /RM+AMC/amc_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import os 6 | import time 7 | import argparse 8 | import shutil 9 | import math 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from nni.algorithms.compression.pytorch.pruning.amc.lib.net_measure import measure_model 17 | from nni.algorithms.compression.pytorch.pruning.amc.lib.utils import get_output_folder 18 | from nni.compression.pytorch import ModelSpeedup 19 | 20 | from data import get_dataset 21 | from utils import AverageMeter, accuracy, progress_bar 22 | 23 | from models.rmnet import * 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='AMC train / fine-tune script') 27 | parser.add_argument('--model_type', default='mobilenetv2', type=str, 28 | choices=['mobilenetv2', 'mobilenetv1'], 29 | help='name of the model to train') 30 | parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train') 31 | parser.add_argument('--lr', default=0.05, type=float, help='learning rate') 32 | parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use') 33 | parser.add_argument('--batch_size', default=256, type=int, help='batch size') 34 | parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker') 35 | parser.add_argument('--lr_type', default='cos', type=str, help='lr scheduler (exp/cos/step3/fixed)') 36 | parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train') 37 | parser.add_argument('--wd', default=4e-5, type=float, help='weight decay') 38 | parser.add_argument('--seed', default=None, type=int, help='random seed to set') 39 | parser.add_argument('--data_root', default='/dev/shm', type=str, help='dataset path') 40 | # resume 41 | parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune') 42 | parser.add_argument('--mask_path', default=None, type=str, help='mask path for speedup') 43 | 44 | # run eval 45 | parser.add_argument('--eval', action='store_true', help='Simply run eval') 46 | parser.add_argument('--calc_flops', action='store_true', help='Calculate flops') 47 | 48 | return parser.parse_args() 49 | 50 | def get_model(args): 51 | print('=> Building model..') 52 | 53 | if args.dataset == 'imagenet': 54 | n_class = 1000 55 | elif args.dataset == 'cifar10': 56 | n_class = 10 57 | else: 58 | raise NotImplementedError 59 | 60 | # the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py 61 | print('=> Loading checkpoint {} ..'.format(args.ckpt_path)) 62 | net = ckpt_to_model(args.ckpt_path,args.model_type,n_class) 63 | if args.mask_path is not None: 64 | SZ = 224 if args.dataset == 'imagenet' else 32 65 | data = torch.randn(2, 3, SZ, SZ) 66 | ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu')) 67 | ms.speedup_model() 68 | 69 | net.to(args.device) 70 | if torch.cuda.is_available() and args.n_gpu > 1: 71 | net = torch.nn.DataParallel(net, list(range(args.n_gpu))) 72 | print(net) 73 | return net 74 | 75 | def train(epoch, train_loader, device): 76 | print('\nEpoch: %d' % epoch) 77 | net.train() 78 | 79 | batch_time = AverageMeter() 80 | losses = AverageMeter() 81 | top1 = AverageMeter() 82 | top5 = AverageMeter() 83 | end = time.time() 84 | 85 | for batch_idx, (inputs, targets) in enumerate(train_loader): 86 | inputs, targets = inputs.to(device), targets.to(device) 87 | optimizer.zero_grad() 88 | outputs = net(inputs) 89 | loss = criterion(outputs, targets) 90 | 91 | loss.backward() 92 | optimizer.step() 93 | 94 | # measure accuracy and record loss 95 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 96 | losses.update(loss.item(), inputs.size(0)) 97 | top1.update(prec1.item(), inputs.size(0)) 98 | top5.update(prec5.item(), inputs.size(0)) 99 | # timing 100 | batch_time.update(time.time() - end) 101 | end = time.time() 102 | 103 | progress_bar(batch_idx, len(train_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%' 104 | .format(losses.avg, top1.avg, top5.avg)) 105 | writer.add_scalar('loss/train', losses.avg, epoch) 106 | writer.add_scalar('acc/train_top1', top1.avg, epoch) 107 | writer.add_scalar('acc/train_top5', top5.avg, epoch) 108 | 109 | def test(epoch, test_loader, device, save=True): 110 | global best_acc 111 | net.eval() 112 | 113 | batch_time = AverageMeter() 114 | losses = AverageMeter() 115 | top1 = AverageMeter() 116 | top5 = AverageMeter() 117 | end = time.time() 118 | 119 | with torch.no_grad(): 120 | for batch_idx, (inputs, targets) in enumerate(test_loader): 121 | inputs, targets = inputs.to(device), targets.to(device) 122 | outputs = net(inputs) 123 | loss = criterion(outputs, targets) 124 | 125 | # measure accuracy and record loss 126 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 127 | losses.update(loss.item(), inputs.size(0)) 128 | top1.update(prec1.item(), inputs.size(0)) 129 | top5.update(prec5.item(), inputs.size(0)) 130 | # timing 131 | batch_time.update(time.time() - end) 132 | end = time.time() 133 | 134 | progress_bar(batch_idx, len(test_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%' 135 | .format(losses.avg, top1.avg, top5.avg)) 136 | 137 | if save: 138 | writer.add_scalar('loss/test', losses.avg, epoch) 139 | writer.add_scalar('acc/test_top1', top1.avg, epoch) 140 | writer.add_scalar('acc/test_top5', top5.avg, epoch) 141 | 142 | is_best = False 143 | if top1.avg > best_acc: 144 | best_acc = top1.avg 145 | is_best = True 146 | 147 | print('Current best acc: {}'.format(best_acc)) 148 | save_checkpoint({ 149 | 'epoch': epoch, 150 | 'model': args.model_type, 151 | 'dataset': args.dataset, 152 | 'state_dict': net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict(), 153 | 'acc': top1.avg, 154 | 'optimizer': optimizer.state_dict(), 155 | }, is_best, checkpoint_dir=log_dir) 156 | 157 | def adjust_learning_rate(optimizer, epoch): 158 | if args.lr_type == 'cos': # cos without warm-up 159 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.n_epoch)) 160 | elif args.lr_type == 'exp': 161 | step = 1 162 | decay = 0.96 163 | lr = args.lr * (decay ** (epoch // step)) 164 | elif args.lr_type == 'fixed': 165 | lr = args.lr 166 | else: 167 | raise NotImplementedError 168 | print('=> lr: {}'.format(lr)) 169 | for param_group in optimizer.param_groups: 170 | param_group['lr'] = lr 171 | return lr 172 | 173 | def save_checkpoint(state, is_best, checkpoint_dir='.'): 174 | filename = os.path.join(checkpoint_dir, 'ckpt.pth') 175 | print('=> Saving checkpoint to {}'.format(filename)) 176 | torch.save(state, filename) 177 | if is_best: 178 | shutil.copyfile(filename, filename.replace('.pth', '.best.pth')) 179 | 180 | if __name__ == '__main__': 181 | args = parse_args() 182 | 183 | if torch.cuda.is_available(): 184 | torch.backends.cudnn.benchmark = True 185 | args.device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu') 186 | 187 | best_acc = 0 # best test accuracy 188 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 189 | 190 | if args.seed is not None: 191 | np.random.seed(args.seed) 192 | torch.manual_seed(args.seed) 193 | torch.cuda.manual_seed(args.seed) 194 | 195 | print('=> Preparing data..') 196 | train_loader, val_loader, n_class = get_dataset(args.dataset, args.batch_size, args.n_worker, 197 | data_root=args.data_root) 198 | 199 | net = get_model(args) # for measure 200 | 201 | if args.calc_flops: 202 | IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32 203 | n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE, args.device) 204 | print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6)) 205 | exit(0) 206 | 207 | criterion = nn.CrossEntropyLoss() 208 | print('Using SGD...') 209 | print('weight decay = {}'.format(args.wd)) 210 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd) 211 | 212 | if args.eval: # just run eval 213 | print('=> Start evaluation...') 214 | test(0, val_loader, args.device, save=False) 215 | else: # train 216 | print('=> Start training...') 217 | print('Training {} on {}...'.format(args.model_type, args.dataset)) 218 | train_type = 'train' if args.ckpt_path is None else 'finetune' 219 | log_dir = get_output_folder('./logs', '{}_{}_{}'.format(args.model_type, args.dataset, train_type)) 220 | print('=> Saving logs to {}'.format(log_dir)) 221 | # tf writer 222 | writer = SummaryWriter(logdir=log_dir) 223 | 224 | for epoch in range(start_epoch, start_epoch + args.n_epoch): 225 | lr = adjust_learning_rate(optimizer, epoch) 226 | train(epoch, train_loader, args.device) 227 | test(epoch, val_loader, args.device) 228 | 229 | writer.close() 230 | print('=> Best top-1 acc: {}%'.format(best_acc)) 231 | -------------------------------------------------------------------------------- /RM+AMC/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import csv 5 | import json 6 | import time 7 | import logging 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | from collections import OrderedDict 12 | from contextlib import suppress 13 | from functools import partial 14 | from timm.models import is_model, list_models 15 | from timm.data import resolve_data_config 16 | from timm.utils import setup_default_logging 17 | from models.rmnet import * 18 | import thop 19 | 20 | has_apex = False 21 | try: 22 | from apex import amp 23 | has_apex = True 24 | except ImportError: 25 | pass 26 | 27 | has_native_amp = False 28 | try: 29 | if getattr(torch.cuda.amp, 'autocast') is not None: 30 | has_native_amp = True 31 | except AttributeError: 32 | pass 33 | 34 | torch.backends.cudnn.benchmark = True 35 | _logger = logging.getLogger('validate') 36 | 37 | parser = argparse.ArgumentParser(description='PyTorch Benchmark') 38 | 39 | # benchmark specific args 40 | parser.add_argument('--model-list', 41 | metavar='NAME', 42 | default='', 43 | help='txt file based list of model names to benchmark') 44 | parser.add_argument( 45 | '--bench', 46 | default='both', 47 | type=str, 48 | help= 49 | "Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'" 50 | ) 51 | parser.add_argument( 52 | '--detail', 53 | action='store_true', 54 | default=False, 55 | help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False' 56 | ) 57 | parser.add_argument('--results-file', 58 | default='', 59 | type=str, 60 | metavar='FILENAME', 61 | help='Output csv file for validation results (summary)') 62 | parser.add_argument('--num-warm-iter', 63 | default=50, 64 | type=int, 65 | metavar='N', 66 | help='Number of warmup iterations (default: 10)') 67 | parser.add_argument('--num-bench-iter', 68 | default=50, 69 | type=int, 70 | metavar='N', 71 | help='Number of benchmark iterations (default: 40)') 72 | 73 | # common inference / train args 74 | parser.add_argument('--model', 75 | '-m', 76 | metavar='NAME', 77 | default='resnet50', 78 | help='model architecture (default: resnet50)') 79 | parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune') 80 | parser.add_argument('-b', 81 | '--batch-size', 82 | default=128, 83 | type=int, 84 | metavar='N', 85 | help='mini-batch size (default: 256)') 86 | parser.add_argument('--img-size', 87 | default=None, 88 | type=int, 89 | metavar='N', 90 | help='Input image dimension, uses model default if empty') 91 | parser.add_argument( 92 | '--input-size', 93 | default=(3,32,32), 94 | nargs=3, 95 | type=int, 96 | metavar='N N N', 97 | help= 98 | 'Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty' 99 | ) 100 | parser.add_argument('--num-classes', 101 | type=int, 102 | default=10, 103 | help='Number classes in dataset') 104 | parser.add_argument( 105 | '--gp', 106 | default=None, 107 | type=str, 108 | metavar='POOL', 109 | help= 110 | 'Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.' 111 | ) 112 | parser.add_argument('--channels-last', 113 | action='store_true', 114 | default=False, 115 | help='Use channels_last memory layout') 116 | parser.add_argument( 117 | '--amp', 118 | action='store_true', 119 | default=False, 120 | help= 121 | 'use PyTorch Native AMP for mixed precision training. Overrides --precision arg.' 122 | ) 123 | parser.add_argument( 124 | '--precision', 125 | default='float32', 126 | type=str, 127 | help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') 128 | parser.add_argument('--torchscript', 129 | dest='torchscript', 130 | action='store_true', 131 | help='convert model torchscript for inference') 132 | 133 | 134 | def timestamp(sync=False): 135 | return time.perf_counter() 136 | 137 | 138 | def cuda_timestamp(sync=False, device=None): 139 | if sync: 140 | torch.cuda.synchronize(device=device) 141 | return time.perf_counter() 142 | 143 | 144 | def count_params(model: nn.Module): 145 | return sum([m.numel() for m in model.parameters()]) 146 | 147 | 148 | def resolve_precision(precision: str): 149 | assert precision in ('amp', 'float16', 'bfloat16', 'float32') 150 | use_amp = False 151 | model_dtype = torch.float32 152 | data_dtype = torch.float32 153 | if precision == 'amp': 154 | use_amp = True 155 | elif precision == 'float16': 156 | model_dtype = torch.float16 157 | data_dtype = torch.float16 158 | elif precision == 'bfloat16': 159 | model_dtype = torch.bfloat16 160 | data_dtype = torch.bfloat16 161 | return use_amp, model_dtype, data_dtype 162 | 163 | 164 | class BenchmarkRunner: 165 | 166 | def __init__(self, 167 | model_name, 168 | detail=False, 169 | device='cuda', 170 | torchscript=False, 171 | precision='float32', 172 | num_warm_iter=10, 173 | num_bench_iter=50, 174 | **kwargs): 175 | self.model_name = model_name 176 | self.detail = detail 177 | self.device = device 178 | self.use_amp, self.model_dtype, self.data_dtype = resolve_precision( 179 | precision) 180 | self.channels_last = kwargs.pop('channels_last', False) 181 | self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress 182 | self.num_classes = kwargs.pop('num_classes') 183 | self.ckpt_path = kwargs.pop('ckpt_path') 184 | self.model = ckpt_to_model(self.ckpt_path,self.model_name,int(self.num_classes)) 185 | #self.model = torch.load(model_name) 186 | print(self.model) 187 | 188 | 189 | self.model.to( 190 | device=self.device, 191 | dtype=self.model_dtype, 192 | memory_format=torch.channels_last if self.channels_last else None) 193 | self.param_count = count_params(self.model) 194 | _logger.info('Model %s created, param count: %d' % 195 | (model_name, self.param_count)) 196 | if torchscript: 197 | self.model = torch.jit.script(self.model) 198 | 199 | data_config = resolve_data_config(kwargs, 200 | model=self.model, 201 | use_test_size=True) 202 | self.input_size = data_config['input_size'] 203 | self.batch_size = kwargs.pop('batch_size', 256) 204 | print(thop.profile(self.model,(torch.randn(1,3,self.input_size[1],self.input_size[2]).to(self.device),))) 205 | 206 | self.example_inputs = None 207 | self.num_warm_iter = num_warm_iter 208 | self.num_bench_iter = num_bench_iter 209 | self.log_freq = num_bench_iter // 5 210 | if 'cuda' in self.device: 211 | self.time_fn = partial(cuda_timestamp, device=self.device) 212 | else: 213 | self.time_fn = timestamp 214 | 215 | def _init_input(self): 216 | self.example_inputs = torch.randn((self.batch_size,) + self.input_size, 217 | device=self.device, 218 | dtype=self.data_dtype) 219 | if self.channels_last: 220 | self.example_inputs = self.example_inputs.contiguous( 221 | memory_format=torch.channels_last) 222 | 223 | 224 | class InferenceBenchmarkRunner(BenchmarkRunner): 225 | 226 | def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): 227 | super().__init__(model_name=model_name, 228 | device=device, 229 | torchscript=torchscript, 230 | **kwargs) 231 | self.model.eval() 232 | 233 | def run(self): 234 | 235 | def _step(): 236 | t_step_start = self.time_fn() 237 | with self.amp_autocast(): 238 | output = self.model(self.example_inputs) 239 | t_step_end = self.time_fn(True) 240 | return t_step_end - t_step_start 241 | 242 | _logger.info( 243 | f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ ' 244 | f'input size {self.input_size} and batch size {self.batch_size}.') 245 | 246 | with torch.no_grad(): 247 | self._init_input() 248 | 249 | for _ in range(self.num_warm_iter): 250 | _step() 251 | 252 | total_step = 0. 253 | num_samples = 0 254 | t_run_start = self.time_fn() 255 | for i in range(self.num_bench_iter): 256 | delta_fwd = _step() 257 | total_step += delta_fwd 258 | num_samples += self.batch_size 259 | num_steps = i + 1 260 | if num_steps % self.log_freq == 0: 261 | _logger.info(f"Infer [{num_steps}/{self.num_bench_iter}]." 262 | f" {num_samples / total_step:0.2f} samples/sec." 263 | f" {1000 * total_step / num_steps:0.3f} ms/step.") 264 | t_run_end = self.time_fn(True) 265 | t_run_elapsed = t_run_end - t_run_start 266 | 267 | results = dict( 268 | samples_per_sec=round(num_samples / t_run_elapsed, 2), 269 | step_time=round(1000 * total_step / self.num_bench_iter, 3), 270 | batch_size=self.batch_size, 271 | img_size=self.input_size[-1], 272 | param_count=round(self.param_count / 1e6, 2), 273 | ) 274 | 275 | _logger.info( 276 | f"Inference benchmark of {self.model_name} done. " 277 | f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step" 278 | ) 279 | 280 | return results 281 | 282 | 283 | def decay_batch_exp(batch_size, factor=0.5, divisor=16): 284 | out_batch_size = batch_size * factor 285 | if out_batch_size > divisor: 286 | out_batch_size = (out_batch_size + 1) // divisor * divisor 287 | else: 288 | out_batch_size = batch_size - 1 289 | return max(0, int(out_batch_size)) 290 | 291 | 292 | def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): 293 | batch_size = initial_batch_size 294 | results = dict() 295 | while batch_size >= 1: 296 | try: 297 | bench = bench_fn(model_name=model_name, 298 | batch_size=batch_size, 299 | **bench_kwargs) 300 | results = bench.run() 301 | return results 302 | except RuntimeError as e: 303 | torch.cuda.empty_cache() 304 | batch_size = decay_batch_exp(batch_size) 305 | print( 306 | f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.' 307 | ) 308 | return results 309 | 310 | 311 | def benchmark(args): 312 | if args.amp: 313 | _logger.warning("Overriding precision to 'amp' since --amp flag set.") 314 | args.precision = 'amp' 315 | _logger.info(f'Benchmarking in {args.precision} precision. ' 316 | f'{"NHWC" if args.channels_last else "NCHW"} layout. ' 317 | f'torchscript {"enabled" if args.torchscript else "disabled"}') 318 | 319 | bench_kwargs = vars(args).copy() 320 | bench_kwargs.pop('amp') 321 | model = bench_kwargs.pop('model') 322 | batch_size = bench_kwargs.pop('batch_size') 323 | 324 | bench_fns = (InferenceBenchmarkRunner,) 325 | prefixes = ('infer',) 326 | model_results = OrderedDict(model=model) 327 | for prefix, bench_fn in zip(prefixes, bench_fns): 328 | run_results = _try_run(model, 329 | bench_fn, 330 | initial_batch_size=batch_size, 331 | bench_kwargs=bench_kwargs) 332 | if prefix: 333 | run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} 334 | model_results.update(run_results) 335 | param_count = model_results.pop('infer_param_count', 336 | model_results.pop('train_param_count', 0)) 337 | model_results.setdefault('param_count', param_count) 338 | model_results.pop('train_param_count', 0) 339 | return model_results 340 | 341 | 342 | def main(): 343 | setup_default_logging() 344 | args = parser.parse_args() 345 | model_cfgs = [] 346 | model_names = [] 347 | 348 | if args.model_list: 349 | args.model = '' 350 | with open(args.model_list) as f: 351 | model_names = [line.rstrip() for line in f] 352 | model_cfgs = [(n, None) for n in model_names] 353 | elif args.model == 'all': 354 | # validate all models in a list of names with pretrained checkpoints 355 | args.pretrained = True 356 | model_names = list_models(pretrained=True, exclude_filters=['*in21k']) 357 | model_cfgs = [(n, None) for n in model_names] 358 | elif not is_model(args.model): 359 | # model name doesn't exist, try as wildcard filter 360 | model_names = list_models(args.model) 361 | model_cfgs = [(n, None) for n in model_names] 362 | 363 | if len(model_cfgs): 364 | results_file = args.results_file or './benchmark.csv' 365 | _logger.info( 366 | 'Running bulk validation on these pretrained models: {}'.format( 367 | ', '.join(model_names))) 368 | results = [] 369 | try: 370 | for m, _ in model_cfgs: 371 | if not m: 372 | continue 373 | args.model = m 374 | r = benchmark(args) 375 | results.append(r) 376 | except KeyboardInterrupt as e: 377 | pass 378 | sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec' 379 | results = sorted(results, key=lambda x: x[sort_key], reverse=True) 380 | if len(results): 381 | write_results(results_file, results) 382 | 383 | import json 384 | json_str = json.dumps(results, indent=4) 385 | print(json_str) 386 | else: 387 | benchmark(args) 388 | 389 | 390 | def write_results(results_file, results): 391 | with open(results_file, mode='w') as cf: 392 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 393 | dw.writeheader() 394 | for r in results: 395 | dw.writerow(r) 396 | cf.flush() 397 | 398 | 399 | if __name__ == '__main__': 400 | main() 401 | -------------------------------------------------------------------------------- /RM+AMC/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | import numpy as np 13 | 14 | import os 15 | 16 | 17 | def get_dataset(dset_name, batch_size, n_worker, data_root='../../data'): 18 | cifar_tran_train = [ 19 | transforms.RandomCrop(32, padding=4), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 23 | ] 24 | cifar_tran_test = [ 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 27 | ] 28 | print('=> Preparing data..') 29 | if dset_name == 'cifar10': 30 | transform_train = transforms.Compose(cifar_tran_train) 31 | transform_test = transforms.Compose(cifar_tran_test) 32 | trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train) 33 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, 34 | num_workers=n_worker, pin_memory=True, sampler=None) 35 | testset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test) 36 | val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, 37 | num_workers=n_worker, pin_memory=True) 38 | n_class = 10 39 | elif dset_name == 'imagenet': 40 | # get dir 41 | traindir = os.path.join(data_root, 'train') 42 | valdir = os.path.join(data_root, 'val') 43 | 44 | # preprocessing 45 | input_size = 224 46 | imagenet_tran_train = [ 47 | transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0)), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 51 | ] 52 | imagenet_tran_test = [ 53 | transforms.Resize(int(input_size / 0.875)), 54 | transforms.CenterCrop(input_size), 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 57 | ] 58 | 59 | train_loader = torch.utils.data.DataLoader( 60 | datasets.ImageFolder(traindir, transforms.Compose(imagenet_tran_train)), 61 | batch_size=batch_size, shuffle=True, 62 | num_workers=n_worker, pin_memory=True, sampler=None) 63 | 64 | val_loader = torch.utils.data.DataLoader( 65 | datasets.ImageFolder(valdir, transforms.Compose(imagenet_tran_test)), 66 | batch_size=batch_size, shuffle=False, 67 | num_workers=n_worker, pin_memory=True) 68 | n_class = 1000 69 | 70 | else: 71 | raise NotImplementedError 72 | 73 | return train_loader, val_loader, n_class 74 | 75 | 76 | def get_split_dataset(dset_name, batch_size, n_worker, val_size, data_root='../data', shuffle=True): 77 | ''' 78 | split the train set into train / val for rl search 79 | ''' 80 | if shuffle: 81 | index_sampler = SubsetRandomSampler 82 | else: # every time we use the same order for the split subset 83 | class SubsetSequentialSampler(SubsetRandomSampler): 84 | def __iter__(self): 85 | return (self.indices[i] for i in torch.arange(len(self.indices)).int()) 86 | index_sampler = SubsetSequentialSampler 87 | 88 | print('=> Preparing data: {}...'.format(dset_name)) 89 | if dset_name == 'cifar10': 90 | transform_train = transforms.Compose([ 91 | transforms.RandomCrop(32, padding=4), 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 95 | ]) 96 | transform_test = transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 99 | ]) 100 | trainset = torchvision.datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform_train) 101 | valset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_test) 102 | n_train = len(trainset) 103 | indices = list(range(n_train)) 104 | # now shuffle the indices 105 | #np.random.shuffle(indices) 106 | assert val_size < n_train 107 | train_idx, val_idx = indices[val_size:], indices[:val_size] 108 | 109 | train_sampler = index_sampler(train_idx) 110 | val_sampler = index_sampler(val_idx) 111 | 112 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, sampler=train_sampler, 113 | num_workers=n_worker, pin_memory=True) 114 | val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, sampler=val_sampler, 115 | num_workers=n_worker, pin_memory=True) 116 | n_class = 10 117 | elif dset_name == 'imagenet': 118 | train_dir = os.path.join(data_root, 'train') 119 | val_dir = os.path.join(data_root, 'val') 120 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 121 | std=[0.229, 0.224, 0.225]) 122 | input_size = 224 123 | train_transform = transforms.Compose([ 124 | transforms.RandomResizedCrop(input_size), 125 | transforms.RandomHorizontalFlip(), 126 | transforms.ToTensor(), 127 | normalize, 128 | ]) 129 | test_transform = transforms.Compose([ 130 | transforms.Resize(int(input_size/0.875)), 131 | transforms.CenterCrop(input_size), 132 | transforms.ToTensor(), 133 | normalize, 134 | ]) 135 | 136 | trainset = datasets.ImageFolder(train_dir, train_transform) 137 | valset = datasets.ImageFolder(train_dir, test_transform) 138 | n_train = len(trainset) 139 | indices = list(range(n_train)) 140 | np.random.shuffle(indices) 141 | assert val_size < n_train 142 | train_idx, val_idx = indices[val_size:], indices[:val_size] 143 | 144 | train_sampler = index_sampler(train_idx) 145 | val_sampler = index_sampler(val_idx) 146 | 147 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, 148 | num_workers=n_worker, pin_memory=True) 149 | val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler, 150 | num_workers=n_worker, pin_memory=True) 151 | 152 | n_class = 1000 153 | else: 154 | raise NotImplementedError 155 | 156 | return train_loader, val_loader, n_class 157 | -------------------------------------------------------------------------------- /RM+AMC/models/__pycache__/rmnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/RM+AMC/models/__pycache__/rmnet.cpython-36.pyc -------------------------------------------------------------------------------- /RM+AMC/models/rmnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | 7 | def conv_bn(inp, oup, stride): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 10 | nn.BatchNorm2d(oup), 11 | nn.ReLU6(inplace=True) 12 | ) 13 | 14 | 15 | def conv_1x1_bn(inp, oup): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU6(inplace=True) 20 | ) 21 | 22 | 23 | class InvertedResidual(nn.Module): 24 | def __init__(self, inp, oup, stride, expand_ratio): 25 | super(InvertedResidual, self).__init__() 26 | self.stride = stride 27 | assert stride in [1, 2] 28 | 29 | hidden_dim = round(inp * expand_ratio) 30 | self.use_res_connect = self.stride == 1 and inp == oup 31 | 32 | if expand_ratio == 1: 33 | self.conv = nn.Sequential( 34 | # dw 35 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 36 | nn.BatchNorm2d(hidden_dim), 37 | nn.ReLU6(inplace=True), 38 | # pw-linear 39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 40 | nn.BatchNorm2d(oup), 41 | ) 42 | else: 43 | self.conv = nn.Sequential( 44 | # pw 45 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 46 | nn.BatchNorm2d(hidden_dim), 47 | nn.ReLU6(inplace=True), 48 | # dw 49 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # pw-linear 53 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 54 | nn.BatchNorm2d(oup), 55 | ) 56 | 57 | 58 | def forward(self, x): 59 | if self.use_res_connect: 60 | return x + self.conv(x) 61 | else: 62 | return self.conv(x) 63 | 64 | 65 | class MobileNetV2(nn.Module): 66 | def __init__(self, 67 | n_class=1000, 68 | input_size=224, 69 | width_mult=1, 70 | input_channel = 32, 71 | last_channel = 1280, 72 | interverted_residual_setting = [ 73 | # t, c, n, s 74 | [1, 16, 1, 1], 75 | [6, 24, 2, 2], 76 | [6, 32, 3, 2], 77 | [6, 64, 4, 2], 78 | [6, 96, 3, 1], 79 | [6, 160, 3, 2], 80 | [6, 320, 1, 1], 81 | ]): 82 | super(MobileNetV2, self).__init__() 83 | block = InvertedResidual 84 | 85 | # building first layer 86 | assert input_size % 32 == 0 87 | input_channel = int(input_channel * width_mult) 88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 89 | self.features = [conv_bn(3, input_channel, 2 if n_class==1000 else 1)] 90 | # building inverted residual blocks 91 | for t, c, n, s in interverted_residual_setting: 92 | output_channel = int(c * width_mult) 93 | for i in range(n): 94 | if i == 0: 95 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 96 | else: 97 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 98 | input_channel = output_channel 99 | # building last several layers 100 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 101 | # make it nn.Sequential 102 | self.features = nn.Sequential(*self.features) 103 | 104 | # building classifier 105 | self.classifier = nn.Sequential( 106 | nn.Dropout(0.2), 107 | nn.Linear(self.last_channel, n_class), 108 | ) 109 | 110 | self._initialize_weights() 111 | 112 | def forward(self, x): 113 | x = self.features(x) 114 | # it's same with .mean(3).mean(2), but 115 | # speedup only suport the mean option 116 | # whose output only have two dimensions 117 | x = x.mean([2, 3]) 118 | x = self.classifier(x) 119 | return x 120 | 121 | def _initialize_weights(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2. / n)) 126 | if m.bias is not None: 127 | m.bias.data.zero_() 128 | elif isinstance(m, nn.BatchNorm2d): 129 | if m.weight is not None: 130 | m.weight.data.fill_(1) 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | elif isinstance(m, nn.Linear): 134 | n = m.weight.size(1) 135 | m.weight.data.normal_(0, 0.01) 136 | m.bias.data.zero_() 137 | 138 | def ckpt_to_mobilenetv2(ckpt,n_class=1000): 139 | interverted_residual_setting=[ 140 | [1, 16, 1, 1], 141 | [6, 24, 1, 2], 142 | [6, 24, 1, 1], 143 | [6, 32, 1, 2], 144 | [6, 32, 1, 1], 145 | [6, 32, 1, 1], 146 | [6, 64, 1, 2], 147 | [6, 64, 1, 1], 148 | [6, 64, 1, 1], 149 | [6, 64, 1, 1], 150 | [6, 96, 1, 1], 151 | [6, 96, 1, 1], 152 | [6, 96, 1, 1], 153 | [6, 160, 1, 2], 154 | [6, 160, 1, 1], 155 | [6, 160, 1, 1], 156 | [6, 320, 1, 1]] 157 | input_channel=ckpt['features.1.conv.4.weight'].shape[0] 158 | interverted_residual_setting[0][1]=input_channel 159 | for i in range(2,18): 160 | interverted_residual_setting[i-1][0]=ckpt['features.%d.conv.6.weight'%i].shape[1]/input_channel 161 | input_channel=ckpt['features.%d.conv.6.weight'%i].shape[0] 162 | interverted_residual_setting[i-1][1]=input_channel 163 | input_channel=ckpt['features.0.1.weight'].shape[0] 164 | last_channel=ckpt['features.18.1.weight'].shape[0] 165 | model=MobileNetV2(n_class=n_class,input_channel=input_channel,interverted_residual_setting=interverted_residual_setting,last_channel=last_channel) 166 | model.load_state_dict(ckpt) 167 | return model 168 | 169 | def ckpt_to_mobilenetv1(ckpt,n_class=1000): 170 | model=mobilenetv2_to_mobilenetv1(MobileNetV2(n_class=n_class)) 171 | channels=[] 172 | for k,v in ckpt.items(): 173 | if len(v.shape)==4: 174 | channels.append(v.shape[0]) 175 | in_channels=3 176 | features=[] 177 | for m in model.features: 178 | if isinstance(m,nn.Conv2d): 179 | out_channels=channels[0] 180 | features.append(nn.Conv2d(in_channels,out_channels,kernel_size=m.kernel_size,stride=m.stride,padding=m.padding,groups=in_channels if m.groups!=1 else 1,bias=m.bias)) 181 | channels.pop(0) 182 | in_channels=out_channels 183 | elif isinstance(m,nn.BatchNorm2d): 184 | features.append(nn.BatchNorm2d(in_channels)) 185 | elif isinstance(m,nn.ReLU6): 186 | features.append(nn.ReLU6(inplace=True)) 187 | elif isinstance(m,nn.PReLU): 188 | features.append(nn.PReLU(in_channels)) 189 | else: 190 | print(m) 191 | model.features=nn.Sequential(*features) 192 | model.classifier[1]=nn.Linear(out_channels,n_class) 193 | model.load_state_dict(ckpt) 194 | return model 195 | 196 | def ckpt_to_model(ckpt_path,dist_model_type,n_class=1000): 197 | if ckpt_path is None: 198 | model = MobileNetV2(n_class=n_class) 199 | if dist_model_type == 'mobilenetv1': 200 | model = mobilenetv2_to_mobilenetv1(model) 201 | return model 202 | else: 203 | ckpt = torch.load(ckpt_path,map_location='cpu') 204 | if 'mobilenetv1' in ckpt_path: 205 | model_type='mobilenetv1' 206 | else: 207 | model_type='mobilenetv2' 208 | if 'state_dict' in ckpt.keys(): 209 | ckpt=ckpt['state_dict'] 210 | ckpt = {k.replace('module.',''):v for k,v in ckpt.items()} 211 | if model_type == 'mobilenetv1': 212 | return ckpt_to_mobilenetv1(ckpt,n_class) 213 | elif model_type == 'mobilenetv2': 214 | model = ckpt_to_mobilenetv2(ckpt,n_class) 215 | if dist_model_type == 'mobilenetv1': 216 | model = mobilenetv2_to_mobilenetv1(model) 217 | return model 218 | else: 219 | raise NotImplementedError 220 | 221 | def fuse_cbcb(conv1,bn1,conv2,bn2): 222 | inp=conv1.in_channels 223 | mid=conv1.out_channels 224 | oup=conv2.out_channels 225 | conv1=torch.nn.utils.fuse_conv_bn_eval(conv1.eval(),bn1.eval()) 226 | fused_conv=nn.Conv2d(inp,oup,1,bias=False) 227 | fused_conv.weight.data=(conv2.weight.data.view(oup,mid)@conv1.weight.data.view(mid,-1)).view(oup,inp,1,1) 228 | bn2.running_mean-=conv2.weight.data.view(oup,mid)@conv1.bias.data 229 | return fused_conv,bn2 230 | 231 | def fuse_cb(conv_w, bn_rm, bn_rv,bn_w,bn_b, eps): 232 | bn_var_rsqrt = torch.rsqrt(bn_rv + eps) 233 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 234 | conv_b = bn_rm * bn_var_rsqrt * bn_w-bn_b 235 | return conv_w,conv_b 236 | 237 | def rm_r(model): 238 | inp = model.conv[0].in_channels 239 | mid = inp+model.conv[0].out_channels 240 | oup = model.conv[6].out_channels 241 | 242 | running1 = nn.BatchNorm2d(inp,affine=False) 243 | running2 = nn.BatchNorm2d(oup,affine=False) 244 | 245 | idconv1 = nn.Conv2d(inp, mid, kernel_size=1, bias=False).eval() 246 | idbn1=nn.BatchNorm2d(mid).eval() 247 | 248 | nn.init.dirac_(idconv1.weight.data[:inp]) 249 | bn_var_sqrt=torch.sqrt(running1.running_var + running1.eps) 250 | idbn1.weight.data[:inp]=bn_var_sqrt 251 | idbn1.bias.data[:inp]=running1.running_mean 252 | idbn1.running_mean.data[:inp]=running1.running_mean 253 | idbn1.running_var.data[:inp]=running1.running_var 254 | 255 | idconv1.weight.data[inp:]=model.conv[0].weight.data 256 | idbn1.weight.data[inp:]=model.conv[1].weight.data 257 | idbn1.bias.data[inp:]=model.conv[1].bias.data 258 | idbn1.running_mean.data[inp:]=model.conv[1].running_mean 259 | idbn1.running_var.data[inp:]=model.conv[1].running_var 260 | idrelu1 = nn.PReLU(mid) 261 | torch.nn.init.ones_(idrelu1.weight.data[:inp]) 262 | torch.nn.init.zeros_(idrelu1.weight.data[inp:]) 263 | 264 | idconv2 = nn.Conv2d(mid, mid, kernel_size=3, stride=model.stride, padding=1,groups=mid, bias=False).eval() 265 | idbn2=nn.BatchNorm2d(mid).eval() 266 | 267 | nn.init.dirac_(idconv2.weight.data[:inp],groups=inp) 268 | idbn2.weight.data[:inp]=idbn1.weight.data[:inp] 269 | idbn2.bias.data[:inp]=idbn1.bias.data[:inp] 270 | idbn2.running_mean.data[:inp]=idbn1.running_mean.data[:inp] 271 | idbn2.running_var.data[:inp]=idbn1.running_var.data[:inp] 272 | 273 | idconv2.weight.data[inp:]=model.conv[3].weight.data 274 | idbn2.weight.data[inp:]=model.conv[4].weight.data 275 | idbn2.bias.data[inp:]=model.conv[4].bias.data 276 | idbn2.running_mean.data[inp:]=model.conv[4].running_mean 277 | idbn2.running_var.data[inp:]=model.conv[4].running_var 278 | idrelu2 = nn.PReLU(mid) 279 | torch.nn.init.ones_(idrelu2.weight.data[:inp]) 280 | torch.nn.init.zeros_(idrelu2.weight.data[inp:]) 281 | 282 | idconv3 = nn.Conv2d(mid, oup, kernel_size=1, bias=False).eval() 283 | idbn3=nn.BatchNorm2d(oup).eval() 284 | 285 | nn.init.dirac_(idconv3.weight.data[:,:inp]) 286 | idconv3.weight.data[:,inp:],bias=fuse_cb(model.conv[6].weight,model.conv[7].running_mean,model.conv[7].running_var,model.conv[7].weight,model.conv[7].bias,model.conv[7].eps) 287 | bn_var_sqrt=torch.sqrt(running2.running_var + running2.eps) 288 | idbn3.weight.data=bn_var_sqrt 289 | idbn3.bias.data=running2.running_mean 290 | idbn3.running_mean.data=running2.running_mean+bias 291 | idbn3.running_var.data=running2.running_var 292 | return [idconv1,idbn1,idrelu1,idconv2,idbn2,idrelu2,idconv3,idbn3] 293 | 294 | def mobilenetv2_to_mobilenetv1(model): 295 | features=[] 296 | for m in model.features: 297 | if isinstance(m,InvertedResidual)and m.use_res_connect: 298 | features+=rm_r(m) 299 | else: 300 | for mm in m.modules(): 301 | if not list(mm.children()): 302 | features.append(mm) 303 | 304 | new_features=[] 305 | while features: 306 | if isinstance(features[0],nn.Conv2d) and isinstance(features[1],nn.BatchNorm2d) and isinstance(features[2],nn.Conv2d) and isinstance(features[3],nn.BatchNorm2d): 307 | conv,bn = fuse_cbcb(features[0],features[1],features[2],features[3]) 308 | new_features.append(conv) 309 | new_features.append(bn) 310 | features=features[4:] 311 | else: 312 | new_features.append(features.pop(0)) 313 | 314 | model.features=nn.Sequential(*new_features) 315 | return model 316 | -------------------------------------------------------------------------------- /RM+AMC/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import os 6 | import time 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | if self.count > 0: 24 | self.avg = self.sum / self.count 25 | 26 | def accumulate(self, val, n=1): 27 | self.sum += val 28 | self.count += n 29 | if self.count > 0: 30 | self.avg = self.sum / self.count 31 | 32 | 33 | def accuracy(output, target, topk=(1, 5)): 34 | """Computes the precision@k for the specified values of k""" 35 | batch_size = target.size(0) 36 | num = output.size(1) 37 | target_topk = [] 38 | appendices = [] 39 | for k in topk: 40 | if k <= num: 41 | target_topk.append(k) 42 | else: 43 | appendices.append([0.0]) 44 | topk = target_topk 45 | maxk = max(topk) 46 | _, pred = output.topk(maxk, 1, True, True) 47 | pred = pred.t() 48 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 49 | 50 | res = [] 51 | for k in topk: 52 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 53 | res.append(correct_k.mul_(100.0 / batch_size)) 54 | return res + appendices 55 | 56 | 57 | # Custom progress bar 58 | _, term_width = os.popen('stty size', 'r').read().split() 59 | term_width = int(term_width) 60 | TOTAL_BAR_LENGTH = 40. 61 | last_time = time.time() 62 | begin_time = last_time 63 | 64 | 65 | def progress_bar(current, total, msg=None): 66 | def format_time(seconds): 67 | days = int(seconds / 3600 / 24) 68 | seconds = seconds - days * 3600 * 24 69 | hours = int(seconds / 3600) 70 | seconds = seconds - hours * 3600 71 | minutes = int(seconds / 60) 72 | seconds = seconds - minutes * 60 73 | secondsf = int(seconds) 74 | seconds = seconds - secondsf 75 | millis = int(seconds * 1000) 76 | 77 | f = '' 78 | i = 1 79 | if days > 0: 80 | f += str(days) + 'D' 81 | i += 1 82 | if hours > 0 and i <= 2: 83 | f += str(hours) + 'h' 84 | i += 1 85 | if minutes > 0 and i <= 2: 86 | f += str(minutes) + 'm' 87 | i += 1 88 | if secondsf > 0 and i <= 2: 89 | f += str(secondsf) + 's' 90 | i += 1 91 | if millis > 0 and i <= 2: 92 | f += str(millis) + 'ms' 93 | i += 1 94 | if f == '': 95 | f = '0ms' 96 | return f 97 | 98 | global last_time, begin_time 99 | if current == 0: 100 | begin_time = time.time() # Reset for new bar. 101 | 102 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 103 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 104 | 105 | sys.stdout.write(' [') 106 | for i in range(cur_len): 107 | sys.stdout.write('=') 108 | sys.stdout.write('>') 109 | for i in range(rest_len): 110 | sys.stdout.write('.') 111 | sys.stdout.write(']') 112 | 113 | cur_time = time.time() 114 | step_time = cur_time - last_time 115 | last_time = cur_time 116 | tot_time = cur_time - begin_time 117 | 118 | L = [] 119 | L.append(' Step: %s' % format_time(step_time)) 120 | L.append(' | Tot: %s' % format_time(tot_time)) 121 | if msg: 122 | L.append(' | ' + msg) 123 | 124 | msg = ''.join(L) 125 | sys.stdout.write(msg) 126 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 127 | sys.stdout.write(' ') 128 | 129 | # Go back to the center of the bar. 130 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 131 | sys.stdout.write('\b') 132 | sys.stdout.write(' %d/%d ' % (current+1, total)) 133 | 134 | if current < total-1: 135 | sys.stdout.write('\r') 136 | else: 137 | sys.stdout.write('\n') 138 | sys.stdout.flush() 139 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet2vgg import * 2 | from .rmnext import * 3 | from .rmobilenet import * 4 | from .rmrepvgg import * 5 | from .rmnet_pruning import * 6 | -------------------------------------------------------------------------------- /models/easy_mb2_mb1.py: -------------------------------------------------------------------------------- 1 | # A simplify version for better understanding 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torchvision.models.mobilenetv2 import InvertedResidual, mobilenet_v2 7 | 8 | def rm_r_InvertedResidual(block): 9 | inp = block.conv[0][0].in_channels 10 | mid = inp+block.conv[0][0].out_channels 11 | oup = block.conv[2].out_channels 12 | 13 | #merge conv1 and bn1 14 | conv1=nn.utils.fuse_conv_bn_eval(block.conv[0][0],block.conv[0][1]) 15 | #new conv1 16 | idconv1 = nn.Conv2d(inp, mid, kernel_size=1).eval() 17 | idrelu1 = nn.PReLU(mid) 18 | #origional channels 19 | idconv1.weight.data[inp:]=conv1.weight.data 20 | idconv1.bias.data[inp:]=conv1.bias.data 21 | torch.nn.init.zeros_(idrelu1.weight.data[inp:]) 22 | #reserve input featuremaps with dirac initialized channels 23 | nn.init.dirac_(idconv1.weight.data[:inp]) 24 | nn.init.zeros_(idconv1.bias.data[:inp]) 25 | torch.nn.init.ones_(idrelu1.weight.data[:inp]) 26 | 27 | #merge conv2 and bn2 28 | conv2=nn.utils.fuse_conv_bn_eval(block.conv[1][0],block.conv[1][1]) 29 | #new conv2 30 | idconv2 = nn.Conv2d(mid, mid, kernel_size=3, stride=block.stride, padding=1,groups=mid).eval() 31 | idrelu2 = nn.PReLU(mid) 32 | #origional channels 33 | idconv2.weight.data[inp:]=conv2.weight.data 34 | idconv2.bias.data[inp:]=conv2.bias.data 35 | torch.nn.init.zeros_(idrelu2.weight.data[inp:]) 36 | #reserve input featuremaps with dirac initialized channels 37 | nn.init.dirac_(idconv2.weight.data[:inp],groups=inp) 38 | nn.init.zeros_(idconv2.bias.data[:inp]) 39 | torch.nn.init.ones_(idrelu2.weight.data[:inp]) 40 | 41 | #merge conv3 and bn3 42 | conv3=nn.utils.fuse_conv_bn_eval(block.conv[2],block.conv[3]) 43 | #new conv3 44 | idconv3 = nn.Conv2d(mid, oup, kernel_size=1).eval() 45 | #origional channels 46 | idconv3.weight.data[:,inp:]=conv3.weight.data 47 | idconv3.bias.data=conv3.bias.data 48 | #merge input featuremaps to output featuremaps 49 | nn.init.dirac_(idconv3.weight.data[:,:inp]) 50 | 51 | return [idconv1,idrelu1,idconv2,idrelu2,idconv3] 52 | 53 | def fuse_conv1_conv2(conv1,conv2): 54 | inp=conv1.in_channels 55 | mid=conv1.out_channels 56 | oup=conv2.out_channels 57 | fused_conv=nn.Conv2d(inp,oup,1) 58 | fused_conv.weight.data=(conv2.weight.data.view(oup,mid)@conv1.weight.data.view(mid,-1)).view(oup,inp,1,1) 59 | fused_conv.bias.data=conv2.bias.data+conv2.weight.data.view(oup,mid)@conv1.bias.data 60 | return fused_conv 61 | 62 | def mobilenetv2_to_mobilenetv1(model): 63 | model.eval() 64 | features=[] 65 | for m in model.features: 66 | if isinstance(m,InvertedResidual)and m.use_res_connect: 67 | features+=rm_r_InvertedResidual(m) 68 | else: 69 | for mm in m.modules(): 70 | if not list(mm.children()): 71 | #fuse conv and bn 72 | if isinstance(mm,nn.Conv2d): 73 | conv=mm 74 | continue 75 | elif isinstance(mm,nn.BatchNorm2d): 76 | mm=nn.utils.fuse_conv_bn_eval(conv,mm) 77 | features.append(mm) 78 | 79 | #fuse consecutive convolutional layers 80 | new_features=[features[0]] 81 | for m in features[1:]: 82 | if isinstance(m,nn.Conv2d)and isinstance(new_features[-1],nn.Conv2d): 83 | new_features[-1]=fuse_conv1_conv2(new_features[-1],m) 84 | else: 85 | new_features.append(m) 86 | model.features=nn.Sequential(*new_features) 87 | return model 88 | 89 | model=mobilenet_v2() 90 | x=torch.randn(2,3,224,224) 91 | model(x) 92 | print(model.eval()(x)) 93 | mobilenetv2_to_mobilenetv1(model) 94 | model(x) 95 | print(model.eval()(x)) 96 | -------------------------------------------------------------------------------- /models/easy_resnet2vgg.py: -------------------------------------------------------------------------------- 1 | # A simplify version for better understanding 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from torchvision.models import resnet 6 | def rm_r_BasicBlock(block): 7 | block.eval() 8 | in_planes = block.conv1.in_channels 9 | mid_planes = in_planes + block.conv1.out_channels 10 | out_planes = block.conv2.out_channels 11 | 12 | #merge conv1 and bn1 13 | block.conv1=nn.utils.fuse_conv_bn_eval(block.conv1,block.bn1) 14 | #new conv1 15 | idconv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=block.stride, padding=1).eval() 16 | #origional channels 17 | idconv1.weight.data[in_planes:]=block.conv1.weight.data 18 | idconv1.bias.data[in_planes:]=block.conv1.bias.data 19 | #reserve input featuremaps with dirac initialized channels 20 | nn.init.dirac_(idconv1.weight.data[:in_planes]) 21 | nn.init.zeros_(idconv1.bias.data[:in_planes]) 22 | 23 | #merge conv2 and bn2 24 | block.conv2=nn.utils.fuse_conv_bn_eval(block.conv2,block.bn2) 25 | #new conv 26 | idconv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1).eval() 27 | #origional channels 28 | idconv2.weight.data[:,in_planes:]=block.conv2.weight.data 29 | idconv2.bias.data=block.conv2.bias.data 30 | #merge input featuremaps to output featuremaps 31 | if in_planes==out_planes: 32 | nn.init.dirac_(idconv2.weight.data[:,:in_planes]) 33 | else: 34 | #if there are a downsample layer 35 | downsample=nn.utils.fuse_conv_bn_eval(block.downsample[0],block.downsample[1]) 36 | #conv1*1 -> conv3*3 37 | idconv2.weight.data[:,:in_planes]=F.pad(downsample.weight.data, [1, 1, 1, 1]) 38 | idconv2.bias.data+=downsample.bias.data 39 | return nn.Sequential(*[idconv1,block.relu,idconv2,block.relu]) 40 | 41 | def resnet_to_vgg(model): 42 | model.layer1=nn.Sequential(*[rm_r_BasicBlock(block) for block in model.layer1]) 43 | model.layer2=nn.Sequential(*[rm_r_BasicBlock(block) for block in model.layer2]) 44 | model.layer3=nn.Sequential(*[rm_r_BasicBlock(block) for block in model.layer3]) 45 | model.layer4=nn.Sequential(*[rm_r_BasicBlock(block) for block in model.layer4]) 46 | 47 | model=resnet.resnet18() 48 | x=torch.randn(2,3,224,224) 49 | model(x) 50 | print(model.eval()(x)) 51 | resnet_to_vgg(model) 52 | print(model.eval()(x)) 53 | -------------------------------------------------------------------------------- /models/easy_resnet50_to_vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torchvision.models import resnet 5 | 6 | def resnet50_to_vgg(model): 7 | def rm_r_Bottleneck(block): 8 | block.eval() 9 | in_planes = block.conv1.in_channels 10 | mid_planes = in_planes + block.conv1.out_channels 11 | out_planes = block.conv3.out_channels 12 | 13 | #merge conv1 and bn1 14 | block.conv1=nn.utils.fuse_conv_bn_eval(block.conv1,block.bn1) 15 | #new conv1 16 | idconv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1).eval() 17 | #origional channels 18 | idconv1.weight.data[in_planes:]=block.conv1.weight.data 19 | idconv1.bias.data[in_planes:]=block.conv1.bias.data 20 | #reserve input featuremaps with dirac initialized channels 21 | nn.init.dirac_(idconv1.weight.data[:in_planes]) 22 | nn.init.zeros_(idconv1.bias.data[:in_planes]) 23 | 24 | 25 | #merge conv2 and bn2 26 | block.conv2=nn.utils.fuse_conv_bn_eval(block.conv2,block.bn2) 27 | #new conv2 28 | idconv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=block.stride, padding=1).eval() 29 | #origional channels 30 | idconv2.weight.data[in_planes:][:,in_planes:]=block.conv2.weight.data 31 | nn.init.zeros_(idconv2.weight.data[in_planes:][:,:in_planes]) 32 | idconv2.bias.data[in_planes:]=block.conv2.bias.data 33 | #reserve input featuremaps with dirac initialized channels 34 | nn.init.dirac_(idconv2.weight.data[:in_planes]) 35 | nn.init.zeros_(idconv2.bias.data[:in_planes]) 36 | 37 | #merge conv3 and bn3 38 | block.conv3=nn.utils.fuse_conv_bn_eval(block.conv3,block.bn3) 39 | #new conv3 40 | idconv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1).eval() 41 | #origional channels 42 | idconv3.weight.data[:,in_planes:]=block.conv3.weight.data 43 | idconv3.bias.data=block.conv3.bias.data 44 | #merge input featuremaps to output featuremaps 45 | if in_planes==out_planes: 46 | nn.init.dirac_(idconv3.weight.data[:,:in_planes]) 47 | else: 48 | #if there are a downsample layer 49 | downsample=nn.utils.fuse_conv_bn_eval(block.downsample[0],block.downsample[1]) 50 | idconv3.weight.data[:,:in_planes]=downsample.weight.data 51 | idconv3.bias.data+=downsample.bias.data 52 | return nn.Sequential(*[idconv1,block.relu,idconv2,block.relu,idconv3,block.relu]) 53 | 54 | model.layer1=nn.Sequential(*[rm_r_Bottleneck(block) for block in model.layer1]) 55 | model.layer2=nn.Sequential(*[rm_r_Bottleneck(block) for block in model.layer2]) 56 | model.layer3=nn.Sequential(*[rm_r_Bottleneck(block) for block in model.layer3]) 57 | model.layer4=nn.Sequential(*[rm_r_Bottleneck(block) for block in model.layer4]) 58 | 59 | model=resnet.resnet50() 60 | x=torch.randn(2,3,224,224) 61 | model(x) 62 | print(model.eval()(x)) 63 | resnet50_to_vgg(model) 64 | print(model.eval()(x)) 65 | -------------------------------------------------------------------------------- /models/preactresnet2vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ResBlock(nn.Module): 6 | def __init__(self, planes): 7 | super(ResBlock, self).__init__() 8 | self.in_planes = planes 9 | self.out_planes = planes 10 | self.mid_planes = planes*2 11 | self.bn1 = nn.BatchNorm2d(planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 17 | 18 | def forward(self, x): 19 | out = self.bn1(x) 20 | out = self.relu1(out) 21 | out = self.conv1(out) 22 | out = self.bn2(out) 23 | out = self.relu2(out) 24 | out = self.conv2(out) 25 | return out+x 26 | 27 | def deploy(self): 28 | idconv0 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=1, bias=False).eval() 29 | nn.init.dirac_(idconv0.weight.data[:self.out_planes]) 30 | nn.init.dirac_(idconv0.weight.data[self.out_planes:]) 31 | 32 | 33 | idbn1=nn.BatchNorm2d(self.mid_planes) 34 | bn_var_sqrt=torch.sqrt(self.bn1.running_var + self.bn1.eps) 35 | idbn1.weight.data[:self.out_planes]=bn_var_sqrt 36 | idbn1.bias.data[:self.out_planes]=self.bn1.running_mean 37 | idbn1.running_mean.data[:self.out_planes]=self.bn1.running_mean 38 | idbn1.running_var.data[:self.out_planes]=self.bn1.running_var 39 | 40 | idbn1.weight.data[self.out_planes:]=self.bn1.weight.data 41 | idbn1.bias.data[self.out_planes:]=self.bn1.bias.data 42 | idbn1.running_mean.data[self.out_planes:]=self.bn1.running_mean 43 | idbn1.running_var.data[self.out_planes:]=self.bn1.running_var 44 | 45 | self.relu1 = nn.PReLU(self.mid_planes) 46 | torch.nn.init.ones_(self.relu1.weight.data[:self.out_planes]) 47 | torch.nn.init.zeros_(self.relu1.weight.data[self.out_planes:]) 48 | 49 | 50 | idconv1 = nn.Conv2d(self.mid_planes, self.mid_planes, kernel_size=3, padding=1, bias=False).eval() 51 | nn.init.dirac_(idconv1.weight.data[:self.out_planes]) 52 | torch.nn.init.zeros_(idconv1.weight.data[self.out_planes:][:,:self.out_planes]) 53 | idconv1.weight.data[self.out_planes:][:,self.out_planes:]=self.conv1.weight.data 54 | 55 | 56 | idbn2=nn.BatchNorm2d(self.mid_planes) 57 | idbn2.weight.data[:self.out_planes]=idbn1.weight.data[:self.out_planes] 58 | idbn2.bias.data[:self.out_planes]=idbn1.bias.data[:self.out_planes] 59 | idbn2.running_mean.data[:self.out_planes]=idbn1.running_mean.data[:self.out_planes] 60 | idbn2.running_var.data[:self.out_planes]=idbn1.running_var.data[:self.out_planes] 61 | 62 | idbn2.weight.data[self.out_planes:]=self.bn2.weight.data 63 | idbn2.bias.data[self.out_planes:]=self.bn2.bias.data 64 | idbn2.running_mean.data[self.out_planes:]=self.bn2.running_mean 65 | idbn2.running_var.data[self.out_planes:]=self.bn2.running_var 66 | 67 | self.relu2 = nn.PReLU(self.mid_planes) 68 | torch.nn.init.ones_(self.relu2.weight.data[:self.out_planes]) 69 | torch.nn.init.zeros_(self.relu2.weight.data[self.out_planes:]) 70 | 71 | 72 | idconv2 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False).eval() 73 | nn.init.dirac_(idconv2.weight.data[:,:self.out_planes]) 74 | idconv2.weight.data[:,self.out_planes:]=self.conv2.weight 75 | 76 | return [idconv0, idbn1, self.relu1, idconv1, idbn2, self.relu2, idconv2] 77 | 78 | class ResChannel(nn.Module): 79 | def __init__(self, planes): 80 | super(ResChannel, self).__init__() 81 | self.in_planes = planes 82 | self.out_planes = planes*2 83 | self.mid_planes = planes*3 84 | 85 | self.bn1 = nn.BatchNorm2d(self.in_planes) 86 | self.relu1 = nn.ReLU(inplace=True) 87 | self.running = nn.BatchNorm2d(self.in_planes,affine=False) 88 | self.downsample = nn.Conv2d(self.in_planes, self.out_planes, kernel_size=1, stride=2, bias=False) 89 | self.conv1 = nn.Conv2d(self.in_planes, self.out_planes, kernel_size=3, stride=2, padding=1, bias=False) 90 | self.bn2 = nn.BatchNorm2d(self.out_planes) 91 | self.relu2 = nn.ReLU(inplace=True) 92 | self.conv2 = nn.Conv2d(self.out_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False) 93 | 94 | def forward(self, x): 95 | out = self.bn1(x) 96 | out = self.relu1(out) 97 | self.running(out) 98 | shortcut = self.downsample(out) 99 | out = self.conv1(out) 100 | out = self.bn2(out) 101 | out = self.relu2(out) 102 | out = self.conv2(out) 103 | return out+shortcut 104 | 105 | def deploy(self): 106 | idconv1 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=3, stride=2, padding=1, bias=False).eval() 107 | nn.init.dirac_(idconv1.weight.data[:self.in_planes]) 108 | idconv1.weight.data[self.in_planes:]=self.conv1.weight.data 109 | 110 | idbn2=nn.BatchNorm2d(self.mid_planes) 111 | bn_var_sqrt=torch.sqrt(self.running.running_var + self.running.eps) 112 | idbn2.weight.data[:self.in_planes]=bn_var_sqrt 113 | idbn2.bias.data[:self.in_planes]=self.running.running_mean 114 | idbn2.running_mean.data[:self.in_planes]=self.running.running_mean 115 | idbn2.running_var.data[:self.in_planes]=self.running.running_var 116 | 117 | idbn2.weight.data[self.in_planes:]=self.bn2.weight.data 118 | idbn2.bias.data[self.in_planes:]=self.bn2.bias.data 119 | idbn2.running_mean.data[self.in_planes:]=self.bn2.running_mean 120 | idbn2.running_var.data[self.in_planes:]=self.bn2.running_var 121 | 122 | idconv2 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False).eval() 123 | idconv2.weight.data[:,:self.in_planes]=F.pad(self.downsample.weight.data, [1, 1, 1, 1]) 124 | idconv2.weight.data[:,self.in_planes:]=self.conv2.weight.data 125 | return [self.bn1, self.relu1, idconv1, idbn2, self.relu2, idconv2] 126 | 127 | class PreActResVGG(nn.Module): 128 | def __init__(self, num_blocks, num_classes=10): 129 | super(PreActResVGG, self).__init__() 130 | self.in_planes = 64 131 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 132 | self.layer1 = self._make_layer(64, num_blocks[0], stride=1) 133 | self.layer2 = self._make_layer(128, num_blocks[1], stride=2) 134 | self.layer3 = self._make_layer(256, num_blocks[2], stride=2) 135 | self.layer4 = self._make_layer(512, num_blocks[3], stride=2) 136 | 137 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 138 | self.flat = nn.Flatten(start_dim=1) 139 | self.fc = nn.Linear(512, num_classes) 140 | 141 | def _make_layer(self, planes, num_blocks, stride): 142 | layers = [] 143 | if stride==2: 144 | layers.append(ResChannel(self.in_planes)) 145 | else: 146 | layers.append(ResBlock(self.in_planes)) 147 | self.in_planes = planes 148 | for i in range(num_blocks-1): 149 | layers.append(ResBlock(self.in_planes)) 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | out = self.conv1(x) 154 | out = self.layer1(out) 155 | out = self.layer2(out) 156 | out = self.layer3(out) 157 | out = self.layer4(out) 158 | out = self.gap(out) 159 | out = self.flat(out) 160 | out = self.fc(out) 161 | return out 162 | 163 | def rm(self): 164 | def foo(net): 165 | global blocks 166 | childrens = list(net.children()) 167 | if isinstance(net, ResBlock)or isinstance(net, ResChannel): 168 | blocks+=net.deploy() 169 | elif not childrens: 170 | blocks+=[net] 171 | else: 172 | for c in childrens: 173 | foo(c) 174 | global blocks 175 | blocks =[] 176 | foo(self) 177 | return nn.Sequential(*blocks) 178 | 179 | def deploy(self): 180 | model=self.rm() 181 | blocks=[] 182 | c11=None 183 | for m in model[::-1]: 184 | if isinstance(m,nn.Conv2d): 185 | if m.kernel_size==(1,1): 186 | c11=m 187 | else: 188 | if c11 is not None: 189 | c31=nn.Conv2d(m.in_channels,c11.out_channels,3,stride=m.stride, padding=1,bias=False) 190 | c31.weight.data=(c11.weight.data.view(c11.out_channels,c11.in_channels)@m.weight.data.view(m.out_channels,-1)).view(c11.out_channels,m.in_channels,3,3) 191 | c11=None 192 | blocks.append(c31) 193 | else: 194 | blocks.append(m) 195 | else: 196 | blocks.append(m) 197 | return nn.Sequential(*blocks[::-1]) 198 | 199 | 200 | def preactresvgg18(num_classes=10): 201 | return PreActResVGG([2, 2, 2, 2], num_classes=num_classes) 202 | 203 | def preactresvgg34(num_classes=10): 204 | return PreActResVGG([3, 4, 6, 3], num_classes=num_classes) 205 | -------------------------------------------------------------------------------- /models/resnet2vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class ResBlock(nn.Module): 6 | def __init__(self, in_planes, mid_planes, out_planes, stride=1): 7 | super(ResBlock, self).__init__() 8 | 9 | assert mid_planes > in_planes 10 | 11 | self.in_planes = in_planes 12 | self.mid_planes = mid_planes - out_planes +in_planes 13 | self.out_planes = out_planes 14 | self.stride = stride 15 | 16 | self.conv1 = nn.Conv2d(in_planes, self.mid_planes - in_planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(self.mid_planes - in_planes) 18 | 19 | self.conv2 = nn.Conv2d(self.mid_planes - in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.downsample=nn.Sequential() 25 | if self.in_planes != self.out_planes or self.stride != 1: 26 | self.downsample=nn.Sequential( 27 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(out_planes)) 29 | self.running1 = nn.BatchNorm2d(in_planes,affine=False) 30 | self.running2 = nn.BatchNorm2d(out_planes,affine=False) 31 | 32 | def forward(self, x): 33 | if self.in_planes == self.out_planes and self.stride == 1: 34 | self.running1(x) 35 | out = self.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.downsample(x) 38 | self.running2(out) 39 | return self.relu(out) 40 | 41 | def deploy(self, merge_bn=False): 42 | idconv1 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=3, stride=self.stride, padding=1, bias=False).eval() 43 | idbn1=nn.BatchNorm2d(self.mid_planes).eval() 44 | 45 | nn.init.dirac_(idconv1.weight.data[:self.in_planes]) 46 | bn_var_sqrt=torch.sqrt(self.running1.running_var + self.running1.eps) 47 | idbn1.weight.data[:self.in_planes]=bn_var_sqrt 48 | idbn1.bias.data[:self.in_planes]=self.running1.running_mean 49 | idbn1.running_mean.data[:self.in_planes]=self.running1.running_mean 50 | idbn1.running_var.data[:self.in_planes]=self.running1.running_var 51 | 52 | idconv1.weight.data[self.in_planes:]=self.conv1.weight.data 53 | idbn1.weight.data[self.in_planes:]=self.bn1.weight.data 54 | idbn1.bias.data[self.in_planes:]=self.bn1.bias.data 55 | idbn1.running_mean.data[self.in_planes:]=self.bn1.running_mean 56 | idbn1.running_var.data[self.in_planes:]=self.bn1.running_var 57 | 58 | idconv2 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False).eval() 59 | idbn2=nn.BatchNorm2d(self.out_planes).eval() 60 | downsample_bias=0 61 | if self.in_planes==self.out_planes: 62 | nn.init.dirac_(idconv2.weight.data[:,:self.in_planes]) 63 | else: 64 | idconv2.weight.data[:,:self.in_planes],downsample_bias=self.fuse(F.pad(self.downsample[0].weight.data, [1, 1, 1, 1]),self.downsample[1].running_mean,self.downsample[1].running_var,self.downsample[1].weight,self.downsample[1].bias,self.downsample[1].eps) 65 | 66 | idconv2.weight.data[:,self.in_planes:],bias=self.fuse(self.conv2.weight,self.bn2.running_mean,self.bn2.running_var,self.bn2.weight,self.bn2.bias,self.bn2.eps) 67 | 68 | bn_var_sqrt=torch.sqrt(self.running2.running_var + self.running2.eps) 69 | idbn2.weight.data=bn_var_sqrt 70 | idbn2.bias.data=self.running2.running_mean 71 | idbn2.running_mean.data=self.running2.running_mean+bias+downsample_bias 72 | idbn2.running_var.data=self.running2.running_var 73 | 74 | if merge_bn: 75 | return [torch.nn.utils.fuse_conv_bn_eval(idconv1,idbn1),self.relu,torch.nn.utils.fuse_conv_bn_eval(idconv2,idbn2),self.relu] 76 | else: 77 | return [idconv1,idbn1,self.relu,idconv2,idbn2,self.relu] 78 | 79 | 80 | def fuse(self,conv_w, bn_rm, bn_rv,bn_w,bn_b, eps): 81 | bn_var_rsqrt = torch.rsqrt(bn_rv + eps) 82 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 83 | conv_b = bn_rm * bn_var_rsqrt * bn_w-bn_b 84 | return conv_w,conv_b 85 | 86 | class RMNet(nn.Module): 87 | def __init__(self, block, num_blocks, num_classes=10,base_wide=64): 88 | super(RMNet, self).__init__() 89 | self.in_planes = base_wide 90 | self.conv1 = nn.Conv2d(3, base_wide, kernel_size=7 if num_classes==1000 else 3, stride=2 if num_classes==1000 else 1, padding=3 if num_classes==1000 else 1, bias=False) 91 | self.bn1 = nn.BatchNorm2d(base_wide) 92 | self.relu = nn.ReLU(inplace=True) 93 | if num_classes==1000: 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | self.layer1 = self._make_layer(block, base_wide, num_blocks[0], stride=1) 96 | self.layer2 = self._make_layer(block, base_wide*2, num_blocks[1], stride=2) 97 | self.layer3 = self._make_layer(block, base_wide*4, num_blocks[2], stride=2) 98 | self.layer4 = None 99 | if len(num_blocks)==4: 100 | self.layer4 = self._make_layer(block, base_wide*8, num_blocks[3], stride=2) 101 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 102 | self.flat = nn.Flatten(start_dim=1) 103 | self.fc = nn.Linear(self.in_planes, num_classes) 104 | 105 | def _make_layer(self, block, planes, num_blocks, stride): 106 | strides = [stride] + [1]*(num_blocks-1) 107 | layers = [] 108 | for stride in strides: 109 | layers.append(block(self.in_planes, planes*2, planes, stride)) 110 | self.in_planes = planes 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = self.relu(self.bn1(self.conv1(x))) 115 | if self.fc.out_features==1000: 116 | out = self.maxpool(out) 117 | out = self.layer1(out) 118 | out = self.layer2(out) 119 | out = self.layer3(out) 120 | if self.layer4 is not None: 121 | out = self.layer4(out) 122 | out = self.gap(out) 123 | out = self.flat(out) 124 | out = self.fc(out) 125 | return out 126 | 127 | def deploy(self, merge_bn=False): 128 | def foo(net): 129 | global blocks 130 | childrens = list(net.children()) 131 | if isinstance(net, ResBlock): 132 | blocks+=net.deploy(merge_bn) 133 | elif not childrens: 134 | if isinstance(net,nn.BatchNorm2d) and isinstance(blocks[-1],nn.Conv2d): 135 | blocks[-1]=torch.nn.utils.fuse_conv_bn_eval(blocks[-1],net) 136 | else: 137 | blocks+=[net] 138 | else: 139 | for c in childrens: 140 | foo(c) 141 | global blocks 142 | 143 | blocks =[] 144 | foo(self.eval()) 145 | return nn.Sequential(*blocks) 146 | 147 | def rmnet18(num_classes=1000): 148 | return RMNet(ResBlock, [2, 2, 2, 2], num_classes=num_classes) 149 | 150 | def rmnet34(num_classes=1000): 151 | return RMNet(ResBlock, [3, 4, 6, 3], num_classes=num_classes) 152 | 153 | def rmnet20(num_classes=10): 154 | return RMNet(ResBlock, [3, 3, 3],num_classes=num_classes,base_wide=16) 155 | 156 | def rmnet32(num_classes=10): 157 | return RMNet(ResBlock, [5, 5, 5],num_classes=num_classes,base_wide=16) 158 | 159 | def rmnet44(num_classes=10): 160 | return RMNet(ResBlock, [7, 7, 7],num_classes=num_classes,base_wide=16) 161 | 162 | def rmnet56(num_classes=10): 163 | return RMNet(ResBlock, [9, 9, 9],num_classes=num_classes,base_wide=16) 164 | 165 | def rmnet110(num_classes=10): 166 | return RMNet(ResBlock, [18, 18, 18],num_classes=num_classes,base_wide=16) 167 | 168 | def rmnet1202(num_classes=10): 169 | return RMNet(ResBlock, [200, 200, 200],num_classes=num_classes,base_wide=16) 170 | -------------------------------------------------------------------------------- /models/rmnet_from_scratch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class RMBlockFromScratch(nn.Module): 6 | def __init__(self, in_planes, mid_planes, out_planes, stride=1): 7 | super(RMBlockFromScratch, self).__init__() 8 | 9 | assert mid_planes > in_planes 10 | 11 | self.in_planes = in_planes 12 | self.mid_planes = mid_planes - out_planes +in_planes 13 | self.out_planes = out_planes 14 | self.stride = stride 15 | 16 | self.conv1 = nn.Conv2d(in_planes, self.mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(self.mid_planes) 18 | self.relu1 = nn.ReLU(inplace=True) 19 | 20 | self.conv2 = nn.Conv2d(self.mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(out_planes) 22 | self.relu2 = nn.ReLU(inplace=True) 23 | 24 | def forward(self, x): 25 | nn.init.dirac_(self.conv1.weight.data[:self.in_planes]) 26 | x = self.conv1(x) 27 | if self.training: 28 | x_mean = x.mean(dim=(0,2,3)) 29 | x_var = x.var(dim=(0,2,3),unbiased=False) 30 | self.bn1.weight.data[:self.in_planes]=torch.sqrt(x_var[:self.in_planes]+self.bn1.eps) 31 | self.bn1.bias.data[:self.in_planes]=x_mean[:self.in_planes] 32 | else: 33 | self.bn1.weight.data[:self.in_planes]=torch.sqrt(self.bn1.running_var[:self.in_planes]+self.bn1.eps) 34 | self.bn1.bias.data[:self.in_planes]=self.bn1.running_mean[:self.in_planes] 35 | x = F.batch_norm(x,self.bn1.running_mean,self.bn1.running_var,self.bn1.weight,self.bn1.bias,training=self.training) 36 | out = self.relu1(x) 37 | if self.in_planes==self.out_planes and self.stride==1: 38 | nn.init.dirac_(self.conv2.weight.data[:,:self.in_planes]) 39 | out = self.bn2(self.conv2(out)) 40 | return self.relu2(out) 41 | -------------------------------------------------------------------------------- /models/rmnet_pruning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class ResBlock(nn.Module): 6 | def __init__(self, in_planes, mid_planes, out_planes, stride=1): 7 | super(ResBlock, self).__init__() 8 | 9 | assert mid_planes > in_planes 10 | 11 | self.in_planes = in_planes 12 | self.mid_planes = mid_planes - out_planes +in_planes 13 | self.out_planes = out_planes 14 | self.stride = stride 15 | 16 | self.conv1 = nn.Conv2d(in_planes, self.mid_planes - in_planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(self.mid_planes - in_planes) 18 | self.mask1 = nn.Conv2d(self.mid_planes - in_planes,self.mid_planes - in_planes,1,groups=self.mid_planes - in_planes,bias=False) 19 | 20 | self.conv2 = nn.Conv2d(self.mid_planes - in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(out_planes) 22 | self.mask2 = nn.Conv2d(out_planes,out_planes,1,groups=out_planes,bias=False) 23 | 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | self.downsample=nn.Sequential() 27 | if self.in_planes != self.out_planes or self.stride != 1: 28 | self.downsample=nn.Sequential( 29 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(out_planes)) 31 | 32 | self.mask_res = nn.Sequential(*[nn.Conv2d(self.in_planes,self.in_planes,1,groups=self.in_planes,bias=False), 33 | nn.ReLU(inplace=True)]) 34 | self.running1 = nn.BatchNorm2d(in_planes,affine=False) 35 | self.running2 = nn.BatchNorm2d(out_planes,affine=False) 36 | nn.init.ones_(self.mask1.weight) 37 | nn.init.ones_(self.mask2.weight) 38 | nn.init.ones_(self.mask_res[0].weight) 39 | 40 | def forward(self, x): 41 | if self.in_planes == self.out_planes and self.stride == 1: 42 | self.running1(x) 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.mask1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | out += self.downsample(self.mask_res(x)) 51 | self.running2(out) 52 | out = self.mask2(out) 53 | out = self.relu(out) 54 | return out 55 | 56 | def deploy(self,merge_bn=False): 57 | idconv1 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=3, stride=self.stride, padding=1, bias=False).eval() 58 | idbn1=nn.BatchNorm2d(self.mid_planes).eval() 59 | 60 | nn.init.dirac_(idconv1.weight.data[:self.in_planes]) 61 | bn_var_sqrt=torch.sqrt(self.running1.running_var + self.running1.eps) 62 | idbn1.weight.data[:self.in_planes]=bn_var_sqrt 63 | idbn1.bias.data[:self.in_planes]=self.running1.running_mean 64 | idbn1.running_mean.data[:self.in_planes]=self.running1.running_mean 65 | idbn1.running_var.data[:self.in_planes]=self.running1.running_var 66 | 67 | idconv1.weight.data[self.in_planes:]=self.conv1.weight.data 68 | idbn1.weight.data[self.in_planes:]=self.bn1.weight.data 69 | idbn1.bias.data[self.in_planes:]=self.bn1.bias.data 70 | idbn1.running_mean.data[self.in_planes:]=self.bn1.running_mean 71 | idbn1.running_var.data[self.in_planes:]=self.bn1.running_var 72 | 73 | mask1=nn.Conv2d(self.mid_planes,self.mid_planes,1,groups=self.mid_planes,bias=False) 74 | mask1.weight.data[:self.in_planes]=self.mask_res[0].weight.data*(self.mask_res[0].weight.data>0) 75 | mask1.weight.data[self.in_planes:]=self.mask1.weight.data 76 | 77 | 78 | idconv2 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False).eval() 79 | idbn2=nn.BatchNorm2d(self.out_planes).eval() 80 | downsample_bias=0 81 | if self.in_planes==self.out_planes: 82 | nn.init.dirac_(idconv2.weight.data[:,:self.in_planes]) 83 | else: 84 | idconv2.weight.data[:,:self.in_planes],downsample_bias=self.fuse(F.pad(self.downsample[0].weight.data, [1, 1, 1, 1]),self.downsample[1].running_mean,self.downsample[1].running_var,self.downsample[1].weight,self.downsample[1].bias,self.downsample[1].eps) 85 | 86 | idconv2.weight.data[:,self.in_planes:],bias=self.fuse(self.conv2.weight,self.bn2.running_mean,self.bn2.running_var,self.bn2.weight,self.bn2.bias,self.bn2.eps) 87 | 88 | bn_var_sqrt=torch.sqrt(self.running2.running_var + self.running2.eps) 89 | idbn2.weight.data=bn_var_sqrt 90 | idbn2.bias.data=self.running2.running_mean 91 | idbn2.running_mean.data=self.running2.running_mean+bias+downsample_bias 92 | idbn2.running_var.data=self.running2.running_var 93 | 94 | 95 | idbn1.weight.data*=mask1.weight.data.reshape(-1) 96 | idbn1.bias.data*=mask1.weight.data.reshape(-1) 97 | idbn2.weight.data*=self.mask2.weight.data.reshape(-1) 98 | idbn2.bias.data*=self.mask2.weight.data.reshape(-1) 99 | return [idconv1,idbn1,nn.ReLU(True),idconv2,idbn2,nn.ReLU(True)] 100 | 101 | 102 | def fuse(self,conv_w, bn_rm, bn_rv,bn_w,bn_b, eps): 103 | bn_var_rsqrt = torch.rsqrt(bn_rv + eps) 104 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 105 | conv_b = bn_rm * bn_var_rsqrt * bn_w-bn_b 106 | return conv_w,conv_b 107 | 108 | 109 | class RMNetPruning(nn.Module): 110 | def __init__(self, block, num_blocks, num_classes=10,base_wide=64): 111 | super(RMNetPruning, self).__init__() 112 | self.in_planes = base_wide 113 | self.conv1 = nn.Conv2d(3, base_wide, kernel_size=7 if num_classes==1000 else 3, stride=2 if num_classes==1000 else 1, padding=3 if num_classes==1000 else 1, bias=False) 114 | self.bn1 = nn.BatchNorm2d(base_wide) 115 | self.mask1 = nn.Conv2d(base_wide,base_wide,1,groups=base_wide,bias=False) 116 | nn.init.ones_(self.mask1.weight) 117 | self.relu = nn.ReLU(inplace=True) 118 | if num_classes==1000: 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | 121 | self.layer1 = self._make_layer(block, base_wide, num_blocks[0], stride=1) 122 | self.layer2 = self._make_layer(block, base_wide*2, num_blocks[1], stride=2) 123 | self.layer3 = self._make_layer(block, base_wide*4, num_blocks[2], stride=2) 124 | self.layer4 = None 125 | if len(num_blocks)==4: 126 | self.layer4 = self._make_layer(block, base_wide*8, num_blocks[3], stride=2) 127 | 128 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 129 | self.flat = nn.Flatten(start_dim=1) 130 | self.fc = nn.Linear(self.in_planes, num_classes) 131 | 132 | def _make_layer(self, block, planes, num_blocks, stride): 133 | strides = [stride] + [1]*(num_blocks-1) 134 | layers = [] 135 | for stride in strides: 136 | layers.append(block(self.in_planes, planes*2, planes, stride)) 137 | self.in_planes = planes 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | out = self.bn1(self.conv1(x)) 142 | out = self.mask1(out) 143 | out = self.relu(out) 144 | if self.fc.out_features==1000: 145 | out = self.maxpool(out) 146 | out = self.layer1(out) 147 | out = self.layer2(out) 148 | out = self.layer3(out) 149 | if self.layer4 is not None: 150 | out = self.layer4(out) 151 | 152 | out = self.gap(out) 153 | out = self.flat(out) 154 | out = self.fc(out) 155 | return out 156 | 157 | def update_mask(self,sr,threshold): 158 | for m in self.modules(): 159 | if isinstance(m,nn.Conv2d): 160 | if m.kernel_size==(1,1) and m.groups!=1: 161 | m.weight.grad.data.add_(sr * torch.sign(m.weight.data)) 162 | m1 = m.weight.data.abs()>threshold 163 | m.weight.grad.data*=m1 164 | m.weight.data*=m1 165 | 166 | def fix_mask(self): 167 | for m in self.modules(): 168 | if isinstance(m,nn.Conv2d): 169 | if m.kernel_size==(1,1) and m.groups!=1: 170 | m.weight.requires_grad=False 171 | 172 | def deploy(self): 173 | def foo(net): 174 | global blocks 175 | childrens = list(net.children()) 176 | if isinstance(net, ResBlock): 177 | blocks+=net.deploy() 178 | 179 | elif not childrens: 180 | if isinstance(net,nn.Conv2d) and net.groups!=1: 181 | blocks[-1].weight.data*=net.weight.data.reshape(-1) 182 | blocks[-1].bias.data*=net.weight.data.reshape(-1) 183 | else: 184 | blocks+=[net] 185 | else: 186 | for c in childrens: 187 | foo(c) 188 | global blocks 189 | 190 | blocks =[] 191 | foo(self.eval()) 192 | return nn.Sequential(*blocks) 193 | 194 | def prune(self,use_bn=True): 195 | features=[] 196 | in_mask=torch.ones(3)>0 197 | blocks=self.deploy() 198 | for i,m in enumerate(blocks): 199 | if isinstance(m,nn.BatchNorm2d): 200 | mask=m.weight.data.abs().reshape(-1)>0 201 | conv=nn.Conv2d(int(in_mask.sum()),int(mask.sum()),blocks[i-1].kernel_size,stride=blocks[i-1].stride,padding=blocks[i-1].padding,bias=False) 202 | conv.weight.data=blocks[i-1].weight.data[mask][:,in_mask] 203 | bn=nn.BatchNorm2d(int(mask.sum())) 204 | bn.weight.data=m.weight.data[mask] 205 | bn.bias.data=m.bias.data[mask] 206 | bn.running_mean=m.running_mean[mask] 207 | bn.running_var=m.running_var[mask] 208 | if use_bn: 209 | features.extend([conv,bn]) 210 | else: 211 | features.extend([nn.utils.fuse_conv_bn_eval(conv.eval(),bn.eval())]) 212 | in_mask=mask 213 | elif isinstance(m,nn.Conv2d): 214 | if not isinstance(blocks[i+1],nn.BatchNorm2d): 215 | print(m) 216 | elif isinstance(m,nn.Linear): 217 | linear=nn.Linear(int(in_mask.sum()),m.out_features) 218 | linear.weight.data=m.weight.data[:,in_mask] 219 | linear.bias.data=m.bias.data 220 | features.append(linear) 221 | else: 222 | features.append(m) 223 | return nn.Sequential(*features) 224 | 225 | 226 | def rmnet_pruning_18(num_classes=1000): 227 | return RMNetPruning(ResBlock, [2, 2, 2, 2], num_classes=num_classes) 228 | 229 | def rmnet_pruning_34(num_classes=1000): 230 | return RMNetPruning(ResBlock, [3, 4, 6, 3], num_classes=num_classes) 231 | -------------------------------------------------------------------------------- /models/rmnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class RepBlock(nn.Module): 7 | def __init__(self, in_planes, out_planes, stride=1): 8 | 9 | super(RepBlock, self).__init__() 10 | 11 | self.in_planes = in_planes 12 | self.out_planes = out_planes 13 | self.stride = stride 14 | 15 | self.conv33 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn33 = nn.BatchNorm2d(out_planes) 17 | self.conv11 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 18 | self.bn11 = nn.BatchNorm2d(out_planes) 19 | #self.running = nn.BatchNorm2d(out_planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | out = self.bn33(self.conv33(x)) 24 | out += self.bn11(self.conv11(x)) 25 | #self.running(out) 26 | return self.relu(out) 27 | 28 | def deploy(self, merge_bn=False): 29 | self.eval() 30 | conv33_bn33 = torch.nn.utils.fuse_conv_bn_eval(self.conv33, self.bn33).eval() 31 | conv11_bn11 = torch.nn.utils.fuse_conv_bn_eval(self.conv11, self.bn11).eval() 32 | conv33_bn33.weight.data += F.pad(conv11_bn11.weight.data, [1, 1, 1, 1]) 33 | conv33_bn33.bias.data += conv11_bn11.bias.data 34 | 35 | #self.running.weight.data = torch.sqrt(self.running.running_var + self.running.eps) 36 | #self.running.bias.data = self.running.running_mean 37 | #if merge_bn: 38 | return [conv33_bn33, self.relu] 39 | #else: 40 | # return [conv33_bn33, self.running, self.relu] 41 | 42 | 43 | class RMBlock(nn.Module): 44 | def __init__(self, in_planes, out_planes, stride, expand_ratio=2, cpg=1): 45 | super(RMBlock, self).__init__() 46 | self.in_planes = in_planes 47 | self.mid_planes = out_planes * expand_ratio - out_planes 48 | self.out_planes = out_planes 49 | self.stride = stride 50 | self.cpg = cpg 51 | assert self.mid_planes % cpg == 0 and (self.mid_planes + self.in_planes) % cpg == 0 52 | self.groups = self.mid_planes // cpg 53 | 54 | self.conv1 = nn.Conv2d(in_planes, self.mid_planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(self.mid_planes) 56 | self.conv2 = nn.Conv2d(self.mid_planes, self.mid_planes, kernel_size=3, stride=stride, padding=1, groups=self.groups, bias=False) 57 | self.bn2 = nn.BatchNorm2d(self.mid_planes) 58 | self.conv3 = nn.Conv2d(self.mid_planes, out_planes, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(out_planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = nn.Sequential() 62 | if self.stride != 1 or in_planes != out_planes: 63 | self.downsample = nn.Sequential( 64 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False), 65 | nn.BatchNorm2d(out_planes) 66 | ) 67 | self.running1 = nn.BatchNorm2d(self.in_planes, affine=False) 68 | self.running2 = nn.BatchNorm2d(self.out_planes, affine=False) 69 | 70 | def forward(self, x): 71 | self.running1(x) 72 | out = self.relu(self.bn1(self.conv1(x))) 73 | out = self.relu(self.bn2(self.conv2(out))) 74 | out = self.bn3(self.conv3(out)) 75 | out += self.downsample(x) 76 | self.running2(out) 77 | return self.relu(out) 78 | 79 | def deploy(self, merge_bn=False): 80 | self.mid_planes = self.conv2.in_channels + self.in_planes 81 | self.groups = self.mid_planes // self.cpg 82 | idconv1 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=1, bias=False).eval() 83 | idbn1 = nn.BatchNorm2d(self.mid_planes).eval() 84 | 85 | nn.init.dirac_(idconv1.weight.data[:self.in_planes]) 86 | bn_var_sqrt = torch.sqrt(self.running1.running_var + self.running1.eps) 87 | idbn1.weight.data[:self.in_planes] = bn_var_sqrt 88 | idbn1.bias.data[:self.in_planes] = self.running1.running_mean 89 | idbn1.running_mean.data[:self.in_planes] = self.running1.running_mean 90 | idbn1.running_var.data[:self.in_planes] = self.running1.running_var 91 | 92 | idconv1.weight.data[self.in_planes:] = self.conv1.weight.data 93 | idbn1.weight.data[self.in_planes:] = self.bn1.weight.data 94 | idbn1.bias.data[self.in_planes:] = self.bn1.bias.data 95 | idbn1.running_mean.data[self.in_planes:] = self.bn1.running_mean 96 | idbn1.running_var.data[self.in_planes:] = self.bn1.running_var 97 | 98 | idconv2 = nn.Conv2d(self.mid_planes, self.mid_planes, kernel_size=3, stride=self.stride, padding=1, groups=self.groups, bias=False).eval() 99 | idbn2 = nn.BatchNorm2d(self.mid_planes).eval() 100 | 101 | idbn2.weight.data[:self.in_planes] = idbn1.weight.data[:self.in_planes] 102 | idbn2.bias.data[:self.in_planes] = idbn1.bias.data[:self.in_planes] 103 | idbn2.running_mean.data[:self.in_planes] = idbn1.running_mean.data[:self.in_planes] 104 | idbn2.running_var.data[:self.in_planes] = idbn1.running_var.data[:self.in_planes] 105 | nn.init.dirac_(idconv2.weight.data[:self.in_planes], groups=self.groups - self.conv2.groups) 106 | 107 | idconv2.weight.data[self.in_planes:] = self.conv2.weight.data 108 | idbn2.weight.data[self.in_planes:] = self.bn2.weight.data 109 | idbn2.bias.data[self.in_planes:] = self.bn2.bias.data 110 | idbn2.running_mean.data[self.in_planes:] = self.bn2.running_mean 111 | idbn2.running_var.data[self.in_planes:] = self.bn2.running_var 112 | 113 | idconv3 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=1, bias=False).eval() 114 | idbn3 = nn.BatchNorm2d(self.out_planes).eval() 115 | 116 | downsample_bias = 0 117 | if self.in_planes == self.out_planes and self.stride == 1: 118 | nn.init.dirac_(idconv3.weight.data[:, :self.in_planes]) 119 | else: 120 | idconv3.weight.data[:, :self.in_planes], downsample_bias = self.fuse(self.downsample[0].weight, self.downsample[1].running_mean, self.downsample[1].running_var, self.downsample[1].weight, self.downsample[1].bias, self.downsample[1].eps) 121 | 122 | idconv3.weight.data[:, self.in_planes:], bias = self.fuse(self.conv3.weight, self.bn3.running_mean, self.bn3.running_var, self.bn3.weight, self.bn3.bias, self.bn3.eps) 123 | bn_var_sqrt = torch.sqrt(self.running2.running_var + self.running2.eps) 124 | idbn3.weight.data = bn_var_sqrt 125 | idbn3.bias.data = self.running2.running_mean 126 | idbn3.running_mean.data = self.running2.running_mean + bias + downsample_bias 127 | idbn3.running_var.data = self.running2.running_var 128 | 129 | if merge_bn: 130 | return [torch.nn.utils.fuse_conv_bn_eval(idconv1, idbn1), self.relu, torch.nn.utils.fuse_conv_bn_eval(idconv2, idbn2), self.relu, torch.nn.utils.fuse_conv_bn_eval(idconv3, idbn3), self.relu] 131 | else: 132 | return [idconv1, idbn1, self.relu, idconv2, idbn2, self.relu, idconv3, idbn3, self.relu] 133 | 134 | def fuse(self, conv_w, bn_rm, bn_rv, bn_w, bn_b, eps): 135 | bn_var_rsqrt = torch.rsqrt(bn_rv + eps) 136 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 137 | conv_b = bn_rm * bn_var_rsqrt * bn_w - bn_b 138 | return conv_w, conv_b 139 | 140 | 141 | class RMNeXt(nn.Module): 142 | def __init__(self, num_blocks, num_classes=1000, base_wide=64, expand_ratio=2, cpg=1): 143 | super(RMNeXt, self).__init__() 144 | self.in_planes = min(64, base_wide) 145 | if num_classes==1000: 146 | self.layer0 = nn.Sequential(*[RepBlock(3, self.in_planes, stride=2), RepBlock(self.in_planes, self.in_planes, stride=2)]) 147 | else: 148 | self.layer0 = RepBlock(3, self.in_planes, stride=1) 149 | self.layer1 = self._make_layer(base_wide, num_blocks[0], expand_ratio, cpg, stride=1) 150 | self.layer2 = self._make_layer(base_wide * 2, num_blocks[1], expand_ratio, cpg * 2, stride=2) 151 | self.layer3 = self._make_layer(base_wide * 4, num_blocks[2], expand_ratio, cpg * 4, stride=2) 152 | self.layer4 = self._make_layer(base_wide * 8, num_blocks[3], expand_ratio, cpg * 8, stride=2) 153 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 154 | self.flat = nn.Flatten(start_dim=1) 155 | self.fc = nn.Linear(self.in_planes, num_classes) 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 159 | if isinstance(m, RMBlock): 160 | nn.init.constant_(m.bn3.weight, 0) 161 | 162 | def _make_layer(self, planes, num_blocks, expand_ratio, cpg, stride=1): 163 | strides = [stride] + [1] * (num_blocks - 1) 164 | layers = [] 165 | for stride in strides: 166 | layers.append(RMBlock(self.in_planes, planes, stride, expand_ratio, cpg)) 167 | self.in_planes = planes 168 | return nn.Sequential(*layers) 169 | 170 | def forward(self, x): 171 | out = self.layer0(x) 172 | out = self.layer1(out) 173 | out = self.layer2(out) 174 | out = self.layer3(out) 175 | out = self.layer4(out) 176 | out = self.gap(out) 177 | out = self.flat(out) 178 | out = self.fc(out) 179 | return out 180 | 181 | def deploy(self, merge_bn=False): 182 | def foo(net): 183 | global blocks 184 | childrens = list(net.children()) 185 | if isinstance(net, RMBlock) or isinstance(net, RepBlock): 186 | blocks += net.deploy(merge_bn) 187 | elif not childrens: 188 | blocks += [net] 189 | else: 190 | for c in childrens: 191 | foo(c) 192 | 193 | global blocks 194 | 195 | blocks = [] 196 | foo(self.eval()) 197 | return nn.Sequential(*blocks) 198 | 199 | 200 | def rmnext41_64x5_g16(num_classes=1000): 201 | return RMNeXt([2, 3, 5, 3], num_classes=num_classes, base_wide=64, expand_ratio=5, cpg=16) 202 | 203 | 204 | def rmnext50_64x5_g32(num_classes=1000): 205 | return RMNeXt([3, 4, 6, 3], num_classes=num_classes, base_wide=64, expand_ratio=5, cpg=32) 206 | 207 | 208 | def rmnext50_64x6_g32(num_classes=1000): 209 | return RMNeXt([3, 4, 6, 3], num_classes=num_classes, base_wide=64, expand_ratio=6, cpg=32) 210 | 211 | 212 | def rmnext101_64x6_g16(num_classes=1000): 213 | return RMNeXt([3, 4, 23, 3], num_classes=num_classes, base_wide=64, expand_ratio=6, cpg=16) 214 | 215 | 216 | def rmnext152_64x6_g32(num_classes=1000): 217 | return RMNeXt([3, 8, 36, 3], num_classes=num_classes, base_wide=64, expand_ratio=6, cpg=32) 218 | -------------------------------------------------------------------------------- /models/rmobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import math 5 | 6 | def fuse_cbcb(conv1,bn1,conv2,bn2): 7 | inp=conv1.in_channels 8 | mid=conv1.out_channels 9 | oup=conv2.out_channels 10 | conv1=torch.nn.utils.fuse_conv_bn_eval(conv1.eval(),bn1.eval()) 11 | fused_conv=nn.Conv2d(inp,oup,1,bias=False) 12 | fused_conv.weight.data=(conv2.weight.data.view(oup,mid)@conv1.weight.data.view(mid,-1)).view(oup,inp,1,1) 13 | bn2.running_mean-=conv2.weight.data.view(oup,mid)@conv1.bias.data 14 | return fused_conv,bn2 15 | 16 | class InvertedResidual(nn.Module): 17 | def __init__(self, inp, oup, stride, expand_ratio, free=1): 18 | super(InvertedResidual, self).__init__() 19 | self.in_planes=inp 20 | self.out_planes=oup 21 | 22 | self.stride = stride 23 | self.free = free 24 | assert stride in [1, 2] 25 | 26 | hidden_dim = round(inp * expand_ratio) 27 | self.use_res_connect = self.stride == 1 and inp == oup 28 | self.mid_planes=hidden_dim+ inp if self.use_res_connect else 0 29 | 30 | self.conv = nn.Sequential( 31 | # pw 32 | nn.Conv2d(inp*free, hidden_dim, 1, 1, 0, bias=False), 33 | nn.BatchNorm2d(hidden_dim), 34 | nn.ReLU6(inplace=True), 35 | # dw 36 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 37 | nn.BatchNorm2d(hidden_dim), 38 | nn.ReLU6(inplace=True), 39 | # pw-linear 40 | nn.Conv2d(hidden_dim, oup*free, 1, 1, 0, bias=False), 41 | nn.BatchNorm2d(oup*free), 42 | ) 43 | if self.use_res_connect: 44 | self.running1 = nn.BatchNorm2d(self.in_planes,affine=False) 45 | self.running2 = nn.BatchNorm2d(self.in_planes*free,affine=False) 46 | 47 | def forward(self, x): 48 | if self.use_res_connect: 49 | out = self.conv(x) 50 | out[:,:self.in_planes] += x[:,:self.in_planes] 51 | self.running1(x[:,:self.in_planes]) 52 | self.running2(out) 53 | return out 54 | else: 55 | return self.conv(x) 56 | 57 | def deploy(self): 58 | if self.use_res_connect: 59 | idconv1 = nn.Conv2d(self.in_planes*self.free, self.mid_planes, kernel_size=1, bias=False).eval() 60 | idbn1=nn.BatchNorm2d(self.mid_planes).eval() 61 | 62 | nn.init.dirac_(idconv1.weight.data[:self.in_planes]) 63 | bn_var_sqrt=torch.sqrt(self.running1.running_var + self.running1.eps) 64 | idbn1.weight.data[:self.in_planes]=bn_var_sqrt 65 | idbn1.bias.data[:self.in_planes]=self.running1.running_mean 66 | idbn1.running_mean.data[:self.in_planes]=self.running1.running_mean 67 | idbn1.running_var.data[:self.in_planes]=self.running1.running_var 68 | 69 | idconv1.weight.data[self.in_planes:]=self.conv[0].weight.data 70 | idbn1.weight.data[self.in_planes:]=self.conv[1].weight.data 71 | idbn1.bias.data[self.in_planes:]=self.conv[1].bias.data 72 | idbn1.running_mean.data[self.in_planes:]=self.conv[1].running_mean 73 | idbn1.running_var.data[self.in_planes:]=self.conv[1].running_var 74 | idrelu1 = nn.PReLU(self.mid_planes) 75 | torch.nn.init.ones_(idrelu1.weight.data[:self.in_planes]) 76 | torch.nn.init.zeros_(idrelu1.weight.data[self.in_planes:]) 77 | 78 | 79 | 80 | idconv2 = nn.Conv2d(self.mid_planes, self.mid_planes, kernel_size=3, stride=self.stride, padding=1,groups=self.mid_planes, bias=False).eval() 81 | idbn2=nn.BatchNorm2d(self.mid_planes).eval() 82 | 83 | nn.init.dirac_(idconv2.weight.data[:self.in_planes],groups=self.in_planes) 84 | idbn2.weight.data[:self.in_planes]=idbn1.weight.data[:self.in_planes] 85 | idbn2.bias.data[:self.in_planes]=idbn1.bias.data[:self.in_planes] 86 | idbn2.running_mean.data[:self.in_planes]=idbn1.running_mean.data[:self.in_planes] 87 | idbn2.running_var.data[:self.in_planes]=idbn1.running_var.data[:self.in_planes] 88 | 89 | idconv2.weight.data[self.in_planes:]=self.conv[3].weight.data 90 | idbn2.weight.data[self.in_planes:]=self.conv[4].weight.data 91 | idbn2.bias.data[self.in_planes:]=self.conv[4].bias.data 92 | idbn2.running_mean.data[self.in_planes:]=self.conv[4].running_mean 93 | idbn2.running_var.data[self.in_planes:]=self.conv[4].running_var 94 | idrelu2 = nn.PReLU(self.mid_planes) 95 | torch.nn.init.ones_(idrelu2.weight.data[:self.in_planes]) 96 | torch.nn.init.zeros_(idrelu2.weight.data[self.in_planes:]) 97 | 98 | idconv3 = nn.Conv2d(self.mid_planes, self.in_planes*self.free, kernel_size=1, bias=False).eval() 99 | idbn3=nn.BatchNorm2d(self.in_planes*self.free).eval() 100 | 101 | nn.init.dirac_(idconv3.weight.data[:,:self.in_planes]) 102 | idconv3.weight.data[:,self.in_planes:],bias=self.fuse(self.conv[6].weight,self.conv[7].running_mean,self.conv[7].running_var,self.conv[7].weight,self.conv[7].bias,self.conv[7].eps) 103 | bn_var_sqrt=torch.sqrt(self.running2.running_var + self.running2.eps) 104 | idbn3.weight.data=bn_var_sqrt 105 | idbn3.bias.data=self.running2.running_mean 106 | idbn3.running_mean.data=self.running2.running_mean+bias 107 | idbn3.running_var.data=self.running2.running_var 108 | 109 | self.use_res_connect=False 110 | self.running1 = None 111 | self.running2 = None 112 | self.conv=nn.Sequential(*[idconv1,idbn1,idrelu1,idconv2,idbn2,idrelu2,idconv3,idbn3]) 113 | 114 | def fuse(self,conv_w, bn_rm, bn_rv,bn_w,bn_b, eps): 115 | bn_var_rsqrt = torch.rsqrt(bn_rv + eps) 116 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 117 | conv_b = bn_rm * bn_var_rsqrt * bn_w-bn_b 118 | return conv_w,conv_b 119 | 120 | 121 | class RMobileNet(nn.Module): 122 | def __init__(self, setting, input_channel, output_channel, last_channel, t_free=1, n_class=100): 123 | super(RMobileNet, self).__init__() 124 | self.features = [ 125 | nn.Sequential( 126 | nn.Conv2d(3, input_channel, 3, 2 if n_class==1000 else 1, 1, bias=False), 127 | nn.BatchNorm2d(input_channel), 128 | nn.ReLU6(inplace=True) 129 | ) 130 | ] 131 | self.features.append( 132 | nn.Sequential( 133 | # dw 134 | nn.Conv2d(input_channel, input_channel, 3, stride=1, padding=1, groups=input_channel, bias=False), 135 | nn.BatchNorm2d(input_channel), 136 | nn.ReLU6(inplace=True), 137 | # pw-linear 138 | nn.Conv2d(input_channel, output_channel * t_free, 1, 1, 0, bias=False), 139 | nn.BatchNorm2d(output_channel * t_free), 140 | ) 141 | ) 142 | input_channel = output_channel 143 | for t, output_channel, n, s in setting: 144 | for i in range(n): 145 | self.features.append(InvertedResidual(input_channel, output_channel, s, expand_ratio=t,free=t_free)) 146 | input_channel = output_channel 147 | self.features.append( 148 | nn.Sequential( 149 | nn.Conv2d(input_channel * t_free, last_channel, 1, 1, 0, bias=False), 150 | nn.BatchNorm2d(last_channel), 151 | nn.ReLU6(inplace=True) 152 | ) 153 | ) 154 | self.features = nn.Sequential(*self.features) 155 | self.classifier = nn.Sequential( 156 | nn.Dropout(0.2), 157 | nn.Linear(last_channel, n_class), 158 | ) 159 | 160 | self._initialize_weights() 161 | 162 | def forward(self, x): 163 | x = self.features(x) 164 | x = x.mean([2, 3]) 165 | x = self.classifier(x) 166 | return x 167 | def rm(self): 168 | for m in self.features: 169 | if isinstance(m,InvertedResidual): 170 | m.deploy() 171 | return self 172 | 173 | def deploy(self): 174 | self.rm() 175 | features=[] 176 | for m in self.features.modules(): 177 | if isinstance(m,nn.Conv2d) or isinstance(m,nn.BatchNorm2d) or isinstance(m,nn.PReLU) or isinstance(m,nn.ReLU6): 178 | features.append(m) 179 | new_features=[] 180 | while len(features)>3: 181 | if isinstance(features[0],nn.Conv2d) and isinstance(features[1],nn.BatchNorm2d) and isinstance(features[2],nn.Conv2d) and isinstance(features[3],nn.BatchNorm2d): 182 | conv,bn = fuse_cbcb(features[0],features[1],features[2],features[3]) 183 | new_features.append(conv) 184 | new_features.append(bn) 185 | features=features[4:] 186 | else: 187 | new_features.append(features.pop(0)) 188 | new_features+=features 189 | self.features=nn.Sequential(*new_features) 190 | return self 191 | 192 | def _initialize_weights(self): 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 196 | m.weight.data.normal_(0, math.sqrt(2. / n)) 197 | if m.bias is not None: 198 | m.bias.data.zero_() 199 | elif isinstance(m, nn.BatchNorm2d): 200 | if m.weight is not None: 201 | m.weight.data.fill_(1) 202 | if m.bias is not None: 203 | m.bias.data.zero_() 204 | elif isinstance(m, nn.Linear): 205 | n = m.weight.size(1) 206 | m.weight.data.normal_(0, 0.01) 207 | m.bias.data.zero_() 208 | 209 | def mobilenetv1_cifar(n_class=100,width_mult=1,t_free=8): 210 | input_channel = int(32 * width_mult) 211 | output_channel = int(32 * width_mult) 212 | last_channel = 1024 213 | setting =[ 214 | [2,int(32 * width_mult),1,2], 215 | [3,int(32 * width_mult),1,1], 216 | 217 | [4,int(64 * width_mult),1,2], 218 | [3,int(64 * width_mult),1,1], 219 | 220 | [4,int(128 * width_mult),1,2], 221 | [3,int(128 * width_mult),5,1], 222 | 223 | [4,int(256 * width_mult),1,2], 224 | [3,int(256 * width_mult),1,1] 225 | ] 226 | return RMobileNet(setting, input_channel,output_channel,last_channel,t_free,n_class) 227 | 228 | def mobilenetv2_imagenet(n_class=1000,width_mult=1,t_free=1,pretrained=True): 229 | input_channel = int(32 * width_mult) 230 | output_channel = int(16 * width_mult) 231 | last_channel = 1280 232 | setting = [ 233 | [6,int(24 * width_mult),1,2], 234 | [6, int(24 * width_mult), 1, 1], 235 | [6, int(32 * width_mult), 1, 2], 236 | [6, int(32 * width_mult), 2, 1], 237 | [6, int(64 * width_mult), 1, 2], 238 | [6, int(64 * width_mult), 3, 1], 239 | [6, int(96 * width_mult), 1, 1], 240 | [6, int(96 * width_mult), 2, 1], 241 | [6, int(160 * width_mult), 1, 2], 242 | [6, int(160 * width_mult), 2, 1], 243 | [6, int(320 * width_mult), 1, 1] 244 | ] 245 | 246 | model = RMobileNet(setting, input_channel,output_channel,last_channel,t_free,n_class) 247 | if pretrained: 248 | assert t_free==1 249 | pretrain_model=torchvision.models.mobilenet_v2(pretrained=True) 250 | for i in range(1,18): 251 | blocks=[] 252 | for m in pretrain_model.features[i].modules(): 253 | if isinstance(m,nn.Conv2d) or isinstance(m,nn.BatchNorm2d) or isinstance(m,nn.ReLU6): 254 | blocks.append(m) 255 | pretrain_model.features[i].conv=nn.Sequential(*blocks) 256 | pretrain_model.features[1]=pretrain_model.features[1].conv 257 | print(model.load_state_dict(pretrain_model.state_dict(),strict=False)) 258 | return model 259 | -------------------------------------------------------------------------------- /models/rmrepse.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | class SEBlock(nn.Module): 7 | def __init__(self, input_channels, internal_neurons,ratio): 8 | super(SEBlock, self).__init__() 9 | self.input_channels = input_channels 10 | self.internal_neurons=internal_neurons 11 | self.rmplanes=int(input_channels*ratio) 12 | self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True) 13 | self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels-self.rmplanes, kernel_size=1, stride=1, bias=True) 14 | 15 | def forward(self, inputs): 16 | x = F.avg_pool2d(inputs, kernel_size=inputs.size(3)) 17 | x = self.down(x) 18 | x = F.relu(x) 19 | x = self.up(x) 20 | x = torch.sigmoid(x) 21 | x = x.view(-1, self.input_channels-self.rmplanes, 1, 1) 22 | return torch.cat([inputs[:,:self.rmplanes],inputs[:,self.rmplanes:] * x],dim=1) 23 | 24 | def deploy(self): 25 | up = nn.Conv2d(in_channels=self.internal_neurons, out_channels=self.input_channels, kernel_size=1, stride=1, bias=True) 26 | nn.init.zeros_(up.weight.data[:self.rmplanes]) 27 | up.weight.data[self.rmplanes:]=self.up.weight.data 28 | nn.init.constant_(up.bias.data[:self.rmplanes],100) 29 | up.bias.data[self.rmplanes:]=self.up.bias.data 30 | self.rmplanes=0 31 | self.up=up 32 | 33 | class RMBlock(nn.Module): 34 | def __init__(self, in_planes, out_planes, ratio=0.5, stride=1, dilation=1, use_se=False): 35 | 36 | super(RMBlock, self).__init__() 37 | self.in_planes = in_planes 38 | self.out_planes = out_planes 39 | self.stride = stride 40 | self.dilation = dilation 41 | self.use_se=use_se 42 | self.rmplanes=int(out_planes*ratio) 43 | assert not ratio or (in_planes==out_planes and stride==1) 44 | 45 | self.conv33 = nn.Conv2d(in_planes, self.out_planes-self.rmplanes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 46 | self.bn33 = nn.BatchNorm2d(self.out_planes-self.rmplanes) 47 | self.conv11 = nn.Conv2d(in_planes, self.out_planes-self.rmplanes, kernel_size=1, stride=stride, padding=0, bias=False) 48 | self.bn11 = nn.BatchNorm2d(self.out_planes-self.rmplanes) 49 | if self.in_planes==self.out_planes and stride==1: 50 | self.bn00 = nn.BatchNorm2d(self.out_planes-self.rmplanes) 51 | self.se = SEBlock(out_planes,out_planes//16,ratio) if use_se else nn.Sequential() 52 | self.relu = nn.ReLU(True) 53 | 54 | def forward(self, x): 55 | out = self.bn33(self.conv33(x)) 56 | out += self.bn11(self.conv11(x)) 57 | if self.in_planes==self.out_planes and self.stride==1: 58 | out += self.bn00(x[:,self.rmplanes:]) 59 | out = torch.cat([x[:,:self.rmplanes], out],dim=1) 60 | return self.relu(self.se(out)) 61 | 62 | def res2rep(self): 63 | if self.rmplanes: 64 | conv33 = nn.Conv2d(self.in_planes, self.out_planes, kernel_size=3, stride=self.stride, padding=self.dilation, dilation=self.dilation, bias=False) 65 | bn33 = nn.BatchNorm2d(self.out_planes) 66 | conv11 = nn.Conv2d(self.in_planes, self.out_planes, kernel_size=1, stride=self.stride, padding=0, bias=False) 67 | bn11 = nn.BatchNorm2d(self.out_planes) 68 | bn00 = nn.BatchNorm2d(self.out_planes) 69 | 70 | nn.init.zeros_(conv33.weight.data) 71 | conv33.weight.data[self.rmplanes:]=self.conv33.weight.data 72 | bn33.weight.data[self.rmplanes:]=self.bn33.weight.data 73 | bn33.bias.data[self.rmplanes:]=self.bn33.bias.data 74 | bn33.running_mean.data[self.rmplanes:]=self.bn33.running_mean.data 75 | bn33.running_var.data[self.rmplanes:]=self.bn33.running_var.data 76 | 77 | nn.init.zeros_(conv11.weight.data) 78 | conv11.weight.data[self.rmplanes:]=self.conv11.weight.data 79 | bn11.weight.data[self.rmplanes:]=self.bn11.weight.data 80 | bn11.bias.data[self.rmplanes:]=self.bn11.bias.data 81 | bn11.running_mean.data[self.rmplanes:]=self.bn11.running_mean.data 82 | bn11.running_var.data[self.rmplanes:]=self.bn11.running_var.data 83 | 84 | 85 | bn00.weight.data[self.rmplanes:]=self.bn00.weight.data 86 | bn00.bias.data[self.rmplanes:]=self.bn00.bias.data 87 | bn00.running_mean.data[self.rmplanes:]=self.bn00.running_mean.data 88 | bn00.running_var.data[self.rmplanes:]=self.bn00.running_var.data 89 | 90 | self.conv33=conv33 91 | self.bn33=bn33 92 | self.conv11=conv11 93 | self.bn11=bn11 94 | self.bn00=bn00 95 | if self.use_se: 96 | self.se.deploy() 97 | self.rmplanes=0 98 | return self 99 | 100 | def rep2vgg(self): 101 | if self.rmplanes: 102 | self.res2rep() 103 | self.eval() 104 | 105 | conv33_bn33 = torch.nn.utils.fuse_conv_bn_eval(self.conv33, self.bn33).eval() 106 | conv11_bn11 = torch.nn.utils.fuse_conv_bn_eval(self.conv11, self.bn11).eval() 107 | conv33_bn33.weight.data += F.pad(conv11_bn11.weight.data, [1, 1, 1, 1]) 108 | conv33_bn33.bias.data += conv11_bn11.bias.data 109 | if self.in_planes == self.out_planes and self.stride == 1: 110 | conv00 = nn.Conv2d(self.in_planes, self.out_planes, kernel_size=3, padding=1, dilation=self.dilation, bias=False).eval() 111 | nn.init.dirac_(conv00.weight.data) 112 | conv00_bn00 = torch.nn.utils.fuse_conv_bn_eval(conv00, self.bn00) 113 | conv33_bn33.weight.data += conv00_bn00.weight.data 114 | conv33_bn33.bias.data += conv00_bn00.bias.data 115 | if self.use_se: 116 | return [conv33_bn33,self.se,nn.ReLU(True)] 117 | else: 118 | return [conv33_bn33,nn.ReLU(True)] 119 | 120 | class RMRepSE(nn.Module): 121 | def __init__(self, num_blocks, num_classes=1000, base_wide=64,use_se=False): 122 | super(RMRepSE, self).__init__() 123 | feature=[] 124 | in_planes=3 125 | for t,stride,dilation,ratio in num_blocks: 126 | out_planes=int(t*base_wide) 127 | feature.append(RMBlock(in_planes, out_planes, ratio, stride, dilation, use_se)) 128 | in_planes=out_planes 129 | 130 | self.feature=nn.Sequential(*feature) 131 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 132 | self.flat = nn.Flatten(start_dim=1) 133 | self.fc = nn.Linear(out_planes, num_classes) 134 | 135 | def forward(self, x): 136 | out = self.feature(x) 137 | out = self.gap(out) 138 | out = self.flat(out) 139 | out = self.fc(out) 140 | return out 141 | 142 | def res2rep(self): 143 | for m in self.feature: 144 | if isinstance(m,RMBlock): 145 | m.res2rep() 146 | 147 | def deploy(self): 148 | blocks=[] 149 | for m in self.feature: 150 | if isinstance(m,RMBlock): 151 | blocks+=m.rep2vgg() 152 | blocks.append(self.gap) 153 | blocks.append(self.flat) 154 | blocks.append(self.fc) 155 | return nn.Sequential(*blocks) 156 | 157 | def rmrepse(num_classes=1000, ratio=0, use_se=False): 158 | return RMRepSE([[1,2,1,0]]*1+ 159 | [[2.5,2,1,0],[2.5,1,1,0]]*1+ 160 | [[2.5,1,1,ratio],[2.5,1,1,0]]*1+ 161 | [[5, 2,1,0],[5, 1,1,0]]*1+ 162 | [[5, 1,1,ratio],[5, 1,1,0]]*2+ 163 | [[10, 2,1,0],[10, 1,1,0]]*1+ 164 | [[10, 1,1,ratio],[10, 1,1,0]]*7+ 165 | [[40, 2,1,0]], 166 | num_classes=num_classes,use_se=use_se) 167 | -------------------------------------------------------------------------------- /models/rmrepvgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | class RepBlock(nn.Module): 7 | def __init__(self, in_planes, out_planes, stride=1): 8 | 9 | super(RepBlock, self).__init__() 10 | 11 | self.in_planes = in_planes 12 | self.out_planes = out_planes 13 | self.stride = stride 14 | 15 | self.conv33 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn33 = nn.BatchNorm2d(out_planes) 17 | self.conv11 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 18 | self.bn11 = nn.BatchNorm2d(out_planes) 19 | if self.in_planes == self.out_planes and self.stride == 1: 20 | self.bn00 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | out = self.bn33(self.conv33(x)) 24 | out += self.bn11(self.conv11(x)) 25 | if self.in_planes == self.out_planes and self.stride == 1: 26 | out += self.bn00(x) 27 | return F.relu(out) 28 | 29 | def deploy(self): 30 | self.eval() 31 | conv33_bn33 = torch.nn.utils.fuse_conv_bn_eval(self.conv33, self.bn33).eval() 32 | conv11_bn11 = torch.nn.utils.fuse_conv_bn_eval(self.conv11, self.bn11).eval() 33 | conv33_bn33.weight.data += F.pad(conv11_bn11.weight.data, [1, 1, 1, 1]) 34 | conv33_bn33.bias.data += conv11_bn11.bias.data 35 | if self.in_planes == self.out_planes and self.stride == 1: 36 | conv00 = nn.Conv2d(self.in_planes, self.out_planes, kernel_size=3, padding=1, bias=False).eval() 37 | nn.init.dirac_(conv00.weight.data) 38 | conv00_bn00 = torch.nn.utils.fuse_conv_bn_eval(conv00, self.bn00) 39 | conv33_bn33.weight.data += conv00_bn00.weight.data 40 | conv33_bn33.bias.data += conv00_bn00.bias.data 41 | return [conv33_bn33,nn.ReLU(True)] 42 | 43 | class RMBlock(nn.Module): 44 | def __init__(self, planes, ratio=0.5): 45 | 46 | super(RMBlock, self).__init__() 47 | self.planes = planes 48 | self.rmplanes=int(planes*ratio) 49 | 50 | self.conv33 = nn.Conv2d(planes, self.planes-self.rmplanes, kernel_size=3, padding=1, bias=False) 51 | self.bn33 = nn.BatchNorm2d(self.planes-self.rmplanes) 52 | self.conv11 = nn.Conv2d(planes, self.planes-self.rmplanes, kernel_size=1, padding=0, bias=False) 53 | self.bn11 = nn.BatchNorm2d(self.planes-self.rmplanes) 54 | self.bn00 = nn.BatchNorm2d(self.planes-self.rmplanes) 55 | 56 | def forward(self, x): 57 | out = self.bn33(self.conv33(x)) 58 | out += self.bn11(self.conv11(x)) 59 | out += self.bn00(x[:,self.rmplanes:]) 60 | return F.relu(torch.cat([x[:,:self.rmplanes],out],dim=1)) 61 | 62 | def deploy(self): 63 | self.eval() 64 | conv33=nn.utils.fuse_conv_bn_eval(self.conv33,self.bn33) 65 | conv11=nn.utils.fuse_conv_bn_eval(self.conv11,self.bn11) 66 | conv00=nn.Conv2d(self.planes,self.planes-self.rmplanes,kernel_size=3,padding=1,bias=False).eval() 67 | nn.init.zeros_(conv00.weight.data[:,:self.rmplanes]) 68 | nn.init.dirac_(conv00.weight.data[:,self.rmplanes:]) 69 | conv00=nn.utils.fuse_conv_bn_eval(conv00,self.bn00) 70 | conv3=nn.Conv2d(self.planes,self.planes,kernel_size=3,padding=1) 71 | conv1=nn.Conv2d(self.planes,self.planes,kernel_size=1) 72 | conv0=nn.Conv2d(self.planes,self.planes,kernel_size=3,padding=1) 73 | nn.init.zeros_(conv3.weight.data[:self.rmplanes]) 74 | nn.init.zeros_(conv1.weight.data[:self.rmplanes]) 75 | nn.init.dirac_(conv0.weight.data[:self.rmplanes]) 76 | nn.init.zeros_(conv3.bias.data[:self.rmplanes]) 77 | nn.init.zeros_(conv1.bias.data[:self.rmplanes]) 78 | nn.init.zeros_(conv0.bias.data[:self.rmplanes]) 79 | conv3.weight.data[self.rmplanes:]=conv33.weight.data 80 | conv1.weight.data[self.rmplanes:]=conv11.weight.data 81 | conv0.weight.data[self.rmplanes:]=conv00.weight.data 82 | conv3.bias.data[self.rmplanes:]=conv33.bias.data 83 | conv1.bias.data[self.rmplanes:]=conv11.bias.data 84 | conv0.bias.data[self.rmplanes:]=conv00.bias.data 85 | 86 | conv3.weight.data += F.pad(conv1.weight.data, [1, 1, 1, 1]) 87 | conv3.bias.data += conv1.bias.data 88 | conv3.weight.data += conv0.weight.data 89 | conv3.bias.data += conv0.bias.data 90 | return [conv3,nn.ReLU(True)] 91 | 92 | 93 | class RMRep(nn.Module): 94 | def __init__(self, num_blocks, num_classes=1000, base_wide=64,ratio=0.5): 95 | super(RMRep, self).__init__() 96 | feature=[] 97 | in_planes=3 98 | for b,t,s,n in num_blocks: 99 | out_planes=t*base_wide 100 | for i in range(n): 101 | if b=='rm_rep': 102 | feature.append(RMBlock(out_planes,ratio)) 103 | feature.append(RepBlock(in_planes,out_planes,s)) 104 | in_planes=out_planes 105 | 106 | self.feature=nn.Sequential(*feature) 107 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 108 | self.flat = nn.Flatten(start_dim=1) 109 | self.fc = nn.Linear(base_wide*num_blocks[-1][1], num_classes) 110 | 111 | def forward(self, x): 112 | out = self.feature(x) 113 | out = self.gap(out) 114 | out = self.flat(out) 115 | out = self.fc(out) 116 | return out 117 | 118 | def deploy(self): 119 | blocks=[] 120 | for m in self.feature: 121 | if isinstance(m,RepBlock) or isinstance(m,RMBlock): 122 | blocks+=m.deploy() 123 | blocks.append(self.gap) 124 | blocks.append(self.flat) 125 | blocks.append(self.fc) 126 | return nn.Sequential(*blocks) 127 | 128 | 129 | 130 | def repvgg_21(num_classes=1000,depth=2): 131 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 132 | ['rm_rep',1,1,depth], 133 | ['rep_rep',2,2,1], 134 | ['rm_rep',2,1,depth], 135 | ['rep',4,2,1], 136 | ['rm_rep',4,1,depth], 137 | ['rep',8,2,1], 138 | ['rm_rep',8,1,depth]], 139 | num_classes=num_classes,ratio=0) 140 | 141 | def repvgg_37(num_classes=1000,depth=4): 142 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 143 | ['rm_rep',1,1,depth], 144 | ['rep_rep',2,2,1], 145 | ['rm_rep',2,1,depth], 146 | ['rep',4,2,1], 147 | ['rm_rep',4,1,depth], 148 | ['rep',8,2,1], 149 | ['rm_rep',8,1,depth]], 150 | num_classes=num_classes,ratio=0) 151 | 152 | def repvgg_69(num_classes=1000,depth=8): 153 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 154 | ['rm_rep',1,1,depth], 155 | ['rep_rep',2,2,1], 156 | ['rm_rep',2,1,depth], 157 | ['rep',4,2,1], 158 | ['rm_rep',4,1,depth], 159 | ['rep',8,2,1], 160 | ['rm_rep',8,1,depth]], 161 | num_classes=num_classes,ratio=0) 162 | 163 | def repvgg_133(num_classes=1000,depth=16): 164 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 165 | ['rm_rep',1,1,depth], 166 | ['rep_rep',2,2,1], 167 | ['rm_rep',2,1,depth], 168 | ['rep',4,2,1], 169 | ['rm_rep',4,1,depth], 170 | ['rep',8,2,1], 171 | ['rm_rep',8,1,depth]], 172 | num_classes=num_classes,ratio=0) 173 | 174 | 175 | 176 | def rmrep_21(num_classes=1000,depth=2): 177 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 178 | ['rm_rep',1,1,depth], 179 | ['rep_rep',2,2,1], 180 | ['rm_rep',2,1,depth], 181 | ['rep',4,2,1], 182 | ['rm_rep',4,1,depth], 183 | ['rep',8,2,1], 184 | ['rm_rep',8,1,depth]], 185 | num_classes=num_classes,ratio=0.25) 186 | 187 | def rmrep_37(num_classes=1000,depth=4): 188 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 189 | ['rm_rep',1,1,depth], 190 | ['rep_rep',2,2,1], 191 | ['rm_rep',2,1,depth], 192 | ['rep',4,2,1], 193 | ['rm_rep',4,1,depth], 194 | ['rep',8,2,1], 195 | ['rm_rep',8,1,depth]], 196 | num_classes=num_classes,ratio=0.25) 197 | 198 | def rmrep_69(num_classes=1000,depth=8): 199 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 200 | ['rm_rep',1,1,depth], 201 | ['rep_rep',2,2,1], 202 | ['rm_rep',2,1,depth], 203 | ['rep',4,2,1], 204 | ['rm_rep',4,1,depth], 205 | ['rep',8,2,1], 206 | ['rm_rep',8,1,depth]], 207 | num_classes=num_classes,ratio=0.5) 208 | 209 | def rmrep_133(num_classes=1000,depth=16): 210 | return RMRep([['rep',1,2,2] if num_classes==1000 else ['rep',1,1,1], 211 | ['rm_rep',1,1,depth], 212 | ['rep_rep',2,2,1], 213 | ['rm_rep',2,1,depth], 214 | ['rep',4,2,1], 215 | ['rm_rep',4,1,depth], 216 | ['rep',8,2,1], 217 | ['rm_rep',8,1,depth]], 218 | num_classes=num_classes,ratio=0.75) 219 | 220 | 221 | -------------------------------------------------------------------------------- /ppt_souce_files/RMNet.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/RMNet.pptx -------------------------------------------------------------------------------- /ppt_souce_files/downsample_prelu.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/downsample_prelu.pptx -------------------------------------------------------------------------------- /ppt_souce_files/downsample_relu.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/downsample_relu.pptx -------------------------------------------------------------------------------- /ppt_souce_files/improving_repvgg.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/improving_repvgg.pptx -------------------------------------------------------------------------------- /ppt_souce_files/mobilenetv2_1.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/mobilenetv2_1.pptx -------------------------------------------------------------------------------- /ppt_souce_files/prune.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/prune.pptx -------------------------------------------------------------------------------- /ppt_souce_files/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ppt_souce_files/repvgg_vs_resnet.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/repvgg_vs_resnet.pptx -------------------------------------------------------------------------------- /ppt_souce_files/resnet.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/resnet.pptx -------------------------------------------------------------------------------- /ppt_souce_files/resnet_vs_repvgg.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxmeng/RMNet/4754b8e8a0c8b0f6e4e0c8c77f6cde47712c8feb/ppt_souce_files/resnet_vs_repvgg.pptx -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | from torch.optim.lr_scheduler import CosineAnnealingLR 20 | from utils import AverageMeter, accuracy, ProgressMeter 21 | 22 | import models 23 | 24 | IMAGENET_TRAINSET_SIZE = 1281167 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0') 30 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=256, type=int, 37 | metavar='N', 38 | help='mini-batch size (default: 256), this is the total ' 39 | 'batch size of all GPUs on the current node when ' 40 | 'using Data Parallel or Distributed Data Parallel') 41 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 42 | metavar='LR', help='initial learning rate', dest='lr') 43 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 44 | help='momentum') 45 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 46 | metavar='W', help='weight decay (default: 1e-4)', 47 | dest='weight_decay') 48 | parser.add_argument('-p', '--print-freq', default=10, type=int, 49 | metavar='N', help='print frequency (default: 10)') 50 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 51 | help='path to latest checkpoint (default: none)') 52 | parser.add_argument('-e', '--evaluate', type=str, default=None, 53 | help='evaluate model on validation set') 54 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 55 | help='use pre-trained model') 56 | parser.add_argument('--world-size', default=-1, type=int, 57 | help='number of nodes for distributed training') 58 | parser.add_argument('--rank', default=-1, type=int, 59 | help='node rank for distributed training') 60 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 61 | help='url used to set up distributed training') 62 | parser.add_argument('--dist-backend', default='nccl', type=str, 63 | help='distributed backend') 64 | parser.add_argument('--seed', default=None, type=int, 65 | help='seed for initializing training. ') 66 | parser.add_argument('--gpu', default=None, type=int, 67 | help='GPU id to use.') 68 | parser.add_argument('--multiprocessing-distributed', action='store_true', 69 | help='Use multi-processing distributed training to launch ' 70 | 'N processes per node, which has N GPUs. This is the ' 71 | 'fastest way to use PyTorch for either single node or ' 72 | 'multi node data parallel training') 73 | best_acc1 = 0 74 | 75 | def sgd_optimizer(model, lr, momentum, weight_decay): 76 | params = [] 77 | for key, value in model.named_parameters(): 78 | if not value.requires_grad: 79 | continue 80 | apply_weight_decay = weight_decay 81 | apply_lr = lr 82 | if 'bias' in key or 'bn' in key: 83 | apply_weight_decay = 0 84 | print('set weight decay=0 for {}'.format(key)) 85 | if 'bias' in key: 86 | apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference. 87 | params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}] 88 | optimizer = torch.optim.SGD(params, lr, momentum=momentum) 89 | return optimizer 90 | 91 | def main(): 92 | args = parser.parse_args() 93 | 94 | if args.seed is not None: 95 | random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | cudnn.deterministic = True 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | if args.dist_url == "env://" and args.world_size == -1: 109 | args.world_size = int(os.environ["WORLD_SIZE"]) 110 | 111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 112 | 113 | ngpus_per_node = torch.cuda.device_count() 114 | if args.multiprocessing_distributed: 115 | # Since we have ngpus_per_node processes per node, the total world_size 116 | # needs to be adjusted accordingly 117 | args.world_size = ngpus_per_node * args.world_size 118 | # Use torch.multiprocessing.spawn to launch distributed processes: the 119 | # main_worker process function 120 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 121 | else: 122 | # Simply call main_worker function 123 | main_worker(args.gpu, ngpus_per_node, args) 124 | 125 | 126 | def main_worker(gpu, ngpus_per_node, args): 127 | global best_acc1 128 | args.gpu = gpu 129 | 130 | if args.gpu is not None: 131 | print("Use GPU: {} for training".format(args.gpu)) 132 | 133 | if args.distributed: 134 | if args.dist_url == "env://" and args.rank == -1: 135 | args.rank = int(os.environ["RANK"]) 136 | if args.multiprocessing_distributed: 137 | # For multiprocessing distributed training, rank needs to be the 138 | # global rank among all the processes 139 | args.rank = args.rank * ngpus_per_node + gpu 140 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 141 | world_size=args.world_size, rank=args.rank) 142 | 143 | model = models.__dict__[args.arch]() 144 | 145 | if not torch.cuda.is_available(): 146 | print('using CPU, this will be slow') 147 | elif args.distributed: 148 | # For multiprocessing distributed, DistributedDataParallel constructor 149 | # should always set the single device scope, otherwise, 150 | # DistributedDataParallel will use all available devices. 151 | if args.gpu is not None: 152 | torch.cuda.set_device(args.gpu) 153 | model.cuda(args.gpu) 154 | # When using a single GPU per process and per 155 | # DistributedDataParallel, we need to divide the batch size 156 | # ourselves based on the total number of GPUs we have 157 | args.batch_size = int(args.batch_size / ngpus_per_node) 158 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 159 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 160 | else: 161 | model.cuda() 162 | # DistributedDataParallel will divide and allocate batch_size to all 163 | # available GPUs if device_ids are not set 164 | model = torch.nn.parallel.DistributedDataParallel(model) 165 | elif args.gpu is not None: 166 | torch.cuda.set_device(args.gpu) 167 | model = model.cuda(args.gpu) 168 | else: 169 | # DataParallel will divide and allocate batch_size to all available GPUs 170 | model = torch.nn.DataParallel(model).cuda() 171 | 172 | # define loss function (criterion) and optimizer 173 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 174 | 175 | # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 176 | optimizer = sgd_optimizer(model, args.lr, args.momentum, args.weight_decay) # better for repvgg 177 | 178 | lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node) 179 | 180 | # optionally resume from a checkpoint 181 | if args.resume: 182 | if os.path.isfile(args.resume): 183 | print("=> loading checkpoint '{}'".format(args.resume)) 184 | if args.gpu is None: 185 | checkpoint = torch.load(args.resume) 186 | else: 187 | # Map model to be loaded to specified single gpu. 188 | loc = 'cuda:{}'.format(args.gpu) 189 | checkpoint = torch.load(args.resume, map_location=loc) 190 | args.start_epoch = checkpoint['epoch'] 191 | best_acc1 = checkpoint['best_acc1'] 192 | if args.gpu is not None: 193 | # best_acc1 may be from a checkpoint from a different GPU 194 | best_acc1 = best_acc1.to(args.gpu) 195 | model.load_state_dict(checkpoint['state_dict']) 196 | optimizer.load_state_dict(checkpoint['optimizer']) 197 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 198 | print("=> loaded checkpoint '{}' (epoch {})" 199 | .format(args.resume, checkpoint['epoch'])) 200 | else: 201 | print("=> no checkpoint found at '{}'".format(args.resume)) 202 | 203 | cudnn.benchmark = True 204 | 205 | # Data loading code 206 | traindir = os.path.join(args.data, 'train') 207 | valdir = os.path.join(args.data, 'val') 208 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 209 | std=[0.229, 0.224, 0.225]) 210 | 211 | train_dataset = datasets.ImageFolder( 212 | traindir, 213 | transforms.Compose([ 214 | transforms.RandomResizedCrop(224), 215 | transforms.RandomHorizontalFlip(), 216 | transforms.ToTensor(), 217 | normalize, 218 | ])) 219 | 220 | if args.distributed: 221 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 222 | else: 223 | train_sampler = None 224 | 225 | train_loader = torch.utils.data.DataLoader( 226 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 227 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 228 | 229 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ 230 | transforms.Resize(256), 231 | transforms.CenterCrop(224), 232 | transforms.ToTensor(), 233 | normalize, 234 | ])) 235 | val_loader = torch.utils.data.DataLoader( 236 | val_dataset, 237 | batch_size=args.batch_size, shuffle=False, 238 | num_workers=args.workers, pin_memory=True) 239 | 240 | if args.evaluate is not None: 241 | checkpoint = torch.load(args.evaluate) 242 | model.load_state_dict(checkpoint["state_dict"]) 243 | model = model.module.cpu().deploy() 244 | model = torch.nn.DataParallel(model).cuda() 245 | print(model) 246 | validate(val_loader, model, criterion, args) 247 | return 248 | 249 | for epoch in range(args.start_epoch, args.epochs): 250 | if args.distributed: 251 | train_sampler.set_epoch(epoch) 252 | 253 | # train for one epoch 254 | train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler) 255 | 256 | # evaluate on validation set 257 | acc1 = validate(val_loader, model, criterion, args) 258 | 259 | # remember best acc@1 and save checkpoint 260 | is_best = acc1 > best_acc1 261 | best_acc1 = max(acc1, best_acc1) 262 | 263 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 264 | and args.rank % ngpus_per_node == 0): 265 | save_checkpoint({ 266 | 'epoch': epoch + 1, 267 | 'arch': args.arch, 268 | 'state_dict': model.state_dict(), 269 | 'best_acc1': best_acc1, 270 | 'optimizer': optimizer.state_dict(), 271 | 'scheduler': lr_scheduler.state_dict(), 272 | }, is_best) 273 | 274 | 275 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler): 276 | batch_time = AverageMeter('Time', ':6.3f') 277 | data_time = AverageMeter('Data', ':6.3f') 278 | losses = AverageMeter('Loss', ':.4e') 279 | top1 = AverageMeter('Acc@1', ':6.2f') 280 | top5 = AverageMeter('Acc@5', ':6.2f') 281 | progress = ProgressMeter( 282 | len(train_loader), 283 | [batch_time, data_time, losses, top1, top5, ], 284 | prefix="Epoch: [{}]".format(epoch)) 285 | 286 | # switch to train mode 287 | model.train() 288 | 289 | end = time.time() 290 | for i, (images, target) in enumerate(train_loader): 291 | # measure data loading time 292 | data_time.update(time.time() - end) 293 | 294 | if args.gpu is not None: 295 | images = images.cuda(args.gpu, non_blocking=True) 296 | if torch.cuda.is_available(): 297 | target = target.cuda(args.gpu, non_blocking=True) 298 | 299 | # compute output 300 | output = model(images) 301 | loss = criterion(output, target) 302 | 303 | # measure accuracy and record loss 304 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 305 | losses.update(loss.item(), images.size(0)) 306 | top1.update(acc1[0], images.size(0)) 307 | top5.update(acc5[0], images.size(0)) 308 | 309 | # compute gradient and do SGD step 310 | optimizer.zero_grad() 311 | loss.backward() 312 | optimizer.step() 313 | 314 | # measure elapsed time 315 | batch_time.update(time.time() - end) 316 | end = time.time() 317 | 318 | lr_scheduler.step() 319 | 320 | if i % args.print_freq == 0: 321 | progress.display(i) 322 | if i % 1000 == 0: 323 | print('cur lr: ', lr_scheduler.get_lr()[0]) 324 | 325 | 326 | def validate(val_loader, model, criterion, args): 327 | batch_time = AverageMeter('Time', ':6.3f') 328 | losses = AverageMeter('Loss', ':.4e') 329 | top1 = AverageMeter('Acc@1', ':6.2f') 330 | top5 = AverageMeter('Acc@5', ':6.2f') 331 | progress = ProgressMeter( 332 | len(val_loader), 333 | [batch_time, losses, top1, top5], 334 | prefix='Test: ') 335 | 336 | # switch to evaluate mode 337 | model.eval() 338 | 339 | with torch.no_grad(): 340 | end = time.time() 341 | for i, (images, target) in enumerate(val_loader): 342 | if args.gpu is not None: 343 | images = images.cuda(args.gpu, non_blocking=True) 344 | if torch.cuda.is_available(): 345 | target = target.cuda(args.gpu, non_blocking=True) 346 | 347 | # compute output 348 | output = model(images) 349 | loss = criterion(output, target) 350 | 351 | # measure accuracy and record loss 352 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 353 | losses.update(loss.item(), images.size(0)) 354 | top1.update(acc1[0], images.size(0)) 355 | top5.update(acc5[0], images.size(0)) 356 | 357 | # measure elapsed time 358 | batch_time.update(time.time() - end) 359 | end = time.time() 360 | 361 | if i % args.print_freq == 0: 362 | progress.display(i) 363 | 364 | # TODO: this should also be done with the ProgressMeter 365 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 366 | .format(top1=top1, top5=top5)) 367 | 368 | return top1.avg 369 | 370 | 371 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 372 | torch.save(state, filename) 373 | if is_best: 374 | shutil.copyfile(filename, 'model_best.pth.tar') 375 | 376 | 377 | if __name__ == '__main__': 378 | main() 379 | 380 | -------------------------------------------------------------------------------- /train_pruning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import os 9 | import argparse 10 | import models 11 | import thop 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 14 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 15 | parser.add_argument('--sr', default=0, type=float, help='learning rate') 16 | parser.add_argument('--threshold', default=0, type=float, help='learning rate') 17 | parser.add_argument('--finetune', type=str) 18 | parser.add_argument('--debn', action='store_true',default=False) 19 | parser.add_argument('--eval', type=str) 20 | 21 | args = parser.parse_args() 22 | 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | best_acc = 0 # best test accuracy 25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 26 | 27 | # Data 28 | print('==> Preparing data..') 29 | transform_train = transforms.Compose([ 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 34 | ]) 35 | 36 | transform_test = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 39 | ]) 40 | 41 | trainset = torchvision.datasets.CIFAR10( 42 | root='/dev/shm', train=True, download=True, transform=transform_train) 43 | trainloader = torch.utils.data.DataLoader( 44 | trainset, batch_size=128, shuffle=True, num_workers=2) 45 | 46 | testset = torchvision.datasets.CIFAR10( 47 | root='/dev/shm', train=False, download=True, transform=transform_test) 48 | testloader = torch.utils.data.DataLoader( 49 | testset, batch_size=100, shuffle=False, num_workers=2) 50 | 51 | # Model 52 | print('==> Building model..') 53 | net = models.rmnet_pruning_18(10).to(device) 54 | if args.sr*args.threshold==0: 55 | net.fix_mask() 56 | if args.finetune or args.eval: 57 | if args.finetune: 58 | ckpt=torch.load(args.finetune) 59 | else: 60 | ckpt=torch.load(args.eval) 61 | net.load_state_dict(ckpt) 62 | net=net.cpu().prune(not args.debn).cuda() 63 | print(net) 64 | 65 | criterion = nn.CrossEntropyLoss() 66 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 67 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 68 | 69 | 70 | # Training 71 | def train(epoch): 72 | net.train() 73 | train_loss = 0 74 | correct = 0 75 | total = 0 76 | for batch_idx, (inputs, targets) in enumerate(trainloader): 77 | inputs, targets = inputs.to(device), targets.to(device) 78 | optimizer.zero_grad() 79 | outputs = net(inputs) 80 | loss = criterion(outputs, targets) 81 | loss.backward() 82 | if args.sr*args.threshold>0 and not args.finetune: 83 | net.update_mask(args.sr,args.threshold) 84 | optimizer.step() 85 | 86 | train_loss += loss.item() 87 | _, predicted = outputs.max(1) 88 | total += targets.size(0) 89 | correct += predicted.eq(targets).sum().item() 90 | 91 | 92 | def test(epoch): 93 | global best_acc 94 | net.eval() 95 | test_loss = 0 96 | correct = 0 97 | total = 0 98 | with torch.no_grad(): 99 | for batch_idx, (inputs, targets) in enumerate(testloader): 100 | inputs, targets = inputs.to(device), targets.to(device) 101 | outputs = net(inputs) 102 | loss = criterion(outputs, targets) 103 | 104 | test_loss += loss.item() 105 | _, predicted = outputs.max(1) 106 | total += targets.size(0) 107 | correct += predicted.eq(targets).sum().item() 108 | 109 | print('Epoch: %d Acc: %.3f%%' %(epoch, 100.*correct/total)) 110 | 111 | # Save checkpoint. 112 | acc = 100.*correct/total 113 | if acc > best_acc: 114 | print('Saving..') 115 | if args.finetune: 116 | save_dir= args.finetune.replace('ckpt','finetune_lr%f'%args.lr) 117 | else: 118 | save_dir='./lr_%f_sr_%f_thres_%f'%( args.lr, args.sr,args.threshold) 119 | if not os.path.isdir(save_dir): 120 | os.mkdir(save_dir) 121 | save_dir+='/ckpt.pth' 122 | torch.save(net.state_dict(), save_dir) 123 | best_acc = acc 124 | 125 | if args.eval: 126 | best_acc=100 127 | test(0) 128 | flops,params=thop.profile(net,(torch.randn(1,3,224,224).to(device),)) 129 | print('flops:%.2fM,\tparams:%.2fM'%(flops/1e6,params/1e6)) 130 | else: 131 | for epoch in range(start_epoch, start_epoch+200): 132 | train(epoch) 133 | test(epoch) 134 | scheduler.step() 135 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value""" 5 | def __init__(self, name, fmt=':f'): 6 | self.name = name 7 | self.fmt = fmt 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | def __str__(self): 23 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 24 | return fmtstr.format(**self.__dict__) 25 | 26 | 27 | class ProgressMeter(object): 28 | def __init__(self, num_batches, meters, prefix=""): 29 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 30 | self.meters = meters 31 | self.prefix = prefix 32 | 33 | def display(self, batch): 34 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 35 | entries += [str(meter) for meter in self.meters] 36 | print('\t'.join(entries)) 37 | 38 | def _get_batch_fmtstr(self, num_batches): 39 | num_digits = len(str(num_batches // 1)) 40 | fmt = '{:' + str(num_digits) + 'd}' 41 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 42 | 43 | 44 | def accuracy(output, target, topk=(1,)): 45 | """Computes the accuracy over the k top predictions for the specified values of k""" 46 | with torch.no_grad(): 47 | maxk = max(topk) 48 | batch_size = target.size(0) 49 | 50 | _, pred = output.topk(maxk, 1, True, True) 51 | pred = pred.t() 52 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 53 | 54 | res = [] 55 | for k in topk: 56 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 57 | res.append(correct_k.mul_(100.0 / batch_size)) 58 | return res --------------------------------------------------------------------------------