├── README.md ├── mobilenet ├── 1_step1 │ ├── reactnet.py │ ├── run.sh │ └── train.py └── 2_step2 │ ├── reactnet.py │ ├── run.sh │ └── train.py ├── resnet ├── 1_step1 │ ├── birealnet.py │ ├── run.sh │ └── train.py └── 2_step2 │ ├── birealnet.py │ ├── run.sh │ └── train.py └── utils ├── KD_loss.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ReActNet 2 | 3 | This is the pytorch implementation of our paper ["ReActNet: Towards Precise Binary NeuralNetwork with Generalized Activation Functions"](https://arxiv.org/abs/2003.03488), published in ECCV 2020. 4 | 5 |
6 | 7 |
8 | 9 | In this paper, we propose to generalize the traditional Sign and PReLU functions to RSign and RPReLU, which enable explicit learning of the distribution reshape and shift at near-zero extra cost. By adding simple learnable bias, ReActNet achieves 69.4% top-1 accuracy on Imagenet dataset with both weights and activations being binary, a near ResNet-level accuracy. 10 | 11 | ## Citation 12 | 13 | If you find our code useful for your research, please consider citing: 14 | 15 | @inproceedings{liu2020reactnet, 16 | title={ReActNet: Towards Precise Binary Neural Network with Generalized Activation Functions}, 17 | author={Liu, Zechun and Shen, Zhiqiang and Savvides, Marios and Cheng, Kwang-Ting}, 18 | booktitle={European Conference on Computer Vision (ECCV)}, 19 | year={2020} 20 | } 21 | 22 | ## Run 23 | 24 | ### 1. Requirements: 25 | * python3, pytorch 1.4.0, torchvision 0.5.0 26 | 27 | ### 2. Data: 28 | * Download ImageNet dataset 29 | 30 | ### 3. Steps to run: 31 | (1) Step1: binarizing activations 32 | * Change directory to `./resnet/1_step1/` or `./mobilenet/1_step1/` 33 | * run `bash run.sh` 34 | 35 | (2) Step2: binarizing weights + activations 36 | * Change directory to `./resnet/2_step2/` or `./mobilenet/2_step2/` 37 | * run `bash run.sh` 38 | 39 | 40 | ## Models 41 | 42 | | Methods | Top1-Acc | FLOPs | Trained Model | 43 | | --- | --- | --- | --- | 44 | | XNOR-Net | 51.2% | 1.67 x 10^8 | - | 45 | | Bi-Real Net| 56.4% | 1.63 x 10^8 | - | 46 | | Real-to-Binary| 65.4% | 1.83 x 10^8 | - | 47 | | ReActNet (Bi-Real based) | 65.9% | 1.63 x 10^8 | [Model-ReAct-ResNet](https://hkustconnect-my.sharepoint.com/:u:/g/personal/zliubq_connect_ust_hk/EY9P7mxs-8BLkTlqMZif4s4BnNWcKbUnvqeA_CvN3c9q4w?e=IpUyF4) | 48 | | ReActNet-A | 69.5% | 0.87 x 10^8 | [Model-ReAct-MobileNet](https://hkustconnect-my.sharepoint.com/:u:/g/personal/zliubq_connect_ust_hk/EW1FVkAKN5dJg1ns_CcMtQoBJAy1Yxx-b7lpaTFjTJIUKw?e=oyebWy) | 49 | 50 | ## Contact 51 | 52 | Zechun Liu, HKUST (zliubq at connect.ust.hk) 53 | 54 | Zhiqiang Shen, CMU (zhiqians at andrew.cmu.edu) 55 | -------------------------------------------------------------------------------- /mobilenet/1_step1/reactnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | class firstconv3x3(nn.Module): 20 | def __init__(self, inp, oup, stride): 21 | super(firstconv3x3, self).__init__() 22 | 23 | self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(oup) 25 | 26 | def forward(self, x): 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | 31 | return out 32 | 33 | class BinaryActivation(nn.Module): 34 | def __init__(self): 35 | super(BinaryActivation, self).__init__() 36 | 37 | def forward(self, x): 38 | out_forward = torch.sign(x) 39 | mask1 = x < -1 40 | mask2 = x < 0 41 | mask3 = x < 1 42 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) 43 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) 44 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) 45 | out = out_forward.detach() - out3.detach() + out3 46 | 47 | return out 48 | 49 | class LearnableBias(nn.Module): 50 | def __init__(self, out_chn): 51 | super(LearnableBias, self).__init__() 52 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) 53 | 54 | def forward(self, x): 55 | out = x + self.bias.expand_as(x) 56 | return out 57 | 58 | class BasicBlock(nn.Module): 59 | def __init__(self, inplanes, planes, stride=1): 60 | super(BasicBlock, self).__init__() 61 | norm_layer = nn.BatchNorm2d 62 | 63 | self.move11 = LearnableBias(inplanes) 64 | self.binary_3x3= conv3x3(inplanes, inplanes, stride=stride) 65 | self.bn1 = norm_layer(inplanes) 66 | 67 | self.move12 = LearnableBias(inplanes) 68 | self.prelu1 = nn.PReLU(inplanes) 69 | self.move13 = LearnableBias(inplanes) 70 | 71 | self.move21 = LearnableBias(inplanes) 72 | 73 | if inplanes == planes: 74 | self.binary_pw = conv1x1(inplanes, planes) 75 | self.bn2 = norm_layer(planes) 76 | else: 77 | self.binary_pw_down1 = conv1x1(inplanes, inplanes) 78 | self.binary_pw_down2 = conv1x1(inplanes, inplanes) 79 | self.bn2_1 = norm_layer(inplanes) 80 | self.bn2_2 = norm_layer(inplanes) 81 | 82 | self.move22 = LearnableBias(planes) 83 | self.prelu2 = nn.PReLU(planes) 84 | self.move23 = LearnableBias(planes) 85 | 86 | self.binary_activation = BinaryActivation() 87 | self.stride = stride 88 | self.inplanes = inplanes 89 | self.planes = planes 90 | 91 | if self.inplanes != self.planes: 92 | self.pooling = nn.AvgPool2d(2,2) 93 | 94 | def forward(self, x): 95 | 96 | out1 = self.move11(x) 97 | 98 | out1 = self.binary_activation(out1) 99 | out1 = self.binary_3x3(out1) 100 | out1 = self.bn1(out1) 101 | 102 | if self.stride == 2: 103 | x = self.pooling(x) 104 | 105 | out1 = x + out1 106 | 107 | out1 = self.move12(out1) 108 | out1 = self.prelu1(out1) 109 | out1 = self.move13(out1) 110 | 111 | out2 = self.move21(out1) 112 | out2 = self.binary_activation(out2) 113 | 114 | if self.inplanes == self.planes: 115 | out2 = self.binary_pw(out2) 116 | out2 = self.bn2(out2) 117 | out2 += out1 118 | 119 | else: 120 | assert self.planes == self.inplanes * 2 121 | 122 | out2_1 = self.binary_pw_down1(out2) 123 | out2_2 = self.binary_pw_down2(out2) 124 | out2_1 = self.bn2_1(out2_1) 125 | out2_2 = self.bn2_2(out2_2) 126 | out2_1 += out1 127 | out2_2 += out1 128 | out2 = torch.cat([out2_1, out2_2], dim=1) 129 | 130 | out2 = self.move22(out2) 131 | out2 = self.prelu2(out2) 132 | out2 = self.move23(out2) 133 | 134 | return out2 135 | 136 | 137 | class reactnet(nn.Module): 138 | def __init__(self, num_classes=1000): 139 | super(reactnet, self).__init__() 140 | self.feature = nn.ModuleList() 141 | for i in range(len(stage_out_channel)): 142 | if i == 0: 143 | self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) 144 | elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: 145 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) 146 | else: 147 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) 148 | self.pool1 = nn.AdaptiveAvgPool2d(1) 149 | self.fc = nn.Linear(1024, num_classes) 150 | 151 | def forward(self, x): 152 | for i, block in enumerate(self.feature): 153 | x = block(x) 154 | 155 | x = self.pool1(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /mobilenet/1_step1/run.sh: -------------------------------------------------------------------------------- 1 | clear 2 | mkdir log 3 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch 4 | python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=1.25e-3 --epochs=128 --weight_decay=1e-5 | tee -a log/training.txt 5 | # 256 epoch setting: longer training, similar performance to 128 epoch 6 | # python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=5e-4 --epochs=256 --weight_decay=1e-5 | tee -a log/training.txt 7 | -------------------------------------------------------------------------------- /mobilenet/1_step1/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import numpy as np 5 | import time, datetime 6 | import torch 7 | import random 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.utils.data.distributed 15 | 16 | sys.path.append("../../") 17 | from utils.utils import * 18 | from utils import KD_loss 19 | from torchvision import datasets, transforms 20 | from torch.autograd import Variable 21 | from reactnet import reactnet 22 | import torchvision.models as models 23 | 24 | parser = argparse.ArgumentParser("birealnet18") 25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs') 27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate') 28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') 30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models') 31 | parser.add_argument('--data', metavar='DIR', help='path to dataset') 32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') 33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet') 34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | args = parser.parse_args() 37 | 38 | CLASSES = 1000 39 | 40 | if not os.path.exists('log'): 41 | os.mkdir('log') 42 | 43 | log_format = '%(asctime)s %(message)s' 44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 45 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 46 | fh = logging.FileHandler(os.path.join('log/log.txt')) 47 | fh.setFormatter(logging.Formatter(log_format)) 48 | logging.getLogger().addHandler(fh) 49 | 50 | def main(): 51 | if not torch.cuda.is_available(): 52 | sys.exit(1) 53 | start_t = time.time() 54 | 55 | cudnn.benchmark = True 56 | cudnn.enabled=True 57 | logging.info("args = %s", args) 58 | 59 | # load model 60 | model_teacher = models.__dict__[args.teacher](pretrained=True) 61 | model_teacher = nn.DataParallel(model_teacher).cuda() 62 | for p in model_teacher.parameters(): 63 | p.requires_grad = False 64 | model_teacher.eval() 65 | 66 | model_student = reactnet() 67 | logging.info('student:') 68 | logging.info(model_student) 69 | model_student = nn.DataParallel(model_student).cuda() 70 | 71 | criterion = nn.CrossEntropyLoss() 72 | criterion = criterion.cuda() 73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) 74 | criterion_smooth = criterion_smooth.cuda() 75 | criterion_kd = KD_loss.DistributionLoss() 76 | 77 | all_parameters = model_student.parameters() 78 | weight_parameters = [] 79 | for pname, p in model_student.named_parameters(): 80 | if p.ndimension() == 4 or 'conv' in pname: 81 | weight_parameters.append(p) 82 | weight_parameters_id = list(map(id, weight_parameters)) 83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) 84 | 85 | optimizer = torch.optim.Adam( 86 | [{'params' : other_parameters}, 87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], 88 | lr=args.learning_rate,) 89 | 90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1) 91 | start_epoch = 0 92 | best_top1_acc= 0 93 | 94 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') 95 | if os.path.exists(checkpoint_tar): 96 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) 97 | checkpoint = torch.load(checkpoint_tar) 98 | start_epoch = checkpoint['epoch'] + 1 99 | best_top1_acc = checkpoint['best_top1_acc'] 100 | model_student.load_state_dict(checkpoint['state_dict'], strict=False) 101 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) 102 | 103 | # adjust the learning rate according to the checkpoint 104 | for epoch in range(start_epoch): 105 | scheduler.step() 106 | 107 | # load training data 108 | traindir = os.path.join(args.data, 'train') 109 | valdir = os.path.join(args.data, 'val') 110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 111 | std=[0.229, 0.224, 0.225]) 112 | 113 | # data augmentation 114 | crop_scale = 0.08 115 | lighting_param = 0.1 116 | train_transforms = transforms.Compose([ 117 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), 118 | Lighting(lighting_param), 119 | transforms.RandomHorizontalFlip(), 120 | transforms.ToTensor(), 121 | normalize]) 122 | 123 | train_dataset = datasets.ImageFolder( 124 | traindir, 125 | transform=train_transforms) 126 | 127 | train_loader = torch.utils.data.DataLoader( 128 | train_dataset, batch_size=args.batch_size, shuffle=True, 129 | num_workers=args.workers, pin_memory=True) 130 | 131 | # load validation data 132 | val_loader = torch.utils.data.DataLoader( 133 | datasets.ImageFolder(valdir, transforms.Compose([ 134 | transforms.Resize(256), 135 | transforms.CenterCrop(224), 136 | transforms.ToTensor(), 137 | normalize, 138 | ])), 139 | batch_size=args.batch_size, shuffle=False, 140 | num_workers=args.workers, pin_memory=True) 141 | 142 | # train the model 143 | epoch = start_epoch 144 | while epoch < args.epochs: 145 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler) 146 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args) 147 | 148 | is_best = False 149 | if valid_top1_acc > best_top1_acc: 150 | best_top1_acc = valid_top1_acc 151 | is_best = True 152 | 153 | save_checkpoint({ 154 | 'epoch': epoch, 155 | 'state_dict': model_student.state_dict(), 156 | 'best_top1_acc': best_top1_acc, 157 | 'optimizer' : optimizer.state_dict(), 158 | }, is_best, args.save) 159 | 160 | epoch += 1 161 | 162 | training_time = (time.time() - start_t) / 3600 163 | print('total training time = {} hours'.format(training_time)) 164 | 165 | 166 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler): 167 | batch_time = AverageMeter('Time', ':6.3f') 168 | data_time = AverageMeter('Data', ':6.3f') 169 | losses = AverageMeter('Loss', ':.4e') 170 | top1 = AverageMeter('Acc@1', ':6.2f') 171 | top5 = AverageMeter('Acc@5', ':6.2f') 172 | 173 | progress = ProgressMeter( 174 | len(train_loader), 175 | [batch_time, data_time, losses, top1, top5], 176 | prefix="Epoch: [{}]".format(epoch)) 177 | 178 | model_student.train() 179 | model_teacher.eval() 180 | end = time.time() 181 | scheduler.step() 182 | 183 | for param_group in optimizer.param_groups: 184 | cur_lr = param_group['lr'] 185 | print('learning_rate:', cur_lr) 186 | 187 | for i, (images, target) in enumerate(train_loader): 188 | data_time.update(time.time() - end) 189 | images = images.cuda() 190 | target = target.cuda() 191 | 192 | # compute outputy 193 | logits_student = model_student(images) 194 | logits_teacher = model_teacher(images) 195 | loss = criterion(logits_student, logits_teacher) 196 | 197 | # measure accuracy and record loss 198 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) 199 | n = images.size(0) 200 | losses.update(loss.item(), n) #accumulated loss 201 | top1.update(prec1.item(), n) 202 | top5.update(prec5.item(), n) 203 | 204 | # compute gradient and do SGD step 205 | optimizer.zero_grad() 206 | loss.backward() 207 | optimizer.step() 208 | 209 | # measure elapsed time 210 | batch_time.update(time.time() - end) 211 | end = time.time() 212 | 213 | progress.display(i) 214 | 215 | return losses.avg, top1.avg, top5.avg 216 | 217 | def validate(epoch, val_loader, model, criterion, args): 218 | batch_time = AverageMeter('Time', ':6.3f') 219 | losses = AverageMeter('Loss', ':.4e') 220 | top1 = AverageMeter('Acc@1', ':6.2f') 221 | top5 = AverageMeter('Acc@5', ':6.2f') 222 | progress = ProgressMeter( 223 | len(val_loader), 224 | [batch_time, losses, top1, top5], 225 | prefix='Test: ') 226 | 227 | # switch to evaluation mode 228 | model.eval() 229 | with torch.no_grad(): 230 | end = time.time() 231 | for i, (images, target) in enumerate(val_loader): 232 | images = images.cuda() 233 | target = target.cuda() 234 | 235 | # compute output 236 | logits = model(images) 237 | loss = criterion(logits, target) 238 | 239 | # measure accuracy and record loss 240 | pred1, pred5 = accuracy(logits, target, topk=(1, 5)) 241 | n = images.size(0) 242 | losses.update(loss.item(), n) 243 | top1.update(pred1[0], n) 244 | top5.update(pred5[0], n) 245 | 246 | # measure elapsed time 247 | batch_time.update(time.time() - end) 248 | end = time.time() 249 | 250 | progress.display(i) 251 | 252 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}' 253 | .format(top1=top1, top5=top5)) 254 | 255 | return losses.avg, top1.avg, top5.avg 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /mobilenet/2_step2/reactnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | def binaryconv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) 22 | 23 | 24 | def binaryconv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) 27 | 28 | class firstconv3x3(nn.Module): 29 | def __init__(self, inp, oup, stride): 30 | super(firstconv3x3, self).__init__() 31 | 32 | self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(oup) 34 | 35 | def forward(self, x): 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | 40 | return out 41 | 42 | class BinaryActivation(nn.Module): 43 | def __init__(self): 44 | super(BinaryActivation, self).__init__() 45 | 46 | def forward(self, x): 47 | out_forward = torch.sign(x) 48 | mask1 = x < -1 49 | mask2 = x < 0 50 | mask3 = x < 1 51 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) 52 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) 53 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) 54 | out = out_forward.detach() - out3.detach() + out3 55 | 56 | return out 57 | 58 | class LearnableBias(nn.Module): 59 | def __init__(self, out_chn): 60 | super(LearnableBias, self).__init__() 61 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) 62 | 63 | def forward(self, x): 64 | out = x + self.bias.expand_as(x) 65 | return out 66 | 67 | class HardBinaryConv(nn.Module): 68 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): 69 | super(HardBinaryConv, self).__init__() 70 | self.stride = stride 71 | self.padding = padding 72 | self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size 73 | self.shape = (out_chn, in_chn, kernel_size, kernel_size) 74 | self.weights = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) 75 | 76 | def forward(self, x): 77 | real_weights = self.weights.view(self.shape) 78 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) 79 | #print(scaling_factor, flush=True) 80 | scaling_factor = scaling_factor.detach() 81 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights) 82 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0) 83 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights 84 | #print(binary_weights, flush=True) 85 | y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) 86 | 87 | return y 88 | 89 | class BasicBlock(nn.Module): 90 | def __init__(self, inplanes, planes, stride=1): 91 | super(BasicBlock, self).__init__() 92 | norm_layer = nn.BatchNorm2d 93 | 94 | self.move11 = LearnableBias(inplanes) 95 | self.binary_3x3= binaryconv3x3(inplanes, inplanes, stride=stride) 96 | self.bn1 = norm_layer(inplanes) 97 | 98 | self.move12 = LearnableBias(inplanes) 99 | self.prelu1 = nn.PReLU(inplanes) 100 | self.move13 = LearnableBias(inplanes) 101 | 102 | self.move21 = LearnableBias(inplanes) 103 | 104 | if inplanes == planes: 105 | self.binary_pw = binaryconv1x1(inplanes, planes) 106 | self.bn2 = norm_layer(planes) 107 | else: 108 | self.binary_pw_down1 = binaryconv1x1(inplanes, inplanes) 109 | self.binary_pw_down2 = binaryconv1x1(inplanes, inplanes) 110 | self.bn2_1 = norm_layer(inplanes) 111 | self.bn2_2 = norm_layer(inplanes) 112 | 113 | self.move22 = LearnableBias(planes) 114 | self.prelu2 = nn.PReLU(planes) 115 | self.move23 = LearnableBias(planes) 116 | 117 | self.binary_activation = BinaryActivation() 118 | self.stride = stride 119 | self.inplanes = inplanes 120 | self.planes = planes 121 | 122 | if self.inplanes != self.planes: 123 | self.pooling = nn.AvgPool2d(2,2) 124 | 125 | def forward(self, x): 126 | 127 | out1 = self.move11(x) 128 | 129 | out1 = self.binary_activation(out1) 130 | out1 = self.binary_3x3(out1) 131 | out1 = self.bn1(out1) 132 | 133 | if self.stride == 2: 134 | x = self.pooling(x) 135 | 136 | out1 = x + out1 137 | 138 | out1 = self.move12(out1) 139 | out1 = self.prelu1(out1) 140 | out1 = self.move13(out1) 141 | 142 | out2 = self.move21(out1) 143 | out2 = self.binary_activation(out2) 144 | 145 | if self.inplanes == self.planes: 146 | out2 = self.binary_pw(out2) 147 | out2 = self.bn2(out2) 148 | out2 += out1 149 | 150 | else: 151 | assert self.planes == self.inplanes * 2 152 | 153 | out2_1 = self.binary_pw_down1(out2) 154 | out2_2 = self.binary_pw_down2(out2) 155 | out2_1 = self.bn2_1(out2_1) 156 | out2_2 = self.bn2_2(out2_2) 157 | out2_1 += out1 158 | out2_2 += out1 159 | out2 = torch.cat([out2_1, out2_2], dim=1) 160 | 161 | out2 = self.move22(out2) 162 | out2 = self.prelu2(out2) 163 | out2 = self.move23(out2) 164 | 165 | return out2 166 | 167 | 168 | class reactnet(nn.Module): 169 | def __init__(self, num_classes=1000): 170 | super(reactnet, self).__init__() 171 | self.feature = nn.ModuleList() 172 | for i in range(len(stage_out_channel)): 173 | if i == 0: 174 | self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) 175 | elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: 176 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) 177 | else: 178 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) 179 | self.pool1 = nn.AdaptiveAvgPool2d(1) 180 | self.fc = nn.Linear(1024, num_classes) 181 | 182 | def forward(self, x): 183 | for i, block in enumerate(self.feature): 184 | x = block(x) 185 | 186 | x = self.pool1(x) 187 | x = x.view(x.size(0), -1) 188 | x = self.fc(x) 189 | 190 | return x 191 | 192 | 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /mobilenet/2_step2/run.sh: -------------------------------------------------------------------------------- 1 | clear 2 | mkdir models 3 | cp ../1_step1/models/checkpoint.pth.tar ./models/checkpoint_ba.pth.tar 4 | mkdir log 5 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch 6 | python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=1.25e-3 --epochs=128 --weight_decay=0 | tee -a log/training.txt 7 | # 256 epoch setting: longer training, similar performance to 128 epoch 8 | # python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=5e-4 --epochs=256 --weight_decay=0 | tee -a log/training.txt 9 | -------------------------------------------------------------------------------- /mobilenet/2_step2/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import numpy as np 5 | import time, datetime 6 | import torch 7 | import random 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.utils.data.distributed 15 | 16 | sys.path.append("../../") 17 | from utils.utils import * 18 | from utils import KD_loss 19 | from torchvision import datasets, transforms 20 | from torch.autograd import Variable 21 | from reactnet import reactnet 22 | import torchvision.models as models 23 | 24 | parser = argparse.ArgumentParser("birealnet18") 25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs') 27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate') 28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') 30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models') 31 | parser.add_argument('--data', metavar='DIR', help='path to dataset') 32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') 33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet') 34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | args = parser.parse_args() 37 | 38 | CLASSES = 1000 39 | 40 | if not os.path.exists('log'): 41 | os.mkdir('log') 42 | 43 | log_format = '%(asctime)s %(message)s' 44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 45 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 46 | fh = logging.FileHandler(os.path.join('log/log.txt')) 47 | fh.setFormatter(logging.Formatter(log_format)) 48 | logging.getLogger().addHandler(fh) 49 | 50 | def main(): 51 | if not torch.cuda.is_available(): 52 | sys.exit(1) 53 | start_t = time.time() 54 | 55 | cudnn.benchmark = True 56 | cudnn.enabled=True 57 | logging.info("args = %s", args) 58 | 59 | # load model 60 | model_teacher = models.__dict__[args.teacher](pretrained=True) 61 | model_teacher = nn.DataParallel(model_teacher).cuda() 62 | for p in model_teacher.parameters(): 63 | p.requires_grad = False 64 | model_teacher.eval() 65 | 66 | model_student = reactnet() 67 | logging.info('student:') 68 | logging.info(model_student) 69 | model_student = nn.DataParallel(model_student).cuda() 70 | 71 | criterion = nn.CrossEntropyLoss() 72 | criterion = criterion.cuda() 73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) 74 | criterion_smooth = criterion_smooth.cuda() 75 | criterion_kd = KD_loss.DistributionLoss() 76 | 77 | all_parameters = model_student.parameters() 78 | weight_parameters = [] 79 | for pname, p in model_student.named_parameters(): 80 | if p.ndimension() == 4 or 'conv' in pname: 81 | weight_parameters.append(p) 82 | weight_parameters_id = list(map(id, weight_parameters)) 83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) 84 | 85 | optimizer = torch.optim.Adam( 86 | [{'params' : other_parameters}, 87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], 88 | lr=args.learning_rate,) 89 | 90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1) 91 | start_epoch = 0 92 | best_top1_acc= 0 93 | 94 | checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar') 95 | checkpoint = torch.load(checkpoint_tar) 96 | model_student.load_state_dict(checkpoint['state_dict'], strict=False) 97 | 98 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') 99 | if os.path.exists(checkpoint_tar): 100 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) 101 | checkpoint = torch.load(checkpoint_tar) 102 | start_epoch = checkpoint['epoch'] + 1 103 | best_top1_acc = checkpoint['best_top1_acc'] 104 | model_student.load_state_dict(checkpoint['state_dict'], strict=False) 105 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) 106 | 107 | # adjust the learning rate according to the checkpoint 108 | for epoch in range(start_epoch): 109 | scheduler.step() 110 | 111 | # load training data 112 | traindir = os.path.join(args.data, 'train') 113 | valdir = os.path.join(args.data, 'val') 114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 115 | std=[0.229, 0.224, 0.225]) 116 | 117 | # data augmentation 118 | crop_scale = 0.08 119 | lighting_param = 0.1 120 | train_transforms = transforms.Compose([ 121 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), 122 | Lighting(lighting_param), 123 | transforms.RandomHorizontalFlip(), 124 | transforms.ToTensor(), 125 | normalize]) 126 | 127 | train_dataset = datasets.ImageFolder( 128 | traindir, 129 | transform=train_transforms) 130 | 131 | train_loader = torch.utils.data.DataLoader( 132 | train_dataset, batch_size=args.batch_size, shuffle=True, 133 | num_workers=args.workers, pin_memory=True) 134 | 135 | # load validation data 136 | val_loader = torch.utils.data.DataLoader( 137 | datasets.ImageFolder(valdir, transforms.Compose([ 138 | transforms.Resize(256), 139 | transforms.CenterCrop(224), 140 | transforms.ToTensor(), 141 | normalize, 142 | ])), 143 | batch_size=args.batch_size, shuffle=False, 144 | num_workers=args.workers, pin_memory=True) 145 | 146 | # train the model 147 | epoch = start_epoch 148 | while epoch < args.epochs: 149 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler) 150 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args) 151 | 152 | is_best = False 153 | if valid_top1_acc > best_top1_acc: 154 | best_top1_acc = valid_top1_acc 155 | is_best = True 156 | 157 | save_checkpoint({ 158 | 'epoch': epoch, 159 | 'state_dict': model_student.state_dict(), 160 | 'best_top1_acc': best_top1_acc, 161 | 'optimizer' : optimizer.state_dict(), 162 | }, is_best, args.save) 163 | 164 | epoch += 1 165 | 166 | training_time = (time.time() - start_t) / 3600 167 | print('total training time = {} hours'.format(training_time)) 168 | 169 | 170 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler): 171 | batch_time = AverageMeter('Time', ':6.3f') 172 | data_time = AverageMeter('Data', ':6.3f') 173 | losses = AverageMeter('Loss', ':.4e') 174 | top1 = AverageMeter('Acc@1', ':6.2f') 175 | top5 = AverageMeter('Acc@5', ':6.2f') 176 | 177 | progress = ProgressMeter( 178 | len(train_loader), 179 | [batch_time, data_time, losses, top1, top5], 180 | prefix="Epoch: [{}]".format(epoch)) 181 | 182 | model_student.train() 183 | model_teacher.eval() 184 | end = time.time() 185 | scheduler.step() 186 | 187 | for param_group in optimizer.param_groups: 188 | cur_lr = param_group['lr'] 189 | print('learning_rate:', cur_lr) 190 | 191 | for i, (images, target) in enumerate(train_loader): 192 | data_time.update(time.time() - end) 193 | images = images.cuda() 194 | target = target.cuda() 195 | 196 | # compute outputy 197 | logits_student = model_student(images) 198 | logits_teacher = model_teacher(images) 199 | loss = criterion(logits_student, logits_teacher) 200 | 201 | # measure accuracy and record loss 202 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) 203 | n = images.size(0) 204 | losses.update(loss.item(), n) #accumulated loss 205 | top1.update(prec1.item(), n) 206 | top5.update(prec5.item(), n) 207 | 208 | # compute gradient and do SGD step 209 | optimizer.zero_grad() 210 | loss.backward() 211 | optimizer.step() 212 | 213 | # measure elapsed time 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | progress.display(i) 218 | 219 | return losses.avg, top1.avg, top5.avg 220 | 221 | def validate(epoch, val_loader, model, criterion, args): 222 | batch_time = AverageMeter('Time', ':6.3f') 223 | losses = AverageMeter('Loss', ':.4e') 224 | top1 = AverageMeter('Acc@1', ':6.2f') 225 | top5 = AverageMeter('Acc@5', ':6.2f') 226 | progress = ProgressMeter( 227 | len(val_loader), 228 | [batch_time, losses, top1, top5], 229 | prefix='Test: ') 230 | 231 | # switch to evaluation mode 232 | model.eval() 233 | with torch.no_grad(): 234 | end = time.time() 235 | for i, (images, target) in enumerate(val_loader): 236 | images = images.cuda() 237 | target = target.cuda() 238 | 239 | # compute output 240 | logits = model(images) 241 | loss = criterion(logits, target) 242 | 243 | # measure accuracy and record loss 244 | pred1, pred5 = accuracy(logits, target, topk=(1, 5)) 245 | n = images.size(0) 246 | losses.update(loss.item(), n) 247 | top1.update(pred1[0], n) 248 | top5.update(pred5[0], n) 249 | 250 | # measure elapsed time 251 | batch_time.update(time.time() - end) 252 | end = time.time() 253 | 254 | progress.display(i) 255 | 256 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}' 257 | .format(top1=top1, top5=top5)) 258 | 259 | return losses.avg, top1.avg, top5.avg 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /resnet/1_step1/birealnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | __all__ = ['birealnet18', 'birealnet34'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | def conv1x1(in_planes, out_planes, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 19 | 20 | class BinaryActivation(nn.Module): 21 | def __init__(self): 22 | super(BinaryActivation, self).__init__() 23 | 24 | def forward(self, x): 25 | out_forward = torch.sign(x) 26 | #out_e1 = (x^2 + 2*x) 27 | #out_e2 = (-x^2 + 2*x) 28 | out_e_total = 0 29 | mask1 = x < -1 30 | mask2 = x < 0 31 | mask3 = x < 1 32 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) 33 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) 34 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) 35 | out = out_forward.detach() - out3.detach() + out3 36 | 37 | return out 38 | 39 | class LearnableBias(nn.Module): 40 | def __init__(self, out_chn): 41 | super(LearnableBias, self).__init__() 42 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) 43 | 44 | def forward(self, x): 45 | out = x + self.bias.expand_as(x) 46 | return out 47 | 48 | class HardBinaryConv(nn.Module): 49 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): 50 | super(HardBinaryConv, self).__init__() 51 | self.stride = stride 52 | self.padding = padding 53 | self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size 54 | self.shape = (out_chn, in_chn, kernel_size, kernel_size) 55 | #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) 56 | self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True) 57 | 58 | def forward(self, x): 59 | #real_weights = self.weights.view(self.shape) 60 | real_weights = self.weight 61 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) 62 | #print(scaling_factor, flush=True) 63 | scaling_factor = scaling_factor.detach() 64 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights) 65 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0) 66 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights 67 | #print(binary_weights, flush=True) 68 | y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) 69 | 70 | return y 71 | 72 | class BasicBlock(nn.Module): 73 | expansion = 1 74 | 75 | def __init__(self, inplanes, planes, stride=1, downsample=None): 76 | super(BasicBlock, self).__init__() 77 | 78 | self.move0 = LearnableBias(inplanes) 79 | self.binary_activation = BinaryActivation() 80 | self.binary_conv = conv3x3(inplanes, planes, stride=stride) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.move1 = LearnableBias(planes) 83 | self.prelu = nn.PReLU(planes) 84 | self.move2 = LearnableBias(planes) 85 | 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.move0(x) 93 | out = self.binary_activation(out) 94 | out = self.binary_conv(out) 95 | out = self.bn1(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.move1(out) 102 | out = self.prelu(out) 103 | out = self.move2(out) 104 | 105 | return out 106 | 107 | class BiRealNet(nn.Module): 108 | 109 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 110 | super(BiRealNet, self).__init__() 111 | self.inplanes = 64 112 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 113 | bias=False) 114 | self.bn1 = nn.BatchNorm2d(64) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | self.layer1 = self._make_layer(block, 64, layers[0]) 117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 120 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 121 | self.fc = nn.Linear(512 * block.expansion, num_classes) 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | nn.AvgPool2d(kernel_size=2, stride=stride), 128 | conv1x1(self.inplanes, planes * block.expansion), 129 | nn.BatchNorm2d(planes * block.expansion), 130 | ) 131 | 132 | layers = [] 133 | layers.append(block(self.inplanes, planes, stride, downsample)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append(block(self.inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | x = self.conv1(x) 142 | x = self.bn1(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def birealnet18(pretrained=False, **kwargs): 158 | """Constructs a BiRealNet-18 model. """ 159 | model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) 160 | return model 161 | 162 | 163 | def birealnet34(pretrained=False, **kwargs): 164 | """Constructs a BiRealNet-34 model. """ 165 | model = BiRealNet(BasicBlock, [6, 8, 12, 6], **kwargs) 166 | return model 167 | 168 | -------------------------------------------------------------------------------- /resnet/1_step1/run.sh: -------------------------------------------------------------------------------- 1 | clear 2 | mkdir log 3 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch 4 | python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=2.5e-3 --epochs=128 --weight_decay=1e-5 | tee -a log/training.txt 5 | # 256 epoch setting: longer training, similar performance to 128 epoch 6 | # python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=1e-3 --epochs=256 --weight_decay=1e-5 | tee -a log/training.txt 7 | -------------------------------------------------------------------------------- /resnet/1_step1/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import numpy as np 5 | import time, datetime 6 | import torch 7 | import random 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.utils.data.distributed 15 | 16 | sys.path.append("../../") 17 | from utils.utils import * 18 | from utils import KD_loss 19 | from torchvision import datasets, transforms 20 | from torch.autograd import Variable 21 | from birealnet import birealnet18 22 | import torchvision.models as models 23 | 24 | parser = argparse.ArgumentParser("birealnet18") 25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs') 27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate') 28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') 30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models') 31 | parser.add_argument('--data', metavar='DIR', help='path to dataset') 32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') 33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet') 34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | args = parser.parse_args() 37 | 38 | CLASSES = 1000 39 | 40 | if not os.path.exists('log'): 41 | os.mkdir('log') 42 | 43 | log_format = '%(asctime)s %(message)s' 44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 45 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 46 | fh = logging.FileHandler(os.path.join('log/log.txt')) 47 | fh.setFormatter(logging.Formatter(log_format)) 48 | logging.getLogger().addHandler(fh) 49 | 50 | def main(): 51 | if not torch.cuda.is_available(): 52 | sys.exit(1) 53 | start_t = time.time() 54 | 55 | cudnn.benchmark = True 56 | cudnn.enabled=True 57 | logging.info("args = %s", args) 58 | 59 | # load model 60 | model_teacher = models.__dict__[args.teacher](pretrained=True) 61 | model_teacher = nn.DataParallel(model_teacher).cuda() 62 | for p in model_teacher.parameters(): 63 | p.requires_grad = False 64 | model_teacher.eval() 65 | 66 | model_student = birealnet18() 67 | logging.info('student:') 68 | logging.info(model_student) 69 | model_student = nn.DataParallel(model_student).cuda() 70 | 71 | criterion = nn.CrossEntropyLoss() 72 | criterion = criterion.cuda() 73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) 74 | criterion_smooth = criterion_smooth.cuda() 75 | criterion_kd = KD_loss.DistributionLoss() 76 | 77 | all_parameters = model_student.parameters() 78 | weight_parameters = [] 79 | for pname, p in model_student.named_parameters(): 80 | if p.ndimension() == 4 or 'conv' in pname: 81 | weight_parameters.append(p) 82 | weight_parameters_id = list(map(id, weight_parameters)) 83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) 84 | 85 | optimizer = torch.optim.Adam( 86 | [{'params' : other_parameters}, 87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], 88 | lr=args.learning_rate,) 89 | 90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1) 91 | start_epoch = 0 92 | best_top1_acc= 0 93 | 94 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') 95 | if os.path.exists(checkpoint_tar): 96 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) 97 | checkpoint = torch.load(checkpoint_tar) 98 | start_epoch = checkpoint['epoch'] 99 | best_top1_acc = checkpoint['best_top1_acc'] 100 | model_student.load_state_dict(checkpoint['state_dict'], strict=False) 101 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) 102 | 103 | # adjust the learning rate according to the checkpoint 104 | for epoch in range(start_epoch): 105 | scheduler.step() 106 | 107 | # load training data 108 | traindir = os.path.join(args.data, 'train') 109 | valdir = os.path.join(args.data, 'val') 110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 111 | std=[0.229, 0.224, 0.225]) 112 | 113 | # data augmentation 114 | crop_scale = 0.08 115 | lighting_param = 0.1 116 | train_transforms = transforms.Compose([ 117 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), 118 | Lighting(lighting_param), 119 | transforms.RandomHorizontalFlip(), 120 | transforms.ToTensor(), 121 | normalize]) 122 | 123 | train_dataset = datasets.ImageFolder( 124 | traindir, 125 | transform=train_transforms) 126 | 127 | train_loader = torch.utils.data.DataLoader( 128 | train_dataset, batch_size=args.batch_size, shuffle=True, 129 | num_workers=args.workers, pin_memory=True) 130 | 131 | # load validation data 132 | val_loader = torch.utils.data.DataLoader( 133 | datasets.ImageFolder(valdir, transforms.Compose([ 134 | transforms.Resize(256), 135 | transforms.CenterCrop(224), 136 | transforms.ToTensor(), 137 | normalize, 138 | ])), 139 | batch_size=args.batch_size, shuffle=False, 140 | num_workers=args.workers, pin_memory=True) 141 | 142 | # train the model 143 | epoch = start_epoch 144 | while epoch < args.epochs: 145 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler) 146 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args) 147 | 148 | is_best = False 149 | if valid_top1_acc > best_top1_acc: 150 | best_top1_acc = valid_top1_acc 151 | is_best = True 152 | 153 | save_checkpoint({ 154 | 'epoch': epoch, 155 | 'state_dict': model_student.state_dict(), 156 | 'best_top1_acc': best_top1_acc, 157 | 'optimizer' : optimizer.state_dict(), 158 | }, is_best, args.save) 159 | 160 | epoch += 1 161 | 162 | training_time = (time.time() - start_t) / 3600 163 | print('total training time = {} hours'.format(training_time)) 164 | 165 | 166 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler): 167 | batch_time = AverageMeter('Time', ':6.3f') 168 | data_time = AverageMeter('Data', ':6.3f') 169 | losses = AverageMeter('Loss', ':.4e') 170 | top1 = AverageMeter('Acc@1', ':6.2f') 171 | top5 = AverageMeter('Acc@5', ':6.2f') 172 | 173 | progress = ProgressMeter( 174 | len(train_loader), 175 | [batch_time, data_time, losses, top1, top5], 176 | prefix="Epoch: [{}]".format(epoch)) 177 | 178 | model_student.train() 179 | model_teacher.eval() 180 | end = time.time() 181 | scheduler.step() 182 | 183 | for param_group in optimizer.param_groups: 184 | cur_lr = param_group['lr'] 185 | print('learning_rate:', cur_lr) 186 | 187 | for i, (images, target) in enumerate(train_loader): 188 | data_time.update(time.time() - end) 189 | images = images.cuda() 190 | target = target.cuda() 191 | 192 | # compute outputy 193 | logits_student = model_student(images) 194 | logits_teacher = model_teacher(images) 195 | loss = criterion(logits_student, logits_teacher) 196 | 197 | # measure accuracy and record loss 198 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) 199 | n = images.size(0) 200 | losses.update(loss.item(), n) #accumulated loss 201 | top1.update(prec1.item(), n) 202 | top5.update(prec5.item(), n) 203 | 204 | # compute gradient and do SGD step 205 | optimizer.zero_grad() 206 | loss.backward() 207 | optimizer.step() 208 | 209 | # measure elapsed time 210 | batch_time.update(time.time() - end) 211 | end = time.time() 212 | 213 | progress.display(i) 214 | 215 | return losses.avg, top1.avg, top5.avg 216 | 217 | def validate(epoch, val_loader, model, criterion, args): 218 | batch_time = AverageMeter('Time', ':6.3f') 219 | losses = AverageMeter('Loss', ':.4e') 220 | top1 = AverageMeter('Acc@1', ':6.2f') 221 | top5 = AverageMeter('Acc@5', ':6.2f') 222 | progress = ProgressMeter( 223 | len(val_loader), 224 | [batch_time, losses, top1, top5], 225 | prefix='Test: ') 226 | 227 | # switch to evaluation mode 228 | model.eval() 229 | with torch.no_grad(): 230 | end = time.time() 231 | for i, (images, target) in enumerate(val_loader): 232 | images = images.cuda() 233 | target = target.cuda() 234 | 235 | # compute output 236 | logits = model(images) 237 | loss = criterion(logits, target) 238 | 239 | # measure accuracy and record loss 240 | pred1, pred5 = accuracy(logits, target, topk=(1, 5)) 241 | n = images.size(0) 242 | losses.update(loss.item(), n) 243 | top1.update(pred1[0], n) 244 | top5.update(pred5[0], n) 245 | 246 | # measure elapsed time 247 | batch_time.update(time.time() - end) 248 | end = time.time() 249 | 250 | progress.display(i) 251 | 252 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}' 253 | .format(top1=top1, top5=top5)) 254 | 255 | return losses.avg, top1.avg, top5.avg 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /resnet/2_step2/birealnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | __all__ = ['birealnet18', 'birealnet34'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | def conv1x1(in_planes, out_planes, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 19 | 20 | class BinaryActivation(nn.Module): 21 | def __init__(self): 22 | super(BinaryActivation, self).__init__() 23 | 24 | def forward(self, x): 25 | out_forward = torch.sign(x) 26 | #out_e1 = (x^2 + 2*x) 27 | #out_e2 = (-x^2 + 2*x) 28 | out_e_total = 0 29 | mask1 = x < -1 30 | mask2 = x < 0 31 | mask3 = x < 1 32 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) 33 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) 34 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) 35 | out = out_forward.detach() - out3.detach() + out3 36 | 37 | return out 38 | 39 | class LearnableBias(nn.Module): 40 | def __init__(self, out_chn): 41 | super(LearnableBias, self).__init__() 42 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) 43 | 44 | def forward(self, x): 45 | out = x + self.bias.expand_as(x) 46 | return out 47 | 48 | 49 | class HardBinaryConv(nn.Module): 50 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): 51 | super(HardBinaryConv, self).__init__() 52 | self.stride = stride 53 | self.padding = padding 54 | self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size 55 | self.shape = (out_chn, in_chn, kernel_size, kernel_size) 56 | #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) 57 | self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True) 58 | 59 | def forward(self, x): 60 | #real_weights = self.weights.view(self.shape) 61 | real_weights = self.weight 62 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) 63 | #print(scaling_factor, flush=True) 64 | scaling_factor = scaling_factor.detach() 65 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights) 66 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0) 67 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights 68 | #print(binary_weights, flush=True) 69 | y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) 70 | 71 | return y 72 | 73 | class BasicBlock(nn.Module): 74 | expansion = 1 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None): 77 | super(BasicBlock, self).__init__() 78 | 79 | self.move0 = LearnableBias(inplanes) 80 | self.binary_activation = BinaryActivation() 81 | self.binary_conv = HardBinaryConv(inplanes, planes, stride=stride) 82 | self.bn1 = nn.BatchNorm2d(planes) 83 | self.move1 = LearnableBias(planes) 84 | self.prelu = nn.PReLU(planes) 85 | self.move2 = LearnableBias(planes) 86 | 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.move0(x) 94 | out = self.binary_activation(out) 95 | out = self.binary_conv(out) 96 | out = self.bn1(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.move1(out) 103 | out = self.prelu(out) 104 | out = self.move2(out) 105 | 106 | return out 107 | 108 | class BiRealNet(nn.Module): 109 | 110 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 111 | super(BiRealNet, self).__init__() 112 | self.inplanes = 64 113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 114 | bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.AvgPool2d(kernel_size=2, stride=stride), 129 | conv1x1(self.inplanes, planes * block.expansion), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for _ in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.maxpool(x) 145 | 146 | x = self.layer1(x) 147 | x = self.layer2(x) 148 | x = self.layer3(x) 149 | x = self.layer4(x) 150 | 151 | x = self.avgpool(x) 152 | x = x.view(x.size(0), -1) 153 | x = self.fc(x) 154 | 155 | return x 156 | 157 | 158 | def birealnet18(pretrained=False, **kwargs): 159 | """Constructs a BiRealNet-18 model. """ 160 | model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) 161 | return model 162 | 163 | 164 | def birealnet34(pretrained=False, **kwargs): 165 | """Constructs a BiRealNet-34 model. """ 166 | model = BiRealNet(BasicBlock, [6, 8, 12, 6], **kwargs) 167 | return model 168 | 169 | -------------------------------------------------------------------------------- /resnet/2_step2/run.sh: -------------------------------------------------------------------------------- 1 | clear 2 | mkdir models 3 | cp ../1_step1/models/checkpoint.pth.tar ./models/checkpoint_ba.pth.tar 4 | mkdir log 5 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch 6 | python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=2.5e-3 --epochs=128 --weight_decay=0 | tee -a log/training.txt 7 | # 256 epoch setting: longer training, similar performance to 128 epoch 8 | # python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=1e-3 --epochs=256 --weight_decay=0 | tee -a log/training.txt 9 | -------------------------------------------------------------------------------- /resnet/2_step2/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import numpy as np 5 | import time, datetime 6 | import torch 7 | import random 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.utils.data.distributed 15 | 16 | sys.path.append("../../") 17 | from utils.utils import * 18 | from utils import KD_loss 19 | from torchvision import datasets, transforms 20 | from torch.autograd import Variable 21 | from birealnet import birealnet18 22 | import torchvision.models as models 23 | 24 | parser = argparse.ArgumentParser("birealnet18") 25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs') 27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate') 28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') 30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models') 31 | parser.add_argument('--data', metavar='DIR', help='path to dataset') 32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') 33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet') 34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | args = parser.parse_args() 37 | 38 | CLASSES = 1000 39 | 40 | if not os.path.exists('log'): 41 | os.mkdir('log') 42 | 43 | log_format = '%(asctime)s %(message)s' 44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 45 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 46 | fh = logging.FileHandler(os.path.join('log/log.txt')) 47 | fh.setFormatter(logging.Formatter(log_format)) 48 | logging.getLogger().addHandler(fh) 49 | 50 | def main(): 51 | if not torch.cuda.is_available(): 52 | sys.exit(1) 53 | start_t = time.time() 54 | 55 | cudnn.benchmark = True 56 | cudnn.enabled=True 57 | logging.info("args = %s", args) 58 | 59 | # load model 60 | model_teacher = models.__dict__[args.teacher](pretrained=True) 61 | model_teacher = nn.DataParallel(model_teacher).cuda() 62 | for p in model_teacher.parameters(): 63 | p.requires_grad = False 64 | model_teacher.eval() 65 | 66 | model_student = birealnet18() 67 | logging.info('student:') 68 | logging.info(model_student) 69 | model_student = nn.DataParallel(model_student).cuda() 70 | 71 | criterion = nn.CrossEntropyLoss() 72 | criterion = criterion.cuda() 73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) 74 | criterion_smooth = criterion_smooth.cuda() 75 | criterion_kd = KD_loss.DistributionLoss() 76 | 77 | all_parameters = model_student.parameters() 78 | weight_parameters = [] 79 | for pname, p in model_student.named_parameters(): 80 | if p.ndimension() == 4 or 'conv' in pname: 81 | weight_parameters.append(p) 82 | weight_parameters_id = list(map(id, weight_parameters)) 83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) 84 | 85 | optimizer = torch.optim.Adam( 86 | [{'params' : other_parameters}, 87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], 88 | lr=args.learning_rate,) 89 | 90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1) 91 | start_epoch = 0 92 | best_top1_acc= 0 93 | 94 | checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar') 95 | checkpoint = torch.load(checkpoint_tar) 96 | model_student.load_state_dict(checkpoint['state_dict'], strict=False) 97 | 98 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') 99 | if os.path.exists(checkpoint_tar): 100 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) 101 | checkpoint = torch.load(checkpoint_tar) 102 | start_epoch = checkpoint['epoch'] 103 | best_top1_acc = checkpoint['best_top1_acc'] 104 | model_student.load_state_dict(checkpoint['state_dict'], strict=False) 105 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) 106 | 107 | # adjust the learning rate according to the checkpoint 108 | for epoch in range(start_epoch): 109 | scheduler.step() 110 | 111 | # load training data 112 | traindir = os.path.join(args.data, 'train') 113 | valdir = os.path.join(args.data, 'val') 114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 115 | std=[0.229, 0.224, 0.225]) 116 | 117 | # data augmentation 118 | crop_scale = 0.08 119 | lighting_param = 0.1 120 | train_transforms = transforms.Compose([ 121 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), 122 | Lighting(lighting_param), 123 | transforms.RandomHorizontalFlip(), 124 | transforms.ToTensor(), 125 | normalize]) 126 | 127 | train_dataset = datasets.ImageFolder( 128 | traindir, 129 | transform=train_transforms) 130 | 131 | train_loader = torch.utils.data.DataLoader( 132 | train_dataset, batch_size=args.batch_size, shuffle=True, 133 | num_workers=args.workers, pin_memory=True) 134 | 135 | # load validation data 136 | val_loader = torch.utils.data.DataLoader( 137 | datasets.ImageFolder(valdir, transforms.Compose([ 138 | transforms.Resize(256), 139 | transforms.CenterCrop(224), 140 | transforms.ToTensor(), 141 | normalize, 142 | ])), 143 | batch_size=args.batch_size, shuffle=False, 144 | num_workers=args.workers, pin_memory=True) 145 | 146 | # train the model 147 | epoch = start_epoch 148 | while epoch < args.epochs: 149 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler) 150 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args) 151 | 152 | is_best = False 153 | if valid_top1_acc > best_top1_acc: 154 | best_top1_acc = valid_top1_acc 155 | is_best = True 156 | 157 | save_checkpoint({ 158 | 'epoch': epoch, 159 | 'state_dict': model_student.state_dict(), 160 | 'best_top1_acc': best_top1_acc, 161 | 'optimizer' : optimizer.state_dict(), 162 | }, is_best, args.save) 163 | 164 | epoch += 1 165 | 166 | training_time = (time.time() - start_t) / 3600 167 | print('total training time = {} hours'.format(training_time)) 168 | 169 | 170 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler): 171 | batch_time = AverageMeter('Time', ':6.3f') 172 | data_time = AverageMeter('Data', ':6.3f') 173 | losses = AverageMeter('Loss', ':.4e') 174 | top1 = AverageMeter('Acc@1', ':6.2f') 175 | top5 = AverageMeter('Acc@5', ':6.2f') 176 | 177 | progress = ProgressMeter( 178 | len(train_loader), 179 | [batch_time, data_time, losses, top1, top5], 180 | prefix="Epoch: [{}]".format(epoch)) 181 | 182 | model_student.train() 183 | model_teacher.eval() 184 | end = time.time() 185 | scheduler.step() 186 | 187 | for param_group in optimizer.param_groups: 188 | cur_lr = param_group['lr'] 189 | print('learning_rate:', cur_lr) 190 | 191 | for i, (images, target) in enumerate(train_loader): 192 | data_time.update(time.time() - end) 193 | images = images.cuda() 194 | target = target.cuda() 195 | 196 | # compute outputy 197 | logits_student = model_student(images) 198 | logits_teacher = model_teacher(images) 199 | loss = criterion(logits_student, logits_teacher) 200 | 201 | # measure accuracy and record loss 202 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) 203 | n = images.size(0) 204 | losses.update(loss.item(), n) #accumulated loss 205 | top1.update(prec1.item(), n) 206 | top5.update(prec5.item(), n) 207 | 208 | # compute gradient and do SGD step 209 | optimizer.zero_grad() 210 | loss.backward() 211 | optimizer.step() 212 | 213 | # measure elapsed time 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | progress.display(i) 218 | 219 | return losses.avg, top1.avg, top5.avg 220 | 221 | def validate(epoch, val_loader, model, criterion, args): 222 | batch_time = AverageMeter('Time', ':6.3f') 223 | losses = AverageMeter('Loss', ':.4e') 224 | top1 = AverageMeter('Acc@1', ':6.2f') 225 | top5 = AverageMeter('Acc@5', ':6.2f') 226 | progress = ProgressMeter( 227 | len(val_loader), 228 | [batch_time, losses, top1, top5], 229 | prefix='Test: ') 230 | 231 | # switch to evaluation mode 232 | model.eval() 233 | with torch.no_grad(): 234 | end = time.time() 235 | for i, (images, target) in enumerate(val_loader): 236 | images = images.cuda() 237 | target = target.cuda() 238 | 239 | # compute output 240 | logits = model(images) 241 | loss = criterion(logits, target) 242 | 243 | # measure accuracy and record loss 244 | pred1, pred5 = accuracy(logits, target, topk=(1, 5)) 245 | n = images.size(0) 246 | losses.update(loss.item(), n) 247 | top1.update(pred1[0], n) 248 | top5.update(pred5[0], n) 249 | 250 | # measure elapsed time 251 | batch_time.update(time.time() - end) 252 | end = time.time() 253 | 254 | progress.display(i) 255 | 256 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}' 257 | .format(top1=top1, top5=top5)) 258 | 259 | return losses.avg, top1.avg, top5.avg 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /utils/KD_loss.py: -------------------------------------------------------------------------------- 1 | # Code is modified from MEAL (https://arxiv.org/abs/1812.02425) and Label Refinery (https://arxiv.org/abs/1805.02641). 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.nn.modules import loss 6 | 7 | 8 | class DistributionLoss(loss._Loss): 9 | """The KL-Divergence loss for the binary student model and real teacher output. 10 | 11 | output must be a pair of (model_output, real_output), both NxC tensors. 12 | The rows of real_output must all add up to one (probability scores); 13 | however, model_output must be the pre-softmax output of the network.""" 14 | 15 | def forward(self, model_output, real_output): 16 | 17 | self.size_average = True 18 | 19 | # Target is ignored at training time. Loss is defined as KL divergence 20 | # between the model output and the refined labels. 21 | if real_output.requires_grad: 22 | raise ValueError("real network output should not require gradients.") 23 | 24 | model_output_log_prob = F.log_softmax(model_output, dim=1) 25 | real_output_soft = F.softmax(real_output, dim=1) 26 | del model_output, real_output 27 | 28 | # Loss is -dot(model_output_log_prob, real_output). Prepare tensors 29 | # for batch matrix multiplicatio 30 | real_output_soft = real_output_soft.unsqueeze(1) 31 | model_output_log_prob = model_output_log_prob.unsqueeze(2) 32 | 33 | # Compute the loss, and average/sum for the batch. 34 | cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob) 35 | if self.size_average: 36 | cross_entropy_loss = cross_entropy_loss.mean() 37 | else: 38 | cross_entropy_loss = cross_entropy_loss.sum() 39 | # Return a pair of (loss_output, model_output). Model output will be 40 | # used for top-1 and top-5 evaluation. 41 | # model_output_log_prob = model_output_log_prob.squeeze(2) 42 | return cross_entropy_loss 43 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import numpy as np 5 | import time, datetime 6 | import torch 7 | import random 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import torch.utils 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torch.backends.cudnn as cudnn 15 | from PIL import Image 16 | from torch.autograd import Variable 17 | 18 | #lighting data augmentation 19 | imagenet_pca = { 20 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), 21 | 'eigvec': np.asarray([ 22 | [-0.5675, 0.7192, 0.4009], 23 | [-0.5808, -0.0045, -0.8140], 24 | [-0.5836, -0.6948, 0.4203], 25 | ]) 26 | } 27 | 28 | 29 | class Lighting(object): 30 | def __init__(self, alphastd, 31 | eigval=imagenet_pca['eigval'], 32 | eigvec=imagenet_pca['eigvec']): 33 | self.alphastd = alphastd 34 | assert eigval.shape == (3,) 35 | assert eigvec.shape == (3, 3) 36 | self.eigval = eigval 37 | self.eigvec = eigvec 38 | 39 | def __call__(self, img): 40 | if self.alphastd == 0.: 41 | return img 42 | rnd = np.random.randn(3) * self.alphastd 43 | rnd = rnd.astype('float32') 44 | v = rnd 45 | old_dtype = np.asarray(img).dtype 46 | v = v * self.eigval 47 | v = v.reshape((3, 1)) 48 | inc = np.dot(self.eigvec, v).reshape((3,)) 49 | img = np.add(img, inc) 50 | if old_dtype == np.uint8: 51 | img = np.clip(img, 0, 255) 52 | img = Image.fromarray(img.astype(old_dtype), 'RGB') 53 | return img 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + '()' 57 | 58 | #label smooth 59 | class CrossEntropyLabelSmooth(nn.Module): 60 | 61 | def __init__(self, num_classes, epsilon): 62 | super(CrossEntropyLabelSmooth, self).__init__() 63 | self.num_classes = num_classes 64 | self.epsilon = epsilon 65 | self.logsoftmax = nn.LogSoftmax(dim=1) 66 | 67 | def forward(self, inputs, targets): 68 | log_probs = self.logsoftmax(inputs) 69 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 70 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 71 | loss = (-targets * log_probs).mean(0).sum() 72 | return loss 73 | 74 | 75 | class AverageMeter(object): 76 | """Computes and stores the average and current value""" 77 | def __init__(self, name, fmt=':f'): 78 | self.name = name 79 | self.fmt = fmt 80 | self.reset() 81 | 82 | def reset(self): 83 | self.val = 0 84 | self.avg = 0 85 | self.sum = 0 86 | self.count = 0 87 | 88 | def update(self, val, n=1): 89 | self.val = val 90 | self.sum += val * n 91 | self.count += n 92 | self.avg = self.sum / self.count 93 | 94 | def __str__(self): 95 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 96 | return fmtstr.format(**self.__dict__) 97 | 98 | 99 | class ProgressMeter(object): 100 | def __init__(self, num_batches, meters, prefix=""): 101 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 102 | self.meters = meters 103 | self.prefix = prefix 104 | 105 | def display(self, batch): 106 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 107 | entries += [str(meter) for meter in self.meters] 108 | print('\t'.join(entries)) 109 | 110 | def _get_batch_fmtstr(self, num_batches): 111 | num_digits = len(str(num_batches // 1)) 112 | fmt = '{:' + str(num_digits) + 'd}' 113 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 114 | 115 | 116 | def save_checkpoint(state, is_best, save): 117 | if not os.path.exists(save): 118 | os.makedirs(save) 119 | filename = os.path.join(save, 'checkpoint.pth.tar') 120 | torch.save(state, filename) 121 | if is_best: 122 | best_filename = os.path.join(save, 'model_best.pth.tar') 123 | shutil.copyfile(filename, best_filename) 124 | 125 | 126 | def adjust_learning_rate(optimizer, epoch, args): 127 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 128 | lr = args.lr * (0.1 ** (epoch // 30)) 129 | for param_group in optimizer.param_groups: 130 | param_group['lr'] = lr 131 | 132 | 133 | def accuracy(output, target, topk=(1,)): 134 | """Computes the accuracy over the k top predictions for the specified values of k""" 135 | with torch.no_grad(): 136 | maxk = max(topk) 137 | batch_size = target.size(0) 138 | 139 | _, pred = output.topk(maxk, 1, True, True) 140 | pred = pred.t() 141 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 142 | 143 | res = [] 144 | for k in topk: 145 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 146 | res.append(correct_k.mul_(100.0 / batch_size)) 147 | return res 148 | --------------------------------------------------------------------------------