├── .DS_Store ├── README.md ├── finetune.py ├── model.py ├── model_ablation_1.py ├── model_ablation_2.py ├── new_layers.py ├── pretrain.py ├── read_data.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bohanzhuang/Group-Net-image-classification/bbe9c24270cd84264b9319a4227d563f8e07a921/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structured Binary Neural Networks for Image Recognition 2 | 3 | created by [Bohan Zhuang](https://sites.google.com/view/bohanzhuang) 4 | 5 | 6 | ***If you use this code in your research, please cite our paper:*** 7 | 8 | ``` 9 | @inproceedings{zhuang2019structured, 10 | title={Structured Binary Neural Networks for Accurate Image Classification and Semantic Segmentation}, 11 | author={Zhuang, Bohan and Shen, Chunhua and Tan, Mingkui and Liu, Lingqiao and Reid, Ian}, 12 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 13 | pages={413--422}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | ## This is a simple implementation for imagenet classification. 19 | 20 | 21 | pretrain.py: pretrain using Tanh() 22 | 23 | finetune.py: finetune the binary model 24 | 25 | model.py: define the model with softgates but without learnt scales. The pretrained model is resnet18_without_scales_with_softgates.pth.tar: [Google Drive](https://drive.google.com/file/d/1zQMUhl1WIp_v7kp4lcSiKQqS6RwJcNxO/view?usp=sharing) 26 | 27 | model_ablation_1.py: define the model with softgates and learnt scales. The pretrained model is resnet18_with_scales_with_softgates.pth.tar: [Google Drive](https://drive.google.com/file/d/1dbrUJuHIuYnsCANHc4HxVEnHGsmCOCku/view?usp=sharing) 28 | 29 | model_ablation_2.py: define the model without softgates but with learnt scales. The pretrained model is in [Google Drive](https://drive.google.com/file/d/1oe85OXmRvqxffZjKfPospMH73tzq_TuC/view?usp=sharing) 30 | 31 | new_layers.py: define necessary quantization functions 32 | 33 | utils.py: define auxiliary functions 34 | 35 | 36 | resnet34_without_softgates.pth.tar: pretrained model for ResNet-34 without softgates but with learnt scales, [Google Drive](https://drive.google.com/file/d/1hTciMyZTma2o23W7Bq9y0yLfDKFkNxHy/view?usp=sharing) 37 | 38 | resnet34_with_softgates.pth.tar: pretrained model for ResNet-34 with softgates and learnt scales, [Google Drive](https://drive.google.com/file/d/13CXpXJ__1hPdIhO7NKTHW3ipbAnvQuqu/view?usp=sharing) 39 | 40 | 41 | 42 | 43 | ## Semantic segmentation and object detection 44 | 45 | Please refer to [semantic segmentation](https://bitbucket.org/jingruixiaozhuang/group-net-semantic-segmentation/src/master/) and [object detection](https://bitbucket.org/jingruixiaozhuang/group-net-object-detection/src/master/). 46 | 47 | ## Copyright 48 | 49 | Copyright (c) Bohan Zhuang. 2019 50 | 51 | ** This code is for non-commercial purposes only. For commerical purposes, 52 | please contact Chunhua Shen ** 53 | 54 | This program is free software: you can redistribute it and/or modify 55 | it under the terms of the GNU General Public License as published by 56 | the Free Software Foundation, either version 3 of the License, or 57 | (at your option) any later version. 58 | 59 | This program is distributed in the hope that it will be useful,https://bitbucket.org/jingruixiaozhuang/group_net/src/master/README.md?mode=edit&spa=0&at=master&fileviewer=file-view-default# 60 | but WITHOUT ANY WARRANTY; without even the implied warranty of 61 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 62 | GNU General Public License for more details. 63 | 64 | You should have received a copy of the GNU General Public License 65 | along with this program. If not, see . 66 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import time 5 | import logging 6 | import sys 7 | import argparse 8 | import torch 9 | import glob 10 | from torch.autograd import Variable 11 | from torchvision import transforms, datasets 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from read_data import MyDataset 15 | from model import resnet18 16 | import utils 17 | from utils import adjust_learning_rate 18 | import numpy as np 19 | from random import shuffle 20 | 21 | 22 | parser = argparse.ArgumentParser("ImageNet") 23 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 24 | parser.add_argument('--learning_rate', type=float, default=5e-4, help='init learning rate') 25 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 26 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay') 27 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency') 28 | parser.add_argument('--epochs', type=int, default=40, help='num of training epochs') 29 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 30 | parser.add_argument('--seed', type=int, default=2, help='random seed') 31 | parser.add_argument('--grad_clip', type=float, default=10, help='gradient clipping') 32 | parser.add_argument('--resume_train', action='store_true', default=False, help='resume training') 33 | parser.add_argument('--resume_dir', type=str, default='./weights/checkpoint.pth.tar', help='save weights directory') 34 | parser.add_argument('--load_epoch', type=int, default=30, help='random seed') 35 | parser.add_argument('--weights_dir', type=str, default='./weights/', help='save weights directory') 36 | parser.add_argument('--learning_step', type=list, default=[20,30,40], help='learning rate steps') 37 | 38 | 39 | args = parser.parse_args() 40 | 41 | log_format = '%(asctime)s %(message)s' 42 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 43 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 44 | if not os.path.exists(args.save): 45 | os.makedirs(args.save) 46 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 47 | fh.setFormatter(logging.Formatter(log_format)) 48 | logging.getLogger().addHandler(fh) 49 | 50 | 51 | def main(): 52 | 53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 54 | std=[0.229, 0.224, 0.225]) 55 | 56 | 57 | # Image Preprocessing 58 | train_transform = transforms.Compose([ 59 | transforms.RandomResizedCrop(224), 60 | transforms.RandomHorizontalFlip(), 61 | transforms.ToTensor(), 62 | normalize,]) 63 | 64 | test_transform = transforms.Compose([ 65 | transforms.Resize(256), 66 | transforms.CenterCrop(224), 67 | transforms.ToTensor(), 68 | normalize,]) 69 | 70 | 71 | num_epochs = args.epochs 72 | batch_size = args.batch_size 73 | 74 | train_dataset = datasets.folder.ImageFolder(root='/usr/local/data/imagenet/train/', transform=train_transform) 75 | test_dataset = datasets.folder.ImageFolder(root='/usr/local/data/imagenet/val/', transform=test_transform) 76 | 77 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 78 | batch_size=batch_size, 79 | shuffle=True, num_workers=10, pin_memory=True) 80 | 81 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 82 | batch_size=64, 83 | shuffle=False, num_workers=10, pin_memory=True) 84 | 85 | 86 | 87 | 88 | num_train = train_dataset.__len__() 89 | n_train_batches = math.floor(num_train / batch_size) 90 | 91 | 92 | criterion = nn.CrossEntropyLoss().cuda() 93 | bitW = 1 94 | bitA = 1 95 | model = resnet18(bitW, bitA, pretrained=True) 96 | model = utils.dataparallel(model, 4) 97 | 98 | 99 | print("Compilation complete, starting training...") 100 | 101 | test_record = [] 102 | train_record = [] 103 | learning_rate = args.learning_rate 104 | epoch = 0 105 | step_idx = 0 106 | best_top1 = 0 107 | 108 | 109 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 110 | 111 | 112 | while epoch < num_epochs: 113 | 114 | logging.info('epoch %d lr %e', epoch, learning_rate) 115 | epoch = epoch + 1 116 | # resume training 117 | if (args.resume_train) and (epoch == 1): 118 | checkpoint = torch.load(args.resume_dir) 119 | epoch = checkpoint['epoch'] 120 | learning_rate = checkpoint['learning_rate'] 121 | optimizer.load_state_dict(checkpoint['optimizer']) 122 | model.load_state_dict(checkpoint['state_dict']) 123 | test_record = list( 124 | np.load(args.weights_dir + 'test_record.npy')) 125 | train_record = list( 126 | np.load(args.weights_dir + 'train_record.npy')) 127 | 128 | # training 129 | train_acc_top1, train_acc_top5, train_obj = train(train_loader, model, criterion, optimizer, learning_rate) 130 | logging.info('train_acc %f', train_acc_top1) 131 | train_record.append([train_acc_top1, train_acc_top5]) 132 | np.save(args.weights_dir + 'train_record.npy', train_record) 133 | 134 | # test 135 | test_acc_top1, test_acc_top5, test_obj = infer(test_loader, model, criterion) 136 | is_best = test_acc_top1 > best_top1 137 | if is_best: 138 | best_top1 = test_acc_top1 139 | 140 | logging.info('test_acc %f', test_acc_top1) 141 | test_record.append([test_acc_top1, test_acc_top5]) 142 | np.save(args.weights_dir + 'test_record.npy', test_record) 143 | 144 | save_checkpoint({ 145 | 'epoch': epoch + 1, 146 | 'state_dict': model.state_dict(), 147 | 'optimizer' : optimizer.state_dict(), 148 | 'best_top1': best_top1, 149 | 'learning_rate': learning_rate, 150 | }, args, is_best) 151 | 152 | step_idx, learning_rate = utils.adjust_learning_rate(args, epoch, step_idx, 153 | learning_rate) 154 | 155 | for param_group in optimizer.param_groups: 156 | param_group['lr'] = learning_rate 157 | 158 | 159 | def train(train_queue, model, criterion, optimizer, lr): 160 | 161 | objs = utils.AvgrageMeter() 162 | top1 = utils.AvgrageMeter() 163 | top5 = utils.AvgrageMeter() 164 | 165 | model.train() 166 | 167 | for step, (input, target) in enumerate(train_queue): 168 | 169 | n = input.size(0) 170 | input = input.cuda() 171 | target = target.cuda() 172 | 173 | 174 | logits = model(input) 175 | loss = criterion(logits, target) 176 | 177 | optimizer.zero_grad() 178 | loss.backward() 179 | 180 | optimizer.step() 181 | 182 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 183 | objs.update(loss.item(), n) 184 | top1.update(prec1.item(), n) 185 | top5.update(prec5.item(), n) 186 | 187 | if step % args.report_freq == 0: 188 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 189 | 190 | return top1.avg, top5.avg, objs.avg 191 | 192 | 193 | 194 | def infer(valid_queue, model, criterion): 195 | objs = utils.AvgrageMeter() 196 | top1 = utils.AvgrageMeter() 197 | top5 = utils.AvgrageMeter() 198 | model.eval() 199 | 200 | with torch.no_grad(): 201 | for step, (input, target) in enumerate(valid_queue): 202 | input = input.cuda() 203 | target = target.cuda() 204 | 205 | logits = model(input) 206 | loss = criterion(logits, target) 207 | 208 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 209 | n = input.size(0) 210 | objs.update(loss.item(), n) 211 | top1.update(prec1.item(), n) 212 | top5.update(prec5.item(), n) 213 | 214 | if step % args.report_freq == 0: 215 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 216 | 217 | return top1.avg, top5.avg, objs.avg 218 | 219 | 220 | if __name__ == '__main__': 221 | utils.create_folder(args) 222 | main() 223 | 224 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | from new_layers import new_conv, self_conv, Q_A 7 | import torch.nn.init as init 8 | 9 | 10 | 11 | def conv3x3(in_planes, out_planes, bitW, stride=1): 12 | "3x3 convolution with padding" 13 | return self_conv(in_planes, out_planes, bitW, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, num_bases, inplanes, planes, bitW, bitA, stride=1, downsample=None, add_gate=True): 21 | super(BasicBlock, self).__init__() 22 | self.bitW = bitW 23 | self.bitA = bitA 24 | self.num_bases = num_bases 25 | self.add_gate = add_gate 26 | self.relu = nn.ReLU() 27 | self.conv1 = nn.ModuleList([conv3x3(inplanes, planes, bitW, stride) for i in range(num_bases)]) 28 | self.bn1 = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(num_bases)]) 29 | self.conv2 = nn.ModuleList([conv3x3(planes, planes, bitW) for i in range(num_bases)]) 30 | self.bn2 = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(num_bases)]) 31 | self.downsample = downsample 32 | if add_gate: 33 | self.block_gate = nn.Parameter(torch.rand(1).cuda(), requires_grad=True) 34 | 35 | 36 | def quan_activations(self, x, bitA): 37 | if bitA == 32: 38 | return nn.Tanh()(x) 39 | else: 40 | return Q_A.apply(x) 41 | 42 | 43 | def forward(self, input_bases, input_mean): 44 | 45 | final_output = None 46 | output_bases = [] 47 | 48 | if self.add_gate: 49 | 50 | for base, conv1, conv2, bn1, bn2 in zip(input_bases, self.conv1, self.conv2, self.bn1, self.bn2): 51 | 52 | x = nn.Sigmoid()(self.block_gate) * base + (1.0 - nn.Sigmoid()(self.block_gate)) * input_mean 53 | 54 | if self.downsample is not None: 55 | x = self.quan_activations(x, self.bitA) 56 | residual = self.downsample(x) 57 | else: 58 | residual = x 59 | x = self.quan_activations(x, self.bitA) 60 | 61 | out = conv1(x) 62 | out = self.relu(out) 63 | out = bn1(out) 64 | out += residual 65 | 66 | out_new = self.quan_activations(out, self.bitA) 67 | out_new = conv2(out_new) 68 | out_new = self.relu(out_new) 69 | out_new = bn2(out_new) 70 | out_new += out 71 | 72 | output_bases.append(out_new) 73 | 74 | if final_output is None: 75 | final_output = out_new 76 | else: 77 | final_output += out_new 78 | 79 | else: 80 | 81 | for conv1, conv2, bn1, bn2 in zip(self.conv1, self.conv2, self.bn1, self.bn2): 82 | 83 | if self.downsample is not None: 84 | x = self.quan_activations(input_mean, self.bitA) 85 | residual = self.downsample(x) 86 | else: 87 | residual = input_mean 88 | x = self.quan_activations(input_mean, self.bitA) 89 | 90 | out = conv1(x) 91 | out = self.relu(out) 92 | out = bn1(out) 93 | out += residual 94 | 95 | out_new = self.quan_activations(out, self.bitA) 96 | out_new = conv2(out_new) 97 | out_new = self.relu(out_new) 98 | out_new = bn2(out_new) 99 | out_new += out 100 | 101 | output_bases.append(out_new) 102 | 103 | if final_output is None: 104 | final_output = out_new 105 | else: 106 | final_output += out_new 107 | 108 | 109 | return output_bases, final_output / self.num_bases 110 | 111 | 112 | 113 | class downsample_layer(nn.Module): 114 | def __init__(self, inplanes, planes, bitW, kernel_size=1, stride=1, bias=False): 115 | super(downsample_layer, self).__init__() 116 | self.conv = self_conv(inplanes, planes, bitW, kernel_size=kernel_size, stride=stride, bias=False) 117 | self.batch_norm = nn.BatchNorm2d(planes) 118 | 119 | def forward(self, x): 120 | x = self.conv(x) 121 | x = self.batch_norm(x) 122 | return x 123 | 124 | 125 | 126 | class ResNet(nn.Module): 127 | 128 | def __init__(self, block, layers, bitW, bitA, num_classes=1000): 129 | self.inplanes = 64 130 | self.num_bases = 5 131 | self.bitW = bitW 132 | self.bitA = bitA 133 | super(ResNet, self).__init__() 134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 135 | self.bn1 = nn.BatchNorm2d(64) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0], add_gate=False) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) #don't quantize the last layer 141 | self.avgpool = nn.AvgPool2d(7) 142 | self.fc = nn.Linear(512 * block.expansion, num_classes) 143 | 144 | 145 | def _make_layer(self, block, planes, blocks, stride=1, add_gate=True): 146 | downsample = None 147 | if stride != 1 or self.inplanes != planes * block.expansion: 148 | downsample = downsample_layer(self.inplanes, planes * block.expansion, self.bitW, 149 | kernel_size=1, stride=stride, bias=False) 150 | 151 | layers = nn.ModuleList([]) 152 | layers.append(block(self.num_bases, self.inplanes, planes, self.bitW, self.bitA, stride, downsample, add_gate)) 153 | self.inplanes = planes * block.expansion 154 | for i in range(1, blocks): 155 | layers.append(block(self.num_bases, self.inplanes, planes, self.bitW, self.bitA)) 156 | 157 | return layers 158 | 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.maxpool(x) 163 | x = self.bn1(x) 164 | 165 | sep_out = None 166 | sum_out = x 167 | for layer in self.layer1: 168 | sep_out, sum_out = layer(sep_out, sum_out) 169 | 170 | for layer in self.layer2: 171 | sep_out, sum_out = layer(sep_out, sum_out) 172 | 173 | for layer in self.layer3: 174 | sep_out, sum_out = layer(sep_out, sum_out) 175 | 176 | for layer in self.layer4: 177 | sep_out, sum_out = layer(sep_out, sum_out) 178 | 179 | out = self.avgpool(sum_out) 180 | out = out.view(out.size(0), -1) 181 | out = self.fc(out) 182 | 183 | return out 184 | 185 | 186 | def resnet18(bitW, bitA, pretrained=False, **kwargs): 187 | """Constructs a ResNet-18 model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(BasicBlock, [2, 2, 2, 2], bitW, bitA, **kwargs) 192 | if pretrained: 193 | load_dict = torch.load('./full_precision_records/weights/model_best.pth.tar')['state_dict'] 194 | model_dict = model.state_dict() 195 | model_keys = model_dict.keys() 196 | for name, param in load_dict.items(): 197 | if name.replace('module.', '') in model_keys: 198 | model_dict[name.replace('module.', '')] = param 199 | model.load_state_dict(model_dict) 200 | return model 201 | 202 | 203 | def resnet34(bitW, bitA, pretrained=False, **kwargs): 204 | """Constructs a ResNet-34 model. 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(BasicBlock, [3, 4, 6, 3], bitW, bitA, **kwargs) 209 | if pretrained: 210 | load_dict = torch.load('./full_precision_records/weights/model_best.pth.tar')['state_dict'] 211 | model_dict = model.state_dict() 212 | model_keys = model_dict.keys() 213 | for name, param in load_dict.items(): 214 | if name.replace('module.', '') in model_keys: 215 | model_dict[name.replace('module.', '')] = param 216 | model.load_state_dict(model_dict) 217 | return model 218 | 219 | 220 | def resnet50(bitW, bitA, pretrained=False, **kwargs): 221 | """Constructs a ResNet-50 model. 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = ResNet(Bottleneck, [3, 4, 6, 3], bitW, bitA, **kwargs) 226 | return model 227 | -------------------------------------------------------------------------------- /model_ablation_1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | from new_layers import new_conv, self_conv, Q_A 7 | import torch.nn.init as init 8 | 9 | 10 | 11 | def conv3x3(in_planes, out_planes, bitW, stride=1): 12 | "3x3 convolution with padding" 13 | return self_conv(in_planes, out_planes, bitW, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, num_bases, inplanes, planes, bitW, bitA, stride=1, downsample=None, add_gate=True): 21 | super(BasicBlock, self).__init__() 22 | self.bitW = bitW 23 | self.bitA = bitA 24 | self.num_bases = num_bases 25 | self.add_gate = add_gate 26 | self.relu = nn.ReLU() 27 | self.conv1 = nn.ModuleList([conv3x3(inplanes, planes, bitW, stride) for i in range(num_bases)]) 28 | self.bn1 = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(num_bases)]) 29 | self.conv2 = nn.ModuleList([conv3x3(planes, planes, bitW) for i in range(num_bases)]) 30 | self.bn2 = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(num_bases)]) 31 | self.downsample = downsample 32 | self.scales = nn.ParameterList([nn.Parameter(torch.rand(1), requires_grad=True) for i in range(num_bases)]) 33 | if add_gate: 34 | self.block_gate = nn.Parameter(torch.rand(1), requires_grad=True) 35 | 36 | def quan_activations(self, x, bitA): 37 | if bitA == 32: 38 | return nn.Tanh()(x) 39 | else: 40 | return Q_A.apply(x) 41 | 42 | def forward(self, input_bases, input_mean): 43 | 44 | final_output = None 45 | output_bases = [] 46 | 47 | if self.add_gate: 48 | 49 | for base, conv1, conv2, bn1, bn2, scale in zip(input_bases, self.conv1, self.conv2, self.bn1, self.bn2, self.scales): 50 | 51 | x = nn.Sigmoid()(self.block_gate) * base + (1.0 - nn.Sigmoid()(self.block_gate)) * input_mean 52 | 53 | if self.downsample is not None: 54 | x = self.quan_activations(x, self.bitA) 55 | residual = self.downsample(x) 56 | else: 57 | residual = x 58 | x = self.quan_activations(x, self.bitA) 59 | 60 | out = conv1(x) 61 | out = self.relu(out) 62 | out = bn1(out) 63 | out += residual 64 | 65 | out_new = self.quan_activations(out, self.bitA) 66 | out_new = conv2(out_new) 67 | out_new = self.relu(out_new) 68 | out_new = bn2(out_new) 69 | out_new += out 70 | 71 | output_bases.append(out_new) 72 | 73 | if final_output is None: 74 | final_output = scale * out_new 75 | else: 76 | final_output += scale * out_new 77 | 78 | else: 79 | 80 | if self.downsample is not None: 81 | x = self.quan_activations(input_mean, self.bitA) 82 | residual = self.downsample(x) 83 | else: 84 | residual = input_mean 85 | x = self.quan_activations(input_mean, self.bitA) 86 | 87 | for conv1, conv2, bn1, bn2, scale in zip(self.conv1, self.conv2, self.bn1, self.bn2, self.scales): 88 | 89 | out = conv1(x) 90 | out = self.relu(out) 91 | out = bn1(out) 92 | out += residual 93 | 94 | out_new = self.quan_activations(out, self.bitA) 95 | out_new = conv2(out_new) 96 | out_new = self.relu(out_new) 97 | out_new = bn2(out_new) 98 | out_new += out 99 | 100 | output_bases.append(out_new) 101 | 102 | if final_output is None: 103 | final_output = scale * out_new 104 | else: 105 | final_output += scale * out_new 106 | 107 | 108 | return output_bases, final_output 109 | 110 | 111 | 112 | class downsample_layer(nn.Module): 113 | def __init__(self, inplanes, planes, bitW, kernel_size=1, stride=1, bias=False): 114 | super(downsample_layer, self).__init__() 115 | self.conv = self_conv(inplanes, planes, bitW, kernel_size=kernel_size, stride=stride, bias=False) 116 | self.batch_norm = nn.BatchNorm2d(planes) 117 | 118 | def forward(self, x): 119 | x = self.conv(x) 120 | x = self.batch_norm(x) 121 | return x 122 | 123 | 124 | 125 | class ResNet(nn.Module): 126 | 127 | def __init__(self, block, layers, bitW, bitA, num_classes=1000): 128 | self.inplanes = 64 129 | self.num_bases = 5 130 | self.bitW = bitW 131 | self.bitA = bitA 132 | super(ResNet, self).__init__() 133 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=True) 134 | self.bn1 = nn.BatchNorm2d(64) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0], add_gate=False) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) #don't quantize the last layer 140 | self.avgpool = nn.AvgPool2d(7) 141 | self.fc = nn.Linear(512 * block.expansion, num_classes) 142 | 143 | 144 | def _make_layer(self, block, planes, blocks, stride=1, add_gate=True): 145 | downsample = None 146 | if stride != 1 or self.inplanes != planes * block.expansion: 147 | downsample = downsample_layer(self.inplanes, planes * block.expansion, self.bitW, 148 | kernel_size=1, stride=stride, bias=False) 149 | 150 | layers = nn.ModuleList([]) 151 | layers.append(block(self.num_bases, self.inplanes, planes, self.bitW, self.bitA, stride, downsample, add_gate)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.num_bases, self.inplanes, planes, self.bitW, self.bitA)) 155 | 156 | return layers 157 | 158 | 159 | def forward(self, x): 160 | x = self.conv1(x) 161 | x = self.maxpool(x) 162 | x = self.bn1(x) 163 | 164 | sep_out = None 165 | sum_out = x 166 | for layer in self.layer1: 167 | sep_out, sum_out = layer(sep_out, sum_out) 168 | 169 | for layer in self.layer2: 170 | sep_out, sum_out = layer(sep_out, sum_out) 171 | 172 | for layer in self.layer3: 173 | sep_out, sum_out = layer(sep_out, sum_out) 174 | 175 | for layer in self.layer4: 176 | sep_out, sum_out = layer(sep_out, sum_out) 177 | 178 | out = self.avgpool(sum_out) 179 | out = out.view(out.size(0), -1) 180 | out = self.fc(out) 181 | 182 | return out 183 | 184 | 185 | def resnet18(bitW, bitA, pretrained=False, **kwargs): 186 | """Constructs a ResNet-18 model. 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(BasicBlock, [2, 2, 2, 2], bitW, bitA, **kwargs) 191 | if pretrained: 192 | load_dict = torch.load('./full_precision_records/model_best.pth.tar')['state_dict'] 193 | model_dict = model.state_dict() 194 | model_keys = model_dict.keys() 195 | for name, param in load_dict.items(): 196 | if name.replace('module.', '') in model_keys: 197 | model_dict[name.replace('module.', '')] = param 198 | model.load_state_dict(model_dict) 199 | return model 200 | 201 | 202 | def resnet34(bitW, bitA, pretrained=False, **kwargs): 203 | """Constructs a ResNet-34 model. 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | """ 207 | model = ResNet(BasicBlock, [3, 4, 6, 3], bitW, bitA, **kwargs) 208 | if pretrained: 209 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 210 | return model 211 | 212 | 213 | def resnet50(bitW, bitA, pretrained=False, **kwargs): 214 | """Constructs a ResNet-50 model. 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | """ 218 | model = ResNet(Bottleneck, [3, 4, 6, 3], bitW, bitA, **kwargs) 219 | return model 220 | -------------------------------------------------------------------------------- /model_ablation_2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | from new_layers import self_conv, Q_A 7 | import torch.nn.init as init 8 | 9 | 10 | 11 | def conv3x3(in_planes, out_planes, bitW, stride=1): 12 | "3x3 convolution with padding" 13 | return self_conv(in_planes, out_planes, bitW, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, num_bases, inplanes, planes, bitW, bitA, stride=1, downsample=None, quantize=True): 21 | super(BasicBlock, self).__init__() 22 | self.bitW = bitW 23 | self.bitA = bitA 24 | self.num_bases = num_bases 25 | self.relu = nn.ReLU() 26 | self.conv1 = nn.ModuleList([conv3x3(inplanes, planes, bitW, stride) for i in range(num_bases)]) 27 | self.bn1 = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(num_bases)]) 28 | self.conv2 = nn.ModuleList([conv3x3(planes, planes, bitW) for i in range(num_bases)]) 29 | self.bn2 = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(num_bases)]) 30 | self.downsample = downsample 31 | self.scales = nn.ParameterList([nn.Parameter(torch.rand(1), requires_grad=True) for i in range(num_bases)]) 32 | 33 | def quan_activations(self, x, bitA): 34 | if bitA == 32: 35 | return nn.Tanh()(x) 36 | else: 37 | return Q_A.apply(x) 38 | 39 | 40 | def forward(self, x): 41 | 42 | final_output = None 43 | if self.downsample is not None: 44 | x = self.quan_activations(x, self.bitA) 45 | residual = self.downsample(x) 46 | else: 47 | residual = x 48 | x = self.quan_activations(x, self.bitA) 49 | 50 | for conv1, conv2, bn1, bn2, scale in zip(self.conv1, self.conv2, self.bn1, self.bn2, self.scales): 51 | 52 | out = conv1(x) 53 | out = self.relu(out) 54 | out = bn1(out) 55 | out += residual 56 | 57 | out_new = self.quan_activations(out, self.bitA) 58 | out_new = conv2(out_new) 59 | out_new = self.relu(out_new) 60 | out_new = bn2(out_new) 61 | out_new += out 62 | 63 | if final_output is None: 64 | final_output = scale * out_new 65 | else: 66 | final_output += scale * out_new 67 | 68 | return final_output 69 | 70 | 71 | 72 | class downsample_layer(nn.Module): 73 | def __init__(self, inplanes, planes, bitW, kernel_size=1, stride=1, bias=False): 74 | super(downsample_layer, self).__init__() 75 | self.conv = self_conv(inplanes, planes, bitW, kernel_size=kernel_size, stride=stride, bias=False) 76 | self.batch_norm = nn.BatchNorm2d(planes) 77 | 78 | def forward(self, x): 79 | x = self.conv(x) 80 | x = self.batch_norm(x) 81 | return x 82 | 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, block, layers, bitW, bitA, num_classes=1000): 88 | self.inplanes = 64 89 | self.num_bases = 5 90 | self.bitW = bitW 91 | self.bitA = bitA 92 | super(ResNet, self).__init__() 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=True) 94 | self.bn1 = nn.BatchNorm2d(64) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) #don't quantize the last layer 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | 104 | def _make_layer(self, block, planes, blocks, stride=1): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = downsample_layer(self.inplanes, planes * block.expansion, self.bitW, 108 | kernel_size=1, stride=stride, bias=False) 109 | 110 | layers = [] 111 | layers.append(block(self.num_bases, self.inplanes, planes, self.bitW, self.bitA, stride, downsample)) 112 | self.inplanes = planes * block.expansion 113 | for i in range(1, blocks): 114 | layers.append(block(self.num_bases, self.inplanes, planes, self.bitW, self.bitA)) 115 | 116 | return nn.Sequential(*layers) 117 | 118 | 119 | def forward(self, x): 120 | x = self.conv1(x) 121 | x = self.maxpool(x) 122 | x = self.bn1(x) 123 | 124 | x1 = self.layer1(x) 125 | x2 = self.layer2(x1) 126 | x3 = self.layer3(x2) 127 | x4 = self.layer4(x3) 128 | 129 | x4 = self.avgpool(x4) 130 | x4 = x4.view(x4.size(0), -1) 131 | x5 = self.fc(x4) 132 | 133 | return x5 134 | 135 | 136 | def resnet18(bitW, bitA, pretrained=False, **kwargs): 137 | """Constructs a ResNet-18 model. 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | """ 141 | model = ResNet(BasicBlock, [2, 2, 2, 2], bitW, bitA, **kwargs) 142 | if pretrained: 143 | load_dict = torch.load('./full_precision_weights/model_best.pth.tar')['state_dict'] 144 | model_dict = model.state_dict() 145 | model_keys = model_dict.keys() 146 | for name, param in load_dict.items(): 147 | if name.replace('module.', '') in model_keys: 148 | model_dict[name.replace('module.', '')] = param 149 | model.load_state_dict(model_dict) 150 | return model 151 | 152 | 153 | def resnet34(bitW, bitA, pretrained=False, **kwargs): 154 | """Constructs a ResNet-34 model. 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = ResNet(BasicBlock, [3, 4, 6, 3], bitW, bitA, **kwargs) 159 | if pretrained: 160 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 161 | return model 162 | 163 | 164 | def resnet50(bitW, bitA, pretrained=False, **kwargs): 165 | """Constructs a ResNet-50 model. 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet(Bottleneck, [3, 4, 6, 3], bitW, bitA, **kwargs) 170 | return model 171 | -------------------------------------------------------------------------------- /new_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd.function import Function 5 | import torch.nn.functional as F 6 | import math 7 | from torch.autograd import Variable 8 | 9 | class Q_A(torch.autograd.Function): #dorefanet, but constrain to {-1, 1} 10 | @staticmethod 11 | def forward(ctx, x): 12 | ctx.save_for_backward(x) 13 | return x.sign() 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | input, = ctx.saved_tensors 17 | grad_input = grad_output.clone() 18 | grad_input.masked_fill_(input>1.0, 0.0) 19 | grad_input.masked_fill_(input<-1.0, 0.0) 20 | mask_pos = (input>=0.0) & (input<1.0) 21 | mask_neg = (input<0.0) & (input>=-1.0) 22 | grad_input.masked_scatter_(mask_pos, input[mask_pos].mul_(-2.0).add_(2.0)) 23 | grad_input.masked_scatter_(mask_neg, input[mask_neg].mul_(2.0).add_(2.0)) 24 | return grad_input * grad_output 25 | 26 | 27 | 28 | class Q_W(torch.autograd.Function): # xnor-net, but gradient use identity approximation 29 | @staticmethod 30 | def forward(ctx, x): 31 | return x.sign() * x.abs().mean() 32 | @staticmethod 33 | def backward(ctx, grad): 34 | return grad 35 | 36 | 37 | def quantize_a(x): 38 | x = Q_A.apply(x) 39 | return x 40 | 41 | 42 | def quantize_w(x): 43 | x = Q_W.apply(x) 44 | return x 45 | 46 | 47 | def fw(x, bitW): 48 | if bitW == 32: 49 | return x 50 | x = quantize_w(x) 51 | return x 52 | 53 | 54 | def fa(x, bitA): 55 | if bitA == 32: 56 | return x 57 | return quantize_a(x) 58 | 59 | 60 | def nonlinear(x): 61 | return torch.clamp(torch.clamp(x, max=1.0), min=0.0) 62 | 63 | 64 | 65 | class self_conv(nn.Conv2d): 66 | def __init__(self, in_channels, out_channels, bitW, kernel_size, stride=1, padding=0, bias=False): 67 | super(self_conv, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) 68 | self.bitW = bitW 69 | self.padding = padding 70 | self.stride = stride 71 | 72 | def forward(self, input): 73 | if self.padding > 0: 74 | padding_shape = (self.padding, self.padding, self.padding, self.padding) 75 | input = F.pad(input, padding_shape, 'constant', 1) #padding 1 76 | output = F.conv2d(input, fw(self.weight, self.bitW), bias=self.bias, stride=self.stride, dilation=self.dilation, groups=self.groups) 77 | return output 78 | 79 | 80 | 81 | class new_conv(nn.Module): 82 | def __init__(self, num_bases, bitW, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False): 83 | super(new_conv, self).__init__() 84 | self.in_channels = in_channels 85 | self.out_channels = out_channels 86 | self.kernel_size = kernel_size 87 | self.stride = stride 88 | self.padding = padding 89 | self.bias = bias 90 | self.convs = nn.ModuleList([self_conv(self.in_channels, self.out_channels, bitW, self.kernel_size, self.stride, bias=self.bias) for i in range(num_bases)]) 91 | self.scales = nn.ParameterList([nn.Parameter(torch.rand(1).cuda(), requires_grad=True) for i in range(num_bases)]) 92 | 93 | 94 | def forward(self, input): 95 | output = None 96 | if self.padding > 0: 97 | padding_shape = (self.padding, self.padding, self.padding, self.padding) 98 | input = F.pad(input, padding_shape, 'constant', 1) #padding 1 99 | 100 | for scale, module in zip(self.scales, self.convs): 101 | if output is None: 102 | output = scale * module(input) 103 | else: 104 | output += scale * module(input) 105 | return output 106 | 107 | 108 | 109 | class clip_nonlinear(nn.Module): 110 | def __init__(self, bitA): 111 | super(clip_nonlinear, self).__init__() 112 | self.bitA = bitA 113 | 114 | def forward(self, input): 115 | return fa(input, self.bitA) 116 | 117 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import time 5 | import logging 6 | import sys 7 | import argparse 8 | import torch 9 | import glob 10 | from torch.autograd import Variable 11 | from torchvision import transforms, datasets 12 | from new_layers import self_conv 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from model import resnet18 16 | import utils 17 | from utils import adjust_learning_rate, save_checkpoint 18 | import numpy as np 19 | from random import shuffle 20 | 21 | 22 | parser = argparse.ArgumentParser("ImageNet") 23 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 24 | parser.add_argument('--learning_rate', type=float, default=0.05, help='init learning rate') 25 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 26 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay') 27 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency') 28 | parser.add_argument('--epochs', type=int, default=40, help='num of training epochs') 29 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 30 | parser.add_argument('--seed', type=int, default=2, help='random seed') 31 | parser.add_argument('--grad_clip', type=float, default=10, help='gradient clipping') 32 | parser.add_argument('--resume_train', action='store_true', default=False, help='resume training') 33 | parser.add_argument('--resume_dir', type=str, default='./weights/checkpoint.pth.tar', help='save weights directory') 34 | parser.add_argument('--load_epoch', type=int, default=30, help='random seed') 35 | parser.add_argument('--weights_dir', type=str, default='./weights/', help='save weights directory') 36 | parser.add_argument('--learning_step', type=list, default=[25,35,40], help='learning rate steps') 37 | 38 | 39 | args = parser.parse_args() 40 | 41 | log_format = '%(asctime)s %(message)s' 42 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 43 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 44 | if not os.path.exists(args.save): 45 | os.makedirs(args.save) 46 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 47 | fh.setFormatter(logging.Formatter(log_format)) 48 | logging.getLogger().addHandler(fh) 49 | 50 | 51 | def main(): 52 | 53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 54 | std=[0.229, 0.224, 0.225]) 55 | 56 | 57 | # Image Preprocessing 58 | train_transform = transforms.Compose([ 59 | transforms.RandomResizedCrop(224), 60 | transforms.RandomHorizontalFlip(), 61 | transforms.ToTensor(), 62 | normalize,]) 63 | 64 | test_transform = transforms.Compose([ 65 | transforms.Resize(256), 66 | transforms.CenterCrop(224), 67 | transforms.ToTensor(), 68 | normalize,]) 69 | 70 | 71 | num_epochs = args.epochs 72 | batch_size = args.batch_size 73 | 74 | train_dataset = datasets.folder.ImageFolder(root='/fast/users/a1675776/data/imagenet/train/', transform=train_transform) 75 | test_dataset = datasets.folder.ImageFolder(root='/fast/users/a1675776/data/imagenet/val/', transform=test_transform) 76 | 77 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 78 | batch_size=batch_size, 79 | shuffle=True, num_workers=10, pin_memory=True) 80 | 81 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 82 | batch_size=batch_size, 83 | shuffle=False, num_workers=10, pin_memory=True) 84 | 85 | 86 | 87 | 88 | num_train = train_dataset.__len__() 89 | n_train_batches = math.floor(num_train / batch_size) 90 | 91 | 92 | criterion = nn.CrossEntropyLoss().cuda() 93 | bitW = 32 94 | bitA = 32 95 | model = resnet18(bitW, bitA) 96 | model = utils.dataparallel(model, 3) 97 | 98 | 99 | print("Compilation complete, starting training...") 100 | 101 | test_record = [] 102 | train_record = [] 103 | learning_rate = args.learning_rate 104 | epoch = 0 105 | step_idx = 0 106 | best_top1 = 0 107 | 108 | 109 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=args.weight_decay, momentum=args.momentum) 110 | 111 | for m in model.modules(): 112 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, self_conv): 113 | c = float(m.weight.data[0].nelement()) 114 | torch.nn.init.xavier_uniform(m.weight) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data = m.weight.data.zero_().add(1.0) 117 | 118 | 119 | while epoch < num_epochs: 120 | 121 | epoch = epoch + 1 122 | # resume training 123 | if (args.resume_train) and (epoch == 1): 124 | checkpoint = torch.load(args.resume_dir) 125 | epoch = checkpoint['epoch'] 126 | learning_rate = checkpoint['learning_rate'] 127 | optimizer.load_state_dict(checkpoint['optimizer']) 128 | step_idx = checkpoint['step_idx'] 129 | model.load_state_dict(checkpoint['state_dict']) 130 | test_record = list( 131 | np.load(args.weights_dir + 'test_record.npy')) 132 | train_record = list( 133 | np.load(args.weights_dir + 'train_record.npy')) 134 | 135 | logging.info('epoch %d lr %e', epoch, learning_rate) 136 | 137 | # training 138 | train_acc_top1, train_acc_top5, train_obj = train(train_loader, model, criterion, optimizer) 139 | logging.info('train_acc %f', train_acc_top1) 140 | train_record.append([train_acc_top1, train_acc_top5]) 141 | np.save(args.weights_dir + 'train_record.npy', train_record) 142 | 143 | # test 144 | test_acc_top1, test_acc_top5, test_obj = infer(test_loader, model, criterion) 145 | is_best = test_acc_top1 > best_top1 146 | if is_best: 147 | best_top1 = test_acc_top1 148 | 149 | logging.info('test_acc %f', test_acc_top1) 150 | test_record.append([test_acc_top1, test_acc_top5]) 151 | np.save(args.weights_dir + 'test_record.npy', test_record) 152 | 153 | save_checkpoint({ 154 | 'epoch': epoch + 1, 155 | 'state_dict': model.state_dict(), 156 | 'optimizer' : optimizer.state_dict(), 157 | 'best_top1': best_top1, 158 | 'step_idx': step_idx, 159 | 'learning_rate': learning_rate, 160 | }, args, is_best) 161 | 162 | step_idx, learning_rate = utils.adjust_learning_rate(args, epoch, step_idx, 163 | learning_rate) 164 | 165 | for param_group in optimizer.param_groups: 166 | param_group['lr'] = learning_rate 167 | 168 | 169 | def train(train_queue, model, criterion, optimizer): 170 | 171 | objs = utils.AvgrageMeter() 172 | top1 = utils.AvgrageMeter() 173 | top5 = utils.AvgrageMeter() 174 | 175 | model.train() 176 | 177 | for step, (input, target) in enumerate(train_queue): 178 | 179 | n = input.size(0) 180 | input = input.cuda() 181 | target = target.cuda() 182 | 183 | 184 | logits = model(input) 185 | loss = criterion(logits, target) 186 | 187 | optimizer.zero_grad() 188 | loss.backward() 189 | 190 | optimizer.step() 191 | 192 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 193 | objs.update(loss.item(), n) 194 | top1.update(prec1.item(), n) 195 | top5.update(prec5.item(), n) 196 | 197 | if step % args.report_freq == 0: 198 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 199 | 200 | return top1.avg, top5.avg, objs.avg 201 | 202 | 203 | 204 | def infer(valid_queue, model, criterion): 205 | objs = utils.AvgrageMeter() 206 | top1 = utils.AvgrageMeter() 207 | top5 = utils.AvgrageMeter() 208 | model.eval() 209 | 210 | with torch.no_grad(): 211 | for step, (input, target) in enumerate(valid_queue): 212 | input = input.cuda() 213 | target = target.cuda() 214 | 215 | logits = model(input) 216 | loss = criterion(logits, target) 217 | 218 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 219 | n = input.size(0) 220 | objs.update(loss.item(), n) 221 | top1.update(prec1.item(), n) 222 | top5.update(prec5.item(), n) 223 | 224 | if step % args.report_freq == 0: 225 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 226 | 227 | return top1.avg, top5.avg, objs.avg 228 | 229 | 230 | if __name__ == '__main__': 231 | utils.create_folder(args) 232 | main() 233 | 234 | -------------------------------------------------------------------------------- /read_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | import scipy 6 | import numpy as np 7 | import time 8 | import os 9 | 10 | 11 | def img_loader(path): 12 | return Image.open(path).convert('RGB') 13 | 14 | 15 | def list_reader(flist): 16 | 17 | imlist = [] 18 | with open(flist, 'r') as rf: 19 | for line in rf.readlines(): 20 | impath, imlabel = line.split() 21 | imlist.append((impath, int(imlabel))) 22 | 23 | return imlist 24 | 25 | 26 | class MyDataset(data.dataset.Dataset): 27 | def __init__(self, flist, transform=None): 28 | 29 | self.imlist = list_reader(flist) 30 | self.transform = transform 31 | 32 | def __getitem__(self, idx): 33 | 34 | impath, target = self.imlist[idx] 35 | img = img_loader(os.path.join('/data/val', impath)) 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img, target 40 | 41 | 42 | def __len__(self): 43 | return len(self.imlist) 44 | 45 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import shutil 5 | from random import shuffle 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torchvision import transforms 11 | import shutil 12 | 13 | 14 | def data_transforms_cifar100(): 15 | CIFAR_MEAN = [0.5071, 0.4867, 0.4408] 16 | CIFAR_STD = [0.2675, 0.2565, 0.2761] 17 | 18 | train_transform = transforms.Compose([ 19 | transforms.RandomCrop(32, padding=4), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 23 | ]) 24 | 25 | valid_transform = transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 28 | ]) 29 | return train_transform, valid_transform 30 | 31 | 32 | def unpickle(file): 33 | 34 | with open(file, 'rb') as fo: 35 | dict = cPickle.load(fo) 36 | return dict 37 | 38 | 39 | def unpack_data(config): 40 | 41 | train_file = config['train_file'] 42 | test_file = config['test_file'] 43 | train_data = unpickle(train_file) 44 | test_data = unpickle(test_file) 45 | 46 | return train_data, test_data 47 | 48 | 49 | def adjust_learning_rate(args, epoch, step_idx, learning_rate): 50 | 51 | if epoch == args.learning_step[step_idx]: 52 | learning_rate = learning_rate * 0.1 53 | step_idx += 1 54 | return step_idx, learning_rate 55 | 56 | 57 | 58 | def create_folder(args): 59 | 60 | if not os.path.exists(args.weights_dir): 61 | os.makedirs(args.weights_dir) 62 | print("Creat folder: " + args.weights_dir) 63 | 64 | 65 | class AvgrageMeter(object): 66 | 67 | def __init__(self): 68 | self.reset() 69 | 70 | def reset(self): 71 | self.avg = 0 72 | self.sum = 0 73 | self.cnt = 0 74 | 75 | def update(self, val, n=1): 76 | self.sum += val * n 77 | self.cnt += n 78 | self.avg = self.sum / self.cnt 79 | 80 | 81 | def accuracy(output, target, topk=(1,)): 82 | maxk = max(topk) 83 | batch_size = target.size(0) 84 | 85 | _, pred = output.topk(maxk, 1, True, True) 86 | pred = pred.t() 87 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 88 | 89 | res = [] 90 | for k in topk: 91 | correct_k = correct[:k].view(-1).float().sum(0) 92 | res.append(correct_k.mul_(100.0/batch_size)) 93 | return res 94 | 95 | 96 | 97 | 98 | def create_exp_dir(path, scripts_to_save=None): 99 | if not os.path.exists(path): 100 | os.mkdir(path) 101 | print('Experiment dir : {}'.format(path)) 102 | 103 | if scripts_to_save is not None: 104 | os.mkdir(os.path.join(path, 'scripts')) 105 | for script in scripts_to_save: 106 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 107 | shutil.copyfile(script, dst_file) 108 | 109 | 110 | 111 | def dataparallel(model, ngpus, gpu0=0): 112 | if ngpus==0: 113 | assert False, "only support gpu mode" 114 | gpu_list = list(range(gpu0, gpu0+ngpus)) 115 | assert torch.cuda.device_count() >= gpu0 + ngpus 116 | if ngpus > 1: 117 | if not isinstance(model, nn.DataParallel): 118 | model = nn.DataParallel(model, gpu_list).cuda() 119 | else: 120 | model = model.cuda() 121 | return model 122 | 123 | 124 | def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): 125 | torch.save(state, os.path.join(args.weights_dir, filename)) 126 | if is_best: 127 | shutil.copy(os.path.join(args.weights_dir, filename), 128 | os.path.join(args.weights_dir, 'model_best.pth.tar')) 129 | --------------------------------------------------------------------------------