├── README.md
├── mobilenet
├── 1_step1
│ ├── reactnet.py
│ ├── run.sh
│ └── train.py
└── 2_step2
│ ├── reactnet.py
│ ├── run.sh
│ └── train.py
├── resnet
├── 1_step1
│ ├── birealnet.py
│ ├── run.sh
│ └── train.py
└── 2_step2
│ ├── birealnet.py
│ ├── run.sh
│ └── train.py
└── utils
├── KD_loss.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # ReActNet
2 |
3 | This is the pytorch implementation of our paper ["ReActNet: Towards Precise Binary NeuralNetwork with Generalized Activation Functions"](https://arxiv.org/abs/2003.03488), published in ECCV 2020.
4 |
5 |
6 |

7 |
8 |
9 | In this paper, we propose to generalize the traditional Sign and PReLU functions to RSign and RPReLU, which enable explicit learning of the distribution reshape and shift at near-zero extra cost. By adding simple learnable bias, ReActNet achieves 69.4% top-1 accuracy on Imagenet dataset with both weights and activations being binary, a near ResNet-level accuracy.
10 |
11 | ## Citation
12 |
13 | If you find our code useful for your research, please consider citing:
14 |
15 | @inproceedings{liu2020reactnet,
16 | title={ReActNet: Towards Precise Binary Neural Network with Generalized Activation Functions},
17 | author={Liu, Zechun and Shen, Zhiqiang and Savvides, Marios and Cheng, Kwang-Ting},
18 | booktitle={European Conference on Computer Vision (ECCV)},
19 | year={2020}
20 | }
21 |
22 | ## Run
23 |
24 | ### 1. Requirements:
25 | * python3, pytorch 1.4.0, torchvision 0.5.0
26 |
27 | ### 2. Data:
28 | * Download ImageNet dataset
29 |
30 | ### 3. Steps to run:
31 | (1) Step1: binarizing activations
32 | * Change directory to `./resnet/1_step1/` or `./mobilenet/1_step1/`
33 | * run `bash run.sh`
34 |
35 | (2) Step2: binarizing weights + activations
36 | * Change directory to `./resnet/2_step2/` or `./mobilenet/2_step2/`
37 | * run `bash run.sh`
38 |
39 |
40 | ## Models
41 |
42 | | Methods | Top1-Acc | FLOPs | Trained Model |
43 | | --- | --- | --- | --- |
44 | | XNOR-Net | 51.2% | 1.67 x 10^8 | - |
45 | | Bi-Real Net| 56.4% | 1.63 x 10^8 | - |
46 | | Real-to-Binary| 65.4% | 1.83 x 10^8 | - |
47 | | ReActNet (Bi-Real based) | 65.9% | 1.63 x 10^8 | [Model-ReAct-ResNet](https://hkustconnect-my.sharepoint.com/:u:/g/personal/zliubq_connect_ust_hk/EY9P7mxs-8BLkTlqMZif4s4BnNWcKbUnvqeA_CvN3c9q4w?e=IpUyF4) |
48 | | ReActNet-A | 69.5% | 0.87 x 10^8 | [Model-ReAct-MobileNet](https://hkustconnect-my.sharepoint.com/:u:/g/personal/zliubq_connect_ust_hk/EW1FVkAKN5dJg1ns_CcMtQoBJAy1Yxx-b7lpaTFjTJIUKw?e=oyebWy) |
49 |
50 | ## Contact
51 |
52 | Zechun Liu, HKUST (zliubq at connect.ust.hk)
53 |
54 | Zhiqiang Shen, CMU (zhiqians at andrew.cmu.edu)
55 |
--------------------------------------------------------------------------------
/mobilenet/1_step1/reactnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 |
14 |
15 | def conv1x1(in_planes, out_planes, stride=1):
16 | """1x1 convolution"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
18 |
19 | class firstconv3x3(nn.Module):
20 | def __init__(self, inp, oup, stride):
21 | super(firstconv3x3, self).__init__()
22 |
23 | self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
24 | self.bn1 = nn.BatchNorm2d(oup)
25 |
26 | def forward(self, x):
27 |
28 | out = self.conv1(x)
29 | out = self.bn1(out)
30 |
31 | return out
32 |
33 | class BinaryActivation(nn.Module):
34 | def __init__(self):
35 | super(BinaryActivation, self).__init__()
36 |
37 | def forward(self, x):
38 | out_forward = torch.sign(x)
39 | mask1 = x < -1
40 | mask2 = x < 0
41 | mask3 = x < 1
42 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
43 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
44 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
45 | out = out_forward.detach() - out3.detach() + out3
46 |
47 | return out
48 |
49 | class LearnableBias(nn.Module):
50 | def __init__(self, out_chn):
51 | super(LearnableBias, self).__init__()
52 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)
53 |
54 | def forward(self, x):
55 | out = x + self.bias.expand_as(x)
56 | return out
57 |
58 | class BasicBlock(nn.Module):
59 | def __init__(self, inplanes, planes, stride=1):
60 | super(BasicBlock, self).__init__()
61 | norm_layer = nn.BatchNorm2d
62 |
63 | self.move11 = LearnableBias(inplanes)
64 | self.binary_3x3= conv3x3(inplanes, inplanes, stride=stride)
65 | self.bn1 = norm_layer(inplanes)
66 |
67 | self.move12 = LearnableBias(inplanes)
68 | self.prelu1 = nn.PReLU(inplanes)
69 | self.move13 = LearnableBias(inplanes)
70 |
71 | self.move21 = LearnableBias(inplanes)
72 |
73 | if inplanes == planes:
74 | self.binary_pw = conv1x1(inplanes, planes)
75 | self.bn2 = norm_layer(planes)
76 | else:
77 | self.binary_pw_down1 = conv1x1(inplanes, inplanes)
78 | self.binary_pw_down2 = conv1x1(inplanes, inplanes)
79 | self.bn2_1 = norm_layer(inplanes)
80 | self.bn2_2 = norm_layer(inplanes)
81 |
82 | self.move22 = LearnableBias(planes)
83 | self.prelu2 = nn.PReLU(planes)
84 | self.move23 = LearnableBias(planes)
85 |
86 | self.binary_activation = BinaryActivation()
87 | self.stride = stride
88 | self.inplanes = inplanes
89 | self.planes = planes
90 |
91 | if self.inplanes != self.planes:
92 | self.pooling = nn.AvgPool2d(2,2)
93 |
94 | def forward(self, x):
95 |
96 | out1 = self.move11(x)
97 |
98 | out1 = self.binary_activation(out1)
99 | out1 = self.binary_3x3(out1)
100 | out1 = self.bn1(out1)
101 |
102 | if self.stride == 2:
103 | x = self.pooling(x)
104 |
105 | out1 = x + out1
106 |
107 | out1 = self.move12(out1)
108 | out1 = self.prelu1(out1)
109 | out1 = self.move13(out1)
110 |
111 | out2 = self.move21(out1)
112 | out2 = self.binary_activation(out2)
113 |
114 | if self.inplanes == self.planes:
115 | out2 = self.binary_pw(out2)
116 | out2 = self.bn2(out2)
117 | out2 += out1
118 |
119 | else:
120 | assert self.planes == self.inplanes * 2
121 |
122 | out2_1 = self.binary_pw_down1(out2)
123 | out2_2 = self.binary_pw_down2(out2)
124 | out2_1 = self.bn2_1(out2_1)
125 | out2_2 = self.bn2_2(out2_2)
126 | out2_1 += out1
127 | out2_2 += out1
128 | out2 = torch.cat([out2_1, out2_2], dim=1)
129 |
130 | out2 = self.move22(out2)
131 | out2 = self.prelu2(out2)
132 | out2 = self.move23(out2)
133 |
134 | return out2
135 |
136 |
137 | class reactnet(nn.Module):
138 | def __init__(self, num_classes=1000):
139 | super(reactnet, self).__init__()
140 | self.feature = nn.ModuleList()
141 | for i in range(len(stage_out_channel)):
142 | if i == 0:
143 | self.feature.append(firstconv3x3(3, stage_out_channel[i], 2))
144 | elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64:
145 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2))
146 | else:
147 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1))
148 | self.pool1 = nn.AdaptiveAvgPool2d(1)
149 | self.fc = nn.Linear(1024, num_classes)
150 |
151 | def forward(self, x):
152 | for i, block in enumerate(self.feature):
153 | x = block(x)
154 |
155 | x = self.pool1(x)
156 | x = x.view(x.size(0), -1)
157 | x = self.fc(x)
158 |
159 | return x
160 |
161 |
162 |
163 |
164 |
165 |
166 |
--------------------------------------------------------------------------------
/mobilenet/1_step1/run.sh:
--------------------------------------------------------------------------------
1 | clear
2 | mkdir log
3 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch
4 | python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=1.25e-3 --epochs=128 --weight_decay=1e-5 | tee -a log/training.txt
5 | # 256 epoch setting: longer training, similar performance to 128 epoch
6 | # python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=5e-4 --epochs=256 --weight_decay=1e-5 | tee -a log/training.txt
7 |
--------------------------------------------------------------------------------
/mobilenet/1_step1/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import numpy as np
5 | import time, datetime
6 | import torch
7 | import random
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import torch.utils
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.utils.data.distributed
15 |
16 | sys.path.append("../../")
17 | from utils.utils import *
18 | from utils import KD_loss
19 | from torchvision import datasets, transforms
20 | from torch.autograd import Variable
21 | from reactnet import reactnet
22 | import torchvision.models as models
23 |
24 | parser = argparse.ArgumentParser("birealnet18")
25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs')
27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
31 | parser.add_argument('--data', metavar='DIR', help='path to dataset')
32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet')
34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | args = parser.parse_args()
37 |
38 | CLASSES = 1000
39 |
40 | if not os.path.exists('log'):
41 | os.mkdir('log')
42 |
43 | log_format = '%(asctime)s %(message)s'
44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
45 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
46 | fh = logging.FileHandler(os.path.join('log/log.txt'))
47 | fh.setFormatter(logging.Formatter(log_format))
48 | logging.getLogger().addHandler(fh)
49 |
50 | def main():
51 | if not torch.cuda.is_available():
52 | sys.exit(1)
53 | start_t = time.time()
54 |
55 | cudnn.benchmark = True
56 | cudnn.enabled=True
57 | logging.info("args = %s", args)
58 |
59 | # load model
60 | model_teacher = models.__dict__[args.teacher](pretrained=True)
61 | model_teacher = nn.DataParallel(model_teacher).cuda()
62 | for p in model_teacher.parameters():
63 | p.requires_grad = False
64 | model_teacher.eval()
65 |
66 | model_student = reactnet()
67 | logging.info('student:')
68 | logging.info(model_student)
69 | model_student = nn.DataParallel(model_student).cuda()
70 |
71 | criterion = nn.CrossEntropyLoss()
72 | criterion = criterion.cuda()
73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
74 | criterion_smooth = criterion_smooth.cuda()
75 | criterion_kd = KD_loss.DistributionLoss()
76 |
77 | all_parameters = model_student.parameters()
78 | weight_parameters = []
79 | for pname, p in model_student.named_parameters():
80 | if p.ndimension() == 4 or 'conv' in pname:
81 | weight_parameters.append(p)
82 | weight_parameters_id = list(map(id, weight_parameters))
83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
84 |
85 | optimizer = torch.optim.Adam(
86 | [{'params' : other_parameters},
87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
88 | lr=args.learning_rate,)
89 |
90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1)
91 | start_epoch = 0
92 | best_top1_acc= 0
93 |
94 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
95 | if os.path.exists(checkpoint_tar):
96 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
97 | checkpoint = torch.load(checkpoint_tar)
98 | start_epoch = checkpoint['epoch'] + 1
99 | best_top1_acc = checkpoint['best_top1_acc']
100 | model_student.load_state_dict(checkpoint['state_dict'], strict=False)
101 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))
102 |
103 | # adjust the learning rate according to the checkpoint
104 | for epoch in range(start_epoch):
105 | scheduler.step()
106 |
107 | # load training data
108 | traindir = os.path.join(args.data, 'train')
109 | valdir = os.path.join(args.data, 'val')
110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
111 | std=[0.229, 0.224, 0.225])
112 |
113 | # data augmentation
114 | crop_scale = 0.08
115 | lighting_param = 0.1
116 | train_transforms = transforms.Compose([
117 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
118 | Lighting(lighting_param),
119 | transforms.RandomHorizontalFlip(),
120 | transforms.ToTensor(),
121 | normalize])
122 |
123 | train_dataset = datasets.ImageFolder(
124 | traindir,
125 | transform=train_transforms)
126 |
127 | train_loader = torch.utils.data.DataLoader(
128 | train_dataset, batch_size=args.batch_size, shuffle=True,
129 | num_workers=args.workers, pin_memory=True)
130 |
131 | # load validation data
132 | val_loader = torch.utils.data.DataLoader(
133 | datasets.ImageFolder(valdir, transforms.Compose([
134 | transforms.Resize(256),
135 | transforms.CenterCrop(224),
136 | transforms.ToTensor(),
137 | normalize,
138 | ])),
139 | batch_size=args.batch_size, shuffle=False,
140 | num_workers=args.workers, pin_memory=True)
141 |
142 | # train the model
143 | epoch = start_epoch
144 | while epoch < args.epochs:
145 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler)
146 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args)
147 |
148 | is_best = False
149 | if valid_top1_acc > best_top1_acc:
150 | best_top1_acc = valid_top1_acc
151 | is_best = True
152 |
153 | save_checkpoint({
154 | 'epoch': epoch,
155 | 'state_dict': model_student.state_dict(),
156 | 'best_top1_acc': best_top1_acc,
157 | 'optimizer' : optimizer.state_dict(),
158 | }, is_best, args.save)
159 |
160 | epoch += 1
161 |
162 | training_time = (time.time() - start_t) / 3600
163 | print('total training time = {} hours'.format(training_time))
164 |
165 |
166 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler):
167 | batch_time = AverageMeter('Time', ':6.3f')
168 | data_time = AverageMeter('Data', ':6.3f')
169 | losses = AverageMeter('Loss', ':.4e')
170 | top1 = AverageMeter('Acc@1', ':6.2f')
171 | top5 = AverageMeter('Acc@5', ':6.2f')
172 |
173 | progress = ProgressMeter(
174 | len(train_loader),
175 | [batch_time, data_time, losses, top1, top5],
176 | prefix="Epoch: [{}]".format(epoch))
177 |
178 | model_student.train()
179 | model_teacher.eval()
180 | end = time.time()
181 | scheduler.step()
182 |
183 | for param_group in optimizer.param_groups:
184 | cur_lr = param_group['lr']
185 | print('learning_rate:', cur_lr)
186 |
187 | for i, (images, target) in enumerate(train_loader):
188 | data_time.update(time.time() - end)
189 | images = images.cuda()
190 | target = target.cuda()
191 |
192 | # compute outputy
193 | logits_student = model_student(images)
194 | logits_teacher = model_teacher(images)
195 | loss = criterion(logits_student, logits_teacher)
196 |
197 | # measure accuracy and record loss
198 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5))
199 | n = images.size(0)
200 | losses.update(loss.item(), n) #accumulated loss
201 | top1.update(prec1.item(), n)
202 | top5.update(prec5.item(), n)
203 |
204 | # compute gradient and do SGD step
205 | optimizer.zero_grad()
206 | loss.backward()
207 | optimizer.step()
208 |
209 | # measure elapsed time
210 | batch_time.update(time.time() - end)
211 | end = time.time()
212 |
213 | progress.display(i)
214 |
215 | return losses.avg, top1.avg, top5.avg
216 |
217 | def validate(epoch, val_loader, model, criterion, args):
218 | batch_time = AverageMeter('Time', ':6.3f')
219 | losses = AverageMeter('Loss', ':.4e')
220 | top1 = AverageMeter('Acc@1', ':6.2f')
221 | top5 = AverageMeter('Acc@5', ':6.2f')
222 | progress = ProgressMeter(
223 | len(val_loader),
224 | [batch_time, losses, top1, top5],
225 | prefix='Test: ')
226 |
227 | # switch to evaluation mode
228 | model.eval()
229 | with torch.no_grad():
230 | end = time.time()
231 | for i, (images, target) in enumerate(val_loader):
232 | images = images.cuda()
233 | target = target.cuda()
234 |
235 | # compute output
236 | logits = model(images)
237 | loss = criterion(logits, target)
238 |
239 | # measure accuracy and record loss
240 | pred1, pred5 = accuracy(logits, target, topk=(1, 5))
241 | n = images.size(0)
242 | losses.update(loss.item(), n)
243 | top1.update(pred1[0], n)
244 | top5.update(pred5[0], n)
245 |
246 | # measure elapsed time
247 | batch_time.update(time.time() - end)
248 | end = time.time()
249 |
250 | progress.display(i)
251 |
252 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'
253 | .format(top1=top1, top5=top5))
254 |
255 | return losses.avg, top1.avg, top5.avg
256 |
257 |
258 | if __name__ == '__main__':
259 | main()
260 |
--------------------------------------------------------------------------------
/mobilenet/2_step2/reactnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 |
14 |
15 | def conv1x1(in_planes, out_planes, stride=1):
16 | """1x1 convolution"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
18 |
19 | def binaryconv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
22 |
23 |
24 | def binaryconv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
27 |
28 | class firstconv3x3(nn.Module):
29 | def __init__(self, inp, oup, stride):
30 | super(firstconv3x3, self).__init__()
31 |
32 | self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
33 | self.bn1 = nn.BatchNorm2d(oup)
34 |
35 | def forward(self, x):
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 |
40 | return out
41 |
42 | class BinaryActivation(nn.Module):
43 | def __init__(self):
44 | super(BinaryActivation, self).__init__()
45 |
46 | def forward(self, x):
47 | out_forward = torch.sign(x)
48 | mask1 = x < -1
49 | mask2 = x < 0
50 | mask3 = x < 1
51 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
52 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
53 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
54 | out = out_forward.detach() - out3.detach() + out3
55 |
56 | return out
57 |
58 | class LearnableBias(nn.Module):
59 | def __init__(self, out_chn):
60 | super(LearnableBias, self).__init__()
61 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)
62 |
63 | def forward(self, x):
64 | out = x + self.bias.expand_as(x)
65 | return out
66 |
67 | class HardBinaryConv(nn.Module):
68 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
69 | super(HardBinaryConv, self).__init__()
70 | self.stride = stride
71 | self.padding = padding
72 | self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
73 | self.shape = (out_chn, in_chn, kernel_size, kernel_size)
74 | self.weights = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
75 |
76 | def forward(self, x):
77 | real_weights = self.weights.view(self.shape)
78 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
79 | #print(scaling_factor, flush=True)
80 | scaling_factor = scaling_factor.detach()
81 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
82 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
83 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
84 | #print(binary_weights, flush=True)
85 | y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)
86 |
87 | return y
88 |
89 | class BasicBlock(nn.Module):
90 | def __init__(self, inplanes, planes, stride=1):
91 | super(BasicBlock, self).__init__()
92 | norm_layer = nn.BatchNorm2d
93 |
94 | self.move11 = LearnableBias(inplanes)
95 | self.binary_3x3= binaryconv3x3(inplanes, inplanes, stride=stride)
96 | self.bn1 = norm_layer(inplanes)
97 |
98 | self.move12 = LearnableBias(inplanes)
99 | self.prelu1 = nn.PReLU(inplanes)
100 | self.move13 = LearnableBias(inplanes)
101 |
102 | self.move21 = LearnableBias(inplanes)
103 |
104 | if inplanes == planes:
105 | self.binary_pw = binaryconv1x1(inplanes, planes)
106 | self.bn2 = norm_layer(planes)
107 | else:
108 | self.binary_pw_down1 = binaryconv1x1(inplanes, inplanes)
109 | self.binary_pw_down2 = binaryconv1x1(inplanes, inplanes)
110 | self.bn2_1 = norm_layer(inplanes)
111 | self.bn2_2 = norm_layer(inplanes)
112 |
113 | self.move22 = LearnableBias(planes)
114 | self.prelu2 = nn.PReLU(planes)
115 | self.move23 = LearnableBias(planes)
116 |
117 | self.binary_activation = BinaryActivation()
118 | self.stride = stride
119 | self.inplanes = inplanes
120 | self.planes = planes
121 |
122 | if self.inplanes != self.planes:
123 | self.pooling = nn.AvgPool2d(2,2)
124 |
125 | def forward(self, x):
126 |
127 | out1 = self.move11(x)
128 |
129 | out1 = self.binary_activation(out1)
130 | out1 = self.binary_3x3(out1)
131 | out1 = self.bn1(out1)
132 |
133 | if self.stride == 2:
134 | x = self.pooling(x)
135 |
136 | out1 = x + out1
137 |
138 | out1 = self.move12(out1)
139 | out1 = self.prelu1(out1)
140 | out1 = self.move13(out1)
141 |
142 | out2 = self.move21(out1)
143 | out2 = self.binary_activation(out2)
144 |
145 | if self.inplanes == self.planes:
146 | out2 = self.binary_pw(out2)
147 | out2 = self.bn2(out2)
148 | out2 += out1
149 |
150 | else:
151 | assert self.planes == self.inplanes * 2
152 |
153 | out2_1 = self.binary_pw_down1(out2)
154 | out2_2 = self.binary_pw_down2(out2)
155 | out2_1 = self.bn2_1(out2_1)
156 | out2_2 = self.bn2_2(out2_2)
157 | out2_1 += out1
158 | out2_2 += out1
159 | out2 = torch.cat([out2_1, out2_2], dim=1)
160 |
161 | out2 = self.move22(out2)
162 | out2 = self.prelu2(out2)
163 | out2 = self.move23(out2)
164 |
165 | return out2
166 |
167 |
168 | class reactnet(nn.Module):
169 | def __init__(self, num_classes=1000):
170 | super(reactnet, self).__init__()
171 | self.feature = nn.ModuleList()
172 | for i in range(len(stage_out_channel)):
173 | if i == 0:
174 | self.feature.append(firstconv3x3(3, stage_out_channel[i], 2))
175 | elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64:
176 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2))
177 | else:
178 | self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1))
179 | self.pool1 = nn.AdaptiveAvgPool2d(1)
180 | self.fc = nn.Linear(1024, num_classes)
181 |
182 | def forward(self, x):
183 | for i, block in enumerate(self.feature):
184 | x = block(x)
185 |
186 | x = self.pool1(x)
187 | x = x.view(x.size(0), -1)
188 | x = self.fc(x)
189 |
190 | return x
191 |
192 |
193 |
194 |
195 |
196 |
197 |
--------------------------------------------------------------------------------
/mobilenet/2_step2/run.sh:
--------------------------------------------------------------------------------
1 | clear
2 | mkdir models
3 | cp ../1_step1/models/checkpoint.pth.tar ./models/checkpoint_ba.pth.tar
4 | mkdir log
5 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch
6 | python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=1.25e-3 --epochs=128 --weight_decay=0 | tee -a log/training.txt
7 | # 256 epoch setting: longer training, similar performance to 128 epoch
8 | # python3 train.py --data=/datasets/imagenet --batch_size=256 --learning_rate=5e-4 --epochs=256 --weight_decay=0 | tee -a log/training.txt
9 |
--------------------------------------------------------------------------------
/mobilenet/2_step2/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import numpy as np
5 | import time, datetime
6 | import torch
7 | import random
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import torch.utils
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.utils.data.distributed
15 |
16 | sys.path.append("../../")
17 | from utils.utils import *
18 | from utils import KD_loss
19 | from torchvision import datasets, transforms
20 | from torch.autograd import Variable
21 | from reactnet import reactnet
22 | import torchvision.models as models
23 |
24 | parser = argparse.ArgumentParser("birealnet18")
25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs')
27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
31 | parser.add_argument('--data', metavar='DIR', help='path to dataset')
32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet')
34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | args = parser.parse_args()
37 |
38 | CLASSES = 1000
39 |
40 | if not os.path.exists('log'):
41 | os.mkdir('log')
42 |
43 | log_format = '%(asctime)s %(message)s'
44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
45 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
46 | fh = logging.FileHandler(os.path.join('log/log.txt'))
47 | fh.setFormatter(logging.Formatter(log_format))
48 | logging.getLogger().addHandler(fh)
49 |
50 | def main():
51 | if not torch.cuda.is_available():
52 | sys.exit(1)
53 | start_t = time.time()
54 |
55 | cudnn.benchmark = True
56 | cudnn.enabled=True
57 | logging.info("args = %s", args)
58 |
59 | # load model
60 | model_teacher = models.__dict__[args.teacher](pretrained=True)
61 | model_teacher = nn.DataParallel(model_teacher).cuda()
62 | for p in model_teacher.parameters():
63 | p.requires_grad = False
64 | model_teacher.eval()
65 |
66 | model_student = reactnet()
67 | logging.info('student:')
68 | logging.info(model_student)
69 | model_student = nn.DataParallel(model_student).cuda()
70 |
71 | criterion = nn.CrossEntropyLoss()
72 | criterion = criterion.cuda()
73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
74 | criterion_smooth = criterion_smooth.cuda()
75 | criterion_kd = KD_loss.DistributionLoss()
76 |
77 | all_parameters = model_student.parameters()
78 | weight_parameters = []
79 | for pname, p in model_student.named_parameters():
80 | if p.ndimension() == 4 or 'conv' in pname:
81 | weight_parameters.append(p)
82 | weight_parameters_id = list(map(id, weight_parameters))
83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
84 |
85 | optimizer = torch.optim.Adam(
86 | [{'params' : other_parameters},
87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
88 | lr=args.learning_rate,)
89 |
90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1)
91 | start_epoch = 0
92 | best_top1_acc= 0
93 |
94 | checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar')
95 | checkpoint = torch.load(checkpoint_tar)
96 | model_student.load_state_dict(checkpoint['state_dict'], strict=False)
97 |
98 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
99 | if os.path.exists(checkpoint_tar):
100 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
101 | checkpoint = torch.load(checkpoint_tar)
102 | start_epoch = checkpoint['epoch'] + 1
103 | best_top1_acc = checkpoint['best_top1_acc']
104 | model_student.load_state_dict(checkpoint['state_dict'], strict=False)
105 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))
106 |
107 | # adjust the learning rate according to the checkpoint
108 | for epoch in range(start_epoch):
109 | scheduler.step()
110 |
111 | # load training data
112 | traindir = os.path.join(args.data, 'train')
113 | valdir = os.path.join(args.data, 'val')
114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
115 | std=[0.229, 0.224, 0.225])
116 |
117 | # data augmentation
118 | crop_scale = 0.08
119 | lighting_param = 0.1
120 | train_transforms = transforms.Compose([
121 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
122 | Lighting(lighting_param),
123 | transforms.RandomHorizontalFlip(),
124 | transforms.ToTensor(),
125 | normalize])
126 |
127 | train_dataset = datasets.ImageFolder(
128 | traindir,
129 | transform=train_transforms)
130 |
131 | train_loader = torch.utils.data.DataLoader(
132 | train_dataset, batch_size=args.batch_size, shuffle=True,
133 | num_workers=args.workers, pin_memory=True)
134 |
135 | # load validation data
136 | val_loader = torch.utils.data.DataLoader(
137 | datasets.ImageFolder(valdir, transforms.Compose([
138 | transforms.Resize(256),
139 | transforms.CenterCrop(224),
140 | transforms.ToTensor(),
141 | normalize,
142 | ])),
143 | batch_size=args.batch_size, shuffle=False,
144 | num_workers=args.workers, pin_memory=True)
145 |
146 | # train the model
147 | epoch = start_epoch
148 | while epoch < args.epochs:
149 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler)
150 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args)
151 |
152 | is_best = False
153 | if valid_top1_acc > best_top1_acc:
154 | best_top1_acc = valid_top1_acc
155 | is_best = True
156 |
157 | save_checkpoint({
158 | 'epoch': epoch,
159 | 'state_dict': model_student.state_dict(),
160 | 'best_top1_acc': best_top1_acc,
161 | 'optimizer' : optimizer.state_dict(),
162 | }, is_best, args.save)
163 |
164 | epoch += 1
165 |
166 | training_time = (time.time() - start_t) / 3600
167 | print('total training time = {} hours'.format(training_time))
168 |
169 |
170 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler):
171 | batch_time = AverageMeter('Time', ':6.3f')
172 | data_time = AverageMeter('Data', ':6.3f')
173 | losses = AverageMeter('Loss', ':.4e')
174 | top1 = AverageMeter('Acc@1', ':6.2f')
175 | top5 = AverageMeter('Acc@5', ':6.2f')
176 |
177 | progress = ProgressMeter(
178 | len(train_loader),
179 | [batch_time, data_time, losses, top1, top5],
180 | prefix="Epoch: [{}]".format(epoch))
181 |
182 | model_student.train()
183 | model_teacher.eval()
184 | end = time.time()
185 | scheduler.step()
186 |
187 | for param_group in optimizer.param_groups:
188 | cur_lr = param_group['lr']
189 | print('learning_rate:', cur_lr)
190 |
191 | for i, (images, target) in enumerate(train_loader):
192 | data_time.update(time.time() - end)
193 | images = images.cuda()
194 | target = target.cuda()
195 |
196 | # compute outputy
197 | logits_student = model_student(images)
198 | logits_teacher = model_teacher(images)
199 | loss = criterion(logits_student, logits_teacher)
200 |
201 | # measure accuracy and record loss
202 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5))
203 | n = images.size(0)
204 | losses.update(loss.item(), n) #accumulated loss
205 | top1.update(prec1.item(), n)
206 | top5.update(prec5.item(), n)
207 |
208 | # compute gradient and do SGD step
209 | optimizer.zero_grad()
210 | loss.backward()
211 | optimizer.step()
212 |
213 | # measure elapsed time
214 | batch_time.update(time.time() - end)
215 | end = time.time()
216 |
217 | progress.display(i)
218 |
219 | return losses.avg, top1.avg, top5.avg
220 |
221 | def validate(epoch, val_loader, model, criterion, args):
222 | batch_time = AverageMeter('Time', ':6.3f')
223 | losses = AverageMeter('Loss', ':.4e')
224 | top1 = AverageMeter('Acc@1', ':6.2f')
225 | top5 = AverageMeter('Acc@5', ':6.2f')
226 | progress = ProgressMeter(
227 | len(val_loader),
228 | [batch_time, losses, top1, top5],
229 | prefix='Test: ')
230 |
231 | # switch to evaluation mode
232 | model.eval()
233 | with torch.no_grad():
234 | end = time.time()
235 | for i, (images, target) in enumerate(val_loader):
236 | images = images.cuda()
237 | target = target.cuda()
238 |
239 | # compute output
240 | logits = model(images)
241 | loss = criterion(logits, target)
242 |
243 | # measure accuracy and record loss
244 | pred1, pred5 = accuracy(logits, target, topk=(1, 5))
245 | n = images.size(0)
246 | losses.update(loss.item(), n)
247 | top1.update(pred1[0], n)
248 | top5.update(pred5[0], n)
249 |
250 | # measure elapsed time
251 | batch_time.update(time.time() - end)
252 | end = time.time()
253 |
254 | progress.display(i)
255 |
256 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'
257 | .format(top1=top1, top5=top5))
258 |
259 | return losses.avg, top1.avg, top5.avg
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/resnet/1_step1/birealnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | import torch.nn.functional as F
5 |
6 |
7 | __all__ = ['birealnet18', 'birealnet34']
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 |
16 | def conv1x1(in_planes, out_planes, stride=1):
17 | """1x1 convolution"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
19 |
20 | class BinaryActivation(nn.Module):
21 | def __init__(self):
22 | super(BinaryActivation, self).__init__()
23 |
24 | def forward(self, x):
25 | out_forward = torch.sign(x)
26 | #out_e1 = (x^2 + 2*x)
27 | #out_e2 = (-x^2 + 2*x)
28 | out_e_total = 0
29 | mask1 = x < -1
30 | mask2 = x < 0
31 | mask3 = x < 1
32 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
33 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
34 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
35 | out = out_forward.detach() - out3.detach() + out3
36 |
37 | return out
38 |
39 | class LearnableBias(nn.Module):
40 | def __init__(self, out_chn):
41 | super(LearnableBias, self).__init__()
42 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)
43 |
44 | def forward(self, x):
45 | out = x + self.bias.expand_as(x)
46 | return out
47 |
48 | class HardBinaryConv(nn.Module):
49 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
50 | super(HardBinaryConv, self).__init__()
51 | self.stride = stride
52 | self.padding = padding
53 | self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
54 | self.shape = (out_chn, in_chn, kernel_size, kernel_size)
55 | #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
56 | self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)
57 |
58 | def forward(self, x):
59 | #real_weights = self.weights.view(self.shape)
60 | real_weights = self.weight
61 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
62 | #print(scaling_factor, flush=True)
63 | scaling_factor = scaling_factor.detach()
64 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
65 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
66 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
67 | #print(binary_weights, flush=True)
68 | y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)
69 |
70 | return y
71 |
72 | class BasicBlock(nn.Module):
73 | expansion = 1
74 |
75 | def __init__(self, inplanes, planes, stride=1, downsample=None):
76 | super(BasicBlock, self).__init__()
77 |
78 | self.move0 = LearnableBias(inplanes)
79 | self.binary_activation = BinaryActivation()
80 | self.binary_conv = conv3x3(inplanes, planes, stride=stride)
81 | self.bn1 = nn.BatchNorm2d(planes)
82 | self.move1 = LearnableBias(planes)
83 | self.prelu = nn.PReLU(planes)
84 | self.move2 = LearnableBias(planes)
85 |
86 | self.downsample = downsample
87 | self.stride = stride
88 |
89 | def forward(self, x):
90 | residual = x
91 |
92 | out = self.move0(x)
93 | out = self.binary_activation(out)
94 | out = self.binary_conv(out)
95 | out = self.bn1(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.move1(out)
102 | out = self.prelu(out)
103 | out = self.move2(out)
104 |
105 | return out
106 |
107 | class BiRealNet(nn.Module):
108 |
109 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
110 | super(BiRealNet, self).__init__()
111 | self.inplanes = 64
112 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
113 | bias=False)
114 | self.bn1 = nn.BatchNorm2d(64)
115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
116 | self.layer1 = self._make_layer(block, 64, layers[0])
117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
120 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
121 | self.fc = nn.Linear(512 * block.expansion, num_classes)
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | nn.AvgPool2d(kernel_size=2, stride=stride),
128 | conv1x1(self.inplanes, planes * block.expansion),
129 | nn.BatchNorm2d(planes * block.expansion),
130 | )
131 |
132 | layers = []
133 | layers.append(block(self.inplanes, planes, stride, downsample))
134 | self.inplanes = planes * block.expansion
135 | for _ in range(1, blocks):
136 | layers.append(block(self.inplanes, planes))
137 |
138 | return nn.Sequential(*layers)
139 |
140 | def forward(self, x):
141 | x = self.conv1(x)
142 | x = self.bn1(x)
143 | x = self.maxpool(x)
144 |
145 | x = self.layer1(x)
146 | x = self.layer2(x)
147 | x = self.layer3(x)
148 | x = self.layer4(x)
149 |
150 | x = self.avgpool(x)
151 | x = x.view(x.size(0), -1)
152 | x = self.fc(x)
153 |
154 | return x
155 |
156 |
157 | def birealnet18(pretrained=False, **kwargs):
158 | """Constructs a BiRealNet-18 model. """
159 | model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs)
160 | return model
161 |
162 |
163 | def birealnet34(pretrained=False, **kwargs):
164 | """Constructs a BiRealNet-34 model. """
165 | model = BiRealNet(BasicBlock, [6, 8, 12, 6], **kwargs)
166 | return model
167 |
168 |
--------------------------------------------------------------------------------
/resnet/1_step1/run.sh:
--------------------------------------------------------------------------------
1 | clear
2 | mkdir log
3 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch
4 | python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=2.5e-3 --epochs=128 --weight_decay=1e-5 | tee -a log/training.txt
5 | # 256 epoch setting: longer training, similar performance to 128 epoch
6 | # python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=1e-3 --epochs=256 --weight_decay=1e-5 | tee -a log/training.txt
7 |
--------------------------------------------------------------------------------
/resnet/1_step1/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import numpy as np
5 | import time, datetime
6 | import torch
7 | import random
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import torch.utils
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.utils.data.distributed
15 |
16 | sys.path.append("../../")
17 | from utils.utils import *
18 | from utils import KD_loss
19 | from torchvision import datasets, transforms
20 | from torch.autograd import Variable
21 | from birealnet import birealnet18
22 | import torchvision.models as models
23 |
24 | parser = argparse.ArgumentParser("birealnet18")
25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs')
27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
31 | parser.add_argument('--data', metavar='DIR', help='path to dataset')
32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet')
34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | args = parser.parse_args()
37 |
38 | CLASSES = 1000
39 |
40 | if not os.path.exists('log'):
41 | os.mkdir('log')
42 |
43 | log_format = '%(asctime)s %(message)s'
44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
45 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
46 | fh = logging.FileHandler(os.path.join('log/log.txt'))
47 | fh.setFormatter(logging.Formatter(log_format))
48 | logging.getLogger().addHandler(fh)
49 |
50 | def main():
51 | if not torch.cuda.is_available():
52 | sys.exit(1)
53 | start_t = time.time()
54 |
55 | cudnn.benchmark = True
56 | cudnn.enabled=True
57 | logging.info("args = %s", args)
58 |
59 | # load model
60 | model_teacher = models.__dict__[args.teacher](pretrained=True)
61 | model_teacher = nn.DataParallel(model_teacher).cuda()
62 | for p in model_teacher.parameters():
63 | p.requires_grad = False
64 | model_teacher.eval()
65 |
66 | model_student = birealnet18()
67 | logging.info('student:')
68 | logging.info(model_student)
69 | model_student = nn.DataParallel(model_student).cuda()
70 |
71 | criterion = nn.CrossEntropyLoss()
72 | criterion = criterion.cuda()
73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
74 | criterion_smooth = criterion_smooth.cuda()
75 | criterion_kd = KD_loss.DistributionLoss()
76 |
77 | all_parameters = model_student.parameters()
78 | weight_parameters = []
79 | for pname, p in model_student.named_parameters():
80 | if p.ndimension() == 4 or 'conv' in pname:
81 | weight_parameters.append(p)
82 | weight_parameters_id = list(map(id, weight_parameters))
83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
84 |
85 | optimizer = torch.optim.Adam(
86 | [{'params' : other_parameters},
87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
88 | lr=args.learning_rate,)
89 |
90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1)
91 | start_epoch = 0
92 | best_top1_acc= 0
93 |
94 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
95 | if os.path.exists(checkpoint_tar):
96 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
97 | checkpoint = torch.load(checkpoint_tar)
98 | start_epoch = checkpoint['epoch']
99 | best_top1_acc = checkpoint['best_top1_acc']
100 | model_student.load_state_dict(checkpoint['state_dict'], strict=False)
101 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))
102 |
103 | # adjust the learning rate according to the checkpoint
104 | for epoch in range(start_epoch):
105 | scheduler.step()
106 |
107 | # load training data
108 | traindir = os.path.join(args.data, 'train')
109 | valdir = os.path.join(args.data, 'val')
110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
111 | std=[0.229, 0.224, 0.225])
112 |
113 | # data augmentation
114 | crop_scale = 0.08
115 | lighting_param = 0.1
116 | train_transforms = transforms.Compose([
117 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
118 | Lighting(lighting_param),
119 | transforms.RandomHorizontalFlip(),
120 | transforms.ToTensor(),
121 | normalize])
122 |
123 | train_dataset = datasets.ImageFolder(
124 | traindir,
125 | transform=train_transforms)
126 |
127 | train_loader = torch.utils.data.DataLoader(
128 | train_dataset, batch_size=args.batch_size, shuffle=True,
129 | num_workers=args.workers, pin_memory=True)
130 |
131 | # load validation data
132 | val_loader = torch.utils.data.DataLoader(
133 | datasets.ImageFolder(valdir, transforms.Compose([
134 | transforms.Resize(256),
135 | transforms.CenterCrop(224),
136 | transforms.ToTensor(),
137 | normalize,
138 | ])),
139 | batch_size=args.batch_size, shuffle=False,
140 | num_workers=args.workers, pin_memory=True)
141 |
142 | # train the model
143 | epoch = start_epoch
144 | while epoch < args.epochs:
145 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler)
146 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args)
147 |
148 | is_best = False
149 | if valid_top1_acc > best_top1_acc:
150 | best_top1_acc = valid_top1_acc
151 | is_best = True
152 |
153 | save_checkpoint({
154 | 'epoch': epoch,
155 | 'state_dict': model_student.state_dict(),
156 | 'best_top1_acc': best_top1_acc,
157 | 'optimizer' : optimizer.state_dict(),
158 | }, is_best, args.save)
159 |
160 | epoch += 1
161 |
162 | training_time = (time.time() - start_t) / 3600
163 | print('total training time = {} hours'.format(training_time))
164 |
165 |
166 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler):
167 | batch_time = AverageMeter('Time', ':6.3f')
168 | data_time = AverageMeter('Data', ':6.3f')
169 | losses = AverageMeter('Loss', ':.4e')
170 | top1 = AverageMeter('Acc@1', ':6.2f')
171 | top5 = AverageMeter('Acc@5', ':6.2f')
172 |
173 | progress = ProgressMeter(
174 | len(train_loader),
175 | [batch_time, data_time, losses, top1, top5],
176 | prefix="Epoch: [{}]".format(epoch))
177 |
178 | model_student.train()
179 | model_teacher.eval()
180 | end = time.time()
181 | scheduler.step()
182 |
183 | for param_group in optimizer.param_groups:
184 | cur_lr = param_group['lr']
185 | print('learning_rate:', cur_lr)
186 |
187 | for i, (images, target) in enumerate(train_loader):
188 | data_time.update(time.time() - end)
189 | images = images.cuda()
190 | target = target.cuda()
191 |
192 | # compute outputy
193 | logits_student = model_student(images)
194 | logits_teacher = model_teacher(images)
195 | loss = criterion(logits_student, logits_teacher)
196 |
197 | # measure accuracy and record loss
198 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5))
199 | n = images.size(0)
200 | losses.update(loss.item(), n) #accumulated loss
201 | top1.update(prec1.item(), n)
202 | top5.update(prec5.item(), n)
203 |
204 | # compute gradient and do SGD step
205 | optimizer.zero_grad()
206 | loss.backward()
207 | optimizer.step()
208 |
209 | # measure elapsed time
210 | batch_time.update(time.time() - end)
211 | end = time.time()
212 |
213 | progress.display(i)
214 |
215 | return losses.avg, top1.avg, top5.avg
216 |
217 | def validate(epoch, val_loader, model, criterion, args):
218 | batch_time = AverageMeter('Time', ':6.3f')
219 | losses = AverageMeter('Loss', ':.4e')
220 | top1 = AverageMeter('Acc@1', ':6.2f')
221 | top5 = AverageMeter('Acc@5', ':6.2f')
222 | progress = ProgressMeter(
223 | len(val_loader),
224 | [batch_time, losses, top1, top5],
225 | prefix='Test: ')
226 |
227 | # switch to evaluation mode
228 | model.eval()
229 | with torch.no_grad():
230 | end = time.time()
231 | for i, (images, target) in enumerate(val_loader):
232 | images = images.cuda()
233 | target = target.cuda()
234 |
235 | # compute output
236 | logits = model(images)
237 | loss = criterion(logits, target)
238 |
239 | # measure accuracy and record loss
240 | pred1, pred5 = accuracy(logits, target, topk=(1, 5))
241 | n = images.size(0)
242 | losses.update(loss.item(), n)
243 | top1.update(pred1[0], n)
244 | top5.update(pred5[0], n)
245 |
246 | # measure elapsed time
247 | batch_time.update(time.time() - end)
248 | end = time.time()
249 |
250 | progress.display(i)
251 |
252 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'
253 | .format(top1=top1, top5=top5))
254 |
255 | return losses.avg, top1.avg, top5.avg
256 |
257 |
258 | if __name__ == '__main__':
259 | main()
260 |
--------------------------------------------------------------------------------
/resnet/2_step2/birealnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | import torch.nn.functional as F
5 |
6 |
7 | __all__ = ['birealnet18', 'birealnet34']
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 |
16 | def conv1x1(in_planes, out_planes, stride=1):
17 | """1x1 convolution"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
19 |
20 | class BinaryActivation(nn.Module):
21 | def __init__(self):
22 | super(BinaryActivation, self).__init__()
23 |
24 | def forward(self, x):
25 | out_forward = torch.sign(x)
26 | #out_e1 = (x^2 + 2*x)
27 | #out_e2 = (-x^2 + 2*x)
28 | out_e_total = 0
29 | mask1 = x < -1
30 | mask2 = x < 0
31 | mask3 = x < 1
32 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
33 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
34 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
35 | out = out_forward.detach() - out3.detach() + out3
36 |
37 | return out
38 |
39 | class LearnableBias(nn.Module):
40 | def __init__(self, out_chn):
41 | super(LearnableBias, self).__init__()
42 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)
43 |
44 | def forward(self, x):
45 | out = x + self.bias.expand_as(x)
46 | return out
47 |
48 |
49 | class HardBinaryConv(nn.Module):
50 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
51 | super(HardBinaryConv, self).__init__()
52 | self.stride = stride
53 | self.padding = padding
54 | self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
55 | self.shape = (out_chn, in_chn, kernel_size, kernel_size)
56 | #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
57 | self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)
58 |
59 | def forward(self, x):
60 | #real_weights = self.weights.view(self.shape)
61 | real_weights = self.weight
62 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
63 | #print(scaling_factor, flush=True)
64 | scaling_factor = scaling_factor.detach()
65 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
66 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
67 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
68 | #print(binary_weights, flush=True)
69 | y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)
70 |
71 | return y
72 |
73 | class BasicBlock(nn.Module):
74 | expansion = 1
75 |
76 | def __init__(self, inplanes, planes, stride=1, downsample=None):
77 | super(BasicBlock, self).__init__()
78 |
79 | self.move0 = LearnableBias(inplanes)
80 | self.binary_activation = BinaryActivation()
81 | self.binary_conv = HardBinaryConv(inplanes, planes, stride=stride)
82 | self.bn1 = nn.BatchNorm2d(planes)
83 | self.move1 = LearnableBias(planes)
84 | self.prelu = nn.PReLU(planes)
85 | self.move2 = LearnableBias(planes)
86 |
87 | self.downsample = downsample
88 | self.stride = stride
89 |
90 | def forward(self, x):
91 | residual = x
92 |
93 | out = self.move0(x)
94 | out = self.binary_activation(out)
95 | out = self.binary_conv(out)
96 | out = self.bn1(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.move1(out)
103 | out = self.prelu(out)
104 | out = self.move2(out)
105 |
106 | return out
107 |
108 | class BiRealNet(nn.Module):
109 |
110 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
111 | super(BiRealNet, self).__init__()
112 | self.inplanes = 64
113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
114 | bias=False)
115 | self.bn1 = nn.BatchNorm2d(64)
116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
117 | self.layer1 = self._make_layer(block, 64, layers[0])
118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
121 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
122 | self.fc = nn.Linear(512 * block.expansion, num_classes)
123 |
124 | def _make_layer(self, block, planes, blocks, stride=1):
125 | downsample = None
126 | if stride != 1 or self.inplanes != planes * block.expansion:
127 | downsample = nn.Sequential(
128 | nn.AvgPool2d(kernel_size=2, stride=stride),
129 | conv1x1(self.inplanes, planes * block.expansion),
130 | nn.BatchNorm2d(planes * block.expansion),
131 | )
132 |
133 | layers = []
134 | layers.append(block(self.inplanes, planes, stride, downsample))
135 | self.inplanes = planes * block.expansion
136 | for _ in range(1, blocks):
137 | layers.append(block(self.inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x):
142 | x = self.conv1(x)
143 | x = self.bn1(x)
144 | x = self.maxpool(x)
145 |
146 | x = self.layer1(x)
147 | x = self.layer2(x)
148 | x = self.layer3(x)
149 | x = self.layer4(x)
150 |
151 | x = self.avgpool(x)
152 | x = x.view(x.size(0), -1)
153 | x = self.fc(x)
154 |
155 | return x
156 |
157 |
158 | def birealnet18(pretrained=False, **kwargs):
159 | """Constructs a BiRealNet-18 model. """
160 | model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs)
161 | return model
162 |
163 |
164 | def birealnet34(pretrained=False, **kwargs):
165 | """Constructs a BiRealNet-34 model. """
166 | model = BiRealNet(BasicBlock, [6, 8, 12, 6], **kwargs)
167 | return model
168 |
169 |
--------------------------------------------------------------------------------
/resnet/2_step2/run.sh:
--------------------------------------------------------------------------------
1 | clear
2 | mkdir models
3 | cp ../1_step1/models/checkpoint.pth.tar ./models/checkpoint_ba.pth.tar
4 | mkdir log
5 | # 128 epoch setting: larger learning rate, similar performance to 256 epoch
6 | python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=2.5e-3 --epochs=128 --weight_decay=0 | tee -a log/training.txt
7 | # 256 epoch setting: longer training, similar performance to 128 epoch
8 | # python3 train.py --data=/datasets/imagenet --batch_size=512 --learning_rate=1e-3 --epochs=256 --weight_decay=0 | tee -a log/training.txt
9 |
--------------------------------------------------------------------------------
/resnet/2_step2/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import numpy as np
5 | import time, datetime
6 | import torch
7 | import random
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import torch.utils
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.utils.data.distributed
15 |
16 | sys.path.append("../../")
17 | from utils.utils import *
18 | from utils import KD_loss
19 | from torchvision import datasets, transforms
20 | from torch.autograd import Variable
21 | from birealnet import birealnet18
22 | import torchvision.models as models
23 |
24 | parser = argparse.ArgumentParser("birealnet18")
25 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
26 | parser.add_argument('--epochs', type=int, default=256, help='num of training epochs')
27 | parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
28 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
29 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
30 | parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
31 | parser.add_argument('--data', metavar='DIR', help='path to dataset')
32 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
33 | parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet')
34 | parser.add_argument('-j', '--workers', default=40, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | args = parser.parse_args()
37 |
38 | CLASSES = 1000
39 |
40 | if not os.path.exists('log'):
41 | os.mkdir('log')
42 |
43 | log_format = '%(asctime)s %(message)s'
44 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
45 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
46 | fh = logging.FileHandler(os.path.join('log/log.txt'))
47 | fh.setFormatter(logging.Formatter(log_format))
48 | logging.getLogger().addHandler(fh)
49 |
50 | def main():
51 | if not torch.cuda.is_available():
52 | sys.exit(1)
53 | start_t = time.time()
54 |
55 | cudnn.benchmark = True
56 | cudnn.enabled=True
57 | logging.info("args = %s", args)
58 |
59 | # load model
60 | model_teacher = models.__dict__[args.teacher](pretrained=True)
61 | model_teacher = nn.DataParallel(model_teacher).cuda()
62 | for p in model_teacher.parameters():
63 | p.requires_grad = False
64 | model_teacher.eval()
65 |
66 | model_student = birealnet18()
67 | logging.info('student:')
68 | logging.info(model_student)
69 | model_student = nn.DataParallel(model_student).cuda()
70 |
71 | criterion = nn.CrossEntropyLoss()
72 | criterion = criterion.cuda()
73 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
74 | criterion_smooth = criterion_smooth.cuda()
75 | criterion_kd = KD_loss.DistributionLoss()
76 |
77 | all_parameters = model_student.parameters()
78 | weight_parameters = []
79 | for pname, p in model_student.named_parameters():
80 | if p.ndimension() == 4 or 'conv' in pname:
81 | weight_parameters.append(p)
82 | weight_parameters_id = list(map(id, weight_parameters))
83 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
84 |
85 | optimizer = torch.optim.Adam(
86 | [{'params' : other_parameters},
87 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
88 | lr=args.learning_rate,)
89 |
90 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1)
91 | start_epoch = 0
92 | best_top1_acc= 0
93 |
94 | checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar')
95 | checkpoint = torch.load(checkpoint_tar)
96 | model_student.load_state_dict(checkpoint['state_dict'], strict=False)
97 |
98 | checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
99 | if os.path.exists(checkpoint_tar):
100 | logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
101 | checkpoint = torch.load(checkpoint_tar)
102 | start_epoch = checkpoint['epoch']
103 | best_top1_acc = checkpoint['best_top1_acc']
104 | model_student.load_state_dict(checkpoint['state_dict'], strict=False)
105 | logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))
106 |
107 | # adjust the learning rate according to the checkpoint
108 | for epoch in range(start_epoch):
109 | scheduler.step()
110 |
111 | # load training data
112 | traindir = os.path.join(args.data, 'train')
113 | valdir = os.path.join(args.data, 'val')
114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
115 | std=[0.229, 0.224, 0.225])
116 |
117 | # data augmentation
118 | crop_scale = 0.08
119 | lighting_param = 0.1
120 | train_transforms = transforms.Compose([
121 | transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
122 | Lighting(lighting_param),
123 | transforms.RandomHorizontalFlip(),
124 | transforms.ToTensor(),
125 | normalize])
126 |
127 | train_dataset = datasets.ImageFolder(
128 | traindir,
129 | transform=train_transforms)
130 |
131 | train_loader = torch.utils.data.DataLoader(
132 | train_dataset, batch_size=args.batch_size, shuffle=True,
133 | num_workers=args.workers, pin_memory=True)
134 |
135 | # load validation data
136 | val_loader = torch.utils.data.DataLoader(
137 | datasets.ImageFolder(valdir, transforms.Compose([
138 | transforms.Resize(256),
139 | transforms.CenterCrop(224),
140 | transforms.ToTensor(),
141 | normalize,
142 | ])),
143 | batch_size=args.batch_size, shuffle=False,
144 | num_workers=args.workers, pin_memory=True)
145 |
146 | # train the model
147 | epoch = start_epoch
148 | while epoch < args.epochs:
149 | train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler)
150 | valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args)
151 |
152 | is_best = False
153 | if valid_top1_acc > best_top1_acc:
154 | best_top1_acc = valid_top1_acc
155 | is_best = True
156 |
157 | save_checkpoint({
158 | 'epoch': epoch,
159 | 'state_dict': model_student.state_dict(),
160 | 'best_top1_acc': best_top1_acc,
161 | 'optimizer' : optimizer.state_dict(),
162 | }, is_best, args.save)
163 |
164 | epoch += 1
165 |
166 | training_time = (time.time() - start_t) / 3600
167 | print('total training time = {} hours'.format(training_time))
168 |
169 |
170 | def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler):
171 | batch_time = AverageMeter('Time', ':6.3f')
172 | data_time = AverageMeter('Data', ':6.3f')
173 | losses = AverageMeter('Loss', ':.4e')
174 | top1 = AverageMeter('Acc@1', ':6.2f')
175 | top5 = AverageMeter('Acc@5', ':6.2f')
176 |
177 | progress = ProgressMeter(
178 | len(train_loader),
179 | [batch_time, data_time, losses, top1, top5],
180 | prefix="Epoch: [{}]".format(epoch))
181 |
182 | model_student.train()
183 | model_teacher.eval()
184 | end = time.time()
185 | scheduler.step()
186 |
187 | for param_group in optimizer.param_groups:
188 | cur_lr = param_group['lr']
189 | print('learning_rate:', cur_lr)
190 |
191 | for i, (images, target) in enumerate(train_loader):
192 | data_time.update(time.time() - end)
193 | images = images.cuda()
194 | target = target.cuda()
195 |
196 | # compute outputy
197 | logits_student = model_student(images)
198 | logits_teacher = model_teacher(images)
199 | loss = criterion(logits_student, logits_teacher)
200 |
201 | # measure accuracy and record loss
202 | prec1, prec5 = accuracy(logits_student, target, topk=(1, 5))
203 | n = images.size(0)
204 | losses.update(loss.item(), n) #accumulated loss
205 | top1.update(prec1.item(), n)
206 | top5.update(prec5.item(), n)
207 |
208 | # compute gradient and do SGD step
209 | optimizer.zero_grad()
210 | loss.backward()
211 | optimizer.step()
212 |
213 | # measure elapsed time
214 | batch_time.update(time.time() - end)
215 | end = time.time()
216 |
217 | progress.display(i)
218 |
219 | return losses.avg, top1.avg, top5.avg
220 |
221 | def validate(epoch, val_loader, model, criterion, args):
222 | batch_time = AverageMeter('Time', ':6.3f')
223 | losses = AverageMeter('Loss', ':.4e')
224 | top1 = AverageMeter('Acc@1', ':6.2f')
225 | top5 = AverageMeter('Acc@5', ':6.2f')
226 | progress = ProgressMeter(
227 | len(val_loader),
228 | [batch_time, losses, top1, top5],
229 | prefix='Test: ')
230 |
231 | # switch to evaluation mode
232 | model.eval()
233 | with torch.no_grad():
234 | end = time.time()
235 | for i, (images, target) in enumerate(val_loader):
236 | images = images.cuda()
237 | target = target.cuda()
238 |
239 | # compute output
240 | logits = model(images)
241 | loss = criterion(logits, target)
242 |
243 | # measure accuracy and record loss
244 | pred1, pred5 = accuracy(logits, target, topk=(1, 5))
245 | n = images.size(0)
246 | losses.update(loss.item(), n)
247 | top1.update(pred1[0], n)
248 | top5.update(pred5[0], n)
249 |
250 | # measure elapsed time
251 | batch_time.update(time.time() - end)
252 | end = time.time()
253 |
254 | progress.display(i)
255 |
256 | print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'
257 | .format(top1=top1, top5=top5))
258 |
259 | return losses.avg, top1.avg, top5.avg
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/utils/KD_loss.py:
--------------------------------------------------------------------------------
1 | # Code is modified from MEAL (https://arxiv.org/abs/1812.02425) and Label Refinery (https://arxiv.org/abs/1805.02641).
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch.nn.modules import loss
6 |
7 |
8 | class DistributionLoss(loss._Loss):
9 | """The KL-Divergence loss for the binary student model and real teacher output.
10 |
11 | output must be a pair of (model_output, real_output), both NxC tensors.
12 | The rows of real_output must all add up to one (probability scores);
13 | however, model_output must be the pre-softmax output of the network."""
14 |
15 | def forward(self, model_output, real_output):
16 |
17 | self.size_average = True
18 |
19 | # Target is ignored at training time. Loss is defined as KL divergence
20 | # between the model output and the refined labels.
21 | if real_output.requires_grad:
22 | raise ValueError("real network output should not require gradients.")
23 |
24 | model_output_log_prob = F.log_softmax(model_output, dim=1)
25 | real_output_soft = F.softmax(real_output, dim=1)
26 | del model_output, real_output
27 |
28 | # Loss is -dot(model_output_log_prob, real_output). Prepare tensors
29 | # for batch matrix multiplicatio
30 | real_output_soft = real_output_soft.unsqueeze(1)
31 | model_output_log_prob = model_output_log_prob.unsqueeze(2)
32 |
33 | # Compute the loss, and average/sum for the batch.
34 | cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob)
35 | if self.size_average:
36 | cross_entropy_loss = cross_entropy_loss.mean()
37 | else:
38 | cross_entropy_loss = cross_entropy_loss.sum()
39 | # Return a pair of (loss_output, model_output). Model output will be
40 | # used for top-1 and top-5 evaluation.
41 | # model_output_log_prob = model_output_log_prob.squeeze(2)
42 | return cross_entropy_loss
43 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import numpy as np
5 | import time, datetime
6 | import torch
7 | import random
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import torch.utils
12 | import torchvision.datasets as dset
13 | import torchvision.transforms as transforms
14 | import torch.backends.cudnn as cudnn
15 | from PIL import Image
16 | from torch.autograd import Variable
17 |
18 | #lighting data augmentation
19 | imagenet_pca = {
20 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
21 | 'eigvec': np.asarray([
22 | [-0.5675, 0.7192, 0.4009],
23 | [-0.5808, -0.0045, -0.8140],
24 | [-0.5836, -0.6948, 0.4203],
25 | ])
26 | }
27 |
28 |
29 | class Lighting(object):
30 | def __init__(self, alphastd,
31 | eigval=imagenet_pca['eigval'],
32 | eigvec=imagenet_pca['eigvec']):
33 | self.alphastd = alphastd
34 | assert eigval.shape == (3,)
35 | assert eigvec.shape == (3, 3)
36 | self.eigval = eigval
37 | self.eigvec = eigvec
38 |
39 | def __call__(self, img):
40 | if self.alphastd == 0.:
41 | return img
42 | rnd = np.random.randn(3) * self.alphastd
43 | rnd = rnd.astype('float32')
44 | v = rnd
45 | old_dtype = np.asarray(img).dtype
46 | v = v * self.eigval
47 | v = v.reshape((3, 1))
48 | inc = np.dot(self.eigvec, v).reshape((3,))
49 | img = np.add(img, inc)
50 | if old_dtype == np.uint8:
51 | img = np.clip(img, 0, 255)
52 | img = Image.fromarray(img.astype(old_dtype), 'RGB')
53 | return img
54 |
55 | def __repr__(self):
56 | return self.__class__.__name__ + '()'
57 |
58 | #label smooth
59 | class CrossEntropyLabelSmooth(nn.Module):
60 |
61 | def __init__(self, num_classes, epsilon):
62 | super(CrossEntropyLabelSmooth, self).__init__()
63 | self.num_classes = num_classes
64 | self.epsilon = epsilon
65 | self.logsoftmax = nn.LogSoftmax(dim=1)
66 |
67 | def forward(self, inputs, targets):
68 | log_probs = self.logsoftmax(inputs)
69 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
70 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
71 | loss = (-targets * log_probs).mean(0).sum()
72 | return loss
73 |
74 |
75 | class AverageMeter(object):
76 | """Computes and stores the average and current value"""
77 | def __init__(self, name, fmt=':f'):
78 | self.name = name
79 | self.fmt = fmt
80 | self.reset()
81 |
82 | def reset(self):
83 | self.val = 0
84 | self.avg = 0
85 | self.sum = 0
86 | self.count = 0
87 |
88 | def update(self, val, n=1):
89 | self.val = val
90 | self.sum += val * n
91 | self.count += n
92 | self.avg = self.sum / self.count
93 |
94 | def __str__(self):
95 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
96 | return fmtstr.format(**self.__dict__)
97 |
98 |
99 | class ProgressMeter(object):
100 | def __init__(self, num_batches, meters, prefix=""):
101 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
102 | self.meters = meters
103 | self.prefix = prefix
104 |
105 | def display(self, batch):
106 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
107 | entries += [str(meter) for meter in self.meters]
108 | print('\t'.join(entries))
109 |
110 | def _get_batch_fmtstr(self, num_batches):
111 | num_digits = len(str(num_batches // 1))
112 | fmt = '{:' + str(num_digits) + 'd}'
113 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
114 |
115 |
116 | def save_checkpoint(state, is_best, save):
117 | if not os.path.exists(save):
118 | os.makedirs(save)
119 | filename = os.path.join(save, 'checkpoint.pth.tar')
120 | torch.save(state, filename)
121 | if is_best:
122 | best_filename = os.path.join(save, 'model_best.pth.tar')
123 | shutil.copyfile(filename, best_filename)
124 |
125 |
126 | def adjust_learning_rate(optimizer, epoch, args):
127 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
128 | lr = args.lr * (0.1 ** (epoch // 30))
129 | for param_group in optimizer.param_groups:
130 | param_group['lr'] = lr
131 |
132 |
133 | def accuracy(output, target, topk=(1,)):
134 | """Computes the accuracy over the k top predictions for the specified values of k"""
135 | with torch.no_grad():
136 | maxk = max(topk)
137 | batch_size = target.size(0)
138 |
139 | _, pred = output.topk(maxk, 1, True, True)
140 | pred = pred.t()
141 | correct = pred.eq(target.view(1, -1).expand_as(pred))
142 |
143 | res = []
144 | for k in topk:
145 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
146 | res.append(correct_k.mul_(100.0 / batch_size))
147 | return res
148 |
--------------------------------------------------------------------------------