├── LICENSE ├── README.md ├── cifar_train_eval.py ├── doc └── tensorboard.png ├── imagenet_dali_loader.py ├── imagenet_torch_loader.py ├── models ├── __init__.py ├── model_utils │ ├── __init__.py │ ├── bn_fuse.py │ └── quant_dorefa.py ├── resnet_cifar.py ├── resnet_imagenet.py └── test_fused_quant_model.py └── utils ├── bar_show.py └── preprocess.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jzz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dorefa-net 2 | A pytorch implementation of [dorefa](https://arxiv.org/abs/1606.06160).The code is inspired by [LaVieEnRoseSMZ](https://github.com/LaVieEnRoseSMZ/AutoBNN) and [zzzxxxttt](https://github.com/kuangliu/pytorch-cifar). 3 | 4 | ## Requirements 5 | * python > 3.5 6 | * torch >= 1.1.0 7 | * torchvision >= 0.4.0 8 | * tb-nightly, future (for tensorboard) 9 | * nvidia-dali >= 0.12 (faster [dataloader](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/index.html#)) 10 | 11 | ## Cifar-10 Accuracy 12 | 13 | Quantized model are trained from scratch 14 | 15 | | Model | W_bit | A_bit | Acc | 16 | | :-: | :-: | :-: |:-: | 17 | | resnet-18 | 32 | 32 | 94.71% | 18 | | resnet-18 | 4 | 4 | 94.36% | 19 | | resnet-18 | 1 | 4 | 93.87% | 20 | 21 | 22 | ## ImageNet Accuracy 23 | 24 | Quantized model are trained from scratch 25 | 26 | | Model | W_bit | A_bit | Top1 |Top5 | 27 | | :-: | :-: | :-: |:-: |:-: | 28 | | resnet-18 | 32 | 32 | 69.80% |89.32% | 29 | | resnet-18 | 4 | 4 | 66.60% |87.15% | 30 | 31 | ## Usages 32 | Download the ImageNet dataset and move validation images to labeled subfolders.To do this, you can use the following [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) 33 | - To train the model 34 | ``` 35 | python3 cifar_train_eval.py 36 | python3 imagenet_torch_loader --multiprocessing-distributed or python3 imagenet_dali_loader.py 37 | ``` 38 | - To check the tensorboard log 39 | ``` 40 | tensorboard --logdir='your_log_dir' 41 | ``` 42 | 43 | then navigating to https://localhost:6006 . 44 | 45 | - To test the quantized model and bn fused 46 | - convert to the quantized model for inference 47 | ``` 48 | python3 test_fused_quant_model.py 49 | ``` 50 | - test bn fuse on the float model 51 | ``` 52 | python3 bn_fuse.py 53 | ``` 54 | Obviously, this fusion method is not suitable for quantized models. We will change the bn fuse in the future according to the [paper](https://arxiv.org/pdf/1806.08342.pdf) section 3.2.2. 55 | 56 | This bn fuse test result is not serious. However, it is OK to explain the problem qualitatively. 57 | 58 | 59 | | Model on CPU | before fuse | after fuse | 60 | | :-: | :-: | :-: | 61 | | resnet-18 | 0.74 s | 0.51 s | 62 | | resnet-34 | 1.41 s | 0.92 s | 63 | | resnet-50 | 1.96 s | 1.02 s | 64 | 65 | 66 | ## To do 67 | - [x] Train on imagenet2012 68 | - [x] Fold bn 69 | - [x] Test speedup from quantization and bn fold 70 | - [ ] Deploy models to embedded devices 71 | - [ ] ... 72 | -------------------------------------------------------------------------------- /cifar_train_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from datetime import datetime 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | cudnn.benchmark = True 12 | import torchvision 13 | 14 | from models.resnet_cifar import * 15 | from utils.preprocess import * 16 | from utils.bar_show import progress_bar 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='dorefa-net implementation') 20 | 21 | parser.add_argument('--root_dir', type=str, default='./') 22 | parser.add_argument('--data_dir', type=str, default='./data') 23 | parser.add_argument('--log_name', type=str, default='resnet_8w8f_cifar') 24 | parser.add_argument('--pretrain', action='store_true', default=False) 25 | parser.add_argument('--pretrain_dir', type=str, default='resnet_8w8f_cifar') 26 | 27 | parser.add_argument('--cifar', type=int, default=10) 28 | parser.add_argument('--lr', type=float, default=0.1) 29 | parser.add_argument('--wd', type=float, default=1e-4) 30 | parser.add_argument('--train_batch_size', type=int, default=256) 31 | parser.add_argument('--eval_batch_size', type=int, default=100) 32 | parser.add_argument('--max_epochs', type=int, default=250) 33 | parser.add_argument('--log_interval', type=int, default=40) 34 | parser.add_argument('--num_workers', type=int, default=2) 35 | parser.add_argument('--Wbits', type=int, default=8) 36 | parser.add_argument('--Abits', type=int, default=8) 37 | 38 | cfg = parser.parse_args() 39 | 40 | best_acc = 0 # best test accuracy 41 | start_epoch = 0 42 | 43 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 44 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.pretrain_dir) 45 | 46 | os.makedirs(cfg.log_dir, exist_ok=True) 47 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 48 | 49 | def main(): 50 | if cfg.cifar == 10: 51 | print('training CIFAR-10 !') 52 | dataset = torchvision.datasets.CIFAR10 53 | elif cfg.cifar == 100: 54 | print('training CIFAR-100 !') 55 | dataset = torchvision.datasets.CIFAR100 56 | else: 57 | assert False, 'dataset unknown !' 58 | 59 | print('===> Preparing data ..') 60 | train_dataset = dataset(root=cfg.data_dir, train=True, download=True, 61 | transform=cifar_transform(is_training=True)) 62 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.train_batch_size, shuffle=True, 63 | num_workers=cfg.num_workers) 64 | 65 | eval_dataset = dataset(root=cfg.data_dir, train=False, download=True, 66 | transform=cifar_transform(is_training=False)) 67 | eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=cfg.eval_batch_size, shuffle=False, 68 | num_workers=cfg.num_workers) 69 | 70 | print('===> Building ResNet..') 71 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 72 | model = ResNet18(wbit=cfg.Wbits,abit=cfg.Abits).to(device) 73 | 74 | if device == 'cuda': 75 | model = torch.nn.DataParallel(model) 76 | cudnn.benchmark = True 77 | 78 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 79 | # optimizer = torch.optim.Adam(model.parameters(),lr=cfg.lr,weight_decay=cfg.wd) 80 | lr_schedu = optim.lr_scheduler.MultiStepLR(optimizer, [90, 150, 200], gamma=0.1) 81 | criterion = torch.nn.CrossEntropyLoss().cuda() 82 | summary_writer = SummaryWriter(cfg.log_dir) 83 | 84 | if cfg.pretrain: 85 | ckpt = torch.load(os.path.join(cfg.ckpt_dir, f'checkpoint.t7')) 86 | model.load_state_dict(ckpt['model_state_dict']) 87 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 88 | start_epoch = ckpt['epoch'] 89 | print('===> Load last checkpoint data') 90 | else: 91 | start_epoch = 0 92 | print('===> Start from scratch') 93 | 94 | 95 | def train(epoch): 96 | print('\nEpoch: %d' % epoch) 97 | model.train() 98 | train_loss, correct, total = 0, 0 ,0 99 | 100 | for batch_idx, (inputs, targets) in enumerate(train_loader): 101 | inputs, targets = inputs.to('cuda'), targets.to('cuda') 102 | optimizer.zero_grad() 103 | outputs = model(inputs) 104 | loss = criterion(outputs, targets) 105 | loss.backward() 106 | optimizer.step() 107 | 108 | train_loss += loss.item() 109 | _, predicted = outputs.max(1) 110 | total += targets.size(0) 111 | correct += predicted.eq(targets).sum().item() 112 | 113 | progress_bar(batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 114 | % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 115 | 116 | if batch_idx % cfg.log_interval == 0: #every log_interval mini_batches... 117 | summary_writer.add_scalar('Loss/train', train_loss / (batch_idx + 1), epoch * len(train_loader) + batch_idx) 118 | summary_writer.add_scalar('Accuracy/train', 100. * correct / total, epoch * len(train_loader) + batch_idx) 119 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], epoch * len(train_loader) + batch_idx) 120 | # for tag, value in model.named_parameters(): 121 | # tag = tag.replace('.', '/') 122 | # summary_writer.add_histogram(tag, value.detach(), global_step=epoch * len(train_loader) + batch_idx) 123 | # summary_writer.add_histogram(tag + '/grad', value.grad.detach(), global_step=epoch * len(train_loader) + batch_idx) 124 | 125 | 126 | 127 | 128 | def test(epoch): 129 | # pass 130 | global best_acc 131 | model.eval() 132 | 133 | test_loss, correct, total = 0, 0, 0 134 | with torch.no_grad(): 135 | for batch_idx, (inputs, targets) in enumerate(eval_loader): 136 | inputs, targets = inputs.to('cuda'), targets.to('cuda') 137 | outputs = model(inputs) 138 | loss = criterion(outputs, targets) 139 | 140 | test_loss += loss.item() 141 | _, predicted = outputs.max(1) 142 | total += targets.size(0) 143 | correct += predicted.eq(targets).sum().item() 144 | 145 | progress_bar(batch_idx, len(eval_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 146 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 147 | 148 | if batch_idx % cfg.log_interval == 0: # every log_interval mini_batches... 149 | summary_writer.add_scalar('Loss/test', test_loss / (batch_idx + 1), epoch * len(train_loader) + batch_idx) 150 | summary_writer.add_scalar('Accuracy/test', 100. * correct / total, epoch * len(train_loader) + batch_idx) 151 | 152 | acc = 100. * correct / total 153 | if acc > best_acc: 154 | print('Saving..') 155 | state = { 156 | 'model_state_dict': model.state_dict(), 157 | 'optimizer_state_dict': optimizer.state_dict(), 158 | 'acc': acc, 159 | 'epoch': epoch, 160 | } 161 | torch.save(state, os.path.join(cfg.ckpt_dir, f'checkpoint.t7')) 162 | best_acc = acc 163 | 164 | for epoch in range(start_epoch, cfg.max_epochs): 165 | train(epoch) 166 | test(epoch) 167 | lr_schedu.step(epoch) 168 | summary_writer.close() 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | 174 | -------------------------------------------------------------------------------- /doc/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jzz24/pytorch_quantization/0c2d93c8ce4f85dd2c34ea6f36c58d14db21bf8e/doc/tensorboard.png -------------------------------------------------------------------------------- /imagenet_dali_loader.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import warnings 4 | from datetime import datetime 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.utils.data.distributed 11 | import torchvision.models as models 12 | 13 | cudnn.benchmark = True 14 | 15 | from models.resnet_imagenet import * 16 | from utils.preprocess import * 17 | from utils.bar_show import * 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | # Training settings 22 | parser = argparse.ArgumentParser(description='dorefa-net imagenet2012 implementation') 23 | 24 | parser.add_argument('--root_dir', type=str, default='./') 25 | parser.add_argument('--data_dir', type=str, default='/imagenet2012_datasets') 26 | parser.add_argument('--log_name', type=str, default='resnet_imagenet_float') 27 | parser.add_argument('--pretrain', action='store_true', default=False) 28 | parser.add_argument('--pretrain_dir', type=str, default='resnet_float') 29 | 30 | parser.add_argument('--lr', type=float, default=0.1) 31 | parser.add_argument('--wd', type=float, default=1e-4) 32 | parser.add_argument('--train_batch_size', type=int, default=256) 33 | parser.add_argument('--eval_batch_size', type=int, default=100) 34 | parser.add_argument('--max_epochs', type=int, default=90) 35 | parser.add_argument('--log_interval', type=int, default=40) 36 | parser.add_argument('--num_workers', type=int, default=8) 37 | parser.add_argument('--Wbits', type=int, default=8) 38 | parser.add_argument('--Abits', type=int, default=8) 39 | 40 | cfg = parser.parse_args() 41 | 42 | best_acc = 0 # best test accuracy 43 | start_epoch = 0 44 | TOTAL_TRAIN_PICS = 1271171 45 | TOTAL_EVAL_PICS = 50000 46 | 47 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 48 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.pretrain_dir) 49 | 50 | os.makedirs(cfg.log_dir, exist_ok=True) 51 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 52 | 53 | 54 | def main(): 55 | 56 | # nvidia dali dataloader 57 | train_loader = get_imagenet_iter_dali(type='train', image_dir=cfg.data_dir, batch_size=cfg.train_batch_size, 58 | num_threads=16, crop=224, device_id=0, num_gpus=2) 59 | eval_loader = get_imagenet_iter_dali(type='val', image_dir=cfg.data_dir, batch_size=cfg.eval_batch_size, 60 | num_threads=8, crop=224, device_id=0, num_gpus=2) 61 | 62 | print('===> Building ResNet..') 63 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 64 | model = resnet18(wbit=cfg.Wbits, abit=cfg.Abits, pretrained=False) 65 | 66 | if device == 'cuda': 67 | model = torch.nn.DataParallel(model) 68 | cudnn.benchmark = True 69 | 70 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 71 | lr_schedu = optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 90], gamma=0.1) 72 | criterion = torch.nn.CrossEntropyLoss().cuda() 73 | summary_writer = SummaryWriter(cfg.log_dir) 74 | 75 | if cfg.pretrain: 76 | ckpt = torch.load(os.path.join(cfg.ckpt_dir, f'checkpoint.t7')) 77 | model.load_state_dict(ckpt['model_state_dict']) 78 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 79 | start_epoch = ckpt['epoch'] 80 | print('===> Load last checkpoint data') 81 | else: 82 | start_epoch = 0 83 | print('===> Start from scratch') 84 | 85 | for epoch in range(start_epoch, cfg.max_epochs): 86 | train(epoch, model, train_loader, criterion, optimizer, summary_writer) 87 | test(epoch, model, eval_loader, criterion, optimizer, summary_writer) 88 | lr_schedu.step(epoch) 89 | summary_writer.close() 90 | 91 | 92 | def train(epoch, model, train_loader, criterion, optimizer, summary_writer): 93 | 94 | print('\nEpoch: %d' % epoch) 95 | 96 | batch_time = AverageMeter('Time', ':6.3f') 97 | data_time = AverageMeter('Data', ':6.3f') 98 | losses = AverageMeter('Loss', ':.4e') 99 | top1 = AverageMeter('Acc@1', ':6.2f') 100 | top5 = AverageMeter('Acc@5', ':6.2f') 101 | 102 | # switch to train mode 103 | model.train() 104 | 105 | end = time.time() 106 | for batch_idx, data in enumerate(train_loader): 107 | 108 | #measure data loading time 109 | data_time.update(time.time() - end) 110 | 111 | inputs = data[0]["data"].cuda(non_blocking=True) 112 | targets = data[0]["label"].squeeze().long().cuda(non_blocking=True) 113 | 114 | #compute output 115 | outputs = model(inputs) 116 | loss = criterion(outputs, targets) 117 | 118 | #measure acc and record loss 119 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 120 | losses.update(loss.item(), inputs.size(0)) 121 | top1.update(acc1[0], inputs.size(0)) 122 | top5.update(acc5[0], inputs.size(0)) 123 | 124 | #compute gradient and do SGD step 125 | optimizer.zero_grad() 126 | loss.backward() 127 | optimizer.step() 128 | 129 | #measure elapsed time 130 | batch_time.update(time.time() - end) 131 | end = time.time() 132 | 133 | num_batch_per_epoch = TOTAL_TRAIN_PICS // inputs.size(0) + 1 134 | progress_bar(batch_idx, num_batch_per_epoch, 'Loss: %.3f | Acc1: %.3f%% Acc5: %.3f%% ' 135 | % (losses.avg, top1.avg, top5.avg)) 136 | 137 | if batch_idx % cfg.log_interval == 0: #every log_interval mini_batches... 138 | summary_writer.add_scalar('Loss/train', losses.avg, epoch * num_batch_per_epoch + batch_idx) 139 | summary_writer.add_scalar('Accuracy/train', top1.avg, epoch * num_batch_per_epoch + batch_idx) 140 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], epoch * num_batch_per_epoch + batch_idx) 141 | # for tag, value in model.named_parameters(): 142 | # tag = tag.replace('.', '/') 143 | # summary_writer.add_histogram(tag, value.detach(), global_step=epoch * len(train_loader) + batch_idx) 144 | # summary_writer.add_histogram(tag + '/grad', value.grad.detach(), global_step=epoch * len(train_loader) + batch_idx) 145 | 146 | 147 | def test(epoch, model, eval_loader, criterion, optimizer, summary_writer): 148 | # pass 149 | global best_acc 150 | batch_time = AverageMeter('Time', ':6.3f') 151 | losses = AverageMeter('Loss', ':.4e') 152 | top1 = AverageMeter('Acc@1', ':6.2f') 153 | top5 = AverageMeter('Acc@5', ':6.2f') 154 | 155 | # switch to evaluate mode 156 | model.eval() 157 | 158 | with torch.no_grad(): 159 | end = time.time() 160 | for batch_idx, data in enumerate(eval_loader): 161 | 162 | inputs = data[0]["data"].cuda(non_blocking=True) 163 | targets = data[0]["label"].squeeze().long().cuda(non_blocking=True) 164 | 165 | #compute output 166 | outputs = model(inputs) 167 | loss = criterion(outputs, targets) 168 | 169 | #measure acc and record loss 170 | acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) 171 | losses.update(loss.item(), inputs.size(0)) 172 | top1.update(acc1[0], inputs.size(0)) 173 | top5.update(acc5[0], inputs.size(0)) 174 | 175 | # measure elapsed time 176 | batch_time.update(time.time() - end) 177 | end = time.time() 178 | 179 | num_batch_per_epoch = TOTAL_EVAL_PICS // inputs.size(0) 180 | progress_bar(batch_idx, num_batch_per_epoch, 'Loss: %.3f | Acc1: %.3f%% Acc5: %.3f%% ' 181 | % (losses.avg, top1.avg, top5.avg)) 182 | 183 | if batch_idx % cfg.log_interval == 0: # every log_interval mini_batches... 184 | summary_writer.add_scalar('Loss/test', losses.avg, epoch * num_batch_per_epoch + batch_idx) 185 | summary_writer.add_scalar('Accuracy/test', top1.avg, epoch * num_batch_per_epoch + batch_idx) 186 | 187 | acc = top1.avg 188 | if acc > best_acc: 189 | print('Saving..') 190 | state = { 191 | 'model_state_dict': model.state_dict(), 192 | 'optimizer_state_dict': optimizer.state_dict(), 193 | 'acc': acc, 194 | 'epoch': epoch, 195 | } 196 | torch.save(state, os.path.join(cfg.ckpt_dir, f'checkpoint.t7')) 197 | best_acc = acc 198 | 199 | if __name__ == '__main__': 200 | main() 201 | 202 | -------------------------------------------------------------------------------- /imagenet_torch_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy 4 | import argparse 5 | import warnings 6 | from datetime import datetime 7 | 8 | import torch 9 | import torchvision.datasets as datasets 10 | import torch.optim as optim 11 | import torch.backends.cudnn as cudnn 12 | from torch.utils.tensorboard import SummaryWriter 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | import torch.utils.data.distributed 16 | import torchvision.models as models 17 | 18 | 19 | cudnn.benchmark = True 20 | 21 | from models.resnet_imagenet import * 22 | 23 | 24 | 25 | from utils.preprocess import * 26 | from utils.bar_show import * 27 | import warnings 28 | warnings.filterwarnings("ignore") 29 | 30 | # Training settings 31 | parser = argparse.ArgumentParser(description='dorefa-net imagenet2012 implementation') 32 | 33 | parser.add_argument('--root_dir', type=str, default='./') 34 | parser.add_argument('--data_dir', type=str, default='/imagenet2012_datasets') 35 | parser.add_argument('--log_name', type=str, default='resnet_imagenet_4w4f') 36 | parser.add_argument('--pretrain', action='store_true', default=False) 37 | parser.add_argument('--pretrain_dir', type=str, default='resnet_4w4f') 38 | 39 | 40 | parser.add_argument('--lr', type=float, default=0.1) 41 | parser.add_argument('--wd', type=float, default=1e-4) 42 | parser.add_argument('--train_batch_size', type=int, default=256) 43 | parser.add_argument('--eval_batch_size', type=int, default=100) 44 | parser.add_argument('--max_epochs', type=int, default=90) 45 | parser.add_argument('--log_interval', type=int, default=40) 46 | parser.add_argument('--num_workers', type=int, default=6) 47 | parser.add_argument('--Wbits', type=int, default=4) 48 | parser.add_argument('--Abits', type=int, default=4) 49 | 50 | parser.add_argument('--world-size', default=1, type=int, 51 | help='number of nodes for distributed training') 52 | parser.add_argument('--rank', default=0, type=int, 53 | help='node rank for distributed training') 54 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:2345', type=str, 55 | help='url used to set up distributed training') 56 | parser.add_argument('--dist-backend', default='nccl', type=str, 57 | help='distributed backend') 58 | parser.add_argument('--seed', default=None, type=int, 59 | help='seed for initializing training. ') 60 | parser.add_argument('--gpu', default=None, type=int, 61 | help='GPU id to use.') 62 | parser.add_argument('--multiprocessing-distributed', action='store_true', 63 | help='Use multi-processing distributed training to launch ' 64 | 'N processes per node, which has N GPUs. This is the ' 65 | 'fastest way to use PyTorch for either single node or ' 66 | 'multi node data parallel training') 67 | 68 | 69 | cfg = parser.parse_args() 70 | 71 | 72 | best_acc = 0 # best test accuracy 73 | start_epoch = 0 74 | 75 | cfg.log_dir = os.path.join(cfg.root_dir, 'logs', cfg.log_name) 76 | cfg.ckpt_dir = os.path.join(cfg.root_dir, 'ckpt', cfg.pretrain_dir) 77 | 78 | os.makedirs(cfg.log_dir, exist_ok=True) 79 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 80 | 81 | def main(): 82 | 83 | if cfg.gpu is not None: 84 | warnings.warn('You have chosen a specific GPU. This will completely ' 85 | 'disable data parallelism.') 86 | 87 | if cfg.dist_url == "env://" and cfg.world_size == -1: 88 | cfg.world_size = int(os.environ["WORLD_SIZE"]) 89 | 90 | cfg.distributed = cfg.world_size > 1 or cfg.multiprocessing_distributed 91 | 92 | ngpus_per_node = torch.cuda.device_count() 93 | if cfg.multiprocessing_distributed: 94 | # Since we have ngpus_per_node processes per node, the total world_size 95 | # needs to be adjusted accordingly 96 | cfg.world_size = ngpus_per_node * cfg.world_size 97 | # Use torch.multiprocessing.spawn to launch distributed processes: the 98 | # main_worker process function 99 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg)) 100 | else: 101 | # Simply call main_worker function 102 | main_worker(cfg.gpu, ngpus_per_node, cfg) 103 | 104 | 105 | def main_worker(gpu, ngpus_per_node, cfg): 106 | cfg.gpu = gpu 107 | 108 | if cfg.gpu is not None: 109 | print("Use GPU: {} for training".format(cfg.gpu)) 110 | 111 | if cfg.distributed: 112 | if cfg.dist_url == "env://" and cfg.rank == -1: 113 | cfg.rank = int(os.environ["RANK"]) 114 | if cfg.multiprocessing_distributed: 115 | # For multiprocessing distributed training, rank needs to be the 116 | # global rank among all the processes 117 | cfg.rank = cfg.rank * ngpus_per_node + gpu 118 | dist.init_process_group(backend=cfg.dist_backend, init_method=cfg.dist_url, 119 | world_size=cfg.world_size, rank=cfg.rank) 120 | 121 | print('===> Building ResNet..') 122 | 123 | model = resnet18(wbit=cfg.Wbits, abit=cfg.Abits, pretrained=False) 124 | # model = models.__dict__[resnet18(pretrained=False)] 125 | 126 | if cfg.distributed: 127 | # For multiprocessing distributed, DistributedDataParallel constructor 128 | # should always set the single device scope, otherwise, 129 | # DistributedDataParallel will use all available devices. 130 | if cfg.gpu is not None: 131 | torch.cuda.set_device(cfg.gpu) 132 | model.cuda(cfg.gpu) 133 | # When using a single GPU per process and per 134 | # DistributedDataParallel, we need to divide the batch size 135 | # ourselves based on the total number of GPUs we have 136 | cfg.train_batch_size = int(cfg.train_batch_size / ngpus_per_node) 137 | cfg.workers = int((cfg.num_workers + ngpus_per_node - 1) / ngpus_per_node) 138 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.gpu]) 139 | else: 140 | model.cuda() 141 | # DistributedDataParallel will divide and allocate batch_size to all 142 | # available GPUs if device_ids are not set 143 | model = torch.nn.parallel.DistributedDataParallel(model) 144 | elif cfg.gpu is not None: 145 | torch.cuda.set_device(cfg.gpu) 146 | model = model.cuda(cfg.gpu) 147 | else: 148 | # DataParallel will divide and allocate batch_size to all available GPUs 149 | model = torch.nn.DataParallel(model).cuda() 150 | 151 | 152 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 153 | lr_schedu = optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 90], gamma=0.1) 154 | criterion = torch.nn.CrossEntropyLoss().cuda(cfg.gpu) 155 | summary_writer = SummaryWriter(cfg.log_dir) 156 | 157 | # Data loading code 158 | traindir = os.path.join(cfg.data_dir, 'train') 159 | valdir = os.path.join(cfg.data_dir, 'val') 160 | 161 | train_dataset = datasets.ImageFolder(traindir, imgnet_transform(is_training=True)) 162 | val_dataset = datasets.ImageFolder(valdir,imgnet_transform(is_training=False)) 163 | 164 | if cfg.distributed: 165 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 166 | else: 167 | train_sampler = None 168 | 169 | train_loader = torch.utils.data.DataLoader( 170 | train_dataset, batch_size=cfg.train_batch_size, shuffle=(train_sampler is None), num_workers=cfg.num_workers, pin_memory=True, sampler=train_sampler) 171 | eval_loader = torch.utils.data.DataLoader( 172 | val_dataset, batch_size=cfg.eval_batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True) 173 | 174 | 175 | 176 | if cfg.pretrain: 177 | ckpt = torch.load(os.path.join(cfg.ckpt_dir, f'checkpoint.t7')) 178 | model.load_state_dict(ckpt['model_state_dict']) 179 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 180 | start_epoch = ckpt['epoch'] 181 | print('===> Load last checkpoint data') 182 | else: 183 | start_epoch = 0 184 | print('===> Start from scratch') 185 | 186 | for epoch in range(start_epoch, cfg.max_epochs): 187 | if cfg.distributed: 188 | train_sampler.set_epoch(epoch) 189 | train(epoch, model, train_loader, criterion, optimizer, summary_writer) 190 | test(epoch, model, eval_loader, criterion, optimizer, summary_writer) 191 | lr_schedu.step(epoch) 192 | summary_writer.close() 193 | 194 | 195 | def train(epoch, model, train_loader, criterion, optimizer, summary_writer): 196 | 197 | print('\nEpoch: %d' % epoch) 198 | 199 | batch_time = AverageMeter('Time', ':6.3f') 200 | data_time = AverageMeter('Data', ':6.3f') 201 | losses = AverageMeter('Loss', ':.4e') 202 | top1 = AverageMeter('Acc@1', ':6.2f') 203 | top5 = AverageMeter('Acc@5', ':6.2f') 204 | 205 | # switch to train mode 206 | model.train() 207 | 208 | end = time.time() 209 | for batch_idx, (inputs, targets) in enumerate(train_loader): 210 | #measure data loading time 211 | data_time.update(time.time() - end) 212 | 213 | if cfg.gpu is not None: 214 | inputs = inputs.cuda(cfg.gpu, non_blocking=True) 215 | targets = targets.cuda(cfg.gpu, non_blocking=True) 216 | # inputs, targets = inputs.to('cuda'), targets.to('cuda') 217 | 218 | #compute output 219 | outputs = model(inputs) 220 | loss = criterion(outputs, targets) 221 | 222 | #measure acc and record loss 223 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 224 | losses.update(loss.item(), inputs.size(0)) 225 | top1.update(acc1[0], inputs.size(0)) 226 | top5.update(acc5[0], inputs.size(0)) 227 | 228 | #compute gradient and do SGD step 229 | optimizer.zero_grad() 230 | loss.backward() 231 | optimizer.step() 232 | 233 | #measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | progress_bar(batch_idx, len(train_loader), 'Loss: %.3f | Acc1: %.3f%% Acc5: %.3f%% ' 238 | % (losses.avg, top1.avg, top5.avg)) 239 | 240 | if batch_idx % cfg.log_interval == 0: #every log_interval mini_batches... 241 | summary_writer.add_scalar('Loss/train', losses.avg, epoch * len(train_loader) + batch_idx) 242 | summary_writer.add_scalar('Accuracy/train', top1.avg, epoch * len(train_loader) + batch_idx) 243 | summary_writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], epoch * len(train_loader) + batch_idx) 244 | # for tag, value in model.named_parameters(): 245 | # tag = tag.replace('.', '/') 246 | # summary_writer.add_histogram(tag, value.detach(), global_step=epoch * len(train_loader) + batch_idx) 247 | # summary_writer.add_histogram(tag + '/grad', value.grad.detach(), global_step=epoch * len(train_loader) + batch_idx) 248 | 249 | 250 | 251 | 252 | def test(epoch, model, eval_loader, criterion, optimizer, summary_writer): 253 | # pass 254 | global best_acc 255 | batch_time = AverageMeter('Time', ':6.3f') 256 | losses = AverageMeter('Loss', ':.4e') 257 | top1 = AverageMeter('Acc@1', ':6.2f') 258 | top5 = AverageMeter('Acc@5', ':6.2f') 259 | 260 | # switch to evaluate mode 261 | model.eval() 262 | 263 | with torch.no_grad(): 264 | end = time.time() 265 | for batch_idx, (inputs, targets) in enumerate(eval_loader): 266 | if cfg.gpu is not None: 267 | inputs = inputs.cuda(cfg.gpu, non_blocking=True) 268 | targets = targets.cuda(cfg.gpu, non_blocking=True) 269 | 270 | #compute output 271 | outputs = model(inputs) 272 | loss = criterion(outputs, targets) 273 | 274 | #measure acc and record loss 275 | acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) 276 | losses.update(loss.item(), inputs.size(0)) 277 | top1.update(acc1[0], inputs.size(0)) 278 | top5.update(acc5[0], inputs.size(0)) 279 | 280 | # measure elapsed time 281 | batch_time.update(time.time() - end) 282 | end = time.time() 283 | 284 | progress_bar(batch_idx, len(eval_loader), 'Loss: %.3f | Acc1: %.3f%% Acc5: %.3f%% ' 285 | % (losses.avg, top1.avg, top5.avg)) 286 | 287 | if batch_idx % cfg.log_interval == 0: # every log_interval mini_batches... 288 | summary_writer.add_scalar('Loss/test', losses.avg, epoch * len(eval_loader) + batch_idx) 289 | summary_writer.add_scalar('Accuracy/test', top1.avg, epoch * len(eval_loader) + batch_idx) 290 | 291 | acc = top1.avg 292 | if acc > best_acc: 293 | print('Saving..') 294 | state = { 295 | 'model_state_dict': model.state_dict(), 296 | 'optimizer_state_dict': optimizer.state_dict(), 297 | 'acc': acc, 298 | 'epoch': epoch, 299 | } 300 | torch.save(state, os.path.join(cfg.ckpt_dir, f'checkpoint.t7')) 301 | best_acc = acc 302 | 303 | if __name__ == '__main__': 304 | main() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jzz24/pytorch_quantization/0c2d93c8ce4f85dd2c34ea6f36c58d14db21bf8e/models/__init__.py -------------------------------------------------------------------------------- /models/model_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jzz24/pytorch_quantization/0c2d93c8ce4f85dd2c34ea6f36c58d14db21bf8e/models/model_utils/__init__.py -------------------------------------------------------------------------------- /models/model_utils/bn_fuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | import sys 5 | import numpy as np 6 | import torchvision 7 | import torch.nn.functional as F 8 | 9 | 10 | class DummyModule(nn.Module): 11 | def __init__(self): 12 | super(DummyModule, self).__init__() 13 | 14 | def forward(self, x): 15 | # print("Dummy, Dummy.") 16 | return x 17 | 18 | def fuse(conv, bn): 19 | # *******************conv参数******************** 20 | w = conv.weight 21 | 22 | # ********************BN参数********************* 23 | mean = bn.running_mean 24 | var_sqrt = torch.sqrt(bn.running_var + bn.eps) 25 | gamma = bn.weight 26 | beta = bn.bias 27 | 28 | if conv.bias is not None: 29 | b = conv.bias 30 | else: 31 | b = mean.new_zeros(mean.shape) 32 | 33 | w = w * (gamma / var_sqrt).reshape([conv.out_channels, 1, 1, 1]) 34 | b = (b - mean)/var_sqrt * gamma + beta 35 | 36 | # fused_conv = Conv(in_channels=conv.in_channels, 37 | # out_channels=conv.out_channels, 38 | # kernel_size=conv.kernel_size, 39 | # nbit_w=32, 40 | # nbit_a=8, 41 | # stride=conv.stride, 42 | # padding=conv.padding, 43 | # groups=conv.groups, 44 | # bias=True) 45 | fused_conv = nn.Conv2d(conv.in_channels, 46 | conv.out_channels, 47 | conv.kernel_size, 48 | conv.stride, 49 | conv.padding, 50 | bias=True) 51 | fused_conv.weight = nn.Parameter(w) 52 | fused_conv.bias = nn.Parameter(b) 53 | return fused_conv 54 | 55 | def fuse_module(m): 56 | children = list(m.named_children()) 57 | c = None 58 | cn = None 59 | 60 | for name, child in children: 61 | if isinstance(child, nn.BatchNorm2d): 62 | bc = fuse(c, child) 63 | m._modules[cn] = bc 64 | m._modules[name] = DummyModule() 65 | c = None 66 | elif isinstance(child, nn.Conv2d): 67 | c = child 68 | cn = name 69 | else: 70 | fuse_module(child) 71 | 72 | 73 | def test_net(m): 74 | 75 | p = torch.randn([1, 3, 224, 224]) 76 | import time 77 | s = time.time() 78 | o_output = m(p) 79 | print("Original time: ", time.time() - s) 80 | 81 | fuse_module(m) 82 | # print(m) 83 | 84 | s = time.time() 85 | f_output = m(p) 86 | print("Fused time: ", time.time() - s) 87 | 88 | print("Max abs diff: ", (o_output - f_output).abs().max().item()) 89 | assert(o_output.argmax() == f_output.argmax()) 90 | # print(o_output[0][0].item(), f_output[0][0].item()) 91 | print("MSE diff: ", nn.MSELoss()(o_output, f_output).item()) 92 | 93 | def fuse_model(m): 94 | p = torch.rand([1, 3, 32, 32]) 95 | s = time.time() 96 | o_output = m(p) 97 | print("Original time: ", time.time() - s) 98 | 99 | fuse_module(m) 100 | 101 | s = time.time() 102 | f_output = m(p) 103 | print("Fused time: ", time.time() - s) 104 | return m 105 | 106 | 107 | def test(): 108 | print("============================") 109 | print("Module level test: ") 110 | m = torchvision.models.resnet18(True) 111 | m.eval() 112 | test_net(m) 113 | # fuse_model(m) 114 | 115 | 116 | if __name__ == '__main__': 117 | test() -------------------------------------------------------------------------------- /models/model_utils/quant_dorefa.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch.autograd import Function 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | class ScaleSigner(Function): 11 | """take a real value x, output sign(x)*E(|x|)""" 12 | @staticmethod 13 | def forward(ctx, input): 14 | return torch.sign(input) * torch.mean(torch.abs(input)) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | return grad_output 19 | 20 | 21 | def scale_sign(input): 22 | return ScaleSigner.apply(input) 23 | 24 | 25 | class Quantizer(Function): 26 | @staticmethod 27 | def forward(ctx, input, nbit): 28 | scale = 2 ** nbit - 1 29 | return torch.round(input * scale) / scale 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | return grad_output, None 34 | 35 | 36 | def quantize(input, nbit): 37 | return Quantizer.apply(input, nbit) 38 | 39 | 40 | def dorefa_w(w, nbit_w): 41 | if nbit_w == 1: 42 | w = scale_sign(w) 43 | else: 44 | w = torch.tanh(w) 45 | w = w / (2 * torch.max(torch.abs(w))) + 0.5 46 | w = 2 * quantize(w, nbit_w) - 1 47 | 48 | return w 49 | 50 | 51 | def dorefa_a(input, nbit_a): 52 | return quantize(torch.clamp(0.1 * input, 0, 1), nbit_a) 53 | 54 | 55 | class QuanConv(nn.Conv2d): 56 | """docstring for QuanConv""" 57 | def __init__(self, in_channels, out_channels, kernel_size, quan_name_w='dorefa', quan_name_a='dorefa', nbit_w=1, 58 | nbit_a=1, stride=1, 59 | padding=0, dilation=1, groups=1, 60 | bias=True): 61 | super(QuanConv, self).__init__( 62 | in_channels, out_channels, kernel_size, stride, padding, dilation, 63 | groups, bias) 64 | self.nbit_w = nbit_w 65 | self.nbit_a = nbit_a 66 | name_w_dict = {'dorefa': dorefa_w} 67 | name_a_dict = {'dorefa': dorefa_a} 68 | self.quan_w = name_w_dict[quan_name_w] 69 | self.quan_a = name_a_dict[quan_name_a] 70 | 71 | # @weak_script_method 72 | def forward(self, input): 73 | if self.nbit_w < 32: 74 | w = self.quan_w(self.weight, self.nbit_w) 75 | else: 76 | w = self.weight 77 | 78 | if self.nbit_a < 32: 79 | x = self.quan_a(input, self.nbit_a) 80 | else: 81 | x = input 82 | # print('x unique',np.unique(x.detach().numpy()).shape) 83 | # print('w unique',np.unique(w.detach().numpy()).shape) 84 | 85 | output = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) 86 | 87 | return output 88 | 89 | class Linear_Q(nn.Linear): 90 | def __init__(self, in_features, out_features, bias=True, quan_name_w='dorefa', quan_name_a='dorefa', nbit_w=1, nbit_a=1): 91 | super(Linear_Q, self).__init__(in_features, out_features, bias) 92 | self.nbit_w = nbit_w 93 | self.nbit_a = nbit_a 94 | name_w_dict = {'dorefa': dorefa_w} 95 | name_a_dict = {'dorefa': dorefa_a} 96 | self.quan_w = name_w_dict[quan_name_w] 97 | self.quan_a = name_a_dict[quan_name_a] 98 | 99 | # @weak_script_method 100 | def forward(self, input): 101 | if self.nbit_w < 32: 102 | w = self.quan_w(self.weight, self.nbit_w) 103 | else: 104 | w = self.weight 105 | 106 | if self.nbit_a < 32: 107 | x = self.quan_a(input, self.nbit_a) 108 | else: 109 | x = input 110 | 111 | # print('x unique',np.unique(x.detach().numpy())) 112 | # print('w unique',np.unique(w.detach().numpy())) 113 | 114 | output = F.linear(x, w, self.bias) 115 | 116 | return output -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | # from .model_utils.quant_dorefa import QuanConv as Conv 6 | # from .model_utils.quant_dorefa import * 7 | # from .model_utils.bn_fuse import fuse_module 8 | 9 | from model_utils.quant_dorefa import QuanConv as Conv 10 | from model_utils.quant_dorefa import * 11 | from model_utils.bn_fuse import fuse_module 12 | 13 | import torch.nn.functional as F 14 | 15 | 16 | def conv3x3(in_planes, out_planes, wbit, abit, stride=1): 17 | """3x3 convolution with padding""" 18 | return Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, nbit_w=wbit, nbit_a=abit) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, wbit, abit, stride=1): 22 | """1x1 convolution""" 23 | return Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, nbit_w=wbit, nbit_a=abit) 24 | 25 | 26 | def linear(in_featrues, out_features, wbit, abit): 27 | return Linear_Q(in_featrues, out_features, nbit_w=wbit, nbit_a=abit) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, wbit, abit, stride=1): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(in_planes, planes, wbit=wbit, abit=abit, stride=stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes, wbit=wbit, abit=abit, stride=1) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion*planes: 42 | self.shortcut = nn.Sequential( 43 | conv1x1(in_planes, self.expansion*planes, wbit=wbit, abit=abit, stride=stride), 44 | nn.BatchNorm2d(self.expansion*planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, in_planes, planes, wbit, abit, stride=1): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = conv1x1(in_planes, planes, wbit=wbit, abit=abit, stride=1) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = conv3x3(planes, planes, wbit=wbit, abit=abit, stride=stride) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = conv1x1(planes, self.expansion*planes,wbit=wbit, abit=abit, stride=1) 65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion*planes: 69 | self.shortcut = nn.Sequential( 70 | conv1x1(in_planes, self.expansion*planes,wbit=wbit,abit=abit,stride=stride), 71 | nn.BatchNorm2d(self.expansion*planes) 72 | ) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | return out 81 | 82 | class ResNet(nn.Module): 83 | def __init__(self, block, num_blocks, wbit, abit, num_classes=10): 84 | super(ResNet, self).__init__() 85 | self.in_planes = 64 86 | 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # do not quntize the first layer 88 | # self.conv1 = Conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 89 | self.bn1 = nn.BatchNorm2d(64) 90 | self.layer1 = self._make_layer(block, 64, num_blocks[0], wbit=wbit, abit=abit, stride=1) 91 | self.layer2 = self._make_layer(block, 128, num_blocks[1], wbit=wbit, abit=abit, stride=2) 92 | self.layer3 = self._make_layer(block, 256, num_blocks[2], wbit=wbit, abit=abit, stride=2) 93 | self.layer4 = self._make_layer(block, 512, num_blocks[3], wbit=wbit, abit=abit, stride=2) 94 | self.linear = linear(512*block.expansion, num_classes, wbit=8, abit=abit) 95 | # self.linear = nn.Linear(512 * block.expansion, num_classes) 96 | 97 | # weight initialization 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 101 | if m.bias is not None: 102 | nn.init.zeros_(m.bias) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | nn.init.ones_(m.weight) 105 | nn.init.zeros_(m.bias) 106 | elif isinstance(m, nn.Linear): 107 | nn.init.normal_(m.weight, 0, 0.01) 108 | nn.init.zeros_(m.bias) 109 | 110 | 111 | def _make_layer(self, block, planes, num_blocks, wbit, abit, stride): 112 | strides = [stride] + [1]*(num_blocks-1) 113 | layers = [] 114 | for stride in strides: 115 | layers.append(block(self.in_planes, planes, wbit, abit, stride)) 116 | self.in_planes = planes * block.expansion 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | 121 | out = F.relu(self.bn1(self.conv1(x))) 122 | out = self.layer1(out) 123 | out = self.layer2(out) 124 | out = self.layer3(out) 125 | out = self.layer4(out) 126 | out = F.avg_pool2d(out, 4) 127 | out = out.view(out.size(0), -1) 128 | out = self.linear(out) 129 | return out 130 | 131 | def ResNet18(wbit, abit): 132 | return ResNet(BasicBlock, [2,2,2,2], wbit=wbit, abit=abit) 133 | 134 | def ResNet34(wbit, abit): 135 | return ResNet(BasicBlock, [3,4,6,3], wbit=wbit, abit=abit) 136 | 137 | def ResNet50(wbit, abit): 138 | return ResNet(Bottleneck, [3,4,6,3], wbit=wbit, abit=abit) 139 | 140 | def ResNet101(wbit, abit): 141 | return ResNet(Bottleneck, [3,4,23,3], wbit=wbit, abit=abit) 142 | 143 | def ResNet152(wbit, abit): 144 | return ResNet(Bottleneck, [3,8,36,3], wbit=wbit, abit=abit) 145 | 146 | 147 | def test(): 148 | net = ResNet18(wbit=32,abit=32) 149 | print (net) 150 | fuse_module(net) 151 | print (net) 152 | y = net(torch.ones(1,3,32,32)) 153 | print(y) 154 | 155 | if __name__ == '__main__': 156 | test() 157 | 158 | -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from .model_utils.quant_dorefa import QuanConv as Conv 6 | from .model_utils.quant_dorefa import Linear_Q 7 | 8 | import torchvision.models 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 10 | 11 | # you need to download the models to ~/.torch/models 12 | # model_urls = { 13 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | # } 19 | models_dir = os.path.expanduser('~/.torch/models') 20 | model_name = { 21 | 'resnet18': 'resnet18-5c106cde.pth', 22 | 'resnet34': 'resnet34-333f7ec4.pth', 23 | 'resnet50': 'resnet50-19c8e357.pth', 24 | 'resnet101': 'resnet101-5d3b4d8f.pth', 25 | 'resnet152': 'resnet152-b121ed2d.pth', 26 | } 27 | 28 | 29 | # def conv3x3(in_planes, out_planes, stride=1): 30 | # """3x3 convolution with padding""" 31 | # return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 32 | 33 | def conv3x3(in_planes, out_planes, wbit, abit, stride=1): 34 | """3x3 convolution with padding""" 35 | return Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, nbit_w=wbit, nbit_a=abit) 36 | 37 | 38 | def conv1x1(in_planes, out_planes, wbit, abit, stride=1): 39 | """1x1 convolution""" 40 | return Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, nbit_w=wbit, nbit_a=abit) 41 | 42 | 43 | def linear(in_featrues, out_features, wbit, abit): 44 | return Linear_Q(in_featrues, out_features, nbit_w=wbit, nbit_a=abit) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, wbit, abit, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, wbit, abit, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes, wbit, abit) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | residual = self.downsample(x) 72 | 73 | out += residual 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, wbit, abit, stride=1, downsample=None): 83 | super(Bottleneck, self).__init__() 84 | self.conv1 = conv1x1(inplanes, planes, wbit, abit) 85 | self.bn1 = nn.BatchNorm2d(planes) 86 | self.conv2 = conv3x3(planes, planes, wbit, abit, stride=stride) 87 | self.bn2 = nn.BatchNorm2d(planes) 88 | self.conv3 = conv1x1(planes, planes * 4, wbit, abit) 89 | self.bn3 = nn.BatchNorm2d(planes * 4) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = downsample 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | residual = x 96 | 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.bn2(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv3(out) 106 | out = self.bn3(out) 107 | 108 | if self.downsample is not None: 109 | residual = self.downsample(x) 110 | 111 | out += residual 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class ResNet(nn.Module): 118 | 119 | def __init__(self, block, layers, wbit, abit, num_classes=1000): 120 | super(ResNet, self).__init__() 121 | self.inplanes = 64 122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0], wbit=wbit, abit=abit) 127 | self.layer2 = self._make_layer(block, 128, layers[1], wbit=wbit, abit=abit, stride=2) 128 | self.layer3 = self._make_layer(block, 256, layers[2], wbit=wbit, abit=abit, stride=2) 129 | self.layer4 = self._make_layer(block, 512, layers[3], wbit=wbit, abit=abit, stride=2) 130 | self.avgpool = nn.AvgPool2d(7, stride=1) 131 | self.fc = linear(512*block.expansion, num_classes, wbit=8, abit=abit) 132 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | elif isinstance(m, nn.BatchNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | 142 | def _make_layer(self, block, planes, blocks, wbit, abit, stride=1): 143 | downsample = None 144 | if stride != 1 or self.inplanes != planes * block.expansion: 145 | downsample = nn.Sequential( 146 | conv1x1(self.inplanes, planes * block.expansion, wbit=wbit, abit=abit, stride=stride), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, wbit, abit, stride, downsample)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes, wbit, abit)) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def forward(self, x): 159 | x = self.conv1(x) 160 | x = self.bn1(x) 161 | x = self.relu(x) 162 | x = self.maxpool(x) 163 | 164 | x = self.layer1(x) 165 | x = self.layer2(x) 166 | x = self.layer3(x) 167 | x = self.layer4(x) 168 | 169 | x = self.avgpool(x) 170 | x = x.view(x.size(0), -1) 171 | x = self.fc(x) 172 | 173 | return x 174 | 175 | 176 | def resnet18(wbit, abit, pretrained=False, **kwargs): 177 | """Constructs a ResNet-18 model. 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = ResNet(BasicBlock, [2, 2, 2, 2], wbit=wbit, abit=abit, **kwargs) 182 | if pretrained: 183 | model.load_state_dict(torch.load(os.path.join(models_dir, model_name['resnet18']))) 184 | return model 185 | 186 | 187 | def resnet34(wbit, abit, pretrained=False, **kwargs): 188 | """Constructs a ResNet-34 model. 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | model = ResNet(BasicBlock, [3, 4, 6, 3], wbit=wbit, abit=abit, **kwargs) 193 | if pretrained: 194 | model.load_state_dict(torch.load(os.path.join(models_dir, model_name['resnet34']))) 195 | return model 196 | 197 | 198 | def resnet50(wbit, abit, pretrained=False, **kwargs): 199 | """Constructs a ResNet-50 model. 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(Bottleneck, [3, 4, 6, 3], wbit=wbit, abit=abit, **kwargs) 204 | if pretrained: 205 | model.load_state_dict(torch.load(os.path.join(models_dir, model_name['resnet50']))) 206 | return model 207 | 208 | 209 | def resnet101(wbit, abit, pretrained=False, **kwargs): 210 | """Constructs a ResNet-101 model. 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = ResNet(Bottleneck, [3, 4, 23, 3], wbit=wbit, abit=abit, **kwargs) 215 | if pretrained: 216 | model.load_state_dict(torch.load(os.path.join(models_dir, model_name['resnet101']))) 217 | return model 218 | 219 | 220 | def resnet152(wbit, abit, pretrained=False, **kwargs): 221 | """Constructs a ResNet-152 model. 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = ResNet(Bottleneck, [3, 8, 36, 3], wbit=wbit, abit=abit, **kwargs) 226 | if pretrained: 227 | model.load_state_dict(torch.load(os.path.join(models_dir, model_name['resnet152']))) 228 | return model -------------------------------------------------------------------------------- /models/test_fused_quant_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import OrderedDict 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import argparse 12 | import time 13 | import warnings 14 | warnings.filterwarnings("ignore") 15 | 16 | from resnet_cifar import * 17 | from model_utils.quant_dorefa import QuanConv as Conv 18 | 19 | 20 | class DummyModule(nn.Module): 21 | def __init__(self): 22 | super(DummyModule, self).__init__() 23 | def forward(self, x): 24 | return x 25 | 26 | def fuse(conv, bn): 27 | 28 | global fuse_layer_idx ## control first layer fuse 29 | fuse_layer_idx += 1 30 | # *******************conv params******************** 31 | w = conv.weight 32 | 33 | # ********************BN params********************* 34 | mean = bn.running_mean 35 | var_sqrt = torch.sqrt(bn.running_var + bn.eps) 36 | gamma = bn.weight 37 | beta = bn.bias 38 | 39 | if conv.bias is not None: 40 | b = conv.bias 41 | else: 42 | b = mean.new_zeros(mean.shape) 43 | 44 | w = w * (gamma / var_sqrt).reshape([conv.out_channels, 1, 1, 1]) 45 | b = (b - mean)/var_sqrt * gamma + beta 46 | 47 | if fuse_layer_idx == 1: 48 | print ('fuse first layer') 49 | fused_conv = nn.Conv2d(conv.in_channels, 50 | conv.out_channels, 51 | conv.kernel_size, 52 | conv.stride, 53 | conv.padding, 54 | groups=conv.groups, 55 | bias=True) 56 | else: 57 | fused_conv = Conv(in_channels=conv.in_channels, 58 | out_channels=conv.out_channels, 59 | kernel_size=conv.kernel_size, 60 | nbit_w=32, 61 | nbit_a=args.Abits, 62 | stride=conv.stride, 63 | padding=conv.padding, 64 | groups=conv.groups, 65 | bias=True) 66 | fused_conv.weight = nn.Parameter(w) 67 | fused_conv.bias = nn.Parameter(b) 68 | return fused_conv 69 | 70 | def fuse_modules(m): 71 | children = list(m.named_children()) 72 | conv, conv_name= None, None 73 | for layer_idx, (name, child) in enumerate(children): 74 | if isinstance(child, nn.BatchNorm2d): 75 | fused_conv = fuse(conv, child) 76 | print ('===> fusing') 77 | m._modules[conv_name] = fused_conv 78 | m._modules[name] = DummyModule() 79 | conv = None 80 | elif isinstance(child, nn.Conv2d): 81 | conv = child 82 | conv_name = name 83 | else: 84 | fuse_modules(child) 85 | 86 | def model_convert(): 87 | 88 | model = ResNet18(wbit=32,abit=args.Abits) ## modify model config to w=32 89 | model_state_dict = torch.load(args.baseline_model_dir + 'checkpoint.t7')['model_state_dict'] 90 | 91 | ## fix up nn.DataParallel module load to cpu model's bug 92 | new_state_dict = OrderedDict() 93 | for k, v in model_state_dict.items(): 94 | name = k[7:] # remove `module.` 95 | new_state_dict[name] = v 96 | model.load_state_dict(new_state_dict) 97 | 98 | # **********************convert to W qunatized model************************* 99 | conv_lay_idx = 0 ## usually do not quantify the first and last layer 100 | for m in model.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | conv_lay_idx += 1 103 | if conv_lay_idx>=2: 104 | m.weight.data = dorefa_w(m.weight.data, nbit_w=args.Wbits) 105 | elif isinstance(m, nn.Linear): 106 | m.weight.data = dorefa_w(m.weight.data, nbit_w=8) 107 | # m.weight.data = m.weight.data 108 | else: 109 | pass 110 | 111 | torch.save(model, args.baseline_model_dir + 'quan_model.pth') #save entire model 112 | torch.save(model.state_dict(), args.baseline_model_dir + 'quan_model_para.pth') #save model state_dict 113 | 114 | 115 | # ********************** convert to bn_fold W quantized model ************************* 116 | model.eval() 117 | fuse_modules(model) 118 | torch.save(model, args.baseline_model_dir + 'quan_bn_merged_model.pth') 119 | torch.save(model.state_dict, args.baseline_model_dir + 'quan_bn_merged_model_para.pth') 120 | 121 | def test(model): 122 | model.eval() 123 | test_loss = 0 124 | correct = 0 125 | test_time = 0 126 | with torch.no_grad(): 127 | for data, target in testloader: 128 | data, target = Variable(data.cuda()), Variable(target.cuda()) 129 | 130 | start_time = time.time() 131 | output = model(data) 132 | end_time = time.time() 133 | 134 | test_loss += criterion(output, target).data.item() 135 | test_time += end_time - start_time 136 | pred = output.data.max(1, keepdim=True)[1] 137 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 138 | 139 | acc = 100. * float(correct) / len(testloader.dataset) 140 | test_loss /= len(testloader.dataset) 141 | print('\nmodel: Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), Running Time: {:.2f} s'.format( 142 | test_loss * 100, correct, len(testloader.dataset), acc, test_time)) 143 | return 144 | 145 | if __name__ == '__main__': 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--cpu', action='store_true',help='set if only CPU is available') 148 | parser.add_argument('--data', action='store', default='../data',help='dataset path') 149 | parser.add_argument('--baseline_model_dir', type=str, default='../ckpt/resnet_8w8f_cifar/') 150 | parser.add_argument('--eval_batch_size', type=int, default=50) 151 | parser.add_argument('--num_workers', type=int, default=2) 152 | parser.add_argument('--Wbits', type=int, default=8) 153 | parser.add_argument('--Abits', type=int, default=8) 154 | args = parser.parse_args() 155 | 156 | print('==> Options:', args) 157 | print('==> Preparing data..') 158 | 159 | transform_test = transforms.Compose([ 160 | transforms.ToTensor(), 161 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 162 | 163 | testset = torchvision.datasets.CIFAR10(root=args.data, train=False, download=True, transform=transform_test) 164 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, 165 | num_workers=args.num_workers) 166 | 167 | criterion = nn.CrossEntropyLoss() 168 | 169 | fuse_layer_idx = 0 170 | 171 | model_convert() 172 | 173 | # load entire model 174 | quan_model = torch.load(args.baseline_model_dir + 'quan_model.pth') 175 | quan_bn_merged_model = torch.load(args.baseline_model_dir + 'quan_bn_merged_model.pth') 176 | 177 | if not args.cpu: 178 | quan_model.cuda() 179 | quan_bn_merged_model.cuda() 180 | test(quan_model) 181 | test(quan_bn_merged_model) 182 | 183 | -------------------------------------------------------------------------------- /utils/bar_show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import torch 6 | 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | 10 | def get_mean_and_std(dataset): 11 | '''Compute the mean and std value of dataset.''' 12 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 13 | mean = torch.zeros(3) 14 | std = torch.zeros(3) 15 | print('==> Computing mean and std..') 16 | for inputs, targets in dataloader: 17 | for i in range(3): 18 | mean[i] += inputs[:,i,:,:].mean() 19 | std[i] += inputs[:,i,:,:].std() 20 | mean.div_(len(dataset)) 21 | std.div_(len(dataset)) 22 | return mean, std 23 | 24 | 25 | _, term_width = os.popen('stty size', 'r').read().split() 26 | term_width = int(term_width) 27 | 28 | TOTAL_BAR_LENGTH = 45. 29 | last_time = time.time() 30 | begin_time = last_time 31 | 32 | 33 | def progress_bar(current, total, msg=None): 34 | global last_time, begin_time 35 | if current == 0: 36 | begin_time = time.time() # Reset for new bar. 37 | 38 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 39 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 40 | 41 | sys.stdout.write(' [') 42 | for i in range(cur_len): 43 | sys.stdout.write('=') 44 | sys.stdout.write('>') 45 | for i in range(rest_len): 46 | sys.stdout.write('.') 47 | sys.stdout.write(']') 48 | 49 | cur_time = time.time() 50 | step_time = cur_time - last_time 51 | last_time = cur_time 52 | tot_time = cur_time - begin_time 53 | 54 | L = [] 55 | L.append(' Step: %s' % format_time(step_time)) 56 | L.append(' | Tot: %s' % format_time(tot_time)) 57 | if msg: 58 | L.append(' | ' + msg) 59 | 60 | msg = ''.join(L) 61 | sys.stdout.write(msg) 62 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 63 | sys.stdout.write(' ') 64 | 65 | # Go back to the center of the bar. 66 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 67 | sys.stdout.write('\b') 68 | sys.stdout.write(' %d/%d ' % (current+1, total)) 69 | 70 | if current < total-1: 71 | sys.stdout.write('\r') 72 | else: 73 | sys.stdout.write('\n') 74 | sys.stdout.flush() 75 | 76 | 77 | def format_time(seconds): 78 | days = int(seconds / 3600/24) 79 | seconds = seconds - days*3600*24 80 | hours = int(seconds / 3600) 81 | seconds = seconds - hours*3600 82 | minutes = int(seconds / 60) 83 | seconds = seconds - minutes*60 84 | secondsf = int(seconds) 85 | seconds = seconds - secondsf 86 | millis = int(seconds*1000) 87 | 88 | f = '' 89 | i = 1 90 | if days > 0: 91 | f += str(days) + 'D' 92 | i += 1 93 | if hours > 0 and i <= 2: 94 | f += str(hours) + 'h' 95 | i += 1 96 | if minutes > 0 and i <= 2: 97 | f += str(minutes) + 'm' 98 | i += 1 99 | if secondsf > 0 and i <= 2: 100 | f += str(secondsf) + 's' 101 | i += 1 102 | if millis > 0 and i <= 2: 103 | f += str(millis) + 'ms' 104 | i += 1 105 | if f == '': 106 | f = '0ms' 107 | return f 108 | 109 | 110 | def accuracy(output, target, topk=(1,)): 111 | """Computes the accuracy over the k top predictions for the specified values of k""" 112 | with torch.no_grad(): 113 | maxk = max(topk) 114 | batch_size = target.size(0) 115 | 116 | _, pred = output.topk(maxk, 1, True, True) 117 | pred = pred.t() 118 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 119 | 120 | res = [] 121 | for k in topk: 122 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 123 | res.append(correct_k.mul_(100.0 / batch_size)) 124 | return res 125 | 126 | 127 | class AverageMeter(object): 128 | """Computes and stores the average and current value""" 129 | def __init__(self, name, fmt=':f'): 130 | self.name = name 131 | self.fmt = fmt 132 | self.reset() 133 | 134 | def reset(self): 135 | self.val = 0 136 | self.avg = 0 137 | self.sum = 0 138 | self.count = 0 139 | 140 | def update(self, val, n=1): 141 | self.val = val 142 | self.sum += val * n 143 | self.count += n 144 | self.avg = self.sum / self.count 145 | 146 | def __str__(self): 147 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 148 | return fmtstr.format(**self.__dict__) -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.utils.data 3 | import nvidia.dali.ops as ops 4 | import nvidia.dali.types as types 5 | import torchvision.datasets as datasets 6 | from nvidia.dali.pipeline import Pipeline 7 | import torchvision.transforms as transforms 8 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator 9 | 10 | import torchvision.transforms as transforms 11 | 12 | def cifar_transform(is_training=True): 13 | if is_training: 14 | transform_list = [transforms.RandomHorizontalFlip(), 15 | transforms.Pad(padding=4, padding_mode='reflect'), 16 | transforms.RandomCrop(32, padding=0), 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] 19 | else: 20 | transform_list = [transforms.ToTensor(), 21 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] 22 | 23 | transform_list = transforms.Compose(transform_list) 24 | return transform_list 25 | 26 | 27 | def imgnet_transform(is_training=True): 28 | if is_training: 29 | transform_list = transforms.Compose([transforms.RandomResizedCrop(224), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ColorJitter(brightness=0.5, 32 | contrast=0.5, 33 | saturation=0.3), 34 | transforms.ToTensor(), 35 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225])]) 37 | else: 38 | transform_list = transforms.Compose([transforms.Resize(256), 39 | transforms.CenterCrop(224), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225])]) 43 | return transform_list 44 | 45 | class HybridTrainPipe(Pipeline): 46 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False, local_rank=0, world_size=1): 47 | super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) 48 | dali_device = "gpu" 49 | self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True) 50 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 51 | self.res = ops.RandomResizedCrop(device="gpu", size=crop, random_area=[0.08, 1.25]) 52 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 53 | output_dtype=types.FLOAT, 54 | output_layout=types.NCHW, 55 | image_type=types.RGB, 56 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 57 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 58 | self.coin = ops.CoinFlip(probability=0.5) 59 | print('DALI "{0}" variant'.format(dali_device)) 60 | 61 | def define_graph(self): 62 | rng = self.coin() 63 | self.jpegs, self.labels = self.input(name="Reader") 64 | images = self.decode(self.jpegs) 65 | images = self.res(images) 66 | output = self.cmnp(images, mirror=rng) 67 | return [output, self.labels] 68 | 69 | 70 | class HybridValPipe(Pipeline): 71 | def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, local_rank=0, world_size=1): 72 | super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) 73 | self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, 74 | random_shuffle=False) 75 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 76 | self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR) 77 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 78 | output_dtype=types.FLOAT, 79 | output_layout=types.NCHW, 80 | crop=(crop, crop), 81 | image_type=types.RGB, 82 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 83 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 84 | 85 | def define_graph(self): 86 | self.jpegs, self.labels = self.input(name="Reader") 87 | images = self.decode(self.jpegs) 88 | images = self.res(images) 89 | output = self.cmnp(images) 90 | return [output, self.labels] 91 | 92 | 93 | def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256, 94 | world_size=1, 95 | local_rank=0): 96 | if type == 'train': 97 | pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, 98 | data_dir=image_dir + '/train', 99 | crop=crop, world_size=world_size, local_rank=local_rank) 100 | pip_train.build() 101 | dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader") // world_size, auto_reset=True) 102 | return dali_iter_train 103 | elif type == 'val': 104 | pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, 105 | data_dir=image_dir + '/val', 106 | crop=crop, size=val_size, world_size=world_size, local_rank=local_rank) 107 | pip_val.build() 108 | dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader") // world_size, auto_reset=True) 109 | return dali_iter_val 110 | 111 | 112 | def get_imagenet_iter_torch(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256, 113 | world_size=1, local_rank=0): 114 | if type == 'train': 115 | transform = transforms.Compose([ 116 | transforms.RandomResizedCrop(crop, scale=(0.08, 1.25)), 117 | transforms.RandomHorizontalFlip(), 118 | transforms.ToTensor(), 119 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 120 | ]) 121 | dataset = datasets.ImageFolder(image_dir + '/train', transform) 122 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_threads, 123 | pin_memory=True) 124 | else: 125 | transform = transforms.Compose([ 126 | transforms.Resize(val_size), 127 | transforms.CenterCrop(crop), 128 | transforms.ToTensor(), 129 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 130 | ]) 131 | dataset = datasets.ImageFolder(image_dir + '/val', transform) 132 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads, 133 | pin_memory=True) 134 | return dataloader 135 | --------------------------------------------------------------------------------