├── .gitattributes ├── LICENSE ├── README.md ├── alexnet_imagenet ├── datasets │ ├── __init__.py │ └── myfolder.py └── networks │ ├── __init__.py │ ├── main.py │ ├── model_list │ ├── __init__.py │ └── alexnet.py │ └── util.py └── notes.pdf /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [*UNDER CONSTRUCTION*] 2 | 3 | # ABC-Net-Pytorch 4 | My implemenation of [ABC-Net](https://arxiv.org/abs/1711.11294). Currently I finish the ABC-Net on alexnet architecture on imagenet dataset, but the model does NOT converge. I am considering use some small dataset which is relative easy to debug. 5 | 6 | ***IF YOU FIND THE MISTAKEs I MADE IN THE IMPLEMNATION THAT CAUSE THE DISCONVERGENCY, PLEASE LET ME KNOW. thanks: )*** 7 | 8 | ## TO-DO 9 | - ABC-Net in resnet-18 architecture on cifar10 dataset 10 | 11 | ## Dismatchs 12 | Considering some details are NOT specified in the paper, I make modifications as follow: 13 | - *the way to solve alpha.*: The lstsq(scipy.linalg.lstsq) method is adopted. 14 | - *STE attach to each Binary Base OR full precision Weight*: fp Weight 15 | - *gradient of W(see following Notes)*: keep as the original paper 16 | - *alexnet architecture*: basicly adopting the modification in Xnor-Net excepting kernel size to keep identity with pretrained model. 17 | 18 | 19 | ## Notes 20 | I find some preoblem of this paper reported in the [notes](https://github.com/cow8/ABC-Net-pytorch/raw/master/notes.pdf). 21 | -------------------------------------------------------------------------------- /alexnet_imagenet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .myfolder import ImageFolder 2 | 3 | __all__ = ('ImageFolder') 4 | -------------------------------------------------------------------------------- /alexnet_imagenet/datasets/myfolder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | 8 | def pil_loader(path): 9 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 10 | with open(path, 'rb') as f: 11 | with Image.open(f) as img: 12 | return img.convert('RGB') 13 | 14 | 15 | def make_dataset_filetable(root, mapfile): 16 | imgs = [] 17 | with open(mapfile, 'r') as f: 18 | records = f.readlines() 19 | for record in records: 20 | path, label = record.split(" ") 21 | path = os.path.join(root, path) 22 | label = int(label.replace(" ", "")) 23 | imgs.append((path, label)) 24 | return imgs 25 | 26 | 27 | class ImageFolder(data.Dataset): 28 | 29 | def __init__(self, root, transform=None, target_transform=None, loader=pil_loader, mapfile=""): 30 | 31 | imgs = make_dataset_filetable(root, mapfile) 32 | 33 | self.root = root 34 | self.imgs = imgs 35 | self.loader = loader 36 | self.transform = transform 37 | self.target_transform = target_transform 38 | 39 | def __getitem__(self, index): 40 | """ 41 | Args: 42 | index (int): Index 43 | 44 | Returns: 45 | tuple: (image, target) where target is class_index of the target class. 46 | """ 47 | path, target = self.imgs[index] 48 | img = self.loader(path) 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | if self.target_transform is not None: 52 | target = self.target_transform(target) 53 | return img, target 54 | 55 | def __len__(self): 56 | return len(self.imgs) 57 | -------------------------------------------------------------------------------- /alexnet_imagenet/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util, model_list 2 | -------------------------------------------------------------------------------- /alexnet_imagenet/networks/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import sys 6 | import gc 7 | import platform 8 | from collections import OrderedDict 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | 17 | cwd = os.getcwd() 18 | sys.path.append(cwd + '/../') 19 | sys.path.append(cwd + '/networks/') 20 | import networks.model_list as model_list 21 | import networks.util as util 22 | 23 | import datasets as datasets 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('--arch', '-a', metavar='ARCH', default='alexnet', 27 | help='model architecture (default: alexnet)') 28 | parser.add_argument('--data', metavar='DATA_PATH', default='./data/', 29 | help='path to imagenet data (default: ./data/)') 30 | 31 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 32 | help='number of data loading workers (default: 8)') 33 | parser.add_argument('-b', '--batch-size', default=256, type=int, 34 | metavar='N', help='mini-batch size (default: 256)') 35 | parser.add_argument('--base_number', default=3, type=int, 36 | metavar='N', help='base_number (default: 3)') 37 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 42 | metavar='LR', help='initial learning rate') 43 | parser.add_argument('--momentum', default=0.90, type=float, metavar='M', 44 | help='momentum') 45 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 46 | metavar='W', help='weight decay (default: 1e-5)') 47 | parser.add_argument('--print-freq', '-p', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 52 | help='evaluate model on validation set') 53 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 54 | default=False, help='use pre-trained model') 55 | parser.add_argument('--nocuda', dest='nocuda', action='store_true', 56 | help='running on no cuda') 57 | best_prec1 = 0 58 | 59 | # define global bin_op 60 | bin_op = None 61 | 62 | # define optimizer 63 | optimizer = None 64 | 65 | 66 | def main(): 67 | global args, best_prec1 68 | args = parser.parse_args() 69 | 70 | if platform.system() == "Windows": 71 | args.nocuda = True 72 | else: 73 | args.nocuda = False 74 | 75 | # create model 76 | if args.arch == 'alexnet': 77 | model = model_list.alexnet(pretrained=args.pretrained, base_number=args.base_number) 78 | input_size = 227 79 | else: 80 | raise Exception('Model not supported yet') 81 | 82 | model.features = torch.nn.DataParallel(model.features) 83 | if not args.nocuda: 84 | # set the seed 85 | torch.manual_seed(1) 86 | torch.cuda.manual_seed(1) 87 | model.cuda() 88 | # define loss function (criterion) and optimizer 89 | criterion = nn.CrossEntropyLoss().cuda() 90 | # Set benchmark 91 | cudnn.benchmark = True 92 | else: 93 | criterion = nn.CrossEntropyLoss() 94 | 95 | global optimizer 96 | optimizer = torch.optim.Adam(model.parameters(), args.lr, 97 | weight_decay=args.weight_decay) 98 | # random initialization 99 | if not args.pretrained: 100 | for m in model.modules(): 101 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 102 | c = float(m.weight.data[0].nelement()) 103 | m.weight.data = m.weight.data.normal_(0, 1.0 / c) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data = m.weight.data.zero_().add(1.0) 106 | else: 107 | for m in model.modules(): 108 | if isinstance(m, nn.BatchNorm2d): 109 | m.weight.data = m.weight.data.zero_().add(1.0) 110 | # optionally resume from a checkpoint 111 | if args.resume: 112 | if os.path.isfile(args.resume): 113 | print("=> loading checkpoint '{}'".format(args.resume)) 114 | # original saved file with DataParallel 115 | checkpoint = torch.load(args.resume) 116 | args.start_epoch = checkpoint['epoch'] 117 | best_prec1 = checkpoint['best_prec1'] 118 | print(checkpoint) 119 | model.load_state_dict(checkpoint['state_dict']) 120 | optimizer.load_state_dict(checkpoint['optimizer']) 121 | print("=> loaded checkpoint '{}' (epoch {})" 122 | .format(args.resume, checkpoint['epoch'])) 123 | del checkpoint 124 | else: 125 | print("=> no checkpoint found at '{}'".format(args.resume)) 126 | 127 | # # Data loading code 128 | 129 | # if you want to use pre-prosecess in used in caffe: 130 | # transform = transforms.Compose([ 131 | # transforms.Resize((256, 256)), 132 | # transforms.RandomResizedCrop(input_size), 133 | # transforms.RandomHorizontalFlip(), 134 | # transforms.ToTensor(), 135 | # transforms.Lambda(lambda x: x * 255), 136 | # transforms.Lambda(lambda x: torch.cat(reversed(torch.split(x, 1, 0)))), 137 | # transforms.Lambda(lambda x: x - torch.Tensor([103.939, 116.779, 123.68]).view(3, 1, 1).expand(3, 227, 227)) 138 | # ]) 139 | # transform_val = transforms.Compose([ 140 | # transforms.Resize((256, 256)), 141 | # transforms.CenterCrop(input_size), 142 | # transforms.ToTensor(), 143 | # transforms.Lambda(lambda x: x * 255), 144 | # transforms.Lambda(lambda x: torch.cat(reversed(torch.split(x, 1, 0)))), 145 | # transforms.Lambda(lambda x: x - torch.Tensor([103.939, 116.779, 123.68]).view(3, 1, 1).expand(3, 227, 227)) 146 | # ]) 147 | 148 | transform = transforms.Compose([ 149 | transforms.Resize((256, 256)), 150 | transforms.RandomResizedCrop(input_size), 151 | transforms.RandomHorizontalFlip(), 152 | transforms.ToTensor(), 153 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 154 | std=[0.229, 0.224, 0.225]), 155 | ]) 156 | transform_val = transforms.Compose([ 157 | transforms.Resize((256, 256)), 158 | transforms.CenterCrop(input_size), 159 | transforms.ToTensor(), 160 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 161 | std=[0.229, 0.224, 0.225]), 162 | ]) 163 | 164 | traindir = os.path.join(args.data, 'ILSVRC2012_img_train') 165 | valdir = os.path.join(args.data, 'ILSVRC2012_img_val') 166 | train_dataset = datasets.ImageFolder(traindir, transform, mapfile=os.path.join(args.data, "ImageNet12_train.txt")) 167 | val_dataset = datasets.ImageFolder(valdir, transform_val, mapfile=os.path.join(args.data, "ImageNet12_val.txt")) 168 | 169 | if not args.nocuda: 170 | train_loader = torch.utils.data.DataLoader( 171 | train_dataset, batch_size=args.batch_size, shuffle=False, 172 | num_workers=args.workers, pin_memory=True) 173 | 174 | val_loader = torch.utils.data.DataLoader( 175 | val_dataset, batch_size=args.batch_size, shuffle=False, 176 | num_workers=args.workers, pin_memory=True) 177 | else: 178 | train_loader = torch.utils.data.DataLoader( 179 | train_dataset, batch_size=args.batch_size, shuffle=False, 180 | num_workers=args.workers) 181 | 182 | val_loader = torch.utils.data.DataLoader( 183 | val_dataset, batch_size=args.batch_size, shuffle=False, 184 | num_workers=args.workers) 185 | 186 | print(model) 187 | 188 | # define the binarization operator 189 | global bin_op 190 | bin_op = util.BinOp(model) 191 | 192 | if args.evaluate: 193 | validate(val_loader, model, criterion) 194 | return 195 | 196 | for epoch in range(args.start_epoch, args.epochs): 197 | adjust_learning_rate(optimizer, epoch) 198 | 199 | # train for one epoch 200 | train(train_loader, model, criterion, optimizer, epoch) 201 | 202 | # evaluate on validation set 203 | prec1 = validate(val_loader, model, criterion) 204 | 205 | # remember best prec@1 and save checkpoint 206 | is_best = prec1 > best_prec1 207 | best_prec1 = max(prec1, best_prec1) 208 | save_checkpoint({ 209 | 'epoch': epoch + 1, 210 | 'arch': args.arch, 211 | 'state_dict': model.state_dict(), 212 | 'best_prec1': best_prec1, 213 | 'optimizer': optimizer.state_dict(), 214 | }, is_best) 215 | 216 | 217 | def train(train_loader, model, criterion, optimizer, epoch): 218 | batch_time = AverageMeter() 219 | data_time = AverageMeter() 220 | losses = AverageMeter() 221 | top1 = AverageMeter() 222 | top5 = AverageMeter() 223 | 224 | # switch to train mode 225 | model.train() 226 | 227 | end = time.time() 228 | for i, (input, target) in enumerate(train_loader): 229 | # measure data loading time 230 | data_time.update(time.time() - end) 231 | if not args.nocuda: 232 | target = target.cuda(async=True) 233 | input_var = torch.autograd.Variable(input).cuda() 234 | else: 235 | input_var = torch.autograd.Variable(input) 236 | target_var = torch.autograd.Variable(target) 237 | 238 | # process the weights including binarization 239 | bin_op.binarization() 240 | 241 | # compute output 242 | output = model(input_var) 243 | loss = criterion(output, target_var) 244 | 245 | # measure accuracy and record loss 246 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 247 | losses.update(loss.data[0], input.size(0)) 248 | top1.update(prec1[0], input.size(0)) 249 | top5.update(prec5[0], input.size(0)) 250 | 251 | # compute gradient and do SGD step 252 | optimizer.zero_grad() 253 | loss.backward() 254 | 255 | # restore weights 256 | bin_op.restore() 257 | bin_op.updateBinaryGradWeight() 258 | 259 | optimizer.step() 260 | 261 | # measure elapsed time 262 | batch_time.update(time.time() - end) 263 | end = time.time() 264 | 265 | if i % args.print_freq == 0: 266 | print('Epoch: [{0}][{1}/{2}]\t' 267 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 268 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 269 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 270 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 271 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 272 | epoch, i, len(train_loader), batch_time=batch_time, 273 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 274 | 275 | # because the training process is too slow 276 | if i % 100 == 99: 277 | save_checkpoint({ 278 | 'arch': args.arch, 279 | 'state_dict': model.state_dict(), 280 | 'optimizer': optimizer.state_dict(), 281 | }, False, filename="checkpoint_every_100_batches.pth.tar") 282 | gc.collect() 283 | 284 | 285 | def validate(val_loader, model, criterion): 286 | batch_time = AverageMeter() 287 | losses = AverageMeter() 288 | top1 = AverageMeter() 289 | top5 = AverageMeter() 290 | 291 | # switch to evaluate mode 292 | model.eval() 293 | 294 | end = time.time() 295 | bin_op.binarization() 296 | for i, (input, target) in enumerate(val_loader): 297 | 298 | if not args.nocuda: 299 | target = target.cuda(async=True) 300 | input_var = torch.autograd.Variable(input, volatile=True).cuda() 301 | target_var = torch.autograd.Variable(target, volatile=True) 302 | else: 303 | input_var = torch.autograd.Variable(input, volatile=True) 304 | target_var = torch.autograd.Variable(target, volatile=True) 305 | # compute output 306 | output = model(input_var) 307 | loss = criterion(output, target_var) 308 | 309 | # measure accuracy and record loss 310 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 311 | losses.update(loss.data[0], input.size(0)) 312 | top1.update(prec1[0], input.size(0)) 313 | top5.update(prec5[0], input.size(0)) 314 | 315 | # measure elapsed time 316 | batch_time.update(time.time() - end) 317 | end = time.time() 318 | 319 | if i % args.print_freq == 0: 320 | print('Test: [{0}/{1}]\t' 321 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 322 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 323 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 324 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 325 | i, len(val_loader), batch_time=batch_time, loss=losses, 326 | top1=top1, top5=top5)) 327 | bin_op.restore() 328 | 329 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 330 | .format(top1=top1, top5=top5)) 331 | 332 | return top1.avg 333 | 334 | 335 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 336 | torch.save(state, filename) 337 | if is_best: 338 | shutil.copyfile(filename, 'model_best.pth.tar') 339 | 340 | 341 | class AverageMeter(object): 342 | """Computes and stores the average and current value""" 343 | 344 | def __init__(self): 345 | self.reset() 346 | 347 | def reset(self): 348 | self.val = 0 349 | self.avg = 0 350 | self.sum = 0 351 | self.count = 0 352 | 353 | def update(self, val, n=1): 354 | self.val = val 355 | self.sum += val * n 356 | self.count += n 357 | self.avg = self.sum / self.count 358 | 359 | 360 | def adjust_learning_rate(optimizer, epoch): 361 | """Sets the learning rate to the initial LR decayed by 10 every 25 epochs""" 362 | lr = args.lr * (0.1 ** (epoch // 25)) 363 | print('Learning rate:', lr) 364 | for param_group in optimizer.param_groups: 365 | param_group['lr'] = lr 366 | 367 | 368 | def accuracy(output, target, topk=(1,)): 369 | """Computes the precision@k for the specified values of k""" 370 | maxk = max(topk) 371 | batch_size = target.size(0) 372 | 373 | _, pred = output.topk(maxk, 1, True, True) 374 | pred = pred.t() 375 | # print(pred) 376 | # print(target) 377 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 378 | 379 | res = [] 380 | for k in topk: 381 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 382 | res.append(correct_k.mul_(100.0 / batch_size)) 383 | return res 384 | 385 | 386 | if __name__ == '__main__': 387 | main() 388 | -------------------------------------------------------------------------------- /alexnet_imagenet/networks/model_list/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import alexnet 2 | -------------------------------------------------------------------------------- /alexnet_imagenet/networks/model_list/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ABCConv2d(nn.Module): 6 | def __init__(self, input_channels, output_channels, 7 | kernel_size=-1, stride=-1, padding=-1, groups=1, dropout=0.0, 8 | linear=False, base_number=3): 9 | super(ABCConv2d, self).__init__() 10 | assert base_number == 3 or base_number == 1, "support base_number == 3 or base_number == 1 " 11 | self.layer_type = 'ABC_Conv2d' 12 | self.kernel_size = kernel_size 13 | self.stride = stride 14 | self.padding = padding 15 | self.dropout_ratio = dropout 16 | self.base_number = base_number 17 | if dropout != 0: 18 | self.dropout = nn.Dropout(dropout) 19 | self.linear = linear 20 | if not self.linear: 21 | self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True) 22 | if self.base_number == 1: 23 | self.bases_conv2d_1 = nn.Conv2d(input_channels, output_channels, 24 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) 25 | else: 26 | 27 | self.bases_conv2d_1 = nn.Conv2d(input_channels, output_channels, 28 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) 29 | self.bases_conv2d_2 = nn.Conv2d(input_channels, output_channels, 30 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) 31 | self.bases_conv2d_3 = nn.Conv2d(input_channels, output_channels, 32 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) 33 | 34 | else: 35 | self.bn = nn.BatchNorm1d(input_channels, eps=1e-4, momentum=0.1, affine=True) 36 | if self.base_number == 1: 37 | self.bases_linear_1 = nn.Linear(input_channels, output_channels) 38 | else: 39 | self.bases_linear_1 = nn.Linear(input_channels, output_channels) 40 | self.bases_linear_2 = nn.Linear(input_channels, output_channels) 41 | self.bases_linear_3 = nn.Linear(input_channels, output_channels) 42 | 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | def forward(self, x): 46 | x = self.bn(x) 47 | # x = BinActive()(x) 48 | if self.dropout_ratio != 0: 49 | x = self.dropout(x) 50 | if self.base_number == 1: 51 | if not self.linear: 52 | x = self.bases_conv2d_1(x) 53 | else: 54 | x = self.bases_linear_1(x) 55 | else: 56 | if not self.linear: 57 | x = self.bases_conv2d_1(x) + self.bases_conv2d_2(x) + self.bases_conv2d_3(x) 58 | else: 59 | x = self.bases_linear_1(x) + self.bases_linear_2(x) + self.bases_linear_3(x) 60 | x = self.relu(x) 61 | return x 62 | 63 | 64 | BinConv2d = ABCConv2d 65 | 66 | 67 | class AlexNet(nn.Module): 68 | 69 | def __init__(self, num_classes=1000, base_number=3): 70 | super(AlexNet, self).__init__() 71 | self.num_classes = num_classes 72 | self.base_number = base_number 73 | self.features = nn.Sequential( 74 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=0), 75 | nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=True), 76 | nn.ReLU(inplace=True), 77 | nn.MaxPool2d(kernel_size=3, stride=2), 78 | BinConv2d(64, 192, kernel_size=5, stride=1, padding=2, groups=1, base_number=self.base_number), 79 | nn.MaxPool2d(kernel_size=3, stride=2), 80 | BinConv2d(192, 384, kernel_size=3, stride=1, padding=1, base_number=self.base_number), 81 | BinConv2d(384, 256, kernel_size=3, stride=1, padding=1, groups=1, base_number=self.base_number), 82 | BinConv2d(256, 256, kernel_size=3, stride=1, padding=1, groups=1, base_number=self.base_number), 83 | nn.MaxPool2d(kernel_size=3, stride=2), 84 | ) 85 | self.classifier = nn.Sequential( 86 | BinConv2d(256 * 6 * 6, 4096, linear=True, base_number=self.base_number), 87 | BinConv2d(4096, 4096, dropout=0.1, linear=True, base_number=self.base_number), 88 | nn.BatchNorm1d(4096, eps=1e-3, momentum=0.1, affine=True), 89 | nn.Dropout(), 90 | nn.Linear(4096, num_classes), 91 | ) 92 | 93 | def forward(self, x): 94 | x = self.features(x) 95 | x = x.view(x.size(0), 256 * 6 * 6) 96 | x = self.classifier(x) 97 | return x 98 | 99 | def my_model_loader(self, state_dict, strict=True): 100 | own_state = self.state_dict() 101 | # map fp model to ABC-Net 102 | load_map = \ 103 | { 104 | 'features.0.weight': 'features.0.weight', 105 | 'features.0.bias': 'features.0.bias', 106 | 'features.4.bases_conv2d_1.weight': 'features.3.weight', 107 | 'features.4.bases_conv2d_1.bias': 'features.3.bias', 108 | 'features.6.bases_conv2d_1.weight': 'features.6.weight', 109 | 'features.6.bases_conv2d_1.bias': 'features.6.bias', 110 | 'features.7.bases_conv2d_1.weight': 'features.8.weight', 111 | 'features.7.bases_conv2d_1.bias': 'features.8.bias', 112 | 'features.8.bases_conv2d_1.weight': 'features.10.weight', 113 | 'features.8.bases_conv2d_1.bias': 'features.10.bias', 114 | 'classifier.0.bases_linear_1.weight': 'classifier.1.weight', 115 | 'classifier.0.bases_linear_1.bias': 'classifier.1.bias', 116 | 'classifier.1.bases_linear_1.weight': 'classifier.4.weight', 117 | 'classifier.1.bases_linear_1.bias': 'classifier.4.bias', 118 | 'classifier.4.weight': 'classifier.6.weight', 119 | 'classifier.4.bias': 'classifier.6.bias', 120 | } 121 | 122 | for k, v in load_map.items(): 123 | own_state[k].copy_(state_dict[v].data) 124 | 125 | 126 | def alexnet(pretrained=False, **kwargs): 127 | r"""AlexNet model architecture from the 128 | `"One weird trick..." `_ paper. 129 | 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | """ 133 | model = AlexNet(**kwargs) 134 | if pretrained: 135 | model_path = 'model_list/alexnet_fp_pretrained.pth' 136 | pretrained_model = torch.load(model_path) 137 | model.my_model_loader(pretrained_model) 138 | return model 139 | -------------------------------------------------------------------------------- /alexnet_imagenet/networks/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy 4 | from sklearn.linear_model import LinearRegression 5 | import platform 6 | 7 | 8 | class BinOp(): 9 | def __init__(self, model): 10 | self.base_number = model.base_number 11 | # count the number of Conv2d and Linear 12 | count_targets = 0 13 | for m in model.modules(): 14 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 15 | count_targets = count_targets + 1 16 | 17 | start_range = 1 18 | end_range = count_targets - 2 19 | self.bin_range = numpy.linspace(start_range, 20 | end_range, end_range - start_range + 1) \ 21 | .astype('int').tolist() 22 | self.num_of_params = len(self.bin_range) 23 | self.saved_params = [] 24 | self.target_params = [] 25 | self.target_modules = [] 26 | self.alphas = [] 27 | index = -1 28 | for m in model.modules(): 29 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 30 | index = index + 1 31 | if index in self.bin_range: 32 | tmp = m.weight.data.clone() 33 | self.saved_params.append(tmp) 34 | self.target_modules.append(m.weight) 35 | 36 | for index_conv in range(int(self.num_of_params / self.base_number)): 37 | self.alphas.append(torch.zeros(self.base_number)) 38 | 39 | def binarization(self): 40 | # self.meancenterConvParams() 41 | self.clampConvParams() 42 | self.save_params() 43 | self.binarizeConvParams() 44 | 45 | def clampConvParams(self): 46 | for index in range(int(self.num_of_params / self.base_number)): 47 | self.target_modules[index * self.base_number].data.clamp(-1.0, 1.0, 48 | out=self.target_modules[ 49 | index * self.base_number].data) 50 | 51 | def save_params(self): 52 | for index in range(int(self.num_of_params / self.base_number)): 53 | self.saved_params[index * self.base_number].copy_(self.target_modules[index * self.base_number].data) 54 | 55 | def ABC_binarizeConvParams(self): 56 | for index_conv in range(int(self.num_of_params / self.base_number)): 57 | n_vec = self.target_modules[index_conv * self.base_number].data.nelement() 58 | k_size = self.target_modules[index_conv * self.base_number].data.size() 59 | 60 | W = self.target_modules[index_conv * self.base_number].data.view(n_vec) 61 | 62 | W_neg_mean = W.mean(dim=0, keepdim=True).neg().expand(n_vec) 63 | W_std = W.std(dim=0, keepdim=True).expand(n_vec) 64 | if self.base_number == 1: 65 | B = W.add(W_neg_mean).sign().view(1, n_vec) 66 | if self.base_number == 3: 67 | t1 = W.add(W_neg_mean).add(W_std.mul(-1)).sign().view(1, n_vec) 68 | t2 = W.add(W_neg_mean).sign().view(1, n_vec) 69 | t3 = W.add(W_neg_mean).add(W_std).sign().view(1, n_vec) 70 | B = torch.cat((t1, t2, t3)) 71 | # for base in range(self.base_number): 72 | # u_i=-1 + base * 2 / (self.base_number-1) 73 | # t=W.add(W_neg_mean).add(W_std.mul(u_i)).sign() 74 | # if base==0: 75 | # B=t.view(1,n_vec) 76 | # else: 77 | # B=torch.cat((B,t.view(1,n_vec))) 78 | LRM = LinearRegression() 79 | LRM.fit(B.t(), W) 80 | # alpha = torch.from_numpy(LRM.coef_) 81 | if platform.system() == "Windows": 82 | alpha = torch.Tensor(LRM.coef_) 83 | else: 84 | alpha = torch.Tensor(LRM.coef_).cuda() 85 | 86 | self.alphas[index_conv].copy_(alpha) 87 | for base in range(self.base_number): 88 | self.target_modules[index_conv * self.base_number + base].data.copy_( 89 | B[base].mul(alpha[base]).view(k_size)) 90 | 91 | def ABC_updateBinaryGradWeight(self): 92 | # original version: 93 | for index_conv in range(int(self.num_of_params / self.base_number)): 94 | if self.base_number == 1: 95 | pass 96 | if self.base_number == 3: 97 | # explanation of dW=dB*alpha^2: 98 | # dB=d(L)/d(alpha*B)=1/alpha*d(L)/d(B) 99 | alpha_dB1 = self.target_modules[index_conv * self.base_number].grad.data. \ 100 | mul(self.alphas[index_conv][0] * self.alphas[index_conv][0]) 101 | alpha_dB2 = self.target_modules[index_conv * self.base_number + 1].grad.data. \ 102 | mul(self.alphas[index_conv][1] * self.alphas[index_conv][1]) 103 | alpha_dB3 = self.target_modules[index_conv * self.base_number + 2].grad.data. \ 104 | mul(self.alphas[index_conv][2] * self.alphas[index_conv][2]) 105 | 106 | dW = alpha_dB1.add(alpha_dB2).add(alpha_dB3) 107 | # attach STE to single base OR the sum of them? 108 | W = self.target_modules[index_conv * self.base_number].data 109 | dW[W.lt(-1)] = 0 110 | dW[W.gt(1)] = 0 111 | dW.mul(1e+9) 112 | self.target_modules[index_conv * self.base_number].grad.data.copy_(dW) 113 | 114 | binarizeConvParams = ABC_binarizeConvParams 115 | updateBinaryGradWeight = ABC_updateBinaryGradWeight 116 | 117 | def restore(self): 118 | for index in range(int(self.num_of_params / self.base_number)): 119 | self.target_modules[index * self.base_number].data.copy_(self.saved_params[index * self.base_number]) 120 | 121 | # def binarizeConvParams(self): 122 | # for index in range(self.num_of_params): 123 | # 124 | # n = self.target_modules[index].data[0].nelement() 125 | # s = self.target_modules[index].data.size() 126 | # if len(s) == 4: 127 | # m = self.target_modules[index].data.norm(1, 3, keepdim=True) \ 128 | # .sum(2, keepdim=True).sum(1, keepdim=True).div(n) 129 | # elif len(s) == 2: 130 | # m = self.target_modules[index].data.norm(1, 1, keepdim=True).div(n) 131 | # self.target_modules[index].data.sign() \ 132 | # .mul(m.expand(s), out=self.target_modules[index].data) 133 | 134 | # def updateBinaryGradWeight(self): 135 | # for index in range(self.num_of_params): 136 | # weight = self.target_modules[index].data 137 | # n = weight[0].nelement() 138 | # s = weight.size() 139 | # if len(s) == 4: 140 | # m = weight.norm(1, 3, keepdim=True) \ 141 | # .sum(2, keepdim=True).sum(1, keepdim=True).div(n).(s) 142 | # elif len(s) == 2: 143 | # m = weight.norm(1, 1, keepdim=True).div(n).expand(s) 144 | # # m=alpha 145 | # 146 | # m[weight.lt(-1.0)] = 0 147 | # m[weight.gt(1.0)] = 0 148 | # # m=alpha*r_1 149 | # 150 | # m = m.mul(self.target_modules[index].grad.data) 151 | # # m=alpha*r_1*gi 152 | # 153 | # m_add = weight.sign().mul(self.target_modules[index].grad.data) 154 | # # m_add=gi*sign(w_i) 155 | # if len(s) == 4: 156 | # m_add = m_add.sum(3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s) 157 | # elif len(s) == 2: 158 | # m_add = m_add.sum(1, keepdim=True).div(n).expand(s) 159 | # # in the notes, W_i stands for the i^th element of W(C*k*k) 160 | # 161 | # m_add = m_add.mul(weight.sign()) 162 | # self.target_modules[index].grad.data = m.add(m_add).mul(1.0 - 1.0 / s[1]).mul(n) 163 | # self.target_modules[index].grad.data = self.target_modules[index].grad.data.mul(1e+9) 164 | # 165 | # def meancenterConvParams(self): 166 | # for index in range(int(self.num_of_params/self.base_number)): 167 | # s = self.target_modules[index*self.base_number].data.size() 168 | # negMean = self.target_modules[index*self.base_number].data.mean(1, keepdim=True). \ 169 | # mul(-1).expand_as(self.target_modules[index*self.base_number].data) 170 | # self.target_modules[index*self.base_number].data = self.target_modules[index*self.base_number].data.add(negMean) 171 | # 172 | -------------------------------------------------------------------------------- /notes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuyinheng/ABC-Net-pytorch/50231ee26589ab6d60140d052d99b8b5bfe44c4a/notes.pdf --------------------------------------------------------------------------------